stf-traffic-forecast

Time SeriesBasicTSrigorous codebase

Description

Spatial-Temporal Traffic Forecasting: Custom Model Design

Objective

Design and implement a custom deep learning model for spatial-temporal traffic forecasting. Your code goes in custom_model.py (both the Custom model class and CustomConfig config class). Three reference implementations (STID, DLinear, StemGNN) are provided as read-only.

Background

Spatial-temporal forecasting predicts future values across a network of spatial nodes (e.g., traffic sensors), leveraging both temporal patterns and spatial correlations between nodes. Unlike standard time series forecasting, STF models must capture inter-node dependencies (e.g., traffic at nearby sensors is correlated). Key design choices include:

  • Spatial modeling: learnable node embeddings, graph convolutions, spatial attention
  • Temporal modeling: RNNs, temporal convolutions, Transformers
  • Spatial-temporal fusion: how to combine spatial and temporal information

Model Interface

def forward(self, inputs: torch.Tensor, inputs_timestamps: torch.Tensor) -> torch.Tensor:
    """
    inputs: [batch_size, input_len, num_features]
        - input_len=12 (1 hour of 5-minute intervals)
        - num_features = number of spatial nodes (sensors)
    inputs_timestamps: [batch_size, input_len, 2]
        - channel 0: normalized time-of-day (0 to 1)
        - channel 1: normalized day-of-week (0 to 1)
    Returns: [batch_size, output_len, num_features]
        - output_len=12 (predict next 1 hour)
    """

Evaluation

Trained and evaluated on three traffic datasets:

  • METR-LA (207 sensors, traffic speed, Los Angeles highway)
  • PEMS-BAY (325 sensors, traffic speed, San Francisco Bay Area)
  • PEMS04 (307 sensors, traffic flow, California district 4)

All use input_len=12, output_len=12. Metrics: MAE, RMSE, MAPE (lower is better). Data is Z-score normalized; metrics are computed after inverse transform. Missing values (0.0) are masked during loss computation.

Available Modules

You may import and use components from basicts.modules:

  • basicts.modules.mlps: MLP layers (MLPLayer, ResMLPLayer)
  • basicts.modules.norm: Normalization (RevIN, LayerNorm)
  • basicts.modules.embed: Sequence embeddings
  • basicts.modules.transformer: Transformer components (Encoder, MultiHeadAttention)
  • basicts.modules.activations: Activation functions

Code

custom_model.py
EditableRead-only
1import torch
2import torch.nn as nn
3from dataclasses import dataclass, field
4from typing import Optional
5
6from basicts.configs import BasicTSModelConfig
7
8
9@dataclass
10class CustomConfig(BasicTSModelConfig):
11 """Configuration for the Custom spatial-temporal forecasting model.
12
13 Required fields (set by training script):
14 input_len: Length of input historical sequence.
15 output_len: Length of output prediction sequence.

Additional context files (read-only):

  • BasicTS/src/basicts/modules/mlps.py
  • BasicTS/src/basicts/modules/embed/__init__.py

Results

ModelTypemae METR-LA rmse METR-LA mape METR-LA mae PEMS-BAY rmse PEMS-BAY mape PEMS-BAY mae PEMS04 rmse PEMS04 mape PEMS04
dlinearbaseline4.0627.8930.1122.1334.9920.04728.47144.6720.335
itransformerbaseline3.9127.7850.1092.0214.8440.04425.71940.8990.276
softsbaseline3.9107.8200.1091.9854.7770.04425.92041.2040.274
stemgnnbaseline4.0127.8040.1142.0774.6960.04726.23740.3250.262
stidbaseline3.2196.5540.0931.6643.7580.03819.81131.4150.153
timemixerbaseline3.9507.8860.1092.0644.9890.04527.75143.2120.283
timesnetbaseline3.8967.6380.1082.0724.8970.04722.16035.0450.209