causal-treatment-effect

Causal Inferencescikit-learnrigorous codebase

Description

Causal Treatment Effect Estimation

Research Question

Design a novel estimator for Conditional Average Treatment Effects (CATE) from observational data that is accurate, robust to confounding, and generalizes across datasets with different data generating processes.

Background

Estimating heterogeneous treatment effects -- how the causal effect of a treatment varies across individuals -- is a core problem in causal inference. Given observational data with covariates X, binary treatment T, and outcome Y, the goal is to estimate tau(x) = E[Y(1) - Y(0) | X=x], the conditional average treatment effect (CATE).

Key challenges include:

  • Confounding: Treatment assignment depends on covariates, so naive comparisons are biased
  • Heterogeneity: Treatment effects vary across the covariate space in complex, nonlinear ways
  • Model misspecification: The true response surfaces may not match parametric assumptions
  • Double robustness: Ideally, the estimator is consistent if either the outcome model or propensity model is correct

Classical approaches include S-Learner (single model), T-Learner (separate models), and IPW (propensity reweighting). Modern SOTA methods include Causal Forests (Athey & Wager, 2018), DR-Learner (Kennedy, 2023), and R-Learner (Nie & Wager, 2021), which use orthogonalization/debiasing to achieve better convergence rates.

Task

Modify the CATEEstimator class in custom_cate.py. Your estimator must implement:

  • fit(X, T, Y) -> self: Learn from observational data
  • predict(X) -> tau_hat: Predict individual treatment effects

You have access to scikit-learn and numpy/scipy.

Evaluation

Evaluated on three semi-synthetic benchmarks with known ground-truth treatment effects:

  • IHDP: Infant Health and Development Program (n=747, p=25, nonlinear effects)
  • Jobs: Job training program evaluation (n=2000, p=10, economic outcomes)
  • ACIC: Atlantic Causal Inference Conference simulation (n=4000, p=50, complex confounding)

Metrics (lower is better for both):

  • PEHE: Precision in Estimation of Heterogeneous Effects = sqrt(mean((tau_hat - tau_true)^2))
  • ATE error: |mean(tau_hat) - ATE_true|

Each dataset is evaluated with 5-fold cross-fitting over 10 repetitions with different random seeds.

Code

custom_cate.py
EditableRead-only
1# Custom CATE Estimator for MLS-Bench
2#
3# EDITABLE section: CATEEstimator class (the treatment effect estimator).
4# FIXED sections: everything else (data generation, evaluation, CLI).
5#
6# Research question: Design a novel estimator for Conditional Average Treatment
7# Effects (CATE) from observational data that generalizes across datasets.
8
9import os
10import argparse
11import json
12import time
13import warnings
14from abc import ABC, abstractmethod
15

Results

ModelTypePEHE ihdp ATE error ihdp PEHE jobs ATE error jobs PEHE acic ATE error acic
causal_forestbaseline0.7710.075358.59748.5470.4990.017
dr_learnerbaseline1.3800.081552.06571.8900.4910.025
ipwbaseline2.1020.4265221.204926.5301.2890.038
r_learnerbaseline0.8700.071476.03150.0990.4280.021
s_learnerbaseline0.8030.109391.24159.3650.5240.074
t_learnerbaseline1.1910.128648.91135.3080.7720.029
anthropic/claude-opus-4.6vanilla------
google/gemini-3.1-pro-previewvanilla------
gpt-5.4-provanilla------
anthropic/claude-opus-4.6agent------
google/gemini-3.1-pro-previewagent------
google/gemini-3.1-pro-previewagent------
gpt-5.4-proagent------
gpt-5.4-proagent------

Agent Conversations