stf-traffic-forecast
Description
Spatial-Temporal Traffic Forecasting: Custom Model Design
Objective
Design and implement a custom deep learning model for spatial-temporal traffic forecasting. Your code goes in custom_model.py (both the Custom model class and CustomConfig config class). Three reference implementations (STID, DLinear, StemGNN) are provided as read-only.
Background
Spatial-temporal forecasting predicts future values across a network of spatial nodes (e.g., traffic sensors), leveraging both temporal patterns and spatial correlations between nodes. Unlike standard time series forecasting, STF models must capture inter-node dependencies (e.g., traffic at nearby sensors is correlated). Key design choices include:
- Spatial modeling: learnable node embeddings, graph convolutions, spatial attention
- Temporal modeling: RNNs, temporal convolutions, Transformers
- Spatial-temporal fusion: how to combine spatial and temporal information
Model Interface
def forward(self, inputs: torch.Tensor, inputs_timestamps: torch.Tensor) -> torch.Tensor:
"""
inputs: [batch_size, input_len, num_features]
- input_len=12 (1 hour of 5-minute intervals)
- num_features = number of spatial nodes (sensors)
inputs_timestamps: [batch_size, input_len, 2]
- channel 0: normalized time-of-day (0 to 1)
- channel 1: normalized day-of-week (0 to 1)
Returns: [batch_size, output_len, num_features]
- output_len=12 (predict next 1 hour)
"""
Evaluation
Trained and evaluated on three traffic datasets:
- METR-LA (207 sensors, traffic speed, Los Angeles highway)
- PEMS-BAY (325 sensors, traffic speed, San Francisco Bay Area)
- PEMS04 (307 sensors, traffic flow, California district 4)
All use input_len=12, output_len=12. Metrics: MAE, RMSE, MAPE (lower is better). Data is Z-score normalized; metrics are computed after inverse transform. Missing values (0.0) are masked during loss computation.
Available Modules
You may import and use components from basicts.modules:
basicts.modules.mlps: MLP layers (MLPLayer, ResMLPLayer)basicts.modules.norm: Normalization (RevIN, LayerNorm)basicts.modules.embed: Sequence embeddingsbasicts.modules.transformer: Transformer components (Encoder, MultiHeadAttention)basicts.modules.activations: Activation functions
Code
1import torch2import torch.nn as nn3from dataclasses import dataclass, field4from typing import Optional56from basicts.configs import BasicTSModelConfig789@dataclass10class CustomConfig(BasicTSModelConfig):11"""Configuration for the Custom spatial-temporal forecasting model.1213Required fields (set by training script):14input_len: Length of input historical sequence.15output_len: Length of output prediction sequence.
Additional context files (read-only):
BasicTS/src/basicts/modules/mlps.pyBasicTS/src/basicts/modules/embed/__init__.py
Results
| Model | Type | mae METR-LA ↓ | rmse METR-LA ↓ | mape METR-LA ↓ | mae PEMS-BAY ↓ | rmse PEMS-BAY ↓ | mape PEMS-BAY ↓ | mae PEMS04 ↓ | rmse PEMS04 ↓ | mape PEMS04 ↓ |
|---|---|---|---|---|---|---|---|---|---|---|
| dlinear | baseline | 4.062 | 7.893 | 0.112 | 2.133 | 4.992 | 0.047 | 28.471 | 44.672 | 0.335 |
| itransformer | baseline | 3.912 | 7.785 | 0.109 | 2.021 | 4.844 | 0.044 | 25.719 | 40.899 | 0.276 |
| softs | baseline | 3.910 | 7.820 | 0.109 | 1.985 | 4.777 | 0.044 | 25.920 | 41.204 | 0.274 |
| stemgnn | baseline | 4.012 | 7.804 | 0.114 | 2.077 | 4.696 | 0.047 | 26.237 | 40.325 | 0.262 |
| stid | baseline | 3.219 | 6.554 | 0.093 | 1.664 | 3.758 | 0.038 | 19.811 | 31.415 | 0.153 |
| timemixer | baseline | 3.950 | 7.886 | 0.109 | 2.064 | 4.989 | 0.045 | 27.751 | 43.212 | 0.283 |
| timesnet | baseline | 3.896 | 7.638 | 0.108 | 2.072 | 4.897 | 0.047 | 22.160 | 35.045 | 0.209 |