ai4sci-weather-forecast-aggregation

AI for ScienceClimaXrigorous codebase

Description

Weather Forecast Variable Aggregation

Research Question

How should a weather forecasting model aggregate information across heterogeneous meteorological variables for optimal prediction?

Background

Modern weather forecasting models process many meteorological variables simultaneously (temperature, pressure, wind, humidity at various pressure levels). ClimaX (Nguyen et al., ICML 2023) tokenizes each variable independently via per-variable patch embeddings, then aggregates them into a unified spatial representation before feeding into a Vision Transformer backbone. The default aggregation uses a learnable query with cross-attention over variable tokens at each spatial location, but this is just one design choice. Better aggregation strategies could capture inter-variable correlations more effectively.

Task

Modify the VariableAggregator class in custom_forecast.py to implement a novel variable aggregation mechanism. The module receives per-variable patch embeddings and must produce a single aggregated representation per spatial location.

Interface

class VariableAggregator(nn.Module):
    def __init__(self, embed_dim, num_heads, num_vars):
        """
        Args:
            embed_dim (int): Embedding dimension D (1024).
            num_heads (int): Number of attention heads (16).
            num_vars (int): Number of input variables V (48).
        """
        ...

    def forward(self, x):
        """
        Args:
            x: [B, V, L, D] — per-variable patch embeddings
                B = batch size
                V = number of meteorological variables (48)
                L = number of spatial patches (512 = 16x32)
                D = embedding dimension (1024)

        Returns:
            [B, L, D] — aggregated representation per spatial location
        """
        ...

The input contains 48 variables: 3 surface constants (land-sea mask, orography, latitude), 3 surface fields (2m temperature, 10m wind u/v), and 42 pressure-level fields (geopotential, u/v wind, temperature, relative/specific humidity at 50-925 hPa). Each variable has been independently tokenized into L=512 patch embeddings of dimension D=1024.

Available Components

You have access to standard PyTorch modules (nn.Linear, nn.MultiheadAttention, nn.LayerNorm, etc.) and torch.nn.functional. The FIXED section imports torch, torch.nn, and torch.nn.functional as F.

Evaluation

The model is fine-tuned from pretrained ClimaX weights on ERA5 reanalysis data at 5.625-degree resolution and evaluated on three forecasting targets:

  • z500-3day: Geopotential height at 500 hPa, 3-day lead time
  • t850-5day: Temperature at 850 hPa, 5-day lead time
  • wind10m-7day: 10m wind speed, 7-day lead time

Metric: Latitude-weighted RMSE (lower is better). The metric accounts for the convergence of meridians at the poles by weighting errors by the cosine of latitude.

Code

custom_forecast.py
EditableRead-only
1"""Custom Weather Forecast Variable Aggregation Script
2Based on ClimaX (Nguyen et al., 2023), evaluated on ERA5 at 5.625 deg.
3
4The EDITABLE section contains the variable aggregation module that combines
5per-variable patch embeddings into a unified spatial representation.
6Everything else (ViT backbone, data loading, training loop) is FIXED.
7"""
8
9import math
10import os
11import time
12from functools import lru_cache
13
14import numpy as np
15import torch

Additional context files (read-only):

  • ClimaX/src/climax/arch.py
  • ClimaX/src/climax/parallelpatchembed.py
  • ClimaX/src/climax/utils/metrics.py

Results

ModelTypew rmse geopotential 500 z500-3day w rmse temperature 850 t850-5day w rmse 10m u component of wind wind10m-7day
cross_attentionbaseline245.1303.9744.418
learned_weighted_sumbaseline485.5984.5584.648
mean_poolingbaseline490.7654.5284.643
self_attentionbaseline325.4454.1224.491