causal-observational-linear-gaussian

Causal Inferencecausal-learnrigorous codebase

Description

Causal Discovery: Observational Linear Gaussian Data (CPDAG Recovery)

Objective

Implement a causal discovery algorithm that recovers the CPDAG from purely observational data generated by a linear Gaussian SEM. Your code goes in bench/custom_algorithm.py.

Background

Under linear Gaussian assumptions, observational data can only identify the Markov Equivalence Class (MEC), not a unique fully directed DAG. The MEC is represented by a CPDAG.

Interface

Your run_causal_discovery(X) implementation must return a causallearn.graph.GeneralGraph.GeneralGraph object representing the estimated CPDAG.

Evaluation Scenarios

LabelGraph typeNodesParamsSamplesNoise
ER10Erdos-Renyi10p=0.35001.0
ER20Erdos-Renyi20p=0.210001.0
SF50Scale-Free (BA)50m=220001.0
ER10-HardErdos-Renyi10p=0.5 (denser)2001.0
ER20-HardErdos-Renyi20p=0.35 (denser)4001.0
SF50-HardScale-Free (BA)50m=3 (denser)10001.0
ER10-NoisyErdos-Renyi10p=0.5 (denser)2002.5
ER20-NoisyErdos-Renyi20p=0.35 (denser)4002.5
SF50-NoisyScale-Free (BA)50m=3 (denser)10002.5

Metrics

Metrics are computed between estimated CPDAG and ground-truth CPDAG:

  • SHD (SHD(...).get_shd())
  • Adjacency Precision / Recall (AdjacencyConfusion)
  • Arrow Precision / Recall (ArrowConfusion)

Baselines

  • pc: Peter-Clark (constraint-based)
  • ges: Greedy Equivalence Search (score-based)
  • grasp: Greedy Relaxations of the Sparsest Permutation
  • boss: Best Order Score Search

Code

