Average Treatment Effect (ATE)
- $ATE=E[Y(1)]−E[Y(0)]$
- 전체 인구에서 처치의 평균 효과를 측정하는 지표입니다.
- ATE는 처치를 받은 경우와 받지 않은 경우의 결과 차이를 평균적으로 나타냅니다.
- 장점
- 단순하고 직관적: 전체 인구에 대한 처치의 평균 효과를 쉽게 이해할 수 있습니다.
- 정책 결정에 유용: 전체 인구에 대한 정책 효과를 평가하는 데 유용합니다.
- 단점
- 세부 정보 부족: 특정 하위 그룹에 대한 정보를 제공하지 못합니다.
- 내생성 문제: 처치와 결과 간의 상관관계를 정확히 측정하기 어려울 수 있습니다.
Conditional Average Treatment Effect (CATE)
- 특정 하위 그룹에서 처치의 평균 효과를 측정하는 지표입니다.
- CATE는 주어진 조건에서 처치를 받은 경우와 받지 않은 경우의 결과 차이를 평균적으로 나타냅니다.
- $CATE(X)=E[Y(1)∣X]−E[Y(0)∣X]$
- 여기서 $X$는 조건 또는 특성을 나타냅니다.
- 장점
- 세부 정보 제공: 특정 하위 그룹에 대한 처치 효과를 분석할 수 있습니다.
- 개별화된 정책 결정: 특정 그룹에 맞춘 정책 결정을 지원합니다.
- 단점
- 복잡성 증가: 조건에 따른 처치 효과를 측정하기 위해 더 많은 데이터와 복잡한 모델이 필요합니다.
- 해석 어려움: 여러 조건에 따른 효과를 해석하는 것이 어려울 수 있습니다.
Alternatives
- Instrumental Variables (IV): 도구 변수를 사용하여 내생성을 해결하는 방법입니다.
- Propensity Score Matching (PSM): 유사한 특성을 가진 처치 그룹과 통제 그룹을 매칭하여 처치 효과를 추정합니다.
Sample Code
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
# 데이터 생성
np.random.seed(42)
n = 1000
X = np.random.normal(size=(n, 5))
T = np.random.binomial(1, 0.5, size=n)
Y = 2 * T + X[:, 0] + np.random.normal(size=n)
# 교차 검증을 통한 ATE 추정
kf = KFold(n_splits=5)
ATEs = []
for train_index, test_index in kf.split(X):
X_train, X_test = X[train_index], X[test_index]
T_train, T_test = T[train_index], T[test_index]
Y_train, Y_test = Y[train_index], Y[test_index]
model = LinearRegression().fit(T_train.reshape(-1, 1), Y_train)
ATEs.append(model.coef_[0])
ATE = np.mean(ATEs)
print(f"Cross-validated ATE: {ATE}")
# 교차 검증을 통한 CATE 추정
def cate_cv(X, T, Y, kf):
CATEs = []
for train_index, test_index in kf.split(X):
X_train, X_test = X[train_index], X[test_index]
T_train, T_test = T[train_index], T[test_index]
Y_train, Y_test = Y[train_index], Y[test_index]
model_treated = LinearRegression().fit(X_train[T_train == 1], Y_train[T_train == 1])
model_control = LinearRegression().fit(X_train[T_train == 0], Y_train[T_train == 0])
CATEs.append(model_treated.predict(X_test) - model_control.predict(X_test))
return np.concatenate(CATEs)
CATE_values = cate_cv(X, T, Y, kf)
print(f"Cross-validated CATE (first 10 values): {CATE_values[:10]}")
# 시각화
plt.figure(figsize=(12, 6))
# ATE 시각화
plt.subplot(1, 2, 1)
plt.hist(ATEs, bins=10, edgecolor='black')
plt.title('ATE Distribution')
plt.xlabel('ATE')
plt.ylabel('Frequency')
# CATE 시각화
plt.subplot(1, 2, 2)
plt.hist(CATE_values, bins=30, edgecolor='black')
plt.title('CATE Distribution')
plt.xlabel('CATE')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()