causal-discovery-discrete
Description
Causal Discovery on Real-World Bayesian Network Datasets (bnlearn)
Research Question
Design a causal discovery algorithm that recovers the CPDAG (Completed Partially Directed Acyclic Graph) from purely observational discrete data sampled from real-world Bayesian networks in the bnlearn repository.
Background
The bnlearn repository (https://www.bnlearn.com/bnrepository/) hosts a collection of well-known Bayesian network benchmarks from diverse domains (medicine, biology, meteorology, insurance, agriculture, IT). Each network has a known ground-truth DAG with discrete variables and conditional probability tables. Given observational samples from these networks, the task is to recover the causal structure.
Under the faithfulness assumption, observational data can identify the Markov Equivalence Class (MEC) of the true DAG, represented by a CPDAG. The challenge lies in handling discrete data with varying cardinalities, network sizes (5–76 nodes), and edge densities.
Task
Implement a causal discovery algorithm in bench/custom_algorithm.py. Your run_causal_discovery(X) function receives integer-encoded discrete observational data and must return the estimated CPDAG as a causallearn.graph.GeneralGraph.GeneralGraph object.
Interface
def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
"""
Input: X of shape (n_samples, n_variables), integer-encoded discrete data
Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph
"""
Evaluation Scenarios
Small Networks (<20 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Cancer | Cancer | 5 | 4 | 500 | Medical |
| Earthquake | Earthquake | 5 | 4 | 500 | Seismology |
| Survey | Survey | 6 | 6 | 500 | Social science |
| Asia | Asia | 8 | 8 | 1000 | Medical (lung diseases) |
| Sachs | Sachs | 11 | 17 | 1000 | Biology (protein signaling) |
Medium Networks (20–50 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Child | Child | 20 | 25 | 2000 | Medical |
| Insurance | Insurance | 27 | 52 | 5000 | Automotive insurance |
| Water | Water | 32 | 66 | 5000 | Water treatment |
| Mildew | Mildew | 35 | 46 | 5000 | Agriculture (crop disease) |
| Alarm | Alarm | 37 | 46 | 5000 | Medical monitoring |
| Barley | Barley | 48 | 84 | 10000 | Agriculture |
Large Networks (50–100 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Hailfinder | Hailfinder | 56 | 66 | 10000 | Meteorology |
| Hepar2 | Hepar2 | 70 | 123 | 10000 | Medical (liver disorders) |
| Win95pts | Win95pts | 76 | 112 | 10000 | IT (Windows troubleshooting) |
Metrics
Metrics are computed between estimated CPDAG and ground-truth CPDAG (converted from the true DAG via dag2cpdag):
- SHD (Structural Hamming Distance): total edge errors (lower is better)
- Adjacency Precision / Recall: skeleton recovery quality
- Arrow Precision / Recall: edge orientation accuracy
Baselines
pc: Peter-Clark algorithm with chi-squared CI test (constraint-based)ges: Greedy Equivalence Search with BDeu score (score-based)grasp: Greedy Relaxations of the Sparsest Permutation with BDeu (permutation-based, SOTA)boss: Best Order Score Search with BDeu (permutation-based, SOTA)hc: Hill-Climbing search with BDeu score (score-based)
Code
1import numpy as np2from causallearn.graph.GeneralGraph import GeneralGraph3from causallearn.graph.GraphNode import GraphNode45# =====================================================================6# EDITABLE: implement run_causal_discovery below7# =====================================================================8def run_causal_discovery(X: np.ndarray) -> GeneralGraph:9"""10Input: X of shape (n_samples, n_variables), integer-encoded discrete data11Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph12"""13nodes = [GraphNode(f"X{i + 1}") for i in range(X.shape[1])]14return GeneralGraph(nodes)15# =====================================================================
1"""Evaluation harness for the causal-discovery-discrete task."""2import argparse3import os4import sys56sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))78from custom_algorithm import run_causal_discovery9from data_gen import load_and_sample10from metrics import compute_metrics111213def main():14parser = argparse.ArgumentParser(15description="Evaluate CPDAG recovery on bnlearn discrete data."
1"""Data loader for bnlearn Bayesian network benchmarks.23Uses pgmpy's bundled BIF files (no network access needed) to load4real-world Bayesian networks, sample discrete observational data,5and extract the ground-truth DAG.6"""7import numpy as np8import pandas as pd910from causallearn.graph.Dag import Dag11from causallearn.graph.GraphNode import GraphNode1213# All discrete bnlearn networks supported by pgmpy14SUPPORTED_NETWORKS = [15"cancer", "earthquake", "survey", "asia", "sachs",
1"""Evaluation metrics for CPDAG recovery on bnlearn discrete data."""2from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion3from causallearn.graph.ArrowConfusion import ArrowConfusion4from causallearn.graph.Graph import Graph5from causallearn.graph.SHD import SHD6from causallearn.utils.DAG2CPDAG import dag2cpdag789def _safe_div(numerator, denominator):10return numerator / denominator if denominator > 0 else 0.0111213def _normalize_graph_output(graph_output):14"""Normalize algorithm output to a causallearn Graph object."""15if isinstance(graph_output, dict) and "G" in graph_output:
Results
| Model | Type | shd Cancer ↓ | adj precision Cancer ↑ | adj recall Cancer ↑ | arrow precision Cancer ↑ | arrow recall Cancer ↑ | shd Child ↓ | adj precision Child ↑ | adj recall Child ↑ | arrow precision Child ↑ | arrow recall Child ↑ | shd Alarm ↓ | adj precision Alarm ↑ | adj recall Alarm ↑ | arrow precision Alarm ↑ | arrow recall Alarm ↑ | shd Hailfinder ↓ | adj precision Hailfinder ↑ | adj recall Hailfinder ↑ | arrow precision Hailfinder ↑ | arrow recall Hailfinder ↑ | shd Win95pts ↓ | adj precision Win95pts ↑ | adj recall Win95pts ↑ | arrow precision Win95pts ↑ | arrow recall Win95pts ↑ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| boss | baseline | 4.000 | 1.000 | 0.333 | 0.000 | 0.000 | 6.333 | 0.986 | 0.920 | 0.944 | 0.641 | 23.333 | 0.897 | 0.942 | 0.960 | 0.563 | 54.000 | 0.693 | 0.717 | 0.806 | 0.463 | - | - | - | - | - |
| ges | baseline | 4.000 | 1.000 | 0.333 | 0.000 | 0.000 | 2.333 | 1.000 | 0.947 | 1.000 | 0.872 | 2.667 | 0.993 | 0.957 | 0.984 | 0.944 | 39.667 | 0.845 | 0.823 | 0.558 | 0.721 | 48.667 | 0.801 | 0.827 | 0.726 | 0.777 |
| grasp | baseline | 4.000 | 1.000 | 0.333 | 0.000 | 0.000 | 9.000 | 0.956 | 0.867 | 0.905 | 0.513 | 3.333 | 0.985 | 0.949 | 0.976 | 0.937 | 55.000 | 0.675 | 0.707 | 0.577 | 0.429 | 62.000 | 0.783 | 0.804 | 0.654 | 0.680 |
| hc | baseline | 4.000 | 1.000 | 0.333 | 0.000 | 0.000 | 8.333 | 0.919 | 0.907 | 1.000 | 0.641 | 37.000 | 0.741 | 0.913 | 0.474 | 0.500 | 64.333 | 0.727 | 0.793 | 0.439 | 0.415 | 100.000 | 0.635 | 0.792 | 0.466 | 0.553 |
| pc | baseline | 4.000 | 0.833 | 0.333 | 0.167 | 0.083 | 12.667 | 0.986 | 0.880 | 0.511 | 0.744 | 8.000 | 0.956 | 0.913 | 0.925 | 0.865 | 45.333 | 0.858 | 0.515 | 0.874 | 0.517 | 51.667 | 0.932 | 0.693 | 0.843 | 0.590 |
| pc | baseline | 4.000 | 0.833 | 0.333 | 0.167 | 0.083 | 12.667 | 0.986 | 0.880 | 0.511 | 0.744 | 8.000 | 0.956 | 0.913 | 0.925 | 0.865 | 45.333 | 0.858 | 0.515 | 0.874 | 0.517 | 51.667 | 0.932 | 0.693 | 0.843 | 0.590 |
| anthropic/claude-opus-4.6 | vanilla | 4.000 | 1.000 | 0.500 | 0.000 | 0.000 | 5.000 | 1.000 | 0.920 | 1.000 | 0.615 | 6.000 | 1.000 | 0.891 | 0.973 | 0.857 | 34.000 | 1.000 | 0.515 | 0.939 | 0.633 | 43.000 | 0.952 | 0.705 | 0.867 | 0.650 |
| google/gemini-3.1-pro-preview | vanilla | 4.000 | 0.000 | 0.000 | 0.000 | 0.000 | 25.000 | 1.000 | 0.040 | 0.000 | 0.000 | 46.000 | 0.000 | 0.000 | 0.000 | 0.000 | 68.000 | 0.400 | 0.030 | 0.200 | 0.020 | 117.000 | 0.167 | 0.009 | 0.000 | 0.000 |
| gpt-5.4-pro | vanilla | 3.000 | 1.000 | 0.500 | 0.500 | 0.250 | 11.000 | 1.000 | 0.880 | 0.556 | 0.769 | 4.000 | 0.978 | 0.957 | 0.951 | 0.929 | 39.000 | 0.846 | 0.833 | 0.563 | 0.735 | 49.000 | 0.800 | 0.821 | 0.726 | 0.770 |
| anthropic/claude-opus-4.6 | agent | 4.000 | 1.000 | 0.500 | 0.000 | 0.000 | 5.000 | 1.000 | 0.920 | 1.000 | 0.615 | 6.000 | 1.000 | 0.891 | 0.973 | 0.857 | 34.000 | 1.000 | 0.515 | 0.939 | 0.633 | 43.000 | 0.952 | 0.705 | 0.867 | 0.650 |
| google/gemini-3.1-pro-preview | agent | 9.000 | 0.250 | 0.500 | 0.167 | 0.250 | 151.000 | 0.038 | 0.200 | 0.000 | 0.000 | 574.000 | 0.013 | 0.152 | 0.007 | 0.095 | 476.000 | 0.024 | 0.151 | 0.014 | 0.122 | 2582.000 | 0.010 | 0.223 | 0.004 | 0.100 |
| gpt-5.4-pro | agent | 4.000 | 1.000 | 0.500 | 0.000 | 0.000 | 8.000 | 1.000 | 0.880 | 1.000 | 0.385 | 4.000 | 0.978 | 0.957 | 0.951 | 0.929 | 39.000 | 0.846 | 0.833 | 0.563 | 0.735 | 49.000 | 0.800 | 0.821 | 0.726 | 0.770 |