ai4sci-weather-forecast-aggregation
Description
Weather Forecast Variable Aggregation
Research Question
How should a weather forecasting model aggregate information across heterogeneous meteorological variables for optimal prediction?
Background
Modern weather forecasting models process many meteorological variables simultaneously (temperature, pressure, wind, humidity at various pressure levels). ClimaX (Nguyen et al., ICML 2023) tokenizes each variable independently via per-variable patch embeddings, then aggregates them into a unified spatial representation before feeding into a Vision Transformer backbone. The default aggregation uses a learnable query with cross-attention over variable tokens at each spatial location, but this is just one design choice. Better aggregation strategies could capture inter-variable correlations more effectively.
Task
Modify the VariableAggregator class in custom_forecast.py to implement a novel variable aggregation mechanism. The module receives per-variable patch embeddings and must produce a single aggregated representation per spatial location.
Interface
class VariableAggregator(nn.Module):
def __init__(self, embed_dim, num_heads, num_vars):
"""
Args:
embed_dim (int): Embedding dimension D (1024).
num_heads (int): Number of attention heads (16).
num_vars (int): Number of input variables V (48).
"""
...
def forward(self, x):
"""
Args:
x: [B, V, L, D] — per-variable patch embeddings
B = batch size
V = number of meteorological variables (48)
L = number of spatial patches (512 = 16x32)
D = embedding dimension (1024)
Returns:
[B, L, D] — aggregated representation per spatial location
"""
...
The input contains 48 variables: 3 surface constants (land-sea mask, orography, latitude), 3 surface fields (2m temperature, 10m wind u/v), and 42 pressure-level fields (geopotential, u/v wind, temperature, relative/specific humidity at 50-925 hPa). Each variable has been independently tokenized into L=512 patch embeddings of dimension D=1024.
Available Components
You have access to standard PyTorch modules (nn.Linear, nn.MultiheadAttention, nn.LayerNorm, etc.) and torch.nn.functional. The FIXED section imports torch, torch.nn, and torch.nn.functional as F.
Evaluation
The model is fine-tuned from pretrained ClimaX weights on ERA5 reanalysis data at 5.625-degree resolution and evaluated on three forecasting targets:
- z500-3day: Geopotential height at 500 hPa, 3-day lead time
- t850-5day: Temperature at 850 hPa, 5-day lead time
- wind10m-7day: 10m wind speed, 7-day lead time
Metric: Latitude-weighted RMSE (lower is better). The metric accounts for the convergence of meridians at the poles by weighting errors by the cosine of latitude.
Code
1"""Custom Weather Forecast Variable Aggregation Script2Based on ClimaX (Nguyen et al., 2023), evaluated on ERA5 at 5.625 deg.34The EDITABLE section contains the variable aggregation module that combines5per-variable patch embeddings into a unified spatial representation.6Everything else (ViT backbone, data loading, training loop) is FIXED.7"""89import math10import os11import time12from functools import lru_cache1314import numpy as np15import torch
Additional context files (read-only):
ClimaX/src/climax/arch.pyClimaX/src/climax/parallelpatchembed.pyClimaX/src/climax/utils/metrics.py
Results
| Model | Type | w rmse geopotential 500 z500-3day ↓ | w rmse temperature 850 t850-5day ↓ | w rmse 10m u component of wind wind10m-7day ↓ |
|---|---|---|---|---|
| cross_attention | baseline | 245.130 | 3.974 | 4.418 |
| learned_weighted_sum | baseline | 485.598 | 4.558 | 4.648 |
| mean_pooling | baseline | 490.765 | 4.528 | 4.643 |
| self_attention | baseline | 325.445 | 4.122 | 4.491 |