causal-treatment-effect
Description
Causal Treatment Effect Estimation
Research Question
Design a novel estimator for Conditional Average Treatment Effects (CATE) from observational data that is accurate, robust to confounding, and generalizes across datasets with different data generating processes.
Background
Estimating heterogeneous treatment effects -- how the causal effect of a treatment varies across individuals -- is a core problem in causal inference. Given observational data with covariates X, binary treatment T, and outcome Y, the goal is to estimate tau(x) = E[Y(1) - Y(0) | X=x], the conditional average treatment effect (CATE).
Key challenges include:
- Confounding: Treatment assignment depends on covariates, so naive comparisons are biased
- Heterogeneity: Treatment effects vary across the covariate space in complex, nonlinear ways
- Model misspecification: The true response surfaces may not match parametric assumptions
- Double robustness: Ideally, the estimator is consistent if either the outcome model or propensity model is correct
Classical approaches include S-Learner (single model), T-Learner (separate models), and IPW (propensity reweighting). Modern SOTA methods include Causal Forests (Athey & Wager, 2018), DR-Learner (Kennedy, 2023), and R-Learner (Nie & Wager, 2021), which use orthogonalization/debiasing to achieve better convergence rates.
Task
Modify the CATEEstimator class in custom_cate.py. Your estimator must implement:
fit(X, T, Y) -> self: Learn from observational datapredict(X) -> tau_hat: Predict individual treatment effects
You have access to scikit-learn and numpy/scipy.
Evaluation
Evaluated on three semi-synthetic benchmarks with known ground-truth treatment effects:
- IHDP: Infant Health and Development Program (n=747, p=25, nonlinear effects)
- Jobs: Job training program evaluation (n=2000, p=10, economic outcomes)
- ACIC: Atlantic Causal Inference Conference simulation (n=4000, p=50, complex confounding)
Metrics (lower is better for both):
- PEHE: Precision in Estimation of Heterogeneous Effects = sqrt(mean((tau_hat - tau_true)^2))
- ATE error: |mean(tau_hat) - ATE_true|
Each dataset is evaluated with 5-fold cross-fitting over 10 repetitions with different random seeds.
Code
1# Custom CATE Estimator for MLS-Bench2#3# EDITABLE section: CATEEstimator class (the treatment effect estimator).4# FIXED sections: everything else (data generation, evaluation, CLI).5#6# Research question: Design a novel estimator for Conditional Average Treatment7# Effects (CATE) from observational data that generalizes across datasets.89import os10import argparse11import json12import time13import warnings14from abc import ABC, abstractmethod15
Results
| Model | Type | PEHE ihdp ↓ | ATE error ihdp ↓ | PEHE jobs ↓ | ATE error jobs ↓ | PEHE acic ↓ | ATE error acic ↓ |
|---|---|---|---|---|---|---|---|
| causal_forest | baseline | 0.771 | 0.075 | 358.597 | 48.547 | 0.499 | 0.017 |
| dr_learner | baseline | 1.380 | 0.081 | 552.065 | 71.890 | 0.491 | 0.025 |
| ipw | baseline | 2.102 | 0.426 | 5221.204 | 926.530 | 1.289 | 0.038 |
| r_learner | baseline | 0.870 | 0.071 | 476.031 | 50.099 | 0.428 | 0.021 |
| s_learner | baseline | 0.803 | 0.109 | 391.241 | 59.365 | 0.524 | 0.074 |
| t_learner | baseline | 1.191 | 0.128 | 648.911 | 35.308 | 0.772 | 0.029 |
| anthropic/claude-opus-4.6 | vanilla | - | - | - | - | - | - |
| google/gemini-3.1-pro-preview | vanilla | - | - | - | - | - | - |
| gpt-5.4-pro | vanilla | - | - | - | - | - | - |
| anthropic/claude-opus-4.6 | agent | - | - | - | - | - | - |
| google/gemini-3.1-pro-preview | agent | - | - | - | - | - | - |
| google/gemini-3.1-pro-preview | agent | - | - | - | - | - | - |
| gpt-5.4-pro | agent | - | - | - | - | - | - |
| gpt-5.4-pro | agent | - | - | - | - | - | - |