graph-generation
Description
Graph Generation Model Design
Research Question
Design a novel generative model architecture for unconditional graph generation that produces realistic graph structures matching the statistical properties of a training distribution.
Background
Graph generation is a fundamental problem in machine learning with applications in drug discovery, social network modeling, and materials science. The goal is to learn the distribution of a set of graphs and generate new graphs that are statistically indistinguishable from the training data.
Existing approaches span several paradigms:
- Autoregressive: GraphRNN (You et al., 2018) generates graphs node-by-node with RNNs, GRAN (Liao et al., 2019) uses graph attention for one-shot block generation
- VAE-based: GraphVAE (Simonovsky & Komodakis, 2018) encodes graphs into latent space and decodes adjacency matrices
- Flow-based: MoFlow (Zang & Wang, 2020) uses normalizing flows for invertible graph generation
- Diffusion/Score-based: GDSS (Jo et al., 2022) applies score-based SDEs to graph generation, DiGress (Vignac et al., 2023) uses discrete denoising diffusion
Evaluation uses Maximum Mean Discrepancy (MMD) between graph statistics (degree distribution, clustering coefficients, orbit counts) of generated and reference graphs. Lower MMD indicates the generated graphs better match the training distribution.
What You Can Modify
The GraphGenerator class (lines 341-485) in custom_graphgen.py. This class must implement:
__init__(self, max_nodes, **kwargs): Initialize model parameters and optimizertrain_step(self, adj, node_counts) -> dict: Perform one training step on a batch of adjacency matrices. Must return a dict containing at least'loss'(float).sample(self, n_samples, device) -> (adj, node_counts): Generate graphs. Returns:adj: Tensor [n_samples, max_nodes, max_nodes] — binary symmetric adjacency matrices (no self-loops)node_counts: Tensor [n_samples] — number of nodes per graph (minimum 2)
The input adjacency matrices are binary, symmetric, zero-diagonal, and padded to max_nodes.
You may define helper classes/functions within the editable region. The model's optimizer should be created inside __init__ and updated inside train_step.
Available imports (in the FIXED section): torch, torch.nn, torch.nn.functional, torch.optim, numpy, math.
Evaluation
- Metrics (all lower is better):
mmd_degree: MMD of degree distributionsmmd_clustering: MMD of clustering coefficient distributionsmmd_orbit: MMD of 4-orbit count distributionsmmd_avg: Average of the three MMD metrics
- Datasets:
community_small: 100 synthetic 2-community graphs (12-20 nodes)ego_small: 200 ego graphs from Citeseer (4-18 nodes)enzymes: 587 protein structure graphs from BRENDA (10-125 nodes)
- Training: 3000 epochs, batch size 32, single GPU
- Seeds: Multiple seeds for statistical reliability
Code
1"""Graph Generation Benchmark.23Train a generative model on small graph datasets and evaluate using MMD statistics.45FIXED: Dataset loading/generation, graph statistics computation, MMD evaluation,6training loop orchestration, argument parsing.7EDITABLE: GraphGenerator class (the generative model).89Usage:10python pytorch-geometric/custom_graphgen.py --dataset community_small --seed 4211"""1213import argparse14import math15import os
Results
| Model | Type | mmd avg community small ↓ | mmd avg ego small ↓ | mmd avg enzymes ↓ | mmd clustering community small ↓ | mmd clustering ego small ↓ | mmd clustering enzymes ↓ | mmd degree community small ↓ | mmd degree ego small ↓ | mmd degree enzymes ↓ | mmd orbit community small ↓ | mmd orbit ego small ↓ | mmd orbit enzymes ↓ |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| digress | baseline | 0.060 | 0.230 | 0.300 | 0.105 | 0.449 | 0.533 | 0.075 | 0.240 | 0.368 | 0.000 | 0.000 | 0.001 |
| gdss | baseline | 0.110 | 0.179 | 0.179 | 0.113 | 0.173 | 0.248 | 0.216 | 0.363 | 0.287 | 0.000 | 0.000 | 0.001 |
| gdss | baseline | 0.125 | 0.176 | 0.192 | 0.116 | 0.173 | 0.284 | 0.258 | 0.356 | 0.289 | 0.000 | 0.000 | 0.001 |
| gdss | baseline | 0.075 | 10.000 | 3.527 | 0.130 | 10.000 | 3.735 | 0.095 | 10.000 | 3.512 | 0.000 | 10.000 | 3.334 |
| gdss | baseline | 0.073 | 6.848 | 10.000 | 0.136 | 7.056 | 10.000 | 0.082 | 6.823 | 10.000 | 0.000 | 6.667 | 10.000 |
| gdss | baseline | 0.116 | 0.178 | 0.185 | 0.114 | 0.174 | 0.265 | 0.234 | 0.361 | 0.289 | 0.000 | 0.000 | 0.001 |
| gran | baseline | 0.045 | 0.141 | 0.227 | 0.119 | 0.179 | 0.387 | 0.015 | 0.243 | 0.293 | 0.000 | 0.000 | 0.001 |
| graphvae | baseline | 0.049 | 0.063 | 0.327 | 0.103 | 0.101 | 0.613 | 0.043 | 0.088 | 0.369 | 0.000 | 0.000 | 0.000 |
| moflow | baseline | 0.132 | 0.171 | 0.186 | 0.114 | 0.172 | 0.275 | 0.281 | 0.342 | 0.284 | 0.000 | 0.000 | 0.001 |