causal-discovery-discrete

Causal Inferencecausal-bnlearnrigorous codebase

Description

Causal Discovery on Real-World Bayesian Network Datasets (bnlearn)

Research Question

Design a causal discovery algorithm that recovers the CPDAG (Completed Partially Directed Acyclic Graph) from purely observational discrete data sampled from real-world Bayesian networks in the bnlearn repository.

Background

The bnlearn repository (https://www.bnlearn.com/bnrepository/) hosts a collection of well-known Bayesian network benchmarks from diverse domains (medicine, biology, meteorology, insurance, agriculture, IT). Each network has a known ground-truth DAG with discrete variables and conditional probability tables. Given observational samples from these networks, the task is to recover the causal structure.

Under the faithfulness assumption, observational data can identify the Markov Equivalence Class (MEC) of the true DAG, represented by a CPDAG. The challenge lies in handling discrete data with varying cardinalities, network sizes (5–76 nodes), and edge densities.

Task

Implement a causal discovery algorithm in bench/custom_algorithm.py. Your run_causal_discovery(X) function receives integer-encoded discrete observational data and must return the estimated CPDAG as a causallearn.graph.GeneralGraph.GeneralGraph object.

Interface

def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
    """
    Input:  X of shape (n_samples, n_variables), integer-encoded discrete data
    Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph
    """

Evaluation Scenarios

Small Networks (<20 nodes)

LabelNetworkNodesEdgesSamplesDomain
CancerCancer54500Medical
EarthquakeEarthquake54500Seismology
SurveySurvey66500Social science
AsiaAsia881000Medical (lung diseases)
SachsSachs11171000Biology (protein signaling)

Medium Networks (20–50 nodes)

LabelNetworkNodesEdgesSamplesDomain
ChildChild20252000Medical
InsuranceInsurance27525000Automotive insurance
WaterWater32665000Water treatment
MildewMildew35465000Agriculture (crop disease)
AlarmAlarm37465000Medical monitoring
BarleyBarley488410000Agriculture

Large Networks (50–100 nodes)

LabelNetworkNodesEdgesSamplesDomain
HailfinderHailfinder566610000Meteorology
Hepar2Hepar27012310000Medical (liver disorders)
Win95ptsWin95pts7611210000IT (Windows troubleshooting)

Metrics

Metrics are computed between estimated CPDAG and ground-truth CPDAG (converted from the true DAG via dag2cpdag):

  • SHD (Structural Hamming Distance): total edge errors (lower is better)
  • Adjacency Precision / Recall: skeleton recovery quality
  • Arrow Precision / Recall: edge orientation accuracy

Baselines

  • pc: Peter-Clark algorithm with chi-squared CI test (constraint-based)
  • ges: Greedy Equivalence Search with BDeu score (score-based)
  • grasp: Greedy Relaxations of the Sparsest Permutation with BDeu (permutation-based, SOTA)
  • boss: Best Order Score Search with BDeu (permutation-based, SOTA)
  • hc: Hill-Climbing search with BDeu score (score-based)

Code

custom_algorithm.py
EditableRead-only
1import numpy as np
2from causallearn.graph.GeneralGraph import GeneralGraph
3from causallearn.graph.GraphNode import GraphNode
4
5# =====================================================================
6# EDITABLE: implement run_causal_discovery below
7# =====================================================================
8def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
9 """
10 Input: X of shape (n_samples, n_variables), integer-encoded discrete data
11 Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph
12 """
13 nodes = [GraphNode(f"X{i + 1}") for i in range(X.shape[1])]
14 return GeneralGraph(nodes)
15# =====================================================================
run_eval.py
EditableRead-only
1"""Evaluation harness for the causal-discovery-discrete task."""
2import argparse
3import os
4import sys
5
6sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
7
8from custom_algorithm import run_causal_discovery
9from data_gen import load_and_sample
10from metrics import compute_metrics
11
12
13def main():
14 parser = argparse.ArgumentParser(
15 description="Evaluate CPDAG recovery on bnlearn discrete data."
data_gen.py
EditableRead-only
1"""Data loader for bnlearn Bayesian network benchmarks.
2
3Uses pgmpy's bundled BIF files (no network access needed) to load
4real-world Bayesian networks, sample discrete observational data,
5and extract the ground-truth DAG.
6"""
7import numpy as np
8import pandas as pd
9
10from causallearn.graph.Dag import Dag
11from causallearn.graph.GraphNode import GraphNode
12
13# All discrete bnlearn networks supported by pgmpy
14SUPPORTED_NETWORKS = [
15 "cancer", "earthquake", "survey", "asia", "sachs",
metrics.py
EditableRead-only
1"""Evaluation metrics for CPDAG recovery on bnlearn discrete data."""
2from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion
3from causallearn.graph.ArrowConfusion import ArrowConfusion
4from causallearn.graph.Graph import Graph
5from causallearn.graph.SHD import SHD
6from causallearn.utils.DAG2CPDAG import dag2cpdag
7
8
9def _safe_div(numerator, denominator):
10 return numerator / denominator if denominator > 0 else 0.0
11
12
13def _normalize_graph_output(graph_output):
14 """Normalize algorithm output to a causallearn Graph object."""
15 if isinstance(graph_output, dict) and "G" in graph_output:

Results

ModelTypeshd Cancer adj precision Cancer adj recall Cancer arrow precision Cancer arrow recall Cancer shd Child adj precision Child adj recall Child arrow precision Child arrow recall Child shd Alarm adj precision Alarm adj recall Alarm arrow precision Alarm arrow recall Alarm shd Hailfinder adj precision Hailfinder adj recall Hailfinder arrow precision Hailfinder arrow recall Hailfinder shd Win95pts adj precision Win95pts adj recall Win95pts arrow precision Win95pts arrow recall Win95pts
bossbaseline4.0001.0000.3330.0000.0006.3330.9860.9200.9440.64123.3330.8970.9420.9600.56354.0000.6930.7170.8060.463-----
gesbaseline4.0001.0000.3330.0000.0002.3331.0000.9471.0000.8722.6670.9930.9570.9840.94439.6670.8450.8230.5580.72148.6670.8010.8270.7260.777
graspbaseline4.0001.0000.3330.0000.0009.0000.9560.8670.9050.5133.3330.9850.9490.9760.93755.0000.6750.7070.5770.42962.0000.7830.8040.6540.680
hcbaseline4.0001.0000.3330.0000.0008.3330.9190.9071.0000.64137.0000.7410.9130.4740.50064.3330.7270.7930.4390.415100.0000.6350.7920.4660.553
pcbaseline4.0000.8330.3330.1670.08312.6670.9860.8800.5110.7448.0000.9560.9130.9250.86545.3330.8580.5150.8740.51751.6670.9320.6930.8430.590
pcbaseline4.0000.8330.3330.1670.08312.6670.9860.8800.5110.7448.0000.9560.9130.9250.86545.3330.8580.5150.8740.51751.6670.9320.6930.8430.590
anthropic/claude-opus-4.6vanilla4.0001.0000.5000.0000.0005.0001.0000.9201.0000.6156.0001.0000.8910.9730.85734.0001.0000.5150.9390.63343.0000.9520.7050.8670.650
google/gemini-3.1-pro-previewvanilla4.0000.0000.0000.0000.00025.0001.0000.0400.0000.00046.0000.0000.0000.0000.00068.0000.4000.0300.2000.020117.0000.1670.0090.0000.000
gpt-5.4-provanilla3.0001.0000.5000.5000.25011.0001.0000.8800.5560.7694.0000.9780.9570.9510.92939.0000.8460.8330.5630.73549.0000.8000.8210.7260.770
anthropic/claude-opus-4.6agent4.0001.0000.5000.0000.0005.0001.0000.9201.0000.6156.0001.0000.8910.9730.85734.0001.0000.5150.9390.63343.0000.9520.7050.8670.650
google/gemini-3.1-pro-previewagent9.0000.2500.5000.1670.250151.0000.0380.2000.0000.000574.0000.0130.1520.0070.095476.0000.0240.1510.0140.1222582.0000.0100.2230.0040.100
gpt-5.4-proagent4.0001.0000.5000.0000.0008.0001.0000.8801.0000.3854.0000.9780.9570.9510.92939.0000.8460.8330.5630.73549.0000.8000.8210.7260.770

Agent Conversations