graph-temporal
Description
Graph Message Passing for Spatio-Temporal Traffic Forecasting
Research Question
Design a novel graph message passing mechanism for spatial aggregation in spatio-temporal traffic forecasting networks.
Background
Traffic forecasting on sensor networks requires modeling both temporal dynamics and spatial dependencies between sensors. While temporal modeling (via convolutions or RNNs) is relatively well-understood, the spatial component — how information is passed between graph nodes — remains an active area of research.
Classical approaches include:
- Spectral methods: Chebyshev polynomial approximation of graph convolutions (STGCN)
- Diffusion methods: Random walk-based diffusion on directed graphs (DCRNN, Graph WaveNet)
- Attention methods: Spatial attention mechanisms (ASTGCN, STAEformer)
- Adaptive methods: Learned graph structures combined with multi-hop propagation (MTGNN)
The task is to design a spatial aggregation layer for complex, distance-dependent, and potentially asymmetric relationships between traffic sensors.
Task
Modify the SpatialLayer class in custom_graph_model.py. This class defines the graph message passing component used within each spatio-temporal block. The temporal backbone (gated dilated causal convolutions) and training pipeline are fixed.
Your SpatialLayer receives:
x: Node features[B, N, D]— B=batch, N=nodes (sensors), D=featuresadj: Normalized adjacency matrix[N, N]— symmetric-normalized, weighted by sensor distance
And must return spatially aggregated node features [B, N, D'].
Interface
class SpatialLayer(nn.Module):
def __init__(self, in_dim: int, out_dim: int, dropout: float = 0.0):
...
def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
# x: [B, N, D], adj: [N, N] -> output: [B, N, D']
...
The class is instantiated with SpatialLayer(hidden_dim, hidden_dim, dropout) where hidden_dim=32 by default. You may add parameters, intermediate layers, or learnable components as needed.
Evaluation
Trained and evaluated on three traffic datasets:
- METR-LA (207 sensors, traffic speed, Los Angeles highway network)
- PEMS-BAY (325 sensors, traffic speed, San Francisco Bay Area)
- PEMS04 (307 sensors, traffic flow, California district 4)
All use input_len=12, output_len=12 (5-minute intervals, 1 hour history -> 1 hour prediction).
Metrics: MAE, RMSE, MAPE (lower is better). Data is Z-score normalized; metrics computed after inverse transform.
Code
1"""Graph-temporal forecasting model with editable spatial message passing.23Fixed: temporal backbone (dilated causal conv), adjacency loading, output projection.4Editable: SpatialLayer -- the graph message passing component (lines 72--130).5"""6import math7import os8import pickle910import numpy as np11import torch12import torch.nn as nn13import torch.nn.functional as F14from dataclasses import dataclass, field15from typing import Optional
Results
| Model | Type | mae METR-LA ↓ | mae PEMS-BAY ↓ | mae PEMS04 ↓ | mape METR-LA ↓ | mape PEMS-BAY ↓ | mape PEMS04 ↓ | rmse METR-LA ↓ | rmse PEMS-BAY ↓ | rmse PEMS04 ↓ |
|---|---|---|---|---|---|---|---|---|---|---|
| astgcn | baseline | 6.663 | 2.988 | 55.630 | 0.172 | 0.067 | 0.770 | 10.597 | 5.781 | 71.576 |
| dcrnn | baseline | 4.872 | - | - | 0.137 | - | - | 8.798 | - | - |
| dcrnn | baseline | 3.957 | 2.230 | 34.879 | 0.114 | 0.053 | 0.406 | 7.772 | 4.933 | 49.344 |
| gwnet | baseline | 4.154 | 2.306 | 33.386 | 0.121 | 0.056 | 0.385 | 8.035 | 5.001 | 47.268 |
| mtgnn | baseline | 4.462 | 2.336 | 32.755 | 0.126 | 0.056 | 0.369 | 8.461 | 5.043 | 46.932 |
| staeformer | baseline | 4.028 | 2.115 | 32.044 | 0.109 | 0.049 | 0.386 | 7.731 | 4.607 | 44.366 |
| stgcn | baseline | 4.312 | 2.348 | 36.202 | 0.122 | 0.056 | 0.462 | 8.419 | 5.269 | 50.247 |