ttt-memory
Description
Test-Time Memory Management Strategy for Titans
Objective
Design a better test-time memory management strategy for the Titans neural memory module. The memory is a small MLP that learns at inference time — for each input token, it decides what to store, measures how surprising the input is, and updates its parameters accordingly. Your goal is to minimize test MSE on time series forecasting tasks by improving how the memory learns.
Background
Titans (Learning to Memorize at Test Time) uses a neural memory module (NeuralMemory) within the MAC architecture. The memory management strategy has four components:
- init_memory_state() — Initialize projections and optimizer state (K/V linear layers, learning rates, momentum buffers, etc.)
- compute_keys_values() — Transform input into key-value pairs that determine what gets stored in memory
- compute_surprise() — Measure how novel/surprising the current input is (the loss signal that drives updates)
- apply_update() — Update the memory's parameters based on gradients (the optimization rule)
The default uses MSE-based surprise with SGD+momentum and weight decay. You can modify any or all of these components.
What You Can Modify
Edit the four functions in custom_memory.py (lines 14-98). You may:
- Change the key/value computation (different projections, normalization, etc.)
- Design a new surprise metric (contrastive loss, cosine similarity, information-theoretic measures, etc.)
- Implement a different optimizer (Adam, RMSprop, natural gradient, etc.)
- Add gating mechanisms (only update when surprise exceeds a threshold)
- Change hyperparameters (learning rate, momentum, weight decay)
- Add auxiliary state to the memory object in init_memory_state()
Evaluation
The strategy is evaluated on three benchmarks (lower MSE is better):
- sinwave: Synthetic sine wave forecasting (context_window=16, 20 epochs)
- weather: Real-world weather data forecasting (context_window=16, 20 epochs)
- sinwave-long: Longer-context sine wave forecasting (context_window=32, 30 epochs) — tests memory capacity
Constraints
- You cannot modify files outside of
custom_memory.py - The four function signatures must remain the same
- The memory module must remain compatible with the MACTitan architecture
- No new file creation allowed
Code
1"""Custom test-time memory management strategy.23Four functions controlling how the neural memory learns at inference time:41. init_memory_state() — set up projections and optimizer state52. compute_keys_values() — what to store63. compute_surprise() — when/how much to update74. apply_update() — how to update parameters8"""9# ── EDITABLE REGION START ──10import torch11from torch import nn12from torch.nn.functional import normalize1314def init_memory_state(memory, emb_dim):15"""Initialize key-value projections and any optimizer state.
Additional context files (read-only):
titans-lmm/neural_memory.pytitans-lmm/titans.pytitans-lmm/train_utils.py
Results
| Model | Type | test mse sinwave ↓ | test mse weather ↓ | test mse sinwave-long ↓ |
|---|---|---|---|---|
| adam_gated | baseline | 0.019 | 0.052 | 0.032 |
| contrastive_ema | baseline | 0.018 | 0.050 | 0.033 |
| titans_default | baseline | 0.019 | 0.051 | 0.032 |