Knowledge Graph Reasoning
Overview
Knowledge graphs represent structured information as entities and relations in a graph format. TorchDrug provides comprehensive support for knowledge graph completion (link prediction) using embedding-based models and neural reasoning approaches.
Available Datasets
General Knowledge Graphs
FB15k (Freebase subset): - 14,951 entities - 1,345 relation types - 592,213 triples - General world knowledge from Freebase
FB15k-237: - 14,541 entities - 237 relation types - 310,116 triples - Filtered version removing inverse relations - More challenging benchmark
WN18 (WordNet): - 40,943 entities (word senses) - 18 relation types (lexical relations) - 151,442 triples - Linguistic knowledge graph
WN18RR: - 40,943 entities - 11 relation types - 93,003 triples - Filtered WordNet removing easy inverse patterns
Biomedical Knowledge Graphs
Hetionet: - 45,158 entities (genes, compounds, diseases, pathways, etc.) - 24 relation types (treats, causes, binds, etc.) - 2,250,197 edges - Integrates 29 public biomedical databases - Designed for drug repurposing and disease understanding
Task: KnowledgeGraphCompletion
The primary task for knowledge graphs is link prediction - given a head entity and relation, predict the tail entity (or vice versa).
Task Modes
Head Prediction: - Given (?, relation, tail), predict head entity - "What can cause Disease X?"
Tail Prediction: - Given (head, relation, ?), predict tail entity - "What diseases does Gene X cause?"
Both: - Predict both head and tail - Standard evaluation protocol
Evaluation Metrics
Ranking Metrics: - Mean Rank (MR): Average rank of correct entity - Mean Reciprocal Rank (MRR): Average of 1/rank - Hits@K: Percentage of correct entities in top K predictions - Typically reported for K=1, 3, 10
Filtered vs Raw: - Filtered: Remove other known true triples from ranking - Raw: Rank among all possible entities - Filtered is standard for evaluation
Embedding Models
Translational Models
TransE (Translation Embedding): - Represents relations as translations in embedding space - h + r ≈ t (head + relation ≈ tail) - Simple and effective baseline - Works well for 1-to-1 relations - Struggles with N-to-N relations
RotatE (Rotation Embedding): - Relations as rotations in complex space - Better handles symmetric and inverse relations - State-of-the-art on many benchmarks - Can model composition patterns
Semantic Matching Models
DistMult: - Bilinear scoring function - Handles symmetric relations naturally - Cannot model asymmetric relations - Fast and memory efficient
ComplEx: - Complex-valued embeddings - Models asymmetric and inverse relations - Better than DistMult for most graphs - Balances expressiveness and efficiency
SimplE: - Extends DistMult with inverse relations - Fully expressive (can represent any relation pattern) - Two embeddings per entity (canonical and inverse)
Neural Logic Models
NeuralLP (Neural Logic Programming): - Learns logical rules through differentiable operations - Interprets predictions via learned rules - Good for sparse knowledge graphs - Computationally more expensive
KBGAT (Knowledge Base Graph Attention): - Graph attention networks for KG completion - Learns entity representations from neighborhood - Handles unseen entities through inductive learning - Better for incomplete graphs
Training Workflow
Basic Pipeline
from torchdrug import datasets, models, tasks, core
# Load dataset
dataset = datasets.FB15k237("~/kg-datasets/")
# Define model
model = models.RotatE(
num_entity=dataset.num_entity,
num_relation=dataset.num_relation,
embedding_dim=2000,
max_score=9
)
# Define task
task = tasks.KnowledgeGraphCompletion(
model,
num_negative=128,
adversarial_temperature=2,
criterion="bce"
)
# Train with PyTorch Lightning or custom loop
Negative Sampling
Strategies: - Uniform: Sample entities uniformly at random - Self-Adversarial: Weight samples by current model's scores - Type-Constrained: Sample only valid entity types for relation
Parameters:
- num_negative: Number of negative samples per positive triple
- adversarial_temperature: Temperature for self-adversarial weighting
- Higher temperature = more focus on hard negatives
Loss Functions
Binary Cross-Entropy (BCE): - Treats each triple independently - Balanced classification between positive and negative
Margin Loss:
- Ensures positive scores higher than negative by margin
- max(0, margin + score_neg - score_pos)
Logistic Loss: - Smooth version of margin loss - Better gradient properties
Model Selection Guide
By Relation Patterns
1-to-1 Relations: - TransE works well - Any model will likely succeed
1-to-N Relations: - DistMult, ComplEx, SimplE - Avoid TransE
N-to-1 Relations: - DistMult, ComplEx, SimplE - Avoid TransE
N-to-N Relations: - ComplEx, SimplE, RotatE - Most challenging pattern
Symmetric Relations: - DistMult, ComplEx - RotatE with proper initialization
Antisymmetric Relations: - ComplEx, SimplE, RotatE - Avoid DistMult
Inverse Relations: - ComplEx, SimplE, RotatE - Important for bidirectional reasoning
Composition: - RotatE (best) - TransE (reasonable) - Captures multi-hop paths
By Dataset Characteristics
Small Graphs (< 50k entities): - ComplEx or SimplE - Lower embedding dimensions (200-500)
Large Graphs (> 100k entities): - DistMult for efficiency - RotatE for accuracy - Higher dimensions (500-2000)
Sparse Graphs: - NeuralLP (learns rules from limited data) - Pre-train embeddings on larger graphs
Dense, Complete Graphs: - Any embedding model works well - Choose based on relation patterns
Biomedical/Domain Graphs: - Consider type constraints in sampling - Use domain-specific negative sampling - Hetionet benefits from relation-specific models
Advanced Techniques
Multi-Hop Reasoning
Chain multiple relations to answer complex queries: - "What drugs treat diseases caused by gene X?" - Requires path-based or rule-based reasoning - NeuralLP naturally supports this
Temporal Knowledge Graphs
Extend to time-varying facts: - Add temporal information to triples - Predict future facts - Requires temporal encoding in models
Few-Shot Learning
Handle relations with few examples: - Meta-learning approaches - Transfer from related relations - Important for emerging knowledge
Inductive Learning
Generalize to unseen entities: - KBGAT and other GNN-based methods - Use entity features/descriptions - Critical for evolving knowledge graphs
Biomedical Applications
Drug Repurposing
Predict "drug treats disease" links in Hetionet: 1. Train on known drug-disease associations 2. Predict new treatment candidates 3. Filter by mechanism (gene, pathway involvement) 4. Validate predictions experimentally
Disease Gene Discovery
Identify genes associated with diseases: 1. Model gene-disease-pathway networks 2. Predict missing gene-disease links 3. Incorporate protein interactions, expression data 4. Prioritize candidates for validation
Protein Function Prediction
Link proteins to biological processes: 1. Integrate protein interactions, GO terms 2. Predict missing GO annotations 3. Transfer function from similar proteins
Common Issues and Solutions
Issue: Poor performance on specific relation types - Solution: Analyze relation patterns, choose appropriate model, or use relation-specific models
Issue: Overfitting on small graphs - Solution: Reduce embedding dimension, increase regularization, or use simpler models
Issue: Slow training on large graphs - Solution: Reduce negative samples, use DistMult for efficiency, or implement mini-batch training
Issue: Cannot handle new entities - Solution: Use inductive models (KBGAT), incorporate entity features, or pre-compute embeddings for new entities based on their neighbors
Best Practices
- Start with ComplEx or RotatE for most tasks
- Use self-adversarial negative sampling
- Tune embedding dimension (typically 500-2000)
- Apply regularization to prevent overfitting
- Use filtered evaluation metrics
- Analyze performance per relation type
- Consider relation-specific models for heterogeneous graphs
- Validate predictions with domain experts