graph-signal-propagation

Graph LearningChebNetIIrigorous codebase

Description

Graph Signal Propagation: Spectral/Spatial Graph Filters

Research Question

Design a novel graph signal propagation filter for node feature aggregation in graph neural networks. The filter should effectively handle both homophilic and heterophilic graphs.

Background

Graph Neural Networks propagate node features through graph structure using graph filters. The choice of filter is critical: simple low-pass filters (like GCN's first-order approximation) work well on homophilic graphs (where connected nodes share labels) but fail on heterophilic graphs (where connected nodes often differ). Modern spectral methods learn polynomial filters in various bases:

  • Monomial basis (GPRGNN): h(A) = sum_k gamma_k A^k -- simple but can be numerically unstable
  • Bernstein basis (BernNet): non-negative, excellent controllability, but O(K^2) complexity
  • Chebyshev interpolation (ChebNetII): avoids Runge phenomenon, O(K) complexity
  • Jacobi polynomials (JacobiConv): orthogonal, fast convergence, generalizes Chebyshev/Legendre

Key design axes include: polynomial basis choice, coefficient initialization and constraints, normalization (GCN vs Laplacian), and interaction with the MLP encoder.

Task

Modify the CustomProp (propagation layer) and CustomFilter (full model) classes in custom_filter.py. The propagation layer defines how node features are filtered across the graph; the model wraps it with an MLP encoder and output head.

Interface

class CustomProp(MessagePassing):
    def __init__(self, K, alpha=0.1, **kwargs):
        # K: polynomial order, alpha: teleport probability
        ...
    def forward(self, x, edge_index, edge_weight=None):
        # x: [num_nodes, channels], edge_index: [2, num_edges]
        # Returns: filtered features [num_nodes, channels]
        ...

class CustomFilter(nn.Module):
    def __init__(self, num_features, num_classes, hidden=64, K=10,
                 alpha=0.1, dropout=0.5, dprate=0.5):
        ...
    def forward(self, data):
        # data: PyG Data object with data.x, data.edge_index
        # Returns: log_softmax predictions [num_nodes, num_classes]
        ...

Available Utilities

  • gcn_norm(edge_index) -- GCN normalization D^{-1/2}AD^{-1/2}
  • get_laplacian(edge_index, normalization='sym') -- symmetric normalized Laplacian L = I - D^{-1/2}AD^{-1/2}
  • add_self_loops(edge_index, edge_weight, fill_value) -- add self-loops
  • self.propagate(edge_index, x=x, norm=norm) -- single-step message passing
  • cheby(i, x) -- evaluate Chebyshev polynomial T_i(x)
  • comb(n, k) -- binomial coefficient (from scipy)
  • Constants: K, ALPHA, HIDDEN, DROPOUT, DPRATE

Evaluation

Evaluated on 4 benchmark datasets spanning homophilic and heterophilic graphs:

  • Cora (2,708 nodes, 7 classes, homophilic) -- citation network
  • CiteSeer (3,327 nodes, 6 classes, homophilic) -- citation network
  • Texas (183 nodes, 5 classes, heterophilic) -- WebKB webpage network
  • Cornell (183 nodes, 5 classes, heterophilic) -- WebKB webpage network

Each dataset runs 10 random splits (60/20/20 train/val/test) with early stopping. Metric: mean test accuracy over 10 runs (higher is better).

Code

custom_filter.py
EditableRead-only
1# Custom graph signal propagation filter for MLS-Bench
2#
3# EDITABLE section: CustomProp (propagation layer) + CustomFilter (full model).
4# FIXED sections: everything else (config, data loading, training loop, evaluation).
5
6import os
7import math
8import random
9import time
10from typing import Optional
11
12import numpy as np
13import torch
14import torch.nn as nn
15import torch.nn.functional as F

Results

ModelTypeaccuracy cora accuracy citeseer accuracy texas accuracy cornell
appnpbaseline0.8790.8040.4500.692
bernnetbaseline0.8830.8000.8980.887
bernnetbaseline0.8550.7800.9090.825
chebnetiibaseline0.8720.8000.8770.847
gcnbaseline0.8710.7980.7460.610
gcnbaseline0.8690.7920.6050.636
gprgnnbaseline0.8880.8020.6440.730
gprgnnbaseline0.8880.7970.6660.769
gprgnnbaseline0.8760.7900.6260.710
jacobiconvbaseline0.8790.7750.8450.737
sgcbaseline0.8500.7680.2670.303