causal-observational-nonlinear
Description
Causal Discovery: Observational Nonlinear Data
Objective
Implement a causal discovery algorithm that recovers the DAG structure from purely observational data generated by a nonlinear Additive Noise Model (ANM). Your code goes in bench/custom_algorithm.py.
Background
Additive Noise Models (ANMs) assume structural equations of the form x_j = f_j(parents(j)) + e_j, where f_j are nonlinear functions and e_j are independent noise terms. Under mild conditions on f_j and the noise distribution, the causal DAG is identifiable from observational data alone (Hoyer et al., 2008; Peters et al., 2014). This goes beyond linear methods (LiNGAM, PC, GES) that assume linearity.
Key challenges include: (1) testing independence in high-dimensional nonlinear settings, (2) scaling to larger graphs, (3) handling diverse nonlinear function types (MLP, GP, polynomial, sigmoid), and (4) robustness to different noise distributions including the harder Gaussian noise case.
Evaluation Scenarios
| Label | Graph type | Nodes | Samples | Noise | Nonlinearity |
|---|---|---|---|---|---|
| ER8-MLP | Erdos-Renyi | 8 | 500 | Exponential | MLP |
| ER12-GP | Erdos-Renyi | 12 | 1000 | Laplace | GP |
| SF10-Mixed | Scale-Free (BA) | 10 | 500 | Uniform | Mixed |
| ER15-Sigmoid | Erdos-Renyi | 15 | 1000 | Exponential | Sigmoid |
| ER20-Mixed | Erdos-Renyi | 20 | 2000 | Laplace | Mixed |
| SF20-GP | Scale-Free (BA) | 20 | 2000 | Exponential | GP |
| ER12-LowSample | Erdos-Renyi | 12 | 150 | Laplace | Mixed |
| ER20-Gauss | Erdos-Renyi | 20 | 2000 | Gaussian | Mixed |
Metrics
All computed on the directed edge set (skeleton + direction must be correct):
- F1 (primary ranking metric), SHD, Precision, Recall
Baselines
cam: Causal Additive Models (Buehlmann et al., 2014) -- score-based with GAM regressionnotears_mlp: NOTEARS with MLP parametrization (Zheng et al., 2020) -- continuous optimizationdirectlingam: DirectLiNGAM (Shimizu et al., 2011) -- linear reference baseline
Code
1import numpy as np23# =====================================================================4# EDITABLE: implement run_causal_discovery below5# =====================================================================6def run_causal_discovery(X: np.ndarray) -> np.ndarray:7"""8Input: X of shape (n_samples, n_variables)9Output: adjacency matrix B of shape (n_variables, n_variables)10B[i, j] != 0 means j -> i (follows causal-learn convention)11"""12n = X.shape[1]13return np.zeros((n, n))14# =====================================================================15
1"""Evaluation harness for the causal-observational-nonlinear task."""2import argparse3import os4import sys56sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))78from data_gen import simulate_nonlinear_anm9from metrics import compute_metrics10from custom_algorithm import run_causal_discovery111213def main():14parser = argparse.ArgumentParser(15description="Evaluate a causal discovery algorithm on synthetic nonlinear ANM data."
1"""Synthetic nonlinear additive noise model (ANM) data generator.23Structural equation: x_j = f_j(pa(j)) + e_j4where f_j is a nonlinear function and e_j is additive non-Gaussian noise.5"""6import numpy as np7import networkx as nx8910def simulate_dag(n_nodes, graph_type, seed, er_prob=0.3, sf_m=2):11"""Return a binary adjacency matrix for a random DAG.1213Convention: adj[i, j] = 1 means i -> j (i is a parent of j).14The DAG is enforced by keeping only edges i -> j with i < j, imposing15a topological ordering by node index.
1"""Evaluation metrics for directed causal graph recovery."""2import numpy as np345def compute_metrics(B_est, B_true, threshold=0.01):6"""Compute SHD, F1, precision, and recall for directed edge recovery.78Convention: B[i, j] != 0 means j -> i.910SHD definition (each type counts as exactly 1 error):11- Reversed edge : correct skeleton edge but wrong direction12- Extra edge : present in estimate but absent in truth (non-reversal)13- Missing edge : present in truth but absent in estimate (non-reversal)1415F1 / precision / recall are computed on the directed edge set
Results
| Model | Type | shd ER8-MLP ↓ | f1 ER8-MLP ↑ | precision ER8-MLP ↑ | recall ER8-MLP ↑ | shd SF10-Mixed ↓ | f1 SF10-Mixed ↑ | precision SF10-Mixed ↑ | recall SF10-Mixed ↑ | shd ER15-Sigmoid ↓ | f1 ER15-Sigmoid ↑ | precision ER15-Sigmoid ↑ | recall ER15-Sigmoid ↑ | shd SF20-GP ↓ | f1 SF20-GP ↑ | precision SF20-GP ↑ | recall SF20-GP ↑ | shd ER20-Gauss ↓ | f1 ER20-Gauss ↑ | precision ER20-Gauss ↑ | recall ER20-Gauss ↑ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| cam | baseline | 4.667 | 0.676 | 0.662 | 0.700 | 8.000 | 0.729 | 0.712 | 0.771 | 26.000 | 0.594 | 0.487 | 0.763 | - | - | - | - | - | - | - | - |
| directlingam | baseline | 12.667 | 0.293 | 0.249 | 0.405 | 12.667 | 0.537 | 0.484 | 0.604 | 3.333 | 0.947 | 0.928 | 0.968 | - | - | - | - | - | - | - | - |
| directlingam | baseline | 12.667 | 0.293 | 0.249 | 0.405 | 16.667 | 0.427 | 0.375 | 0.500 | 32.000 | 0.511 | 0.445 | 0.629 | 60.000 | 0.319 | 0.243 | 0.463 | 96.000 | 0.245 | 0.199 | 0.322 |
| grandag | baseline | 10.333 | 0.092 | 0.108 | 0.085 | 19.333 | 0.226 | 0.524 | 0.208 | 30.000 | 0.097 | 0.622 | 0.054 | 32.000 | 0.193 | 0.944 | 0.111 | 58.333 | 0.112 | 0.774 | 0.060 |
| notears_mlp | baseline | 8.000 | 0.000 | 0.000 | 0.000 | 16.000 | 0.000 | 0.000 | 0.000 | 30.667 | 0.000 | 0.000 | 0.000 | - | - | - | - | - | - | - | - |
| notears_mlp | baseline | 5.000 | 0.472 | 0.540 | 0.419 | 16.000 | 0.169 | 0.203 | 0.146 | 21.000 | 0.427 | 0.486 | 0.385 | 35.667 | 0.153 | 0.189 | 0.130 | 52.333 | 0.277 | 0.470 | 0.198 |
| anthropic/claude-opus-4.6 | vanilla | 8.000 | 0.444 | 0.364 | 0.571 | 16.000 | 0.400 | 0.429 | 0.375 | 20.000 | 0.677 | 0.564 | 0.846 | 16.000 | 0.720 | 0.692 | 0.750 | 52.000 | 0.358 | 0.396 | 0.328 |
| google/gemini-3.1-pro-preview | vanilla | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
| gpt-5.4-pro | vanilla | 6.000 | 0.632 | 0.500 | 0.857 | 13.000 | 0.634 | 0.520 | 0.813 | 29.000 | 0.587 | 0.449 | 0.846 | 37.000 | 0.557 | 0.443 | 0.750 | 62.000 | 0.387 | 0.364 | 0.414 |
| anthropic/claude-opus-4.6 | agent | 9.000 | 0.353 | 0.300 | 0.429 | 16.000 | 0.357 | 0.417 | 0.313 | 20.000 | 0.644 | 0.576 | 0.731 | 13.000 | 0.743 | 0.765 | 0.722 | 50.000 | 0.360 | 0.429 | 0.310 |
| google/gemini-3.1-pro-preview | agent | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
| google/gemini-3.1-pro-preview | agent | 4.000 | 0.706 | 0.600 | 0.857 | 16.000 | 0.622 | 0.483 | 0.875 | 17.000 | 0.754 | 0.605 | 1.000 | 17.000 | 0.740 | 0.730 | 0.750 | 79.000 | 0.355 | 0.287 | 0.466 |
| gpt-5.4-pro | agent | 6.000 | 0.632 | 0.500 | 0.857 | 13.000 | 0.634 | 0.520 | 0.813 | 29.000 | 0.587 | 0.449 | 0.846 | 37.000 | 0.557 | 0.443 | 0.750 | 62.000 | 0.387 | 0.364 | 0.414 |