Core Concepts and Technical Details
Overview
This reference covers TorchDrug's fundamental architecture, design principles, and technical implementation details.
Architecture Philosophy
Modular Design
TorchDrug separates concerns into distinct modules:
- Representation Models (models.py): Encode graphs into embeddings
- Task Definitions (tasks.py): Define learning objectives and evaluation
- Data Handling (data.py, datasets.py): Graph structures and datasets
- Core Components (core.py): Base classes and utilities
Benefits: - Reuse representations across tasks - Mix and match components - Easy experimentation and prototyping - Clear separation of concerns
Configurable System
All components inherit from core.Configurable:
- Serialize to configuration dictionaries
- Reconstruct from configurations
- Save and load complete pipelines
- Reproducible experiments
Core Components
core.Configurable
Base class for all TorchDrug components.
Key Methods:
- config_dict(): Serialize to dictionary
- load_config_dict(config): Load from dictionary
- save(file): Save to file
- load(file): Load from file
Example:
from torchdrug import core, models
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
# Save configuration
config = model.config_dict()
# {'class': 'GIN', 'input_dim': 10, 'hidden_dims': [256, 256], ...}
# Reconstruct model
model2 = core.Configurable.load_config_dict(config)
core.Registry
Decorator for registering models, tasks, and datasets.
Usage:
from torchdrug import core as core_td
@core_td.register("models.CustomModel")
class CustomModel(nn.Module, core_td.Configurable):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
def forward(self, graph, input, all_loss, metric):
# Model implementation
pass
Benefits: - Models automatically serializable - String-based model specification - Easy model lookup and instantiation
Data Structures
Graph
Core data structure representing molecular or protein graphs.
Attributes:
- num_node: Number of nodes
- num_edge: Number of edges
- node_feature: Node feature tensor [num_node, feature_dim]
- edge_feature: Edge feature tensor [num_edge, feature_dim]
- edge_list: Edge connectivity [num_edge, 2 or 3]
- num_relation: Number of edge types (for multi-relational)
Methods:
- node_mask(mask): Select subset of nodes
- edge_mask(mask): Select subset of edges
- undirected(): Make graph undirected
- directed(): Make graph directed
Batching: - Graphs batched into single disconnected graph - Automatic batching in DataLoader - Preserves node/edge indices per graph
Molecule (extends Graph)
Specialized graph for molecules.
Additional Attributes:
- atom_type: Atomic numbers
- bond_type: Bond types (single, double, triple, aromatic)
- formal_charge: Atomic formal charges
- explicit_hs: Explicit hydrogen counts
Methods:
- from_smiles(smiles): Create from SMILES string
- from_molecule(mol): Create from RDKit molecule
- to_smiles(): Convert to SMILES
- to_molecule(): Convert to RDKit molecule
- ion_to_molecule(): Neutralize charges
Example:
from torchdrug import data
# From SMILES
mol = data.Molecule.from_smiles("CCO")
# Atom features
print(mol.atom_type) # [6, 6, 8] (C, C, O)
print(mol.bond_type) # [1, 1] (single bonds)
Protein (extends Graph)
Specialized graph for proteins.
Additional Attributes:
- residue_type: Amino acid types
- atom_name: Atom names (CA, CB, etc.)
- atom_type: Atomic numbers
- residue_number: Residue numbering
- chain_id: Chain identifiers
Methods:
- from_pdb(pdb_file): Load from PDB file
- from_sequence(sequence): Create from sequence
- to_pdb(pdb_file): Save to PDB file
Graph Construction: - Nodes typically represent residues (not atoms) - Edges can be sequential, spatial (KNN), or contact-based - Configurable edge construction strategies
Example:
from torchdrug import data
# Load protein
protein = data.Protein.from_pdb("1a3x.pdb")
# Build graph with multiple edge types
graph = protein.residue_graph(
node_position="ca", # Use Cα positions
edge_types=["sequential", "radius"] # Sequential + spatial edges
)
PackedGraph
Efficient batching structure for heterogeneous graphs.
Purpose: - Batch graphs of different sizes - Single GPU memory allocation - Efficient parallel processing
Attributes:
- num_nodes: List of node counts per graph
- num_edges: List of edge counts per graph
- graph_ind: Graph index for each node
Use Cases: - Automatic in DataLoader - Custom batching strategies - Multi-graph operations
Model Interface
Forward Function Signature
All TorchDrug models follow a standardized interface:
def forward(self, graph, input, all_loss=None, metric=None):
"""
Args:
graph (Graph): Batch of graphs
input (Tensor): Node input features
all_loss (Tensor, optional): Accumulator for losses
metric (dict, optional): Dictionary for metrics
Returns:
dict: Output dictionary with representation keys
"""
# Model computation
output = self.layers(graph, input)
return {
"node_feature": output,
"graph_feature": graph_pooling(output)
}
Key Points:
- graph: Batched graph structure
- input: Node features [num_node, input_dim]
- all_loss: Accumulated loss (for multi-task)
- metric: Shared metric dictionary
- Returns dict with representation types
Essential Attributes
All models must define:
- input_dim: Expected input feature dimension
- output_dim: Output representation dimension
Purpose: - Automatic dimension checking - Compose models in pipelines - Error checking and validation
Example:
class CustomModel(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.input_dim = input_dim
self.output_dim = hidden_dim
# ... layers ...
Task Interface
Core Task Methods
All tasks implement these methods:
class CustomTask(tasks.Task):
def preprocess(self, train_set, valid_set, test_set):
"""Dataset-specific preprocessing (optional)"""
pass
def predict(self, batch):
"""Generate predictions for a batch"""
graph, label = batch
output = self.model(graph, graph.node_feature)
pred = self.mlp(output["graph_feature"])
return pred
def target(self, batch):
"""Extract ground truth labels"""
graph, label = batch
return label
def forward(self, batch):
"""Compute training loss"""
pred = self.predict(batch)
target = self.target(batch)
loss = self.criterion(pred, target)
return loss
def evaluate(self, pred, target):
"""Compute evaluation metrics"""
metrics = {}
metrics["auroc"] = compute_auroc(pred, target)
metrics["auprc"] = compute_auprc(pred, target)
return metrics
Task Components
Typical Task Structure: 1. Representation Model: Encodes graph to embeddings 2. Readout/Prediction Head: Maps embeddings to predictions 3. Loss Function: Training objective 4. Metrics: Evaluation measures
Example:
from torchdrug import tasks, models
# Representation model
model = models.GIN(input_dim=10, hidden_dims=[256, 256])
# Task wraps model with prediction head
task = tasks.PropertyPrediction(
model=model,
task=["task1", "task2"], # Multi-task
criterion="bce",
metric=["auroc", "auprc"],
num_mlp_layer=2
)
Training Workflow
Standard Training Loop
import torch
from torch.utils.data import DataLoader
from torchdrug import core, models, tasks, datasets
# 1. Load dataset
dataset = datasets.BBBP("~/datasets/")
train_set, valid_set, test_set = dataset.split()
# 2. Create data loaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=32)
# 3. Define model and task
model = models.GIN(input_dim=dataset.node_feature_dim,
hidden_dims=[256, 256, 256])
task = tasks.PropertyPrediction(model, task=dataset.tasks,
criterion="bce", metric=["auroc", "auprc"])
# 4. Setup optimizer
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
# 5. Training loop
for epoch in range(100):
# Train
task.train()
for batch in train_loader:
loss = task(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Validate
task.eval()
preds, targets = [], []
for batch in valid_loader:
pred = task.predict(batch)
target = task.target(batch)
preds.append(pred)
targets.append(target)
preds = torch.cat(preds)
targets = torch.cat(targets)
metrics = task.evaluate(preds, targets)
print(f"Epoch {epoch}: {metrics}")
PyTorch Lightning Integration
TorchDrug tasks are compatible with PyTorch Lightning:
import pytorch_lightning as pl
class LightningWrapper(pl.LightningModule):
def __init__(self, task):
super().__init__()
self.task = task
def training_step(self, batch, batch_idx):
loss = self.task(batch)
return loss
def validation_step(self, batch, batch_idx):
pred = self.task.predict(batch)
target = self.task.target(batch)
return {"pred": pred, "target": target}
def validation_epoch_end(self, outputs):
preds = torch.cat([o["pred"] for o in outputs])
targets = torch.cat([o["target"] for o in outputs])
metrics = self.task.evaluate(preds, targets)
self.log_dict(metrics)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
Loss Functions
Built-in Criteria
Classification:
- "bce": Binary cross-entropy
- "ce": Cross-entropy (multi-class)
Regression:
- "mse": Mean squared error
- "mae": Mean absolute error
Knowledge Graph:
- "bce": Binary classification of triples
- "ce": Cross-entropy ranking loss
- "margin": Margin-based ranking
Custom Loss
class CustomTask(tasks.Task):
def forward(self, batch):
pred = self.predict(batch)
target = self.target(batch)
# Custom loss computation
loss = custom_loss_function(pred, target)
return loss
Metrics
Common Metrics
Classification: - AUROC: Area under ROC curve - AUPRC: Area under precision-recall curve - Accuracy: Overall accuracy - F1: Harmonic mean of precision and recall
Regression: - MAE: Mean absolute error - RMSE: Root mean squared error - R²: Coefficient of determination - Pearson: Pearson correlation
Ranking (Knowledge Graph): - MR: Mean rank - MRR: Mean reciprocal rank - Hits@K: Percentage in top K
Multi-Task Metrics
For multi-label or multi-task: - Metrics computed per task - Macro-average across tasks - Can weight by task importance
Data Transforms
Molecule Transforms
from torchdrug import transforms
# Add virtual node connected to all atoms
transform1 = transforms.VirtualNode()
# Add virtual edges
transform2 = transforms.VirtualEdge()
# Compose transforms
transform = transforms.Compose([transform1, transform2])
dataset = datasets.BBBP("~/datasets/", transform=transform)
Protein Transforms
# Add edges based on spatial proximity
transform = transforms.TruncateProtein(max_length=500)
dataset = datasets.Fold("~/datasets/", transform=transform)
Best Practices
Memory Efficiency
- Gradient Accumulation: For large models
- Mixed Precision: FP16 training
- Batch Size Tuning: Balance speed and memory
- Data Loading: Multiple workers for I/O
Reproducibility
- Set Seeds: PyTorch, NumPy, Python random
- Deterministic Operations:
torch.use_deterministic_algorithms(True) - Save Configurations: Use
core.Configurable - Version Control: Track TorchDrug version
Debugging
- Check Dimensions: Verify
input_dimandoutput_dim - Validate Batching: Print batch statistics
- Monitor Gradients: Watch for vanishing/exploding
- Overfit Small Batch: Ensure model capacity
Performance Optimization
- GPU Utilization: Monitor with
nvidia-smi - Profile Code: Use PyTorch profiler
- Optimize Data Loading: Prefetch, pin memory
- Compile Models: Use TorchScript if possible
Advanced Topics
Multi-Task Learning
Train single model on multiple related tasks:
task = tasks.PropertyPrediction(
model,
task=["task1", "task2", "task3"],
criterion="bce",
metric=["auroc"],
task_weight=[1.0, 1.0, 2.0] # Weight task 3 more
)
Transfer Learning
- Pre-train on large dataset
- Fine-tune on target dataset
- Optionally freeze early layers
Self-Supervised Pre-training
Use pre-training tasks:
- AttributeMasking: Mask node features
- EdgePrediction: Predict edge existence
- ContextPrediction: Contrastive learning
Custom Layers
Extend TorchDrug with custom GNN layers:
from torchdrug import layers
class CustomConv(layers.MessagePassingBase):
def message(self, graph, input):
# Custom message function
pass
def aggregate(self, graph, message):
# Custom aggregation
pass
def combine(self, input, update):
# Custom combination
pass
Common Pitfalls
- Forgetting
input_dimandoutput_dim: Models won't compose - Not Batching Properly: Use PackedGraph for variable-sized graphs
- Data Leakage: Be careful with scaffold splits and pre-training
- Ignoring Edge Features: Bonds/spatial info can be critical
- Wrong Evaluation Metrics: Match metrics to task (AUROC for imbalanced)
- Insufficient Regularization: Use dropout, weight decay, early stopping
- Not Validating Chemistry: Generated molecules must be valid
- Overfitting Small Datasets: Use pre-training or simpler models