graph-temporal

Graph LearningBasicTSrigorous codebase

Description

Graph Message Passing for Spatio-Temporal Traffic Forecasting

Research Question

Design a novel graph message passing mechanism for spatial aggregation in spatio-temporal traffic forecasting networks.

Background

Traffic forecasting on sensor networks requires modeling both temporal dynamics and spatial dependencies between sensors. While temporal modeling (via convolutions or RNNs) is relatively well-understood, the spatial component — how information is passed between graph nodes — remains an active area of research.

Classical approaches include:

  • Spectral methods: Chebyshev polynomial approximation of graph convolutions (STGCN)
  • Diffusion methods: Random walk-based diffusion on directed graphs (DCRNN, Graph WaveNet)
  • Attention methods: Spatial attention mechanisms (ASTGCN, STAEformer)
  • Adaptive methods: Learned graph structures combined with multi-hop propagation (MTGNN)

The task is to design a spatial aggregation layer for complex, distance-dependent, and potentially asymmetric relationships between traffic sensors.

Task

Modify the SpatialLayer class in custom_graph_model.py. This class defines the graph message passing component used within each spatio-temporal block. The temporal backbone (gated dilated causal convolutions) and training pipeline are fixed.

Your SpatialLayer receives:

  • x: Node features [B, N, D] — B=batch, N=nodes (sensors), D=features
  • adj: Normalized adjacency matrix [N, N] — symmetric-normalized, weighted by sensor distance

And must return spatially aggregated node features [B, N, D'].

Interface

class SpatialLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0):
        ...
    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        # x: [B, N, D], adj: [N, N] -> output: [B, N, D']
        ...

The class is instantiated with SpatialLayer(hidden_dim, hidden_dim, dropout) where hidden_dim=32 by default. You may add parameters, intermediate layers, or learnable components as needed.

Evaluation

Trained and evaluated on three traffic datasets:

  • METR-LA (207 sensors, traffic speed, Los Angeles highway network)
  • PEMS-BAY (325 sensors, traffic speed, San Francisco Bay Area)
  • PEMS04 (307 sensors, traffic flow, California district 4)

All use input_len=12, output_len=12 (5-minute intervals, 1 hour history -> 1 hour prediction). Metrics: MAE, RMSE, MAPE (lower is better). Data is Z-score normalized; metrics computed after inverse transform.

Code

custom_graph_model.py
EditableRead-only
1"""Graph-temporal forecasting model with editable spatial message passing.
2
3Fixed: temporal backbone (dilated causal conv), adjacency loading, output projection.
4Editable: SpatialLayer -- the graph message passing component (lines 72--130).
5"""
6import math
7import os
8import pickle
9
10import numpy as np
11import torch
12import torch.nn as nn
13import torch.nn.functional as F
14from dataclasses import dataclass, field
15from typing import Optional

Results

ModelTypemae METR-LA mae PEMS-BAY mae PEMS04 mape METR-LA mape PEMS-BAY mape PEMS04 rmse METR-LA rmse PEMS-BAY rmse PEMS04
astgcnbaseline6.6632.98855.6300.1720.0670.77010.5975.78171.576
dcrnnbaseline4.872--0.137--8.798--
dcrnnbaseline3.9572.23034.8790.1140.0530.4067.7724.93349.344
gwnetbaseline4.1542.30633.3860.1210.0560.3858.0355.00147.268
mtgnnbaseline4.4622.33632.7550.1260.0560.3698.4615.04346.932
staeformerbaseline4.0282.11532.0440.1090.0490.3867.7314.60744.366
stgcnbaseline4.3122.34836.2020.1220.0560.4628.4195.26950.247