causal-observational-linear-gaussian
Causal Inferencecausal-learnrigorous codebase
Description
Causal Discovery: Observational Linear Gaussian Data (CPDAG Recovery)
Objective
Implement a causal discovery algorithm that recovers the CPDAG from purely observational data generated by a linear Gaussian SEM. Your code goes in bench/custom_algorithm.py.
Background
Under linear Gaussian assumptions, observational data can only identify the Markov Equivalence Class (MEC), not a unique fully directed DAG. The MEC is represented by a CPDAG.
Interface
Your run_causal_discovery(X) implementation must return a causallearn.graph.GeneralGraph.GeneralGraph object representing the estimated CPDAG.
Evaluation Scenarios
| Label | Graph type | Nodes | Params | Samples | Noise |
|---|---|---|---|---|---|
| ER10 | Erdos-Renyi | 10 | p=0.3 | 500 | 1.0 |
| ER20 | Erdos-Renyi | 20 | p=0.2 | 1000 | 1.0 |
| SF50 | Scale-Free (BA) | 50 | m=2 | 2000 | 1.0 |
| ER10-Hard | Erdos-Renyi | 10 | p=0.5 (denser) | 200 | 1.0 |
| ER20-Hard | Erdos-Renyi | 20 | p=0.35 (denser) | 400 | 1.0 |
| SF50-Hard | Scale-Free (BA) | 50 | m=3 (denser) | 1000 | 1.0 |
| ER10-Noisy | Erdos-Renyi | 10 | p=0.5 (denser) | 200 | 2.5 |
| ER20-Noisy | Erdos-Renyi | 20 | p=0.35 (denser) | 400 | 2.5 |
| SF50-Noisy | Scale-Free (BA) | 50 | m=3 (denser) | 1000 | 2.5 |
Metrics
Metrics are computed between estimated CPDAG and ground-truth CPDAG:
- SHD (
SHD(...).get_shd()) - Adjacency Precision / Recall (
AdjacencyConfusion) - Arrow Precision / Recall (
ArrowConfusion)
Baselines
pc: Peter-Clark (constraint-based)ges: Greedy Equivalence Search (score-based)grasp: Greedy Relaxations of the Sparsest Permutationboss: Best Order Score Search
Code
custom_algorithm.py
EditableRead-only
1import numpy as np2from causallearn.graph.GeneralGraph import GeneralGraph3from causallearn.graph.GraphNode import GraphNode45# =====================================================================6# EDITABLE: implement run_causal_discovery below7# =====================================================================8def run_causal_discovery(X: np.ndarray) -> GeneralGraph:9"""10Input: X of shape (n_samples, n_variables)11Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph12"""13nodes = [GraphNode(f"X{i + 1}") for i in range(X.shape[1])]14return GeneralGraph(nodes)15# =====================================================================
run_eval.py
EditableRead-only
1"""Evaluation harness for the causal-observational-linear-gaussian task."""2import argparse3import os4import sys56sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))78from custom_algorithm import run_causal_discovery9from data_gen import simulate_linear_gaussian10from metrics import compute_metrics111213def main():14parser = argparse.ArgumentParser(15description="Evaluate CPDAG recovery on synthetic linear Gaussian data."
data_gen.py
EditableRead-only
1"""Synthetic linear Gaussian DAG data generator for CPDAG benchmarking."""2import networkx as nx3import numpy as np45from causallearn.graph.Dag import Dag6from causallearn.graph.GraphNode import GraphNode789def simulate_dag(n_nodes, graph_type, seed, er_prob=0.5, sf_m=2):10"""Return a binary adjacency matrix with convention adj[i, j] = 1 for i -> j."""11rng = np.random.default_rng(seed)12graph_seed = int(rng.integers(0, 2**31 - 1))1314if graph_type == "er":15graph = nx.erdos_renyi_graph(n_nodes, er_prob, seed=graph_seed, directed=True)
metrics.py
EditableRead-only
1"""Evaluation metrics for CPDAG recovery on linear Gaussian data."""2from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion3from causallearn.graph.ArrowConfusion import ArrowConfusion4from causallearn.graph.Graph import Graph5from causallearn.graph.SHD import SHD6from causallearn.utils.DAG2CPDAG import dag2cpdag789def _safe_div(numerator, denominator):10return numerator / denominator if denominator > 0 else 0.0111213def _normalize_graph_output(graph_output):14"""Normalize algorithm output to a causallearn Graph object."""15if isinstance(graph_output, dict) and "G" in graph_output:
Results
| Model | Type | shd ER10 ↓ | adj precision ER10 ↑ | adj recall ER10 ↑ | arrow precision ER10 ↑ | arrow recall ER10 ↑ | shd ER20 ↓ | adj precision ER20 ↑ | adj recall ER20 ↑ | arrow precision ER20 ↑ | arrow recall ER20 ↑ | shd SF50 ↓ | adj precision SF50 ↑ | adj recall SF50 ↑ | arrow precision SF50 ↑ | arrow recall SF50 ↑ | shd SF50-Hard ↓ | adj precision SF50-Hard ↑ | adj recall SF50-Hard ↑ | arrow precision SF50-Hard ↑ | arrow recall SF50-Hard ↑ | shd ER20-Noisy ↓ | adj precision ER20-Noisy ↑ | adj recall ER20-Noisy ↑ | arrow precision ER20-Noisy ↑ | arrow recall ER20-Noisy ↑ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| boss | baseline | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 11.000 | 0.979 | 0.979 | 0.914 | 1.000 | - | - | - | - | - | 7.000 | 0.985 | 0.927 | 0.969 | 0.912 |
| boss | baseline | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | 7.000 | 0.986 | 0.979 | 0.971 | 0.964 | - | - | - | - | - |
| boss | baseline | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 11.000 | 0.979 | 0.979 | 0.914 | 1.000 | 6.000 | 0.986 | 0.986 | 0.971 | 0.971 | 7.000 | 0.985 | 0.927 | 0.969 | 0.912 |
| ges | baseline | 12.000 | 0.800 | 0.857 | 0.300 | 0.273 | 30.000 | 0.679 | 0.927 | 0.481 | 0.735 | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
| ges | baseline | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | 95.000 | 0.511 | 0.696 | 0.217 | 0.294 |
| ges | baseline | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - |
| ges | baseline | 12.000 | 0.800 | 0.857 | 0.300 | 0.273 | 30.000 | 0.679 | 0.927 | 0.481 | 0.735 | - | - | - | - | - | - | - | - | - | - | 95.000 | 0.511 | 0.696 | 0.217 | 0.294 |
| grasp | baseline | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 20.000 | 0.990 | 0.979 | 0.804 | 1.000 | 15.000 | 0.986 | 0.979 | 0.977 | 0.913 | 47.000 | 0.747 | 0.855 | 0.545 | 0.618 |
| grasp | baseline | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 26.000 | 0.979 | 0.979 | 0.750 | 0.973 | 25.000 | 0.951 | 0.965 | 0.897 | 0.884 | 58.000 | 0.674 | 0.841 | 0.459 | 0.574 |
| pc | baseline | 8.000 | 1.000 | 0.857 | 0.500 | 0.545 | 31.000 | 0.857 | 0.585 | 0.500 | 0.265 | 93.000 | 0.803 | 0.510 | 0.311 | 0.189 | - | - | - | - | - | - | - | - | - | - |
| pc | baseline | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | 143.000 | 0.627 | 0.298 | 0.404 | 0.167 | - | - | - | - | - |
| pc | baseline | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | - | 69.000 | 0.850 | 0.246 | 0.188 | 0.044 |
| pc | baseline | 8.000 | 1.000 | 0.857 | 0.500 | 0.545 | 31.000 | 0.857 | 0.585 | 0.500 | 0.265 | 93.000 | 0.803 | 0.510 | 0.311 | 0.189 | 143.000 | 0.627 | 0.298 | 0.404 | 0.167 | 69.000 | 0.850 | 0.246 | 0.188 | 0.044 |
| anthropic/claude-opus-4.6 | vanilla | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 13.000 | 0.932 | 1.000 | 0.847 | 0.973 | 22.000 | 0.892 | 1.000 | 0.865 | 0.971 | 7.000 | 0.943 | 0.957 | 0.942 | 0.956 |
| google/gemini-3.1-pro-preview | vanilla | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 5.000 | 0.930 | 0.976 | 0.892 | 0.971 | 17.000 | 0.950 | 0.990 | 0.969 | 0.838 | 48.000 | 0.851 | 0.972 | 0.793 | 0.833 | 9.000 | 0.930 | 0.957 | 0.914 | 0.941 |
| gpt-5.4-pro | vanilla | 11.000 | 1.000 | 0.786 | 0.273 | 0.273 | 38.000 | 0.958 | 0.561 | 0.130 | 0.088 | 91.000 | 0.897 | 0.542 | 0.196 | 0.149 | 202.000 | 0.386 | 0.383 | 0.179 | 0.181 | 71.000 | 0.840 | 0.304 | 0.080 | 0.029 |
| anthropic/claude-opus-4.6 | agent | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 4.000 | 0.952 | 0.976 | 0.917 | 0.971 | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 25.000 | 0.897 | 0.986 | 0.849 | 0.935 | 5.000 | 0.971 | 0.971 | 0.956 | 0.956 |
| google/gemini-3.1-pro-preview | agent | 0.000 | 1.000 | 1.000 | 1.000 | 1.000 | 8.000 | 0.854 | 1.000 | 0.825 | 0.971 | 75.000 | 0.640 | 1.000 | 0.457 | 0.797 | 149.000 | 0.534 | 0.957 | 0.433 | 0.775 | 6.000 | 0.957 | 0.971 | 0.942 | 0.956 |
| gpt-5.4-pro | agent | 11.000 | 1.000 | 0.786 | 0.273 | 0.273 | 38.000 | 0.958 | 0.561 | 0.130 | 0.088 | 91.000 | 0.897 | 0.542 | 0.196 | 0.149 | - | - | - | - | - | 71.000 | 0.840 | 0.304 | 0.080 | 0.029 |