Agent Conversation: anthropic/claude-opus-4.6
Task: causal-discovery-discrete
Causal Discovery on Real-World Bayesian Network Datasets (bnlearn)
Research Question
Design a causal discovery algorithm that recovers the CPDAG (Completed Partially Directed Acyclic Graph) from purely observational discrete data sampled from real-world Bayesian networks in the bnlearn repository.
Background
The bnlearn repository (https://www.bnlearn.com/bnrepository/) hosts a collection of well-known Bayesian network benchmarks from diverse domains (medicine, biology, meteorology, insurance, agriculture, IT). Each network has a known ground-truth DAG with discrete variables and conditional probability tables. Given observational samples from these networks, the task is to recover the causal structure.
Under the faithfulness assumption, observational data can identify the Markov Equivalence Class (MEC) of the true DAG, represented by a CPDAG. The challenge lies in handling discrete data with varying cardinalities, network sizes (5–76 nodes), and edge densities.
Task
Implement a causal discovery algorithm in bench/custom_algorithm.py. Your run_causal_discovery(X) function receives integer-encoded discrete observational data and must return the estimated CPDAG as a causallearn.graph.GeneralGraph.GeneralGraph object.
Interface
def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
"""
Input: X of shape (n_samples, n_variables), integer-encoded discrete data
Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph
"""
Evaluation Scenarios
Small Networks (<20 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Cancer | Cancer | 5 | 4 | 500 | Medical |
| Earthquake | Earthquake | 5 | 4 | 500 | Seismology |
| Survey | Survey | 6 | 6 | 500 | Social science |
| Asia | Asia | 8 | 8 | 1000 | Medical (lung diseases) |
| Sachs | Sachs | 11 | 17 | 1000 | Biology (protein signaling) |
Medium Networks (20–50 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Child | Child | 20 | 25 | 2000 | Medical |
| Insurance | Insurance | 27 | 52 | 5000 | Automotive insurance |
| Water | Water | 32 | 66 | 5000 | Water treatment |
| Mildew | Mildew | 35 | 46 | 5000 | Agriculture (crop disease) |
| Alarm | Alarm | 37 | 46 | 5000 | Medical monitoring |
| Barley | Barley | 48 | 84 | 10000 | Agriculture |
Large Networks (50–100 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Hailfinder | Hailfinder | 56 | 66 | 10000 | Meteorology |
| Hepar2 | Hepar2 | 70 | 123 | 10000 | Medical (liver disorders) |
| Win95pts | Win95pts | 76 | 112 | 10000 | IT (Windows troubleshooting) |
Metrics
Metrics are computed between estimated CPDAG and ground-truth CPDAG (converted from the true DAG via dag2cpdag):
- SHD (Structural Hamming Distance): total edge err
Tool Input
1{
2 "op": "replace",
3 "filename": "causal-bnlearn/bench/custom_algorithm.py",
4 "start_line": 3,
5 "end_line": 14,
6 "content": "from causallearn.graph.GraphNode import GraphNode\nfrom causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n n_samples, n_vars = X.shape\n if n_vars <= 12:\n from causallearn.search.PermutationBased.BOSS import boss; return boss(X, score_func=\"local_score_BDeu\", parameters={\"equivalent_sample_size\": 10})\n G = ges(X, score_func=\"local_score_BDeu\", parameters={\"equivalent_sample_size\": 10})['G']\n nmap = {n.get_name(): int(n.get_name()[1:]) - 1 for n in G.get_nodes()}; cit_test = CIT(X, \"chisq\")\n for e in list(G.get_graph_edges()):\n i, j = nmap[e.get_node1().get_name()], nmap[e.get_node2().get_name()]; S = [nmap[n.get_name()] for n in G.get_adjacent_nodes(e.get_node1()) if n != e.get_node2()][:3]\n if cit_test(i, j, S) > 0.05: G.remove_edge(e)\n return G"
7}OK: Replaced lines 3..14 with 12 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–14.
[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–14 | total: 15 lines]
1| import numpy as np
2| from causallearn.graph.GeneralGraph import GeneralGraph
3 from causallearn.graph.GraphNode import GraphNode
4 from causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT
5 def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
6 n_samples, n_vars = X.shape
7 if n_vars <= 12:
8 from causallearn.search.PermutationBased.BOSS import boss; return boss(X, score_func="local_score_BDeu", parameters={"equivalent_sample_size": 10})
9 G = ges(X, score_func="local_score_BDeu", parameters={"equivalent_sample_size": 10})['G']
10 nmap = {n.get_name(): int(n.get_name()[1:]) - 1 for n in G.get_nodes()}; cit_test = CIT(X, "chisq")
11 for e in list(G.get_graph_edg
Tool Input
1{
2 "op": "replace",
3 "filename": "causal-bnlearn/bench/custom_algorithm.py",
4 "start_line": 3,
5 "end_line": 14,
6 "content": "from causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n n, p = X.shape; G = ges(X, score_func=\"local_score_BDeu\")['G']\n cit = CIT(X, \"gsq\"); nds = G.get_nodes(); nm = {nds[k]: k for k in range(len(nds))}\n cards = [len(np.unique(X[:, k])) for k in range(p)]\n for edge in list(G.get_graph_edges()):\n i, j = nm[edge.get_node1()], nm[edge.get_node2()]\n nbrs = sorted((set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node1())) | set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node2()))) - {i, j})\n tests = [[]] + [[k] for k in nbrs] + ([nbrs] if len(nbrs) > 1 else [])\n ok = [t for t in tests if cards[i]*cards[j]*int(np.prod([cards[k] for k in t]+[1])) <= n//5]\n if any(cit(i, j, t) > 0.01 for t in ok): G.remove_edge(edge)\n return G"
7}OK: Replaced lines 3..14 with 12 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–14.
[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–14 | total: 15 lines]
1| import numpy as np
2| from causallearn.graph.GeneralGraph import GeneralGraph
3 from causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT
4 def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
5 n, p = X.shape; G = ges(X, score_func="local_score_BDeu")['G']
6 cit = CIT(X, "gsq"); nds = G.get_nodes(); nm = {nds[k]: k for k in range(len(nds))}
7 cards = [len(np.unique(X[:, k])) for k in range(p)]
8 for edge in list(G.get_graph_edges()):
9 i, j = nm[edge.get_node1()], nm[edge.get_node2()]
10 nbrs = sorted((set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node1())) | set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node2()))) - {i, j})
11 tests = [[]]
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### Cancer (scripts/eval_cancer.sh) Results (Cancer): SHD=4 AdjP=1.0000 AdjR=0.5000 ArrowP=0.0000 ArrowR=0.0000 ### Earthquake (scripts/eval_earthquake.sh) Results (Earthquake): SHD=3 AdjP=0.6667 AdjR=0.5000 ArrowP=0.6667 ArrowR=0.5000 ### Survey (scripts/eval_survey.sh) Results (Survey): SHD=6 AdjP=1.0000 AdjR=0.1667 ArrowP=0.0000 ArrowR=0.0000 ### Asia (scripts/eval_asia.sh) Results (Asia): SHD=3 AdjP=1.0000 AdjR=0.6250 ArrowP=1.0000 ArrowR=0.6000 ### Sachs (scripts/eval_sachs.sh) Results (Sachs): SHD=14 AdjP=0.9333 AdjR=0.8235 ArrowP=0.0000 ArrowR=0.0000 ### Child (scripts/eval_child.sh) Results (Child): SHD=5 AdjP=1.0000 AdjR=0.9200 ArrowP=1.0000 ArrowR=0.6154 ### Insurance (scripts/eval_insurance.sh) Results (Insurance): SHD=25 AdjP=1.0000 AdjR=0.7692 ArrowP=0.6286 ArrowR=0.6471 ### Water (scripts/eval_water.sh) Results (Water): SHD=48 AdjP=0.9643 AdjR=0.4091 ArrowP=0.6957
Tool Input
1{
2 "op": "replace",
3 "filename": "causal-bnlearn/bench/custom_algorithm.py",
4 "start_line": 3,
5 "end_line": 14,
6 "content": "from causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n n, p = X.shape; G = ges(X, score_func=\"local_score_BDeu\")['G']\n if p <= 20: return G\n cit = CIT(X, \"gsq\"); nds = G.get_nodes(); nm = {nds[k]: k for k in range(len(nds))}; cards = [len(np.unique(X[:, k])) for k in range(p)]\n for edge in list(G.get_graph_edges()):\n i, j = nm[edge.get_node1()], nm[edge.get_node2()]\n for nd in [edge.get_node1(), edge.get_node2()]:\n cond = sorted([nm[nb] for nb in G.get_adjacent_nodes(nd) if nm[nb] not in (i, j)])\n cells = cards[i]*cards[j]*int(np.prod([cards[k] for k in cond]+[1]))\n if cond and cells <= n//5 and cit(i, j, cond) > 0.01: G.remove_edge(edge); break\n return G"
7}OK: Replaced lines 3..14 with 12 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–14.
[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–14 | total: 15 lines]
1| import numpy as np
2| from causallearn.graph.GeneralGraph import GeneralGraph
3 from causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT
4 def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
5 n, p = X.shape; G = ges(X, score_func="local_score_BDeu")['G']
6 if p <= 20: return G
7 cit = CIT(X, "gsq"); nds = G.get_nodes(); nm = {nds[k]: k for k in range(len(nds))}; cards = [len(np.unique(X[:, k])) for k in range(p)]
8 for edge in list(G.get_graph_edges()):
9 i, j = nm[edge.get_node1()], nm[edge.get_node2()]
10 for nd in [edge.get_node1(), edge.get_node2()]:
11 cond = sorted([nm[nb] for nb in G.get_adjacent_nodes(nd) if nm[nb] not in (i, j)])
1
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### Cancer (scripts/eval_cancer.sh) Results (Cancer): SHD=4 AdjP=1.0000 AdjR=0.5000 ArrowP=0.0000 ArrowR=0.0000 ### Earthquake (scripts/eval_earthquake.sh) Results (Earthquake): SHD=1 AdjP=0.8000 AdjR=1.0000 ArrowP=0.8000 ArrowR=1.0000 ### Survey (scripts/eval_survey.sh) Results (Survey): SHD=6 AdjP=1.0000 AdjR=0.1667 ArrowP=0.0000 ArrowR=0.0000 ### Asia (scripts/eval_asia.sh) Results (Asia): SHD=1 AdjP=1.0000 AdjR=0.8750 ArrowP=1.0000 ArrowR=1.0000 ### Sachs (scripts/eval_sachs.sh) Results (Sachs): SHD=14 AdjP=0.9333 AdjR=0.8235 ArrowP=0.0000 ArrowR=0.0000 ### Child (scripts/eval_child.sh) Results (Child): SHD=5 AdjP=1.0000 AdjR=0.9200 ArrowP=1.0000 ArrowR=0.6154 ### Insurance (scripts/eval_insurance.sh) Results (Insurance): SHD=27 AdjP=0.9524 AdjR=0.7692 ArrowP=0.6286 ArrowR=0.6471 ### Water (scripts/eval_water.sh) Results (Water): SHD=49 AdjP=0.9630 AdjR=0.3939 ArrowP=0.6818
Tool Input
1{
2 "op": "replace",
3 "filename": "causal-bnlearn/bench/custom_algorithm.py",
4 "start_line": 3,
5 "end_line": 14,
6 "content": "from causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n n, p = X.shape; G = ges(X, score_func=\"local_score_BDeu\")['G']\n if p <= 20: return G\n cit = CIT(X, \"gsq\"); nds = G.get_nodes(); nm = {nds[k]: k for k in range(len(nds))}; cards = [len(np.unique(X[:, k])) for k in range(p)]\n for edge in list(G.get_graph_edges()):\n i, j = nm[edge.get_node1()], nm[edge.get_node2()]\n nbrs = sorted((set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node1())) | set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node2()))) - {i, j})\n tests = [[]] + [[k] for k in nbrs] + ([nbrs] if len(nbrs) > 1 else [])\n ok = [t for t in tests if cards[i]*cards[j]*int(np.prod([cards[k] for k in t]+[1])) <= n//5]\n if any(cit(i, j, t) > 0.01 for t in ok): G.remove_edge(edge)\n return G"
7}OK: Replaced lines 3..14 with 12 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–14.
[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–14 | total: 15 lines]
1| import numpy as np
2| from causallearn.graph.GeneralGraph import GeneralGraph
3 from causallearn.search.ScoreBased.GES import ges; from causallearn.utils.cit import CIT
4 def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
5 n, p = X.shape; G = ges(X, score_func="local_score_BDeu")['G']
6 if p <= 20: return G
7 cit = CIT(X, "gsq"); nds = G.get_nodes(); nm = {nds[k]: k for k in range(len(nds))}; cards = [len(np.unique(X[:, k])) for k in range(p)]
8 for edge in list(G.get_graph_edges()):
9 i, j = nm[edge.get_node1()], nm[edge.get_node2()]
10 nbrs = sorted((set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node1())) | set(nm[nb] for nb in G.get_adjacent_nodes(edge.get_node2()))) - {i, j})
11
[Test #3] (0 tests remaining) ### Cancer (scripts/eval_cancer.sh) Results (Cancer): SHD=4 AdjP=1.0000 AdjR=0.5000 ArrowP=0.0000 ArrowR=0.0000 ### Earthquake (scripts/eval_earthquake.sh) Results (Earthquake): SHD=1 AdjP=0.8000 AdjR=1.0000 ArrowP=0.8000 ArrowR=1.0000 ### Survey (scripts/eval_survey.sh) Results (Survey): SHD=6 AdjP=1.0000 AdjR=0.1667 ArrowP=0.0000 ArrowR=0.0000 ### Asia (scripts/eval_asia.sh) Results (Asia): SHD=1 AdjP=1.0000 AdjR=0.8750 ArrowP=1.0000 ArrowR=1.0000 ### Sachs (scripts/eval_sachs.sh) Results (Sachs): SHD=14 AdjP=0.9333 AdjR=0.8235 ArrowP=0.0000 ArrowR=0.0000 ### Child (scripts/eval_child.sh) Results (Child): SHD=5 AdjP=1.0000 AdjR=0.9200 ArrowP=1.0000 ArrowR=0.6154 ### Insurance (scripts/eval_insurance.sh) Results (Insurance): SHD=25 AdjP=1.0000 AdjR=0.7692 ArrowP=0.6286 ArrowR=0.6471 ### Water (scripts/eval_water.sh) Results (Water): SHD=48 AdjP=0.9643 AdjR=0.4091 ArrowP=0.6957 ArrowR=0.2667 ### Mildew (scripts/eval