Agent Conversation: gpt-5.4-pro

Model: gpt-5.4-pro
Total steps: 5
Messages: 11
UserStep 0

Task: causal-discovery-discrete

Causal Discovery on Real-World Bayesian Network Datasets (bnlearn)

Research Question

Design a causal discovery algorithm that recovers the CPDAG (Completed Partially Directed Acyclic Graph) from purely observational discrete data sampled from real-world Bayesian networks in the bnlearn repository.

Background

The bnlearn repository (https://www.bnlearn.com/bnrepository/) hosts a collection of well-known Bayesian network benchmarks from diverse domains (medicine, biology, meteorology, insurance, agriculture, IT). Each network has a known ground-truth DAG with discrete variables and conditional probability tables. Given observational samples from these networks, the task is to recover the causal structure.

Under the faithfulness assumption, observational data can identify the Markov Equivalence Class (MEC) of the true DAG, represented by a CPDAG. The challenge lies in handling discrete data with varying cardinalities, network sizes (5–76 nodes), and edge densities.

Task

Implement a causal discovery algorithm in bench/custom_algorithm.py. Your run_causal_discovery(X) function receives integer-encoded discrete observational data and must return the estimated CPDAG as a causallearn.graph.GeneralGraph.GeneralGraph object.

Interface

def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
    """
    Input:  X of shape (n_samples, n_variables), integer-encoded discrete data
    Output: estimated CPDAG as causallearn.graph.GeneralGraph.GeneralGraph
    """

Evaluation Scenarios

Small Networks (<20 nodes)

LabelNetworkNodesEdgesSamplesDomain
CancerCancer54500Medical
EarthquakeEarthquake54500Seismology
SurveySurvey66500Social science
AsiaAsia881000Medical (lung diseases)
SachsSachs11171000Biology (protein signaling)

Medium Networks (20–50 nodes)

LabelNetworkNodesEdgesSamplesDomain
ChildChild20252000Medical
InsuranceInsurance27525000Automotive insurance
WaterWater32665000Water treatment
MildewMildew35465000Agriculture (crop disease)
AlarmAlarm37465000Medical monitoring
BarleyBarley488410000Agriculture

Large Networks (50–100 nodes)

LabelNetworkNodesEdgesSamplesDomain
HailfinderHailfinder566610000Meteorology
Hepar2Hepar27012310000Medical (liver disorders)
Win95ptsWin95pts7611210000IT (Windows troubleshooting)

Metrics

Metrics are computed between estimated CPDAG and ground-truth CPDAG (converted from the true DAG via dag2cpdag):

  • SHD (Structural Hamming Distance): total edge err
AssistantStep 1edit
Tool Input
1{
2  "op": "replace",
3  "filename": "causal-bnlearn/bench/custom_algorithm.py",
4  "content": "import numpy as np\nfrom causallearn.graph.GeneralGraph import GeneralGraph\nfrom causallearn.graph.GraphNode import GraphNode\n\n# =====================================================================\n# EDITABLE: implement run_causal_discovery below\n# =====================================================================\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n    \"\"\"Adaptive blanket-validated hybrid causal discovery for discrete data.\"\"\"\n    from causallearn.search.ScoreBased.GES import ges\n    from causallearn.search.ConstraintBased.PC import pc\n    n, p = X.shape\n    nodes = [GraphNode(f\"X{i + 1}\") for i in range(p)]\n    cards = np.array([len(np.unique(X[:, i])) for i in range(p)], dtype=int)\n    rs = np.random.RandomState(0)\n    idx = rs.permutation(n)\n    cut = max(1, n - max(64, n // 5))\n    tr, va = X[idx[:cut]], X[idx[cut:]]\n    if len(va) == 0:\n        tr, va = X[:-1], X[-1:]\n\n    def mi(a, b, ka, kb):\n        c = np.bincount((a.astype(int) * kb + b.astype(int)), minlength=ka * kb).reshape(ka, kb).astype(float)\n        nn, pa, pb = c.sum(), c.sum(1, keepdims=True), c.sum(0, keepdims=True)\n        ex = pa @ pb\n        nz = c > 0\n        return float((c[nz] / nn * (np.log(c[nz] * nn) - np.log(ex[nz]))).sum())\n\n    MI = np.zeros((p, p))\n    for i in range(p):\n        for j in range(i + 1, p):\n            MI[i, j] = MI[j, i] = mi(tr[:, i], tr[:, j], int(cards[i]), int(cards[j]))\n    marg = [np.bincount(tr[:, i].astype(int), minlength=int(cards[i])).astype(float) for i in range(p)]\n    state_cap = int(max(24, min(256, 2 * np.sqrt(len(tr)))))\n\n    def norm_graph(g):\n        return g[\"G\"] if isinstance(g, dict) else getattr(g, \"G\", g)\n\n    def edge_count(G):\n        A = np.asarray(G.graph)\n        S = ((A != 0) | (A.T != 0))\n        return int(np.count_nonzero(np.triu(S, 1)))\n\n    def score_graph(G):\n        A = np.asarray(G.graph)\n        S = ((A != 0) | (A.T != 0))\n        ll = 0.0\n        for i in range(p):\n            nbr = np.where(S[i])[0]\n            if len(nbr):\n                order = nbr[np.argsort(MI[i, nbr])[::-1]]\n                sel, prod = [], 1\n                for j in order:\n                    if len(sel) >= 4:\n                        break\n                    cj = int(cards[j])\n                    if prod * cj <= state_cap or not sel:\n                        sel.append(int(j))\n                        prod *= cj\n                nbr = np.array(sel, dtype=int)\n            base = marg[i]\n            if len(nbr) == 0:\n                prob = (base + 0.5) / (base.sum() + 0.5 * cards[i])\n                ll += np.log(prob[va[:, i].astype(int)]).sum()\n                continue\n            mult = np.cumprod(np.r_[1, cards[nbr[:-1]]]).astype(int)\n            ktr = (tr[:, nbr].astype(int) * mult).sum(1)\n            kva = (va[:, nbr].astype(int) * mult).sum(1)\n            R = int(np.prod(cards[nbr]))\n            joint = np.bincount(ktr * int(cards[i]) + tr[:, i].astype(int), minlength=R * int(cards[i])).reshape(R, int(cards[i])).astype(float)\n            den = joint.sum(1)\n            num = joint[kva, va[:, i].astype(int)]\n            seen = den[kva] > 0\n            prob = np.empty(len(va), float)\n            prob[seen] = (num[seen] + 0.5) / (den[kva][seen] + 0.5 * cards[i])\n            prob[~seen] = (base[va[:, i].astype(int)][~seen] + 0.5) / (base.sum() + 0.5 * cards[i])\n            ll += np.log(prob).sum()\n        return ll / max(1, len(va)) - 0.02 * edge_count(G) / max(1, p)\n\n    avg_card = float(cards.mean())\n    cand_specs = [(\"ges\", lambda: norm_graph(ges(X, score_func=\"local_score_BDeu\")))]\n    if p <= 14:\n        from causallearn.search.PermutationBased.BOSS import boss\n        cand_specs.append((\"boss\", lambda: norm_graph(boss(X, score_func=\"local_score_BDeu\"))))\n    if p <= 35:\n        from causallearn.search.PermutationBased.GRaSP import grasp\n        cand_specs.append((\"grasp\", lambda: norm_graph(grasp(X, score_func=\"local_score_BDeu\"))))\n    pc_tests = [\"chisq\"] + ([\"gsq\"] if (avg_card > 3 or p >= 30) else [])\n    for test_name in pc_tests:\n        cand_specs.append((f\"pc_{test_name}\", lambda t=test_name: norm_graph(pc(X, alpha=0.05, indep_test=t, stable=True, show_progress=False))))\n\n    scored = []\n    for name, fn in cand_specs:\n        try:\n            G = fn()\n            scored.append((score_graph(G), -edge_count(G), name, G))\n        except Exception:\n            pass\n    if not scored:\n        try:\n            return norm_graph(ges(X, score_func=\"local_score_BDeu\"))\n        except Exception:\n            return GeneralGraph(nodes)\n    scored.sort(reverse=True)\n    best = scored[0]\n    for cand in scored[1:]:\n        if best[0] - cand[0] < 0.03 and cand[1] > best[1]:\n            best = cand\n    return best[-1]\n# =====================================================================\n",
5  "after_line": 14,
6  "start_line": 3,
7  "end_line": 14
8}
Tool ResultStep 1
OK: Replaced lines 3..14 with 110 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–112.

[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–112 | total: 113 lines]
     1| import numpy as np
     2| from causallearn.graph.GeneralGraph import GeneralGraph
     3  import numpy as np
     4  from causallearn.graph.GeneralGraph import GeneralGraph
     5  from causallearn.graph.GraphNode import GraphNode
     6  
     7  # =====================================================================
     8  # EDITABLE: implement run_causal_discovery below
     9  # =====================================================================
    10  def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
    11      """Adaptive blanket-validated hybrid causal discovery for discrete data."""
    12      from causallearn.search.ScoreBased.GES import ges
    13      from causallearn.search.ConstraintBased.PC import pc
    14      n, p = X.shape
    15      nodes = [G
AssistantStep 2test
Tool ResultStep 2
[Test #1] (2 tests remaining; call submit(n=1) to submit this result)

### Cancer (scripts/eval_cancer.sh)
Results (Cancer):
  SHD=3  AdjP=1.0000 AdjR=0.5000  ArrowP=0.5000 ArrowR=0.2500

### Earthquake (scripts/eval_earthquake.sh)
Results (Earthquake):
  SHD=1  AdjP=0.8000 AdjR=1.0000  ArrowP=0.8000 ArrowR=1.0000

### Survey (scripts/eval_survey.sh)
Results (Survey):
  SHD=7  AdjP=0.6667 AdjR=0.3333  ArrowP=0.0000 ArrowR=0.0000

### Asia (scripts/eval_asia.sh)
Results (Asia):
  SHD=1  AdjP=1.0000 AdjR=0.8750  ArrowP=1.0000 ArrowR=1.0000

### Sachs (scripts/eval_sachs.sh)
Results (Sachs):
  SHD=14  AdjP=0.9333 AdjR=0.8235  ArrowP=0.0000 ArrowR=0.0000

### Child (scripts/eval_child.sh)
Results (Child):
  SHD=11  AdjP=1.0000 AdjR=0.8800  ArrowP=0.5556 ArrowR=0.7692

### Insurance (scripts/eval_insurance.sh)
Results (Insurance):
  SHD=27  AdjP=0.9000 AdjR=0.6923  ArrowP=0.7059 ArrowR=0.7059

### Water (scripts/eval_water.sh)
Results (Water):
  SHD=50  AdjP=0.9630 AdjR=0.3939  ArrowP=0.785
AssistantStep 3edit
Tool Input
1{
2  "op": "replace",
3  "filename": "causal-bnlearn/bench/custom_algorithm.py",
4  "content": "import numpy as np\nfrom causallearn.graph.GeneralGraph import GeneralGraph\nfrom causallearn.graph.GraphNode import GraphNode\n\n# =====================================================================\n# EDITABLE: implement run_causal_discovery below\n# =====================================================================\ndef run_causal_discovery(X: np.ndarray) -> GeneralGraph:\n    \"\"\"Direction-aware consensus selection over discrete causal discovery candidates.\"\"\"\n    from causallearn.search.ScoreBased.GES import ges\n    from causallearn.search.ConstraintBased.PC import pc\n\n    n, p = X.shape\n    nodes = [GraphNode(f\"X{i + 1}\") for i in range(p)]\n    cards = np.array([len(np.unique(X[:, i])) for i in range(p)], dtype=int)\n    rs = np.random.RandomState(0)\n    idx = rs.permutation(n)\n    cut = max(1, n - max(64, n // 5))\n    tr, va = X[idx[:cut]], X[idx[cut:]]\n    if len(va) == 0:\n        tr, va = X[:-1], X[-1:]\n\n    def norm_graph(g):\n        return g[\"G\"] if isinstance(g, dict) else getattr(g, \"G\", g)\n\n    def mi(a, b, ka, kb):\n        c = np.bincount(a.astype(int) * kb + b.astype(int), minlength=ka * kb).reshape(ka, kb).astype(float)\n        nn, pa, pb = c.sum(), c.sum(1, keepdims=True), c.sum(0, keepdims=True)\n        ex = pa @ pb\n        nz = c > 0\n        return float((c[nz] / nn * (np.log(c[nz] * nn) - np.log(ex[nz]))).sum())\n\n    MI = np.zeros((p, p), float)\n    for i in range(p):\n        for j in range(i + 1, p):\n            MI[i, j] = MI[j, i] = mi(tr[:, i], tr[:, j], int(cards[i]), int(cards[j]))\n    marg = [np.bincount(tr[:, i].astype(int), minlength=int(cards[i])).astype(float) for i in range(p)]\n    state_cap = int(max(24, min(320, 3 * np.sqrt(len(tr)))))\n    ll_cache = {}\n\n    def masks(G):\n        A = np.asarray(G.graph)\n        D = (A == -1) & (A.T == 1)  # i -> j\n        S = (A != 0) | (A.T != 0)\n        U = S & ~D & ~D.T\n        return S, D, U\n\n    def edge_count(S):\n        return int(np.count_nonzero(np.triu(S, 1)))\n\n    def dir_count(D):\n        return int(D.sum())\n\n    def pick_cond(i, parents=(), neighbors=()):\n        used, sel, prod = set(), [], 1\n        for group, quota in ((parents, 3), (neighbors, 2)):\n            order = sorted({int(j) for j in group if j != i and int(j) not in used}, key=lambda j: MI[i, j], reverse=True)\n            taken = 0\n            for j in order:\n                cj = int(cards[j])\n                if len(sel) >= 5:\n                    break\n                if prod * cj <= state_cap or not sel or (group is parents and taken < quota):\n                    sel.append(j)\n                    used.add(j)\n                    prod *= cj\n                    taken += 1\n        return tuple(sorted(sel))\n\n    def node_ll(i, cond):\n        key = (int(i), tuple(sorted(int(j) for j in cond if j != i)))\n        if key in ll_cache:\n            return ll_cache[key]\n        base = marg[i]\n        if len(key[1]) == 0:\n            prob = (base + 0.5) / (base.sum() + 0.5 * cards[i])\n            val = float(np.log(prob[va[:, i].astype(int)]).sum() / max(1, len(va)))\n        else:\n            c = np.array(key[1], dtype=int)\n            mult = np.cumprod(np.r_[1, cards[c[:-1]]]).astype(int)\n            ktr = (tr[:, c].astype(int) * mult).sum(1)\n            kva = (va[:, c].astype(int) * mult).sum(1)\n            R = int(np.prod(cards[c]))\n            joint = np.bincount(ktr * int(cards[i]) + tr[:, i].astype(int), minlength=R * int(cards[i])).reshape(R, int(cards[i])).astype(float)\n            den = joint.sum(1)\n            num = joint[kva, va[:, i].astype(int)]\n            seen = den[kva] > 0\n            prob = np.empty(len(va), float)\n            prob[seen] = (num[seen] + 0.5) / (den[kva][seen] + 0.5 * cards[i])\n            prob[~seen] = (base[va[:, i].astype(int)][~seen] + 0.5) / (base.sum() + 0.5 * cards[i])\n            val = float(np.log(prob).sum() / max(1, len(va)))\n        ll_cache[key] = val\n        return val\n\n    def factor_score(S, D):\n        U = S & ~D & ~D.T\n        total = 0.0\n        for i in range(p):\n            pa = np.where(D[:, i])[0]\n            un = np.where(U[:, i])[0]\n            total += node_ll(i, pick_cond(i, pa, un))\n        bonus = 0.0\n        ii, jj = np.where(D)\n        for a, b in zip(ii.tolist(), jj.tolist()):\n            pa_b = np.where(D[:, b])[0]\n            un_b = np.where(U[:, b])[0]\n            gain = node_ll(b, pick_cond(b, pa_b, un_b)) - node_ll(b, pick_cond(b, [k for k in pa_b if k != a], un_b))\n            pa_a = np.where(D[:, a])[0]\n            un_a = np.where(U[:, a])[0]\n            rev = node_ll(a, pick_cond(a, pa_a, un_a)) - node_ll(a, pick_cond(a, [*pa_a, b], un_a))\n            bonus += max(-0.02, min(0.02, gain + rev))\n        return total + 0.35 * bonus - 0.015 * edge_count(S) / max(1, p)\n\n    avg_card = float(cards.mean())\n    cand_specs = [(\"ges\", lambda: norm_graph(ges(X, score_func=\"local_score_BDeu\")))]\n    if p <= 28:\n        from causallearn.search.PermutationBased.BOSS import boss\n        cand_specs.append((\"boss\", lambda: norm_graph(boss(X, score_func=\"local_score_BDeu\"))))\n    if p <= 30:\n        from causallearn.search.PermutationBased.GRaSP import grasp\n        cand_specs.append((\"grasp\", lambda: norm_graph(grasp(X, score_func=\"local_score_BDeu\"))))\n    pc_tests = [\"chisq\", \"gsq\"] if p <= 15 else ([\"gsq\"] if (avg_card > 2.5 or p >= 30) else [\"chisq\"])\n    for test_name in pc_tests:\n        cand_specs.append((f\"pc_{test_name}\", lambda t=test_name: norm_graph(pc(X, alpha=0.05, indep_test=t, stable=True, show_progress=False))))\n\n    cands = []\n    for name, fn in cand_specs:\n        try:\n            G = fn()\n            S, D, _ = masks(G)\n            pred = max(factor_score(S, D), factor_score(S, D.T))\n            cands.append({\"name\": name, \"G\": G, \"S\": S, \"D\": D, \"pred\": pred, \"e\": edge_count(S), \"d\": dir_count(D)})\n        except Exception:\n            pass\n    if not cands:\n        try:\n            return norm_graph(ges(X, score_func=\"local_score_BDeu\"))\n        except Exception:\n            return GeneralGraph(nodes)\n\n    sup = sum(c[\"S\"].astype(float) for c in cands) / len(cands)\n    dsup = sum(c[\"D\"].astype(float) for c in cands) / len(cands)\n    for c in cands:\n        agr = float(sup[c[\"S\"]].mean()) if c[\"e\"] else 0.0\n        dagr = float(dsup[c[\"D\"]].mean()) if c[\"d\"] else 0.0\n        decis = c[\"d\"] / max(1, c[\"e\"])\n        bias = 0.008 * decis if (p <= 30 and c[\"name\"] in {\"boss\", \"grasp\", \"ges\"}) else 0.0\n        c[\"score\"] = c[\"pred\"] + (0.035 if p <= 35 else 0.02) * agr + 0.015 * dagr + bias\n\n    cands.sort(key=lambda c: (c[\"score\"], c[\"pred\"], -c[\"e\"]), reverse=True)\n    best = cands[0]\n    if p <= 30:\n        perms = [c for c in cands if c[\"name\"] in {\"boss\", \"grasp\"}]\n        if perms:\n            alt = max(perms, key=lambda c: c[\"score\"])\n            if best[\"score\"] - alt[\"score\"] < 0.015 and alt[\"d\"] >= best[\"d\"]:\n                best = alt\n    return best[\"G\"]\n# =====================================================================\n",
5  "after_line": 112,
6  "start_line": 3,
7  "end_line": 112
8}
Tool ResultStep 3
OK: Replaced lines 3..112 with 159 line(s) in causal-bnlearn/bench/custom_algorithm.py. Editable range: 3–161.

[Current file: causal-bnlearn/bench/custom_algorithm.py | editable: 3–161 | total: 162 lines]
     1| import numpy as np
     2| from causallearn.graph.GeneralGraph import GeneralGraph
     3  import numpy as np
     4  from causallearn.graph.GeneralGraph import GeneralGraph
     5  from causallearn.graph.GraphNode import GraphNode
     6  
     7  # =====================================================================
     8  # EDITABLE: implement run_causal_discovery below
     9  # =====================================================================
    10  def run_causal_discovery(X: np.ndarray) -> GeneralGraph:
    11      """Direction-aware consensus selection over discrete causal discovery candidates."""
    12      from causallearn.search.ScoreBased.GES import ges
    13      from causallearn.search.ConstraintBased.PC import pc
    14  
    15      n, p = X.shape
   
AssistantStep 4test
Tool ResultStep 4
[Test #2] (1 test remaining; call submit(n=2) to submit this result)

### Cancer (scripts/eval_cancer.sh)
Results (Cancer):
  SHD=4  AdjP=1.0000 AdjR=0.5000  ArrowP=0.0000 ArrowR=0.0000

### Earthquake (scripts/eval_earthquake.sh)
Results (Earthquake):
  SHD=2  AdjP=0.7500 AdjR=0.7500  ArrowP=1.0000 ArrowR=0.7500

### Survey (scripts/eval_survey.sh)
Results (Survey):
  SHD=6  AdjP=1.0000 AdjR=0.1667  ArrowP=0.0000 ArrowR=0.0000

### Asia (scripts/eval_asia.sh)
Results (Asia):
  SHD=6  AdjP=0.7778 AdjR=0.8750  ArrowP=1.0000 ArrowR=0.4000

### Sachs (scripts/eval_sachs.sh)
Results (Sachs):
  SHD=1  AdjP=1.0000 AdjR=0.9412  ArrowP=0.0000 ArrowR=0.0000

### Child (scripts/eval_child.sh)
Results (Child):
  SHD=8  AdjP=1.0000 AdjR=0.8800  ArrowP=1.0000 ArrowR=0.3846

### Insurance (scripts/eval_insurance.sh)
Results (Insurance):
  SHD=31  AdjP=0.9500 AdjR=0.7308  ArrowP=0.7273 ArrowR=0.4706

### Water (scripts/eval_water.sh)
Results (Water):
  SHD=48  AdjP=1.0000 AdjR=0.4242  ArrowP=0.9231 A
AssistantStep 5submit
Tool Input
1{
2  "n": 2
3}
Tool ResultStep 5
[submit] Submitting result from test #2 as final.

### Cancer (scripts/eval_cancer.sh)
Results (Cancer):
  SHD=4  AdjP=1.0000 AdjR=0.5000  ArrowP=0.0000 ArrowR=0.0000

### Earthquake (scripts/eval_earthquake.sh)
Results (Earthquake):
  SHD=2  AdjP=0.7500 AdjR=0.7500  ArrowP=1.0000 ArrowR=0.7500

### Survey (scripts/eval_survey.sh)
Results (Survey):
  SHD=6  AdjP=1.0000 AdjR=0.1667  ArrowP=0.0000 ArrowR=0.0000

### Asia (scripts/eval_asia.sh)
Results (Asia):
  SHD=6  AdjP=0.7778 AdjR=0.8750  ArrowP=1.0000 ArrowR=0.4000

### Sachs (scripts/eval_sachs.sh)
Results (Sachs):
  SHD=1  AdjP=1.0000 AdjR=0.9412  ArrowP=0.0000 ArrowR=0.0000

### Child (scripts/eval_child.sh)
Results (Child):
  SHD=8  AdjP=1.0000 AdjR=0.8800  ArrowP=1.0000 ArrowR=0.3846

### Insurance (scripts/eval_insurance.sh)
Results (Insurance):
  SHD=31  AdjP=0.9500 AdjR=0.7308  ArrowP=0.7273 ArrowR=0.4706

### Water (scripts/eval_water.sh)
Results (Water):
  SHD=48  AdjP=1.0000 AdjR=0.4242  ArrowP=0.9231 ArrowR=0.2000

### M