ai4bio-protein-inverse-folding

AI for BiologyProteinInvBenchrigorous codebase

Description

Task: Protein Inverse Folding — Structure Encoder Design

Research Question

Design a novel GNN-based structure encoder for protein inverse folding: given backbone atom coordinates (N, CA, C, O), predict the amino acid sequence that would fold into that structure.

Background

Protein inverse folding (also called computational protein design or fixed-backbone design) is a central problem in structural biology. Given a protein backbone structure, the goal is to predict the amino acid sequence most likely to fold into that structure. This is the inverse of the protein folding problem (predicting structure from sequence).

The key challenge is encoding the 3D protein backbone graph into rich per-residue embeddings that capture local geometry, long-range interactions, and structural motifs. Existing approaches differ primarily in how they encode the protein structure:

  • StructGNN: Uses message-passing on a k-nearest-neighbor graph with distance and angular features.
  • GVP (Geometric Vector Perceptron): Uses SE(3)-equivariant message passing with scalar and vector features.
  • ProteinMPNN: Uses a message-passing encoder with edge updates, followed by an autoregressive decoder with masking.
  • PiFold: Uses a specialized encoder with virtual atoms, multi-scale distance features, and dihedral features, plus a non-autoregressive CNN decoder.

The structure encoder is the critical component: all methods share the same input format (backbone coordinates) and output format (amino acid log-probabilities), but differ in how they transform structure into sequence-informative representations.

What to Implement

Modify the editable section of custom_invfold.py (lines 86-238). You must implement:

  1. StructureEncoder: A GNN module that takes backbone coordinates X (B, L, 4, 3) and mask (B, L), and produces per-residue embeddings h_V (B, L, hidden_dim).
  2. InverseFoldingModel: Wraps the encoder with a decoder head that outputs amino acid log-probabilities (B, L, 20).

Interface

class StructureEncoder(nn.Module):
    def __init__(self, hidden_dim=128, ...):
        ...
    def forward(self, X, mask):
        """
        X: (B, L, 4, 3) backbone coordinates [N, CA, C, O]
        mask: (B, L) binary mask (1 for valid residues, 0 for padding)
        Returns: h_V (B, L, hidden_dim) per-residue embeddings
        """
        ...

class InverseFoldingModel(nn.Module):
    def __init__(self, hidden_dim=128, ...):
        ...
    def forward(self, X, mask):
        """
        Returns: log_probs (B, L, 20) amino acid log-probabilities
        """
        ...

Helper functions available in the FIXED section above the editable region:

  • _rbf(D, ...): Radial basis function encoding of distances
  • _dihedrals(X): Backbone dihedral angles (phi, psi, omega) as sin/cos features
  • _orientations(X): Local coordinate frame (forward + binormal vectors)
  • knn_graph(X_ca, mask, k): Build k-nearest neighbor graph from CA coordinates

Evaluation

The model is evaluated on three benchmarks:

  • CATH 4.2: Standard protein design benchmark (single-chain, ~18k train / 608 test).
  • CATH 4.3: Updated CATH with more diverse structures (~21k train / 1120 test).
  • TS50: 50 de novo designed proteins for out-of-distribution generalization (trained on CATH 4.2).

Primary metric: Recovery (fraction of correctly predicted amino acids, higher is better). Secondary metric: Perplexity (exponential of per-residue cross-entropy loss, lower is better).

Code

custom_invfold.py
EditableRead-only
1"""
2Protein Inverse Folding — Self-contained template.
3Given backbone structure (N, CA, C, O coordinates), predict amino acid sequence.
4
5Structure:
6 Lines 1-75: FIXED — Imports, constants, data loading, featurization
7 Lines 76-230: EDITABLE — StructureEncoder + decoder (starter: simple MPNN)
8 Lines 231+: FIXED — Training loop, evaluation, metrics
9"""
10import os
11import sys
12import json
13import math
14import time
15import argparse

Results

ModelTyperecovery CATH4.2 perplexity CATH4.2 recovery CATH4.3 perplexity CATH4.3 recovery TS50 perplexity TS50
gvpbaseline0.4325.8730.4565.3770.4845.039
pifoldbaseline0.5394.1520.4895.0240.5684.369
proteinmpnnbaseline0.4605.4510.5114.5350.5384.187