graph-link-prediction
Description
Task: Graph Link Prediction
Research Question
Design a novel link prediction method for graphs. The goal is to learn an encoder that maps nodes to embeddings and a decoder that scores candidate edges, such that the model accurately predicts missing or future links across diverse graph types.
Background
Link prediction is a fundamental graph learning task: given a partially observed graph, predict which unobserved edges are likely to exist. It has applications in social network analysis (friend recommendation), citation networks (paper recommendation), knowledge graph completion, and biological interaction prediction.
Classical approaches include:
- GCN + dot product: GCN encodes nodes; dot product of embeddings scores edges. Simple but often competitive.
- VGAE (Variational Graph Auto-Encoder): Probabilistic encoder with KL regularization for robust link prediction.
- Node2Vec: Random-walk-based embedding with biased walks capturing structural roles and communities.
Recent SOTA methods explore richer structural information:
- SEAL: Extracts k-hop enclosing subgraphs per edge, uses DRNL labeling trick + GNN for edge classification.
- Neo-GNN: Learns neighborhood overlap features from adjacency matrix to augment GNN predictions.
- BUDDY: Subgraph sketching with MinHash/HyperLogLog features for scalable structural information.
What to Implement
Implement the LinkPredictor class in custom_linkpred.py. You must implement:
__init__(self, in_channels, hidden_channels, num_layers, dropout): Set up your model.encode(self, x, edge_index) -> Tensor [N, hidden_channels]: Encode nodes into embeddings.decode(self, z_src, z_dst) -> Tensor [num_edges]: Score candidate edges from source/dest embeddings.forward(self, x, edge_index, edge_label_index) -> Tensor [num_edges]: Full forward pass.
Input Format
x: Node features[N, in_channels]— feature dimension varies by dataset.edge_index: Training graph edges[2, E_train]in COO format (undirected).edge_label_index: Candidate edges to score[2, num_candidates].
Available PyG Modules (pre-installed)
You can use any PyTorch Geometric module: GCNConv, SAGEConv, GATConv, GINConv, GraphConv, MessagePassing, global pooling, etc. Also available: torch_geometric.utils (negative_sampling, to_undirected, degree, etc.), torch_geometric.nn, and torch_geometric.transforms.
Evaluation
The model is tested on 3 benchmark datasets:
Citation Networks (metric: AUC, MRR, Hits@20 — higher is better):
- Cora: 2,708 nodes, 10,556 edges, 1,433 features, 7 classes. 85/5/10 link split.
- CiteSeer: 3,327 nodes, 9,104 edges, 3,703 features, 6 classes. 85/5/10 link split.
Collaboration Network (metric: Hits@50, MRR — higher is better):
- ogbl-collab: 235,868 nodes, 1,285,465 edges, 128 features. Official OGB split.
Editable Region
Lines 122-196 of custom_linkpred.py are editable (between EDITABLE SECTION START and EDITABLE SECTION END markers). You may define the LinkPredictor class and any helper functions/classes within this region. The class must conform to the interface above.
Code
1"""2Graph Link Prediction — Self-contained template.3Predicts missing links in graphs using learned node representations and a4link scoring function. Evaluated on citation networks (Cora, CiteSeer) and5a collaboration network (ogbl-collab).67Structure:8Lines 1-121: FIXED — Imports, data loading, negative sampling, evaluation9Lines 122-196: EDITABLE — LinkPredictor class (model + scoring)10Lines 197+: FIXED — Training loop, metric computation, CLI11"""12import os13import sys14import math15import argparse
Results
| Model | Type | AUC Cora ↑ | MRR Cora ↑ | Hits@20 Cora ↑ | AUC CiteSeer ↑ | MRR CiteSeer ↑ | Hits@20 CiteSeer ↑ |
|---|---|---|---|---|---|---|---|
| buddy | baseline | 90.940 | 32.053 | 67.107 | 87.767 | 30.757 | 64.173 |
| buddy | baseline | 90.530 | 30.573 | 64.197 | 87.573 | 41.017 | 63.150 |
| buddy | baseline | 90.847 | 37.987 | 63.443 | 87.750 | 44.353 | 63.667 |
| buddy | baseline | 89.877 | 25.337 | 64.960 | 88.207 | 39.430 | 63.663 |
| gcn_dot | baseline | 78.903 | 16.677 | 43.137 | 82.713 | 22.010 | 55.897 |
| gcn_dot | baseline | 89.437 | 27.710 | 66.853 | 89.600 | 44.167 | 72.750 |
| neo_gnn | baseline | 88.723 | 37.700 | 68.060 | 90.273 | 49.267 | 73.923 |
| neo_gnn | baseline | 87.667 | 26.233 | 60.403 | 88.137 | 41.240 | 60.587 |
| neo_gnn | baseline | 86.660 | 25.913 | 58.190 | 88.200 | 31.213 | 61.023 |
| neo_gnn | baseline | 86.453 | 30.963 | 59.393 | 87.710 | 38.313 | 60.147 |
| neo_gnn | baseline | 86.663 | 26.843 | 58.130 | 87.233 | 32.487 | 59.707 |
| node2vec | baseline | 68.103 | 5.623 | 16.193 | 64.893 | 6.497 | 22.563 |
| seal | baseline | 92.980 | 30.547 | 65.277 | 92.537 | 33.963 | 72.747 |
| vgae | baseline | 84.797 | 13.583 | 41.177 | 83.733 | 22.317 | 39.633 |
| vgae | baseline | 84.587 | 17.297 | 44.657 | 84.890 | 19.140 | 43.590 |
| vgae | baseline | - | - | - | - | - | - |