Models and Architectures
Overview
TorchDrug provides a comprehensive collection of pre-built model architectures for various graph-based learning tasks. This reference catalogs all available models with their characteristics, use cases, and implementation details.
Graph Neural Networks
GCN (Graph Convolutional Network)
Type: Spatial message passing Paper: Semi-Supervised Classification with Graph Convolutional Networks (Kipf & Welling, 2017)
Characteristics: - Simple and efficient aggregation - Normalized adjacency matrix convolution - Works well for homophilic graphs - Good baseline for many tasks
Best For: - Initial experiments and baselines - When computational efficiency is important - Graphs with clear local structure
Parameters:
- input_dim: Node feature dimension
- hidden_dims: List of hidden layer dimensions
- edge_input_dim: Edge feature dimension (optional)
- batch_norm: Apply batch normalization
- activation: Activation function (relu, elu, etc.)
- dropout: Dropout rate
Use Cases: - Molecular property prediction - Citation network classification - Social network analysis
GAT (Graph Attention Network)
Type: Attention-based message passing Paper: Graph Attention Networks (Veličković et al., 2018)
Characteristics: - Learns attention weights for neighbors - Different importance for different neighbors - Multi-head attention for robustness - Handles varying node degrees naturally
Best For: - When neighbor importance varies - Heterogeneous graphs - Interpretable predictions
Parameters:
- input_dim, hidden_dims: Standard dimensions
- num_heads: Number of attention heads
- negative_slope: LeakyReLU slope
- concat: Concatenate or average multi-head outputs
Use Cases: - Protein-protein interaction prediction - Molecule generation with attention to reactive sites - Knowledge graph reasoning with relation importance
GIN (Graph Isomorphism Network)
Type: Maximally powerful message passing Paper: How Powerful are Graph Neural Networks? (Xu et al., 2019)
Characteristics: - Theoretically most expressive GNN architecture - Injective aggregation function - Can distinguish graph structures GCN cannot - Often best performance on molecular tasks
Best For: - Molecular property prediction (state-of-the-art) - Tasks requiring structural discrimination - Graph classification
Parameters:
- input_dim, hidden_dims: Standard dimensions
- edge_input_dim: Include edge features
- batch_norm: Typically use true
- readout: Graph pooling ("sum", "mean", "max")
- eps: Learnable or fixed epsilon
Use Cases: - Drug property prediction (BBBP, HIV, etc.) - Molecular generation - Reaction prediction
RGCN (Relational Graph Convolutional Network)
Type: Multi-relational message passing Paper: Modeling Relational Data with Graph Convolutional Networks (Schlichtkrull et al., 2018)
Characteristics: - Handles multiple edge/relation types - Relation-specific weight matrices - Basis decomposition for parameter efficiency - Essential for knowledge graphs
Best For: - Knowledge graph reasoning - Heterogeneous molecular graphs - Multi-relational data
Parameters:
- num_relation: Number of relation types
- hidden_dims: Layer dimensions
- num_bases: Basis decomposition (reduce parameters)
Use Cases: - Knowledge graph completion - Retrosynthesis (different bond types) - Protein interaction networks
MPNN (Message Passing Neural Network)
Type: General message passing framework Paper: Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017)
Characteristics: - Flexible message and update functions - Edge features in message computation - GRU updates for node hidden states - Set2Set readout for graph representation
Best For: - Quantum chemistry predictions - Tasks with important edge information - When node states evolve over multiple iterations
Parameters:
- input_dim, hidden_dim: Feature dimensions
- edge_input_dim: Edge feature dimension
- num_layer: Message passing iterations
- num_mlp_layer: MLP layers in message function
Use Cases: - QM9 quantum property prediction - Molecular dynamics - 3D conformation-aware tasks
SchNet (Continuous-Filter Convolutional Network)
Type: 3D geometry-aware convolution Paper: SchNet: A continuous-filter convolutional neural network (Schütt et al., 2017)
Characteristics: - Operates on 3D atomic coordinates - Continuous filter convolutions - Rotation and translation invariant - Excellent for quantum chemistry
Best For: - 3D molecular structure tasks - Quantum property prediction - Protein structure analysis - Energy and force prediction
Parameters:
- input_dim: Atom features
- hidden_dims: Layer dimensions
- num_gaussian: RBF basis functions for distances
- cutoff: Interaction cutoff distance
Use Cases: - QM9 property prediction - Molecular dynamics simulations - Protein-ligand binding with structures - Crystal property prediction
ChebNet (Chebyshev Spectral CNN)
Type: Spectral convolution Paper: Convolutional Neural Networks on Graphs (Defferrard et al., 2016)
Characteristics: - Spectral graph convolution - Chebyshev polynomial approximation - Captures global graph structure - Computationally efficient
Best For: - Tasks requiring global information - When graph Laplacian is informative - Theoretical analysis
Parameters:
- input_dim, hidden_dims: Dimensions
- num_cheb: Order of Chebyshev polynomial
Use Cases: - Citation network classification - Brain network analysis - Signal processing on graphs
NFP (Neural Fingerprint)
Type: Molecular fingerprint learning Paper: Convolutional Networks on Graphs for Learning Molecular Fingerprints (Duvenaud et al., 2015)
Characteristics: - Learns differentiable molecular fingerprints - Alternative to hand-crafted fingerprints (ECFP) - Circular convolutions like ECFP - Interpretable learned features
Best For: - Molecular similarity learning - Property prediction with limited data - When interpretability is important
Parameters:
- input_dim, output_dim: Feature dimensions
- hidden_dims: Layer dimensions
- num_layer: Circular convolution depth
Use Cases: - Virtual screening - Molecular similarity search - QSAR modeling
Protein-Specific Models
GearNet (Geometry-Aware Relational Graph Network)
Type: Protein structure encoder Paper: Protein Representation Learning by Geometric Structure Pretraining (Zhang et al., 2023)
Characteristics: - Incorporates 3D geometric information - Multiple edge types (sequential, spatial, KNN) - Designed specifically for proteins - State-of-the-art on protein tasks
Best For: - Protein structure prediction - Protein function prediction - Protein-protein interaction - Any task with protein 3D structures
Parameters:
- input_dim: Residue features
- hidden_dims: Layer dimensions
- num_relation: Edge types (sequence, radius, KNN)
- edge_input_dim: Geometric features (distances, angles)
- batch_norm: Typically true
Use Cases: - Enzyme function prediction (EnzymeCommission) - Protein fold recognition - Contact prediction - Binding site identification
ESM (Evolutionary Scale Modeling)
Type: Protein language model (transformer) Paper: Biological structure and function emerge from scaling unsupervised learning (Rives et al., 2021)
Characteristics: - Pre-trained on 250M+ protein sequences - Captures evolutionary and structural information - Transformer architecture - Transfer learning for downstream tasks
Best For: - Any sequence-based protein task - When no structure available - Transfer learning with limited data
Variants: - ESM-1b: 650M parameters - ESM-2: Multiple sizes (8M to 15B parameters)
Use Cases: - Protein function prediction - Variant effect prediction - Protein design - Structure prediction (ESMFold)
ProteinBERT
Type: Masked language model for proteins
Characteristics: - BERT-style pre-training - Masked amino acid prediction - Bidirectional context - Good for sequence-based tasks
Use Cases: - Function annotation - Subcellular localization - Stability prediction
ProteinCNN / ProteinResNet
Type: Convolutional networks for sequences
Characteristics: - 1D convolutions on sequences - Local pattern recognition - Faster than transformers - Good for motif detection
Use Cases: - Binding site prediction - Secondary structure prediction - Domain identification
ProteinLSTM
Type: Recurrent network for sequences
Characteristics: - Bidirectional LSTM - Captures long-range dependencies - Sequential processing - Good baseline for sequence tasks
Use Cases: - Order prediction - Sequential annotation - Time-series protein data
Knowledge Graph Models
TransE (Translation Embedding)
Type: Translation-based embedding Paper: Translating Embeddings for Modeling Multi-relational Data (Bordes et al., 2013)
Characteristics: - h + r ≈ t (head + relation ≈ tail) - Simple and interpretable - Works well for 1-to-1 relations - Memory efficient
Best For: - Large knowledge graphs - Initial experiments - Interpretable embeddings
Parameters:
- num_entity, num_relation: Graph size
- embedding_dim: Embedding dimensions (typically 50-500)
RotatE (Rotation Embedding)
Type: Rotation in complex space Paper: RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space (Sun et al., 2019)
Characteristics: - Relations as rotations in complex space - Handles symmetric, antisymmetric, inverse, composition - State-of-the-art on many benchmarks
Best For: - Most knowledge graph tasks - Complex relation patterns - When accuracy is critical
Parameters:
- num_entity, num_relation: Graph size
- embedding_dim: Must be even (complex embeddings)
- max_score: Score clipping value
DistMult
Type: Bilinear model
Characteristics: - Symmetric relation modeling - Fast and efficient - Cannot model antisymmetric relations
Best For: - Symmetric relations (e.g., "similar to") - When speed is critical - Large-scale graphs
ComplEx
Type: Complex-valued embeddings
Characteristics: - Handles asymmetric and symmetric relations - Better than DistMult for most graphs - Good balance of expressiveness and efficiency
Best For: - General knowledge graph completion - Mixed relation types - When RotatE is too complex
SimplE
Type: Enhanced embedding model
Characteristics: - Two embeddings per entity (canonical + inverse) - Fully expressive - Slightly more parameters than ComplEx
Best For: - When full expressiveness needed - Inverse relations are important
Generative Models
GraphAutoregressiveFlow
Type: Normalizing flow for molecules
Characteristics: - Exact likelihood computation - Invertible transformations - Stable training (no adversarial) - Conditional generation support
Best For: - Molecular generation - Density estimation - Interpolation between molecules
Parameters:
- input_dim: Atom features
- hidden_dims: Coupling layers
- num_flow: Number of flow transformations
Use Cases: - De novo drug design - Chemical space exploration - Property-targeted generation
Pre-training Models
InfoGraph
Type: Contrastive learning
Characteristics: - Maximizes mutual information - Graph-level and node-level contrast - Unsupervised pre-training - Good for small datasets
Use Cases: - Pre-train molecular encoders - Few-shot learning - Transfer learning
MultiviewContrast
Type: Multi-view contrastive learning for proteins
Characteristics: - Contrasts different views of proteins - Geometric pre-training - Uses 3D structure information - Excellent for protein models
Use Cases: - Pre-train GearNet on protein structures - Transfer to property prediction - Limited labeled data scenarios
Model Selection Guide
By Task Type
Molecular Property Prediction: 1. GIN (first choice) 2. GAT (interpretability) 3. SchNet (3D available)
Protein Tasks: 1. ESM (sequence only) 2. GearNet (structure available) 3. ProteinBERT (sequence, lighter than ESM)
Knowledge Graphs: 1. RotatE (best performance) 2. ComplEx (good balance) 3. TransE (large graphs, efficiency)
Molecular Generation: 1. GraphAutoregressiveFlow (exact likelihood) 2. GCPN with GIN backbone (property optimization)
Retrosynthesis: 1. GIN (synthon completion) 2. RGCN (center identification with bond types)
By Dataset Size
Small (< 1k): - Use pre-trained models (ESM for proteins) - Simpler architectures (GCN, ProteinCNN) - Heavy regularization
Medium (1k-100k): - GIN for molecules - GAT for interpretability - Standard training
Large (> 100k): - Any model works - Deeper architectures - Can train from scratch
By Computational Budget
Low: - GCN (simplest) - DistMult (KG) - ProteinLSTM
Medium: - GIN - GAT - ComplEx
High: - ESM (large) - SchNet (3D) - RotatE with high dim
Implementation Tips
- Start Simple: Begin with GCN or GIN baseline
- Use Pre-trained: ESM for proteins, InfoGraph for molecules
- Tune Depth: 3-5 layers typically sufficient
- Batch Normalization: Usually helps (except KG embeddings)
- Residual Connections: Important for deep networks
- Readout Function: "mean" usually works well
- Edge Features: Include when available (bonds, distances)
- Regularization: Dropout, weight decay, early stopping