graph-generation

Graph Learningpytorch-geometricrigorous codebase

Description

Graph Generation Model Design

Research Question

Design a novel generative model architecture for unconditional graph generation that produces realistic graph structures matching the statistical properties of a training distribution.

Background

Graph generation is a fundamental problem in machine learning with applications in drug discovery, social network modeling, and materials science. The goal is to learn the distribution of a set of graphs and generate new graphs that are statistically indistinguishable from the training data.

Existing approaches span several paradigms:

  • Autoregressive: GraphRNN (You et al., 2018) generates graphs node-by-node with RNNs, GRAN (Liao et al., 2019) uses graph attention for one-shot block generation
  • VAE-based: GraphVAE (Simonovsky & Komodakis, 2018) encodes graphs into latent space and decodes adjacency matrices
  • Flow-based: MoFlow (Zang & Wang, 2020) uses normalizing flows for invertible graph generation
  • Diffusion/Score-based: GDSS (Jo et al., 2022) applies score-based SDEs to graph generation, DiGress (Vignac et al., 2023) uses discrete denoising diffusion

Evaluation uses Maximum Mean Discrepancy (MMD) between graph statistics (degree distribution, clustering coefficients, orbit counts) of generated and reference graphs. Lower MMD indicates the generated graphs better match the training distribution.

What You Can Modify

The GraphGenerator class (lines 341-485) in custom_graphgen.py. This class must implement:

  1. __init__(self, max_nodes, **kwargs): Initialize model parameters and optimizer
  2. train_step(self, adj, node_counts) -> dict: Perform one training step on a batch of adjacency matrices. Must return a dict containing at least 'loss' (float).
  3. sample(self, n_samples, device) -> (adj, node_counts): Generate graphs. Returns:
    • adj: Tensor [n_samples, max_nodes, max_nodes] — binary symmetric adjacency matrices (no self-loops)
    • node_counts: Tensor [n_samples] — number of nodes per graph (minimum 2)

The input adjacency matrices are binary, symmetric, zero-diagonal, and padded to max_nodes.

You may define helper classes/functions within the editable region. The model's optimizer should be created inside __init__ and updated inside train_step.

Available imports (in the FIXED section): torch, torch.nn, torch.nn.functional, torch.optim, numpy, math.

Evaluation

  • Metrics (all lower is better):
    • mmd_degree: MMD of degree distributions
    • mmd_clustering: MMD of clustering coefficient distributions
    • mmd_orbit: MMD of 4-orbit count distributions
    • mmd_avg: Average of the three MMD metrics
  • Datasets:
    • community_small: 100 synthetic 2-community graphs (12-20 nodes)
    • ego_small: 200 ego graphs from Citeseer (4-18 nodes)
    • enzymes: 587 protein structure graphs from BRENDA (10-125 nodes)
  • Training: 3000 epochs, batch size 32, single GPU
  • Seeds: Multiple seeds for statistical reliability

Code

custom_graphgen.py
EditableRead-only
1"""Graph Generation Benchmark.
2
3Train a generative model on small graph datasets and evaluate using MMD statistics.
4
5FIXED: Dataset loading/generation, graph statistics computation, MMD evaluation,
6 training loop orchestration, argument parsing.
7EDITABLE: GraphGenerator class (the generative model).
8
9Usage:
10 python pytorch-geometric/custom_graphgen.py --dataset community_small --seed 42
11"""
12
13import argparse
14import math
15import os

Results

ModelTypemmd avg community small mmd avg ego small mmd avg enzymes mmd clustering community small mmd clustering ego small mmd clustering enzymes mmd degree community small mmd degree ego small mmd degree enzymes mmd orbit community small mmd orbit ego small mmd orbit enzymes
digressbaseline0.0600.2300.3000.1050.4490.5330.0750.2400.3680.0000.0000.001
gdssbaseline0.1100.1790.1790.1130.1730.2480.2160.3630.2870.0000.0000.001
gdssbaseline0.1250.1760.1920.1160.1730.2840.2580.3560.2890.0000.0000.001
gdssbaseline0.07510.0003.5270.13010.0003.7350.09510.0003.5120.00010.0003.334
gdssbaseline0.0736.84810.0000.1367.05610.0000.0826.82310.0000.0006.66710.000
gdssbaseline0.1160.1780.1850.1140.1740.2650.2340.3610.2890.0000.0000.001
granbaseline0.0450.1410.2270.1190.1790.3870.0150.2430.2930.0000.0000.001
graphvaebaseline0.0490.0630.3270.1030.1010.6130.0430.0880.3690.0000.0000.000
moflowbaseline0.1320.1710.1860.1140.1720.2750.2810.3420.2840.0000.0000.001