graph-node-classification
Description
Graph Neural Network: Node Classification Message Passing
Research Question
Design a novel message passing mechanism for graph neural networks that improves node classification performance across citation network benchmarks.
Background
Graph Neural Networks (GNNs) learn node representations by iteratively aggregating information from neighboring nodes through message passing. The core design choices are:
- Message construction: how to compute messages from source to target nodes (e.g., linear transform, attention-weighted, edge-conditioned)
- Aggregation: how to combine incoming messages (e.g., sum, mean, max, attention-weighted)
- Update: how to integrate aggregated messages with the node's own representation (e.g., residual, gated, concatenation)
Classic approaches include GCN (symmetric normalization), GAT (attention-based weighting), and GraphSAGE (mean aggregation with self/neighbor separation). Recent advances like Graph Transformers (GPS) combine local message passing with global self-attention, while methods like NAGphormer use multi-hop tokenization with Transformer encoders.
Task
Modify the CustomMessagePassingLayer class and CustomGNN model in custom_nodecls.py to implement a novel message passing mechanism. Your implementation must work within PyTorch Geometric's MessagePassing framework.
Model Interface
class CustomMessagePassingLayer(MessagePassing):
def __init__(self, in_channels: int, out_channels: int):
# Define learnable parameters and layers
...
def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
# x: [num_nodes, in_channels], edge_index: [2, num_edges]
# Returns: [num_nodes, out_channels]
...
def message(self, x_j: Tensor, ...) -> Tensor:
# Define per-edge message computation
...
class CustomGNN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers=2, dropout=0.5):
...
def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
# Returns: [num_nodes, out_channels] (logits)
...
Available PyG Utilities
MessagePassingbase class:self.propagate(edge_index, ...)orchestrates message/aggregate/updateadd_self_loops(edge_index): add self-loop edgesdegree(index, num_nodes): compute node degreessoftmax(src, index): sparse softmax over edges- Convolution layers for reference:
GCNConv,GATConv,SAGEConv(imported but read-only)
Evaluation
Trained and evaluated on three citation network benchmarks (semi-supervised node classification with standard Planetoid splits):
- Cora: 2,708 nodes, 5,429 edges, 7 classes, 1,433 features
- CiteSeer: 3,327 nodes, 4,732 edges, 6 classes, 3,703 features
- PubMed: 19,717 nodes, 44,338 edges, 3 classes, 500 features
Metrics: test accuracy and macro F1 score (higher is better). Training: 200 epochs with early stopping (patience=50), Adam optimizer, lr=0.01, weight_decay=5e-4.
Code
1# Custom GNN message passing mechanism for node classification — MLS-Bench2#3# EDITABLE section: CustomMessagePassingLayer + CustomGNN classes (lines 48-157).4# FIXED sections: everything else (config, data loading, training loop, evaluation).56import os7import copy8import random9import math10from typing import Optional, Tuple1112import numpy as np13import torch14import torch.nn as nn15import torch.nn.functional as F
Results
| Model | Type | accuracy Cora ↑ | macro f1 Cora ↑ | accuracy CiteSeer ↑ | macro f1 CiteSeer ↑ | accuracy PubMed ↑ | macro f1 PubMed ↑ |
|---|---|---|---|---|---|---|---|
| gat | baseline | 0.826 | 0.813 | 0.708 | 0.660 | 0.778 | 0.773 |
| gcn | baseline | 0.821 | 0.808 | 0.718 | 0.683 | - | - |
| gcn | baseline | 0.821 | 0.808 | 0.718 | 0.683 | 0.786 | 0.784 |
| gps | baseline | 0.671 | 0.653 | 0.519 | 0.440 | - | - |
| gps | baseline | 0.686 | 0.672 | 0.516 | 0.426 | - | - |
| gps | baseline | 0.760 | 0.746 | 0.661 | 0.629 | - | - |
| graphsage | baseline | 0.792 | 0.783 | 0.660 | 0.622 | - | - |
| graphsage | baseline | 0.792 | 0.783 | 0.660 | 0.622 | 0.769 | 0.766 |
| nagphormer | baseline | 0.613 | 0.530 | 0.498 | 0.431 | 0.488 | 0.341 |
| nagphormer | baseline | 0.770 | 0.759 | 0.664 | 0.620 | 0.761 | 0.761 |
| revgat | baseline | 0.752 | 0.746 | 0.649 | 0.613 | 0.755 | 0.751 |