custom_algorithm.py
EditableRead-only
1import numpy as np
2from causallearn.graph.GeneralGraph import GeneralGraph
3from causallearn.graph.GraphNode import GraphNode
4
5# =====================================================================
6# EDITABLE: implement run_causal_discovery below
7# =====================================================================
8def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
9 """
10 Input: X of shape (n_samples, n_variables)
11 Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph
12 """
13 nodes = [GraphNode(f"X{i + 1}") for i in range(X.shape[1])]
14 return GeneralGraph(nodes)
15# =====================================================================
run_eval.py
EditableRead-only
1"""Evaluation harness for the causal-observational-linear-gaussian task."""
2import argparse
3import os
4import sys
5
6sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
7
8from custom_algorithm import run_causal_discovery
9from data_gen import simulate_linear_gaussian
10from metrics import compute_metrics
11
12
13def main():
14 parser = argparse.ArgumentParser(
15 description="Evaluate CPDAG recovery on synthetic linear Gaussian data."
data_gen.py
EditableRead-only
1"""Synthetic linear Gaussian DAG data generator for CPDAG benchmarking."""
2import networkx as nx
3import numpy as np
4
5from causallearn.graph.Dag import Dag
6from causallearn.graph.GraphNode import GraphNode
7
8
9def simulate_dag(n_nodes, graph_type, seed, er_prob=0.5, sf_m=2):
10 """Return a binary adjacency matrix with convention adj[i, j] = 1 for i -> j."""
11 rng = np.random.default_rng(seed)
12 graph_seed = int(rng.integers(0, 2**31 - 1))
13
14 if graph_type == "er":
15 graph = nx.erdos_renyi_graph(n_nodes, er_prob, seed=graph_seed, directed=True)
metrics.py
EditableRead-only
1"""Evaluation metrics for CPDAG recovery on linear Gaussian data."""
2from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion
3from causallearn.graph.ArrowConfusion import ArrowConfusion
4from causallearn.graph.Graph import Graph
5from causallearn.graph.SHD import SHD
6from causallearn.utils.DAG2CPDAG import dag2cpdag
7
8
9def _safe_div(numerator, denominator):
10 return numerator / denominator if denominator > 0 else 0.0
11
12
13def _normalize_graph_output(graph_output):
14 """Normalize algorithm output to a causallearn Graph object."""
15 if isinstance(graph_output, dict) and "G" in graph_output:

Results

ModelTypeshd ER10 adj precision ER10 adj recall ER10 arrow precision ER10 arrow recall ER10 shd ER20 adj precision ER20 adj recall ER20 arrow precision ER20 arrow recall ER20 shd SF50 adj precision SF50 adj recall SF50 arrow precision SF50 arrow recall SF50 shd SF50-Hard adj precision SF50-Hard adj recall SF50-Hard arrow precision SF50-Hard arrow recall SF50-Hard shd ER20-Noisy adj precision ER20-Noisy adj recall ER20-Noisy arrow precision ER20-Noisy arrow recall ER20-Noisy
bossbaseline0.0001.0001.0001.0001.0000.0001.0001.0001.0001.00011.0000.9790.9790.9141.000-----7.0000.9850.9270.9690.912
bossbaseline---------------7.0000.9860.9790.9710.964-----
bossbaseline0.0001.0001.0001.0001.0000.0001.0001.0001.0001.00011.0000.9790.9790.9141.0006.0000.9860.9860.9710.9717.0000.9850.9270.9690.912
gesbaseline12.0000.8000.8570.3000.27330.0000.6790.9270.4810.735---------------
gesbaseline--------------------95.0000.5110.6960.2170.294
gesbaseline-------------------------
gesbaseline12.0000.8000.8570.3000.27330.0000.6790.9270.4810.735----------95.0000.5110.6960.2170.294
graspbaseline0.0001.0001.0001.0001.0000.0001.0001.0001.0001.00020.0000.9900.9790.8041.00015.0000.9860.9790.9770.91347.0000.7470.8550.5450.618
graspbaseline0.0001.0001.0001.0001.0000.0001.0001.0001.0001.00026.0000.9790.9790.7500.97325.0000.9510.9650.8970.88458.0000.6740.8410.4590.574
pcbaseline8.0001.0000.8570.5000.54531.0000.8570.5850.5000.26593.0000.8030.5100.3110.189----------
pcbaseline---------------143.0000.6270.2980.4040.167-----
pcbaseline--------------------69.0000.8500.2460.1880.044
pcbaseline8.0001.0000.8570.5000.54531.0000.8570.5850.5000.26593.0000.8030.5100.3110.189143.0000.6270.2980.4040.16769.0000.8500.2460.1880.044
anthropic/claude-opus-4.6vanilla0.0001.0001.0001.0001.0000.0001.0001.0001.0001.00013.0000.9321.0000.8470.97322.0000.8921.0000.8650.9717.0000.9430.9570.9420.956
google/gemini-3.1-pro-previewvanilla0.0001.0001.0001.0001.0005.0000.9300.9760.8920.97117.0000.9500.9900.9690.83848.0000.8510.9720.7930.8339.0000.9300.9570.9140.941
gpt-5.4-provanilla11.0001.0000.7860.2730.27338.0000.9580.5610.1300.08891.0000.8970.5420.1960.149202.0000.3860.3830.1790.18171.0000.8400.3040.0800.029
anthropic/claude-opus-4.6agent0.0001.0001.0001.0001.0004.0000.9520.9760.9170.9710.0001.0001.0001.0001.00025.0000.8970.9860.8490.9355.0000.9710.9710.9560.956
google/gemini-3.1-pro-previewagent0.0001.0001.0001.0001.0008.0000.8541.0000.8250.97175.0000.6401.0000.4570.797149.0000.5340.9570.4330.7756.0000.9570.9710.9420.956
gpt-5.4-proagent11.0001.0000.7860.2730.27338.0000.9580.5610.1300.08891.0000.8970.5420.1960.149-----71.0000.8400.3040.0800.029

Agent Conversations