causal-observational-nonlinear

Causal Inferencecausal-learnrigorous codebase

Description

Causal Discovery: Observational Nonlinear Data

Objective

Implement a causal discovery algorithm that recovers the DAG structure from purely observational data generated by a nonlinear Additive Noise Model (ANM). Your code goes in bench/custom_algorithm.py.

Background

Additive Noise Models (ANMs) assume structural equations of the form x_j = f_j(parents(j)) + e_j, where f_j are nonlinear functions and e_j are independent noise terms. Under mild conditions on f_j and the noise distribution, the causal DAG is identifiable from observational data alone (Hoyer et al., 2008; Peters et al., 2014). This goes beyond linear methods (LiNGAM, PC, GES) that assume linearity.

Key challenges include: (1) testing independence in high-dimensional nonlinear settings, (2) scaling to larger graphs, (3) handling diverse nonlinear function types (MLP, GP, polynomial, sigmoid), and (4) robustness to different noise distributions including the harder Gaussian noise case.

Evaluation Scenarios

LabelGraph typeNodesSamplesNoiseNonlinearity
ER8-MLPErdos-Renyi8500ExponentialMLP
ER12-GPErdos-Renyi121000LaplaceGP
SF10-MixedScale-Free (BA)10500UniformMixed
ER15-SigmoidErdos-Renyi151000ExponentialSigmoid
ER20-MixedErdos-Renyi202000LaplaceMixed
SF20-GPScale-Free (BA)202000ExponentialGP
ER12-LowSampleErdos-Renyi12150LaplaceMixed
ER20-GaussErdos-Renyi202000GaussianMixed

Metrics

All computed on the directed edge set (skeleton + direction must be correct):

  • F1 (primary ranking metric), SHD, Precision, Recall

Baselines

  • cam: Causal Additive Models (Buehlmann et al., 2014) -- score-based with GAM regression
  • notears_mlp: NOTEARS with MLP parametrization (Zheng et al., 2020) -- continuous optimization
  • directlingam: DirectLiNGAM (Shimizu et al., 2011) -- linear reference baseline

Code

custom_algorithm.py
EditableRead-only
1import numpy as np
2
3# =====================================================================
4# EDITABLE: implement run_causal_discovery below
5# =====================================================================
6def run_causal_discovery(X: np.ndarray) -> np.ndarray:
7 """
8 Input: X of shape (n_samples, n_variables)
9 Output: adjacency matrix B of shape (n_variables, n_variables)
10 B[i, j] != 0 means j -> i (follows causal-learn convention)
11 """
12 n = X.shape[1]
13 return np.zeros((n, n))
14# =====================================================================
15
run_eval.py
EditableRead-only
1"""Evaluation harness for the causal-observational-nonlinear task."""
2import argparse
3import os
4import sys
5
6sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
7
8from data_gen import simulate_nonlinear_anm
9from metrics import compute_metrics
10from custom_algorithm import run_causal_discovery
11
12
13def main():
14 parser = argparse.ArgumentParser(
15 description="Evaluate a causal discovery algorithm on synthetic nonlinear ANM data."
data_gen.py
EditableRead-only
1"""Synthetic nonlinear additive noise model (ANM) data generator.
2
3Structural equation: x_j = f_j(pa(j)) + e_j
4where f_j is a nonlinear function and e_j is additive non-Gaussian noise.
5"""
6import numpy as np
7import networkx as nx
8
9
10def simulate_dag(n_nodes, graph_type, seed, er_prob=0.3, sf_m=2):
11 """Return a binary adjacency matrix for a random DAG.
12
13 Convention: adj[i, j] = 1 means i -> j (i is a parent of j).
14 The DAG is enforced by keeping only edges i -> j with i < j, imposing
15 a topological ordering by node index.
metrics.py
EditableRead-only
1"""Evaluation metrics for directed causal graph recovery."""
2import numpy as np
3
4
5def compute_metrics(B_est, B_true, threshold=0.01):
6 """Compute SHD, F1, precision, and recall for directed edge recovery.
7
8 Convention: B[i, j] != 0 means j -> i.
9
10 SHD definition (each type counts as exactly 1 error):
11 - Reversed edge : correct skeleton edge but wrong direction
12 - Extra edge : present in estimate but absent in truth (non-reversal)
13 - Missing edge : present in truth but absent in estimate (non-reversal)
14
15 F1 / precision / recall are computed on the directed edge set

Results

ModelTypeshd ER8-MLP f1 ER8-MLP precision ER8-MLP recall ER8-MLP shd SF10-Mixed f1 SF10-Mixed precision SF10-Mixed recall SF10-Mixed shd ER15-Sigmoid f1 ER15-Sigmoid precision ER15-Sigmoid recall ER15-Sigmoid shd SF20-GP f1 SF20-GP precision SF20-GP recall SF20-GP shd ER20-Gauss f1 ER20-Gauss precision ER20-Gauss recall ER20-Gauss
cambaseline4.6670.6760.6620.7008.0000.7290.7120.77126.0000.5940.4870.763--------
directlingambaseline12.6670.2930.2490.40512.6670.5370.4840.6043.3330.9470.9280.968--------
directlingambaseline12.6670.2930.2490.40516.6670.4270.3750.50032.0000.5110.4450.62960.0000.3190.2430.46396.0000.2450.1990.322
grandagbaseline10.3330.0920.1080.08519.3330.2260.5240.20830.0000.0970.6220.05432.0000.1930.9440.11158.3330.1120.7740.060
notears_mlpbaseline8.0000.0000.0000.00016.0000.0000.0000.00030.6670.0000.0000.000--------
notears_mlpbaseline5.0000.4720.5400.41916.0000.1690.2030.14621.0000.4270.4860.38535.6670.1530.1890.13052.3330.2770.4700.198
anthropic/claude-opus-4.6vanilla8.0000.4440.3640.57116.0000.4000.4290.37520.0000.6770.5640.84616.0000.7200.6920.75052.0000.3580.3960.328
google/gemini-3.1-pro-previewvanilla--------------------
gpt-5.4-provanilla6.0000.6320.5000.85713.0000.6340.5200.81329.0000.5870.4490.84637.0000.5570.4430.75062.0000.3870.3640.414
anthropic/claude-opus-4.6agent9.0000.3530.3000.42916.0000.3570.4170.31320.0000.6440.5760.73113.0000.7430.7650.72250.0000.3600.4290.310
google/gemini-3.1-pro-previewagent--------------------
google/gemini-3.1-pro-previewagent4.0000.7060.6000.85716.0000.6220.4830.87517.0000.7540.6051.00017.0000.7400.7300.75079.0000.3550.2870.466
gpt-5.4-proagent6.0000.6320.5000.85713.0000.6340.5200.81329.0000.5870.4490.84637.0000.5570.4430.75062.0000.3870.3640.414

Agent Conversations