Agent Conversation: gpt-5.4-pro
Task: causal-discovery-discrete
Causal Discovery on Real-World Bayesian Network Datasets (bnlearn)
Research Question
Design a causal discovery algorithm that recovers the CPDAG (Completed Partially Directed Acyclic Graph) from purely observational discrete data sampled from real-world Bayesian networks in the bnlearn repository.
Background
The bnlearn repository (https://www.bnlearn.com/bnrepository/) hosts a collection of well-known Bayesian network benchmarks from diverse domains (medicine, biology, meteorology, insurance, agriculture, IT). Each network has a known ground-truth DAG with discrete variables and conditional probability tables. Given observational samples from these networks, the task is to recover the causal structure.
Under the faithfulness assumption, observational data can identify the Markov Equivalence Class (MEC) of the true DAG, represented by a CPDAG. The challenge lies in handling discrete data with varying cardinalities, network sizes (5–76 nodes), and edge densities.
Task
Implement a causal discovery algorithm in bench/custom_algorithm.py. Your run_causal_discovery(X) function receives integer-encoded discrete observational data and must return the estimated CPDAG as a causallearn.graph.GeneralGraph.GeneralGraph object.
Interface
def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
"""
Input: X of shape (n_samples, n_variables), integer-encoded discrete data
Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph
"""
Evaluation Scenarios
Small Networks (<20 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Cancer | Cancer | 5 | 4 | 500 | Medical |
| Earthquake | Earthquake | 5 | 4 | 500 | Seismology |
| Survey | Survey | 6 | 6 | 500 | Social science |
| Asia | Asia | 8 | 8 | 1000 | Medical (lung diseases) |
| Sachs | Sachs | 11 | 17 | 1000 | Biology (protein signaling) |
Medium Networks (20–50 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Child | Child | 20 | 25 | 2000 | Medical |
| Insurance | Insurance | 27 | 52 | 5000 | Automotive insurance |
| Water | Water | 32 | 66 | 5000 | Water treatment |
| Mildew | Mildew | 35 | 46 | 5000 | Agriculture (crop disease) |
| Alarm | Alarm | 37 | 46 | 5000 | Medical monitoring |
| Barley | Barley | 48 | 84 | 10000 | Agriculture |
Large Networks (50–100 nodes)
| Label | Network | Nodes | Edges | Samples | Domain |
|---|---|---|---|---|---|
| Hailfinder | Hailfinder | 56 | 66 | 10000 | Meteorology |
| Hepar2 | Hepar2 | 70 | 123 | 10000 | Medical (liver disorders) |
| Win95pts | Win95pts | 76 | 112 | 10000 | IT (Windows troubleshooting) |
Metrics
Metrics are computed between estimated CPDAG and ground-truth CPDAG (converted from the true DAG via dag2cpdag):
- SHD (Structural Hamming Distance): total edge err
Tool Input
1{
2 "op": "replace",
3 "filename": "causal-bnlearn/bench/custom_algorithm.py",
4 "content": "import numpy as np\nfrom causallearn.graph.GeneralGraph import GeneralGraph\nfrom causallearn.graph.GraphNode import GraphNode\n\n# =====================================================================\n# EDITABLE: implement run_causal_discovery below\n# =====================================================================\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n \"\"\"Adaptive blanket-validated hybrid causal discovery for discrete data.\"\"\"\n from causallearn.search.ScoreBased.GES import ges\n from causallearn.search.ConstraintBased.PC import pc\n n, p = X.shape\n nodes = [GraphNode(f\"X{i + 1}\") for i in range(p)]\n cards = np.array([len(np.unique(X[:, i])) for i in range(p)], dtype=int)\n rs = np.random.RandomState(0)\n idx = rs.permutation(n)\n cut = max(1, n - max(64, n // 5))\n tr, va = X[idx[:cut]], X[idx[cut:]]\n if len(va) == 0:\n tr, va = X[:-1], X[-1:]\n\n def mi(a, b, ka, kb):\n c = np.bincount((a.astype(int) * kb + b.astype(int)), minlength=ka * kb).reshape(ka, kb).astype(float)\n nn, pa, pb = c.sum(), c.sum(1, keepdims=True), c.sum(0, keepdims=True)\n ex = pa @ pb\n nz = c > 0\n return float((c[nz] / nn * (np.log(c[nz] * nn) - np.log(ex[nz]))).sum())\n\n MI = np.zeros((p, p))\n for i in range(p):\n for j in range(i + 1, p):\n MI[i, j] = MI[j, i] = mi(tr[:, i], tr[:, j], int(cards[i]), int(cards[j]))\n marg = [np.bincount(tr[:, i].astype(int), minlength=int(cards[i])).astype(float) for i in range(p)]\n state_cap = int(max(24, min(256, 2 * np.sqrt(len(tr)))))\n\n def norm_graph(g):\n return g[\"G\"] if isinstance(g, dict) else getattr(g, \"G\", g)\n\n def edge_count(G):\n A = np.asarray(G.graph)\n S = ((A != 0) | (A.T != 0))\n return int(np.count_nonzero(np.triu(S, 1)))\n\n def score_graph(G):\n A = np.asarray(G.graph)\n S = ((A != 0) | (A.T != 0))\n ll = 0.0\n for i in range(p):\n nbr = np.where(S[i])[0]\n if len(nbr):\n order = nbr[np.argsort(MI[i, nbr])[::-1]]\n sel, prod = [], 1\n for j in order:\n if len(sel) >= 4:\n break\n cj = int(cards[j])\n if prod * cj <= state_cap or not sel:\n sel.append(int(j))\n prod *= cj\n nbr = np.array(sel, dtype=int)\n base = marg[i]\n if len(nbr) == 0:\n prob = (base + 0.5) / (base.sum() + 0.5 * cards[i])\n ll += np.log(prob[va[:, i].astype(int)]).sum()\n continue\n mult = np.cumprod(np.r_[1, cards[nbr[:-1]]]).astype(int)\n ktr = (tr[:, nbr].astype(int) * mult).sum(1)\n kva = (va[:, nbr].astype(int) * mult).sum(1)\n R = int(np.prod(cards[nbr]))\n joint = np.bincount(ktr * int(cards[i]) + tr[:, i].astype(int), minlength=R * int(cards[i])).reshape(R, int(cards[i])).astype(float)\n den = joint.sum(1)\n num = joint[kva, va[:, i].astype(int)]\n seen = den[kva] > 0\n prob = np.empty(len(va), float)\n prob[seen] = (num[seen] + 0.5) / (den[kva][seen] + 0.5 * cards[i])\n prob[~seen] = (base[va[:, i].astype(int)][~seen] + 0.5) / (base.sum() + 0.5 * cards[i])\n ll += np.log(prob).sum()\n return ll / max(1, len(va)) - 0.02 * edge_count(G) / max(1, p)\n\n avg_card = float(cards.mean())\n cand_specs = [(\"ges\", lambda: norm_graph(ges(X, score_func=\"local_score_BDeu\")))]\n if p <= 14:\n from causallearn.search.PermutationBased.BOSS import boss\n cand_specs.append((\"boss\", lambda: norm_graph(boss(X, score_func=\"local_score_BDeu\"))))\n if p <= 35:\n from causallearn.search.PermutationBased.GRaSP import grasp\n cand_specs.append((\"grasp\", lambda: norm_graph(grasp(X, score_func=\"local_score_BDeu\"))))\n pc_tests = [\"chisq\"] + ([\"gsq\"] if (avg_card > 3 or p >= 30) else [])\n for test_name in pc_tests:\n cand_specs.append((f\"pc_{test_name}\", lambda t=test_name: norm_graph(pc(X, alpha=0.05, indep_test=t, stable=True, show_progress=False))))\n\n scored = []\n for name, fn in cand_specs:\n try:\n G = fn()\n scored.append((score_graph(G), -edge_count(G), name, G))\n except Exception:\n pass\n if not scored:\n try:\n return norm_graph(ges(X, score_func=\"local_score_BDeu\"))\n except Exception:\n return GeneralGraph(nodes)\n scored.sort(reverse=True)\n best = scored[0]\n for cand in scored[1:]:\n if best[0] - cand[0] < 0.03 and cand[1] > best[1]:\n best = cand\n return best[-1]\n# =====================================================================\n",
5 "after_line": 14,
6 "start_line": 3,
7 "end_line": 14
8}OK: Replaced lines 3..14 with 110 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–112.
[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–112 | total: 113 lines]
1| import numpy as np
2| from causallearn.graph.GeneralGraph import GeneralGraph
3 import numpy as np
4 from causallearn.graph.GeneralGraph import GeneralGraph
5 from causallearn.graph.GraphNode import GraphNode
6
7 # =====================================================================
8 # EDITABLE: implement run_causal_discovery below
9 # =====================================================================
10 def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
11 """Adaptive blanket-validated hybrid causal discovery for discrete data."""
12 from causallearn.search.ScoreBased.GES import ges
13 from causallearn.search.ConstraintBased.PC import pc
14 n, p = X.shape
15 nodes = [G
[Test #1] (2 tests remaining; call submit(n=1) to submit this result) ### Cancer (scripts/eval_cancer.sh) Results (Cancer): SHD=3 AdjP=1.0000 AdjR=0.5000 ArrowP=0.5000 ArrowR=0.2500 ### Earthquake (scripts/eval_earthquake.sh) Results (Earthquake): SHD=1 AdjP=0.8000 AdjR=1.0000 ArrowP=0.8000 ArrowR=1.0000 ### Survey (scripts/eval_survey.sh) Results (Survey): SHD=7 AdjP=0.6667 AdjR=0.3333 ArrowP=0.0000 ArrowR=0.0000 ### Asia (scripts/eval_asia.sh) Results (Asia): SHD=1 AdjP=1.0000 AdjR=0.8750 ArrowP=1.0000 ArrowR=1.0000 ### Sachs (scripts/eval_sachs.sh) Results (Sachs): SHD=14 AdjP=0.9333 AdjR=0.8235 ArrowP=0.0000 ArrowR=0.0000 ### Child (scripts/eval_child.sh) Results (Child): SHD=11 AdjP=1.0000 AdjR=0.8800 ArrowP=0.5556 ArrowR=0.7692 ### Insurance (scripts/eval_insurance.sh) Results (Insurance): SHD=27 AdjP=0.9000 AdjR=0.6923 ArrowP=0.7059 ArrowR=0.7059 ### Water (scripts/eval_water.sh) Results (Water): SHD=50 AdjP=0.9630 AdjR=0.3939 ArrowP=0.785
Tool Input
1{
2 "op": "replace",
3 "filename": "causal-bnlearn/bench/custom_algorithm.py",
4 "content": "import numpy as np\nfrom causallearn.graph.GeneralGraph import GeneralGraph\nfrom causallearn.graph.GraphNode import GraphNode\n\n# =====================================================================\n# EDITABLE: implement run_causal_discovery below\n# =====================================================================\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n \"\"\"Direction-aware consensus selection over discrete causal discovery candidates.\"\"\"\n from causallearn.search.ScoreBased.GES import ges\n from causallearn.search.ConstraintBased.PC import pc\n\n n, p = X.shape\n nodes = [GraphNode(f\"X{i + 1}\") for i in range(p)]\n cards = np.array([len(np.unique(X[:, i])) for i in range(p)], dtype=int)\n rs = np.random.RandomState(0)\n idx = rs.permutation(n)\n cut = max(1, n - max(64, n // 5))\n tr, va = X[idx[:cut]], X[idx[cut:]]\n if len(va) == 0:\n tr, va = X[:-1], X[-1:]\n\n def norm_graph(g):\n return g[\"G\"] if isinstance(g, dict) else getattr(g, \"G\", g)\n\n def mi(a, b, ka, kb):\n c = np.bincount(a.astype(int) * kb + b.astype(int), minlength=ka * kb).reshape(ka, kb).astype(float)\n nn, pa, pb = c.sum(), c.sum(1, keepdims=True), c.sum(0, keepdims=True)\n ex = pa @ pb\n nz = c > 0\n return float((c[nz] / nn * (np.log(c[nz] * nn) - np.log(ex[nz]))).sum())\n\n MI = np.zeros((p, p), float)\n for i in range(p):\n for j in range(i + 1, p):\n MI[i, j] = MI[j, i] = mi(tr[:, i], tr[:, j], int(cards[i]), int(cards[j]))\n marg = [np.bincount(tr[:, i].astype(int), minlength=int(cards[i])).astype(float) for i in range(p)]\n state_cap = int(max(24, min(320, 3 * np.sqrt(len(tr)))))\n ll_cache = {}\n\n def masks(G):\n A = np.asarray(G.graph)\n D = (A == -1) & (A.T == 1) # i -> j\n S = (A != 0) | (A.T != 0)\n U = S & ~D & ~D.T\n return S, D, U\n\n def edge_count(S):\n return int(np.count_nonzero(np.triu(S, 1)))\n\n def dir_count(D):\n return int(D.sum())\n\n def pick_cond(i, parents=(), neighbors=()):\n used, sel, prod = set(), [], 1\n for group, quota in ((parents, 3), (neighbors, 2)):\n order = sorted({int(j) for j in group if j != i and int(j) not in used}, key=lambda j: MI[i, j], reverse=True)\n taken = 0\n for j in order:\n cj = int(cards[j])\n if len(sel) >= 5:\n break\n if prod * cj <= state_cap or not sel or (group is parents and taken < quota):\n sel.append(j)\n used.add(j)\n prod *= cj\n taken += 1\n return tuple(sorted(sel))\n\n def node_ll(i, cond):\n key = (int(i), tuple(sorted(int(j) for j in cond if j != i)))\n if key in ll_cache:\n return ll_cache[key]\n base = marg[i]\n if len(key[1]) == 0:\n prob = (base + 0.5) / (base.sum() + 0.5 * cards[i])\n val = float(np.log(prob[va[:, i].astype(int)]).sum() / max(1, len(va)))\n else:\n c = np.array(key[1], dtype=int)\n mult = np.cumprod(np.r_[1, cards[c[:-1]]]).astype(int)\n ktr = (tr[:, c].astype(int) * mult).sum(1)\n kva = (va[:, c].astype(int) * mult).sum(1)\n R = int(np.prod(cards[c]))\n joint = np.bincount(ktr * int(cards[i]) + tr[:, i].astype(int), minlength=R * int(cards[i])).reshape(R, int(cards[i])).astype(float)\n den = joint.sum(1)\n num = joint[kva, va[:, i].astype(int)]\n seen = den[kva] > 0\n prob = np.empty(len(va), float)\n prob[seen] = (num[seen] + 0.5) / (den[kva][seen] + 0.5 * cards[i])\n prob[~seen] = (base[va[:, i].astype(int)][~seen] + 0.5) / (base.sum() + 0.5 * cards[i])\n val = float(np.log(prob).sum() / max(1, len(va)))\n ll_cache[key] = val\n return val\n\n def factor_score(S, D):\n U = S & ~D & ~D.T\n total = 0.0\n for i in range(p):\n pa = np.where(D[:, i])[0]\n un = np.where(U[:, i])[0]\n total += node_ll(i, pick_cond(i, pa, un))\n bonus = 0.0\n ii, jj = np.where(D)\n for a, b in zip(ii.tolist(), jj.tolist()):\n pa_b = np.where(D[:, b])[0]\n un_b = np.where(U[:, b])[0]\n gain = node_ll(b, pick_cond(b, pa_b, un_b)) - node_ll(b, pick_cond(b, [k for k in pa_b if k != a], un_b))\n pa_a = np.where(D[:, a])[0]\n un_a = np.where(U[:, a])[0]\n rev = node_ll(a, pick_cond(a, pa_a, un_a)) - node_ll(a, pick_cond(a, [*pa_a, b], un_a))\n bonus += max(-0.02, min(0.02, gain + rev))\n return total + 0.35 * bonus - 0.015 * edge_count(S) / max(1, p)\n\n avg_card = float(cards.mean())\n cand_specs = [(\"ges\", lambda: norm_graph(ges(X, score_func=\"local_score_BDeu\")))]\n if p <= 28:\n from causallearn.search.PermutationBased.BOSS import boss\n cand_specs.append((\"boss\", lambda: norm_graph(boss(X, score_func=\"local_score_BDeu\"))))\n if p <= 30:\n from causallearn.search.PermutationBased.GRaSP import grasp\n cand_specs.append((\"grasp\", lambda: norm_graph(grasp(X, score_func=\"local_score_BDeu\"))))\n pc_tests = [\"chisq\", \"gsq\"] if p <= 15 else ([\"gsq\"] if (avg_card > 2.5 or p >= 30) else [\"chisq\"])\n for test_name in pc_tests:\n cand_specs.append((f\"pc_{test_name}\", lambda t=test_name: norm_graph(pc(X, alpha=0.05, indep_test=t, stable=True, show_progress=False))))\n\n cands = []\n for name, fn in cand_specs:\n try:\n G = fn()\n S, D, _ = masks(G)\n pred = max(factor_score(S, D), factor_score(S, D.T))\n cands.append({\"name\": name, \"G\": G, \"S\": S, \"D\": D, \"pred\": pred, \"e\": edge_count(S), \"d\": dir_count(D)})\n except Exception:\n pass\n if not cands:\n try:\n return norm_graph(ges(X, score_func=\"local_score_BDeu\"))\n except Exception:\n return GeneralGraph(nodes)\n\n sup = sum(c[\"S\"].astype(float) for c in cands) / len(cands)\n dsup = sum(c[\"D\"].astype(float) for c in cands) / len(cands)\n for c in cands:\n agr = float(sup[c[\"S\"]].mean()) if c[\"e\"] else 0.0\n dagr = float(dsup[c[\"D\"]].mean()) if c[\"d\"] else 0.0\n decis = c[\"d\"] / max(1, c[\"e\"])\n bias = 0.008 * decis if (p <= 30 and c[\"name\"] in {\"boss\", \"grasp\", \"ges\"}) else 0.0\n c[\"score\"] = c[\"pred\"] + (0.035 if p <= 35 else 0.02) * agr + 0.015 * dagr + bias\n\n cands.sort(key=lambda c: (c[\"score\"], c[\"pred\"], -c[\"e\"]), reverse=True)\n best = cands[0]\n if p <= 30:\n perms = [c for c in cands if c[\"name\"] in {\"boss\", \"grasp\"}]\n if perms:\n alt = max(perms, key=lambda c: c[\"score\"])\n if best[\"score\"] - alt[\"score\"] < 0.015 and alt[\"d\"] >= best[\"d\"]:\n best = alt\n return best[\"G\"]\n# =====================================================================\n",
5 "after_line": 112,
6 "start_line": 3,
7 "end_line": 112
8}OK: Replaced lines 3..112 with 159 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–161.
[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–161 | total: 162 lines]
1| import numpy as np
2| from causallearn.graph.GeneralGraph import GeneralGraph
3 import numpy as np
4 from causallearn.graph.GeneralGraph import GeneralGraph
5 from causallearn.graph.GraphNode import GraphNode
6
7 # =====================================================================
8 # EDITABLE: implement run_causal_discovery below
9 # =====================================================================
10 def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
11 """Direction-aware consensus selection over discrete causal discovery candidates."""
12 from causallearn.search.ScoreBased.GES import ges
13 from causallearn.search.ConstraintBased.PC import pc
14
15 n, p = X.shape
[Test #2] (1 test remaining; call submit(n=2) to submit this result) ### Cancer (scripts/eval_cancer.sh) Results (Cancer): SHD=4 AdjP=1.0000 AdjR=0.5000 ArrowP=0.0000 ArrowR=0.0000 ### Earthquake (scripts/eval_earthquake.sh) Results (Earthquake): SHD=2 AdjP=0.7500 AdjR=0.7500 ArrowP=1.0000 ArrowR=0.7500 ### Survey (scripts/eval_survey.sh) Results (Survey): SHD=6 AdjP=1.0000 AdjR=0.1667 ArrowP=0.0000 ArrowR=0.0000 ### Asia (scripts/eval_asia.sh) Results (Asia): SHD=6 AdjP=0.7778 AdjR=0.8750 ArrowP=1.0000 ArrowR=0.4000 ### Sachs (scripts/eval_sachs.sh) Results (Sachs): SHD=1 AdjP=1.0000 AdjR=0.9412 ArrowP=0.0000 ArrowR=0.0000 ### Child (scripts/eval_child.sh) Results (Child): SHD=8 AdjP=1.0000 AdjR=0.8800 ArrowP=1.0000 ArrowR=0.3846 ### Insurance (scripts/eval_insurance.sh) Results (Insurance): SHD=31 AdjP=0.9500 AdjR=0.7308 ArrowP=0.7273 ArrowR=0.4706 ### Water (scripts/eval_water.sh) Results (Water): SHD=48 AdjP=1.0000 AdjR=0.4242 ArrowP=0.9231 A
Tool Input
1{
2 "n": 2
3}[submit] Submitting result from test #2 as final. ### Cancer (scripts/eval_cancer.sh) Results (Cancer): SHD=4 AdjP=1.0000 AdjR=0.5000 ArrowP=0.0000 ArrowR=0.0000 ### Earthquake (scripts/eval_earthquake.sh) Results (Earthquake): SHD=2 AdjP=0.7500 AdjR=0.7500 ArrowP=1.0000 ArrowR=0.7500 ### Survey (scripts/eval_survey.sh) Results (Survey): SHD=6 AdjP=1.0000 AdjR=0.1667 ArrowP=0.0000 ArrowR=0.0000 ### Asia (scripts/eval_asia.sh) Results (Asia): SHD=6 AdjP=0.7778 AdjR=0.8750 ArrowP=1.0000 ArrowR=0.4000 ### Sachs (scripts/eval_sachs.sh) Results (Sachs): SHD=1 AdjP=1.0000 AdjR=0.9412 ArrowP=0.0000 ArrowR=0.0000 ### Child (scripts/eval_child.sh) Results (Child): SHD=8 AdjP=1.0000 AdjR=0.8800 ArrowP=1.0000 ArrowR=0.3846 ### Insurance (scripts/eval_insurance.sh) Results (Insurance): SHD=31 AdjP=0.9500 AdjR=0.7308 ArrowP=0.7273 ArrowR=0.4706 ### Water (scripts/eval_water.sh) Results (Water): SHD=48 AdjP=1.0000 AdjR=0.4242 ArrowP=0.9231 ArrowR=0.2000 ### M