references/layers_reference.md

PyTorch Geometric Neural Network Layers Reference

This document provides a comprehensive reference of all neural network layers available in torch_geometric.nn.

Layer Capability Flags

When selecting layers, consider these capability flags:

  • SparseTensor: Supports torch_sparse.SparseTensor format for efficient sparse operations
  • edge_weight: Handles one-dimensional edge weight data
  • edge_attr: Processes multi-dimensional edge feature information
  • Bipartite: Works with bipartite graphs (different source/target node dimensions)
  • Static: Operates on static graphs with batched node features
  • Lazy: Enables initialization without specifying input channel dimensions

Convolutional Layers

Standard Graph Convolutions

GCNConv - Graph Convolutional Network layer - Implements spectral graph convolution with symmetric normalization - Supports: SparseTensor, edge_weight, Bipartite, Lazy - Use for: Citation networks, social networks, general graph learning - Example: GCNConv(in_channels, out_channels, improved=False, cached=True)

SAGEConv - GraphSAGE layer - Inductive learning via neighborhood sampling and aggregation - Supports: SparseTensor, Bipartite, Lazy - Use for: Large graphs, inductive learning, heterogeneous features - Example: SAGEConv(in_channels, out_channels, aggr='mean')

GATConv - Graph Attention Network layer - Multi-head attention mechanism for adaptive neighbor weighting - Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy - Use for: Tasks requiring variable neighbor importance - Example: GATConv(in_channels, out_channels, heads=8, dropout=0.6)

GraphConv - Simple graph convolution (Morris et al.) - Basic message passing with optional edge weights - Supports: SparseTensor, edge_weight, Bipartite, Lazy - Use for: Baseline models, simple graph structures - Example: GraphConv(in_channels, out_channels, aggr='add')

GINConv - Graph Isomorphism Network layer - Maximally powerful GNN for graph isomorphism testing - Supports: Bipartite - Use for: Graph classification, molecular property prediction - Example: GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), nn.ReLU()))

TransformerConv - Graph Transformer layer - Combines graph structure with transformer attention - Supports: SparseTensor, Bipartite, Lazy - Use for: Long-range dependencies, complex graphs - Example: TransformerConv(in_channels, out_channels, heads=8, beta=True)

ChebConv - Chebyshev spectral graph convolution - Uses Chebyshev polynomials for efficient spectral filtering - Supports: SparseTensor, edge_weight, Bipartite, Lazy - Use for: Spectral graph learning, efficient convolutions - Example: ChebConv(in_channels, out_channels, K=3)

SGConv - Simplified Graph Convolution - Pre-computes fixed number of propagation steps - Supports: SparseTensor, edge_weight, Bipartite, Lazy - Use for: Fast training, shallow models - Example: SGConv(in_channels, out_channels, K=2)

APPNP - Approximate Personalized Propagation of Neural Predictions - Separates feature transformation from propagation - Supports: SparseTensor, edge_weight, Lazy - Use for: Deep propagation without oversmoothing - Example: APPNP(K=10, alpha=0.1)

ARMAConv - ARMA graph convolution - Uses ARMA filters for graph filtering - Supports: SparseTensor, edge_weight, Bipartite, Lazy - Use for: Advanced spectral methods - Example: ARMAConv(in_channels, out_channels, num_stacks=3, num_layers=2)

GATv2Conv - Improved Graph Attention Network - Fixes static attention computation issue in GAT - Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy - Use for: Better attention learning than original GAT - Example: GATv2Conv(in_channels, out_channels, heads=8)

SuperGATConv - Self-supervised Graph Attention - Adds self-supervised attention mechanism - Supports: SparseTensor, edge_attr, Bipartite, Static, Lazy - Use for: Self-supervised learning, limited labels - Example: SuperGATConv(in_channels, out_channels, heads=8)

GMMConv - Gaussian Mixture Model Convolution - Uses Gaussian kernels in pseudo-coordinate space - Supports: Bipartite - Use for: Point clouds, spatial data - Example: GMMConv(in_channels, out_channels, dim=3, kernel_size=5)

SplineConv - Spline-based convolution - B-spline basis functions for spatial filtering - Supports: Bipartite - Use for: Irregular grids, continuous spaces - Example: SplineConv(in_channels, out_channels, dim=2, kernel_size=5)

NNConv - Neural Network Convolution - Edge features processed by neural networks - Supports: edge_attr, Bipartite - Use for: Rich edge features, molecular graphs - Example: NNConv(in_channels, out_channels, nn=edge_nn, aggr='mean')

CGConv - Crystal Graph Convolution - Designed for crystalline materials - Supports: Bipartite - Use for: Materials science, crystal structures - Example: CGConv(in_channels, dim=3, batch_norm=True)

EdgeConv - Edge Convolution (Dynamic Graph CNN) - Dynamically computes edges based on feature space - Supports: Static - Use for: Point clouds, dynamic graphs - Example: EdgeConv(nn=edge_nn, aggr='max')

PointNetConv - PointNet++ convolution - Local and global feature learning for point clouds - Use for: 3D point cloud processing - Example: PointNetConv(local_nn, global_nn)

ResGatedGraphConv - Residual Gated Graph Convolution - Gating mechanism with residual connections - Supports: edge_attr, Bipartite, Lazy - Use for: Deep GNNs, complex features - Example: ResGatedGraphConv(in_channels, out_channels)

GENConv - Generalized Graph Convolution - Generalizes multiple GNN variants - Supports: SparseTensor, edge_weight, edge_attr, Bipartite, Lazy - Use for: Flexible architecture exploration - Example: GENConv(in_channels, out_channels, aggr='softmax', num_layers=2)

FiLMConv - Feature-wise Linear Modulation - Conditions on global features - Supports: Bipartite, Lazy - Use for: Conditional generation, multi-task learning - Example: FiLMConv(in_channels, out_channels, num_relations=5)

PANConv - Path Attention Network - Attention over multi-hop paths - Supports: SparseTensor, Lazy - Use for: Complex connectivity patterns - Example: PANConv(in_channels, out_channels, filter_size=3)

ClusterGCNConv - Cluster-GCN convolution - Efficient training via graph clustering - Supports: edge_attr, Lazy - Use for: Very large graphs - Example: ClusterGCNConv(in_channels, out_channels)

MFConv - Multi-scale Feature Convolution - Aggregates features at multiple scales - Supports: SparseTensor, Lazy - Use for: Multi-scale patterns - Example: MFConv(in_channels, out_channels)

RGCNConv - Relational Graph Convolution - Handles multiple edge types - Supports: SparseTensor, edge_weight, Lazy - Use for: Knowledge graphs, heterogeneous graphs - Example: RGCNConv(in_channels, out_channels, num_relations=10)

FAConv - Frequency Adaptive Convolution - Adaptive filtering in spectral domain - Supports: SparseTensor, Lazy - Use for: Spectral graph learning - Example: FAConv(in_channels, eps=0.1, dropout=0.5)

Molecular and 3D Convolutions

SchNet - Continuous-filter convolutional layer - Designed for molecular dynamics - Use for: Molecular property prediction, 3D molecules - Example: SchNet(hidden_channels=128, num_filters=64, num_interactions=6)

DimeNet - Directional Message Passing - Uses directional information and angles - Use for: 3D molecular structures, chemical properties - Example: DimeNet(hidden_channels=128, out_channels=1, num_blocks=6)

PointTransformerConv - Point cloud transformer - Transformer for 3D point clouds - Use for: 3D vision, point cloud segmentation - Example: PointTransformerConv(in_channels, out_channels)

Hypergraph Convolutions

HypergraphConv - Hypergraph convolution - Operates on hyperedges (edges connecting multiple nodes) - Supports: Lazy - Use for: Multi-way relationships, chemical reactions - Example: HypergraphConv(in_channels, out_channels)

HGTConv - Heterogeneous Graph Transformer - Transformer for heterogeneous graphs with multiple types - Supports: Lazy - Use for: Heterogeneous networks, knowledge graphs - Example: HGTConv(in_channels, out_channels, metadata, heads=8)

Aggregation Operators

Aggr - Base aggregation class - Flexible aggregation across nodes

SumAggregation - Sum aggregation - Example: SumAggregation()

MeanAggregation - Mean aggregation - Example: MeanAggregation()

MaxAggregation - Max aggregation - Example: MaxAggregation()

SoftmaxAggregation - Softmax-weighted aggregation - Learnable attention weights - Example: SoftmaxAggregation(learn=True)

PowerMeanAggregation - Power mean aggregation - Learnable power parameter - Example: PowerMeanAggregation(learn=True)

LSTMAggregation - LSTM-based aggregation - Sequential processing of neighbors - Example: LSTMAggregation(in_channels, out_channels)

SetTransformerAggregation - Set Transformer aggregation - Transformer for permutation-invariant aggregation - Example: SetTransformerAggregation(in_channels, out_channels)

MultiAggregation - Multiple aggregations - Combines multiple aggregation methods - Example: MultiAggregation(['mean', 'max', 'std'])

Pooling Layers

Global Pooling

global_mean_pool - Global mean pooling - Averages node features per graph - Example: global_mean_pool(x, batch)

global_max_pool - Global max pooling - Max over node features per graph - Example: global_max_pool(x, batch)

global_add_pool - Global sum pooling - Sums node features per graph - Example: global_add_pool(x, batch)

global_sort_pool - Global sort pooling - Sorts and concatenates top-k nodes - Example: global_sort_pool(x, batch, k=30)

GlobalAttention - Global attention pooling - Learnable attention weights for aggregation - Example: GlobalAttention(gate_nn)

Set2Set - Set2Set pooling - LSTM-based attention mechanism - Example: Set2Set(in_channels, processing_steps=3)

Hierarchical Pooling

TopKPooling - Top-k pooling - Keeps top-k nodes based on projection scores - Example: TopKPooling(in_channels, ratio=0.5)

SAGPooling - Self-Attention Graph Pooling - Uses self-attention for node selection - Example: SAGPooling(in_channels, ratio=0.5)

ASAPooling - Adaptive Structure Aware Pooling - Structure-aware node selection - Example: ASAPooling(in_channels, ratio=0.5)

PANPooling - Path Attention Pooling - Attention over paths for pooling - Example: PANPooling(in_channels, ratio=0.5)

EdgePooling - Edge contraction pooling - Pools by contracting edges - Example: EdgePooling(in_channels)

MemPooling - Memory-based pooling - Learnable cluster assignments - Example: MemPooling(in_channels, out_channels, heads=4, num_clusters=10)

avg_pool / max_pool - Average/Max pool with clustering - Pools nodes within clusters - Example: avg_pool(cluster, data)

Normalization Layers

BatchNorm - Batch normalization - Normalizes features across batch - Example: BatchNorm(in_channels)

LayerNorm - Layer normalization - Normalizes features per sample - Example: LayerNorm(in_channels)

InstanceNorm - Instance normalization - Normalizes per sample and graph - Example: InstanceNorm(in_channels)

GraphNorm - Graph normalization - Graph-specific normalization - Example: GraphNorm(in_channels)

PairNorm - Pair normalization - Prevents oversmoothing in deep GNNs - Example: PairNorm(scale_individually=False)

MessageNorm - Message normalization - Normalizes messages during passing - Example: MessageNorm(learn_scale=True)

DiffGroupNorm - Differentiable Group Normalization - Learnable grouping for normalization - Example: DiffGroupNorm(in_channels, groups=10)

Model Architectures

Pre-Built Models

GCN - Complete Graph Convolutional Network - Multi-layer GCN with dropout - Example: GCN(in_channels, hidden_channels, num_layers, out_channels)

GraphSAGE - Complete GraphSAGE model - Multi-layer SAGE with dropout - Example: GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)

GIN - Complete Graph Isomorphism Network - Multi-layer GIN for graph classification - Example: GIN(in_channels, hidden_channels, num_layers, out_channels)

GAT - Complete Graph Attention Network - Multi-layer GAT with attention - Example: GAT(in_channels, hidden_channels, num_layers, out_channels, heads=8)

PNA - Principal Neighbourhood Aggregation - Combines multiple aggregators and scalers - Example: PNA(in_channels, hidden_channels, num_layers, out_channels)

EdgeCNN - Edge Convolution CNN - Dynamic graph CNN for point clouds - Example: EdgeCNN(out_channels, num_layers=3, k=20)

Auto-Encoders

GAE - Graph Auto-Encoder - Encodes graphs into latent space - Example: GAE(encoder)

VGAE - Variational Graph Auto-Encoder - Probabilistic graph encoding - Example: VGAE(encoder)

ARGA - Adversarially Regularized Graph Auto-Encoder - GAE with adversarial regularization - Example: ARGA(encoder, discriminator)

ARGVA - Adversarially Regularized Variational Graph Auto-Encoder - VGAE with adversarial regularization - Example: ARGVA(encoder, discriminator)

Knowledge Graph Embeddings

TransE - Translating embeddings - Learns entity and relation embeddings - Example: TransE(num_nodes, num_relations, hidden_channels)

RotatE - Rotational embeddings - Embeddings in complex space - Example: RotatE(num_nodes, num_relations, hidden_channels)

ComplEx - Complex embeddings - Complex-valued embeddings - Example: ComplEx(num_nodes, num_relations, hidden_channels)

DistMult - Bilinear diagonal model - Simplified bilinear model - Example: DistMult(num_nodes, num_relations, hidden_channels)

Utility Layers

Sequential - Sequential container - Chains multiple layers - Example: Sequential('x, edge_index', [(GCNConv(16, 64), 'x, edge_index -> x'), nn.ReLU()])

JumpingKnowledge - Jumping knowledge connections - Combines representations from all layers - Modes: 'cat', 'max', 'lstm' - Example: JumpingKnowledge(mode='cat')

DeepGCNLayer - Deep GCN layer wrapper - Enables very deep GNNs with skip connections - Example: DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1)

MLP - Multi-layer perceptron - Standard feedforward network - Example: MLP([in_channels, 64, 64, out_channels], dropout=0.5)

Linear - Lazy linear layer - Linear transformation with lazy initialization - Example: Linear(in_channels, out_channels, bias=True)

Dense Layers

For dense (non-sparse) graph representations:

DenseGCNConv - Dense GCN layer DenseSAGEConv - Dense SAGE layer DenseGINConv - Dense GIN layer DenseGraphConv - Dense graph convolution

These are useful when working with small, fully-connected, or densely represented graphs.

Usage Tips

  1. Start simple: Begin with GCNConv or GATConv for most tasks
  2. Consider data type: Use molecular layers (SchNet, DimeNet) for 3D structures
  3. Check capabilities: Match layer capabilities to your data (edge features, bipartite, etc.)
  4. Deep networks: Use normalization (PairNorm, LayerNorm) and JumpingKnowledge for deep GNNs
  5. Large graphs: Use scalable layers (SAGE, Cluster-GCN) with neighbor sampling
  6. Heterogeneous: Use RGCNConv, HGTConv, or to_hetero() conversion
  7. Lazy initialization: Use lazy layers when input dimensions vary or are unknown

Common Patterns

Basic GNN

from torch_geometric.nn import GCNConv, global_mean_pool

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return global_mean_pool(x, batch)

Deep GNN with Normalization

class DeepGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.norms.append(LayerNorm(hidden_channels))

        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.norms.append(LayerNorm(hidden_channels))

        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.jk = JumpingKnowledge(mode='cat')

    def forward(self, x, edge_index, batch):
        xs = []
        for conv, norm in zip(self.convs[:-1], self.norms):
            x = conv(x, edge_index)
            x = norm(x)
            x = F.relu(x)
            xs.append(x)

        x = self.convs[-1](x, edge_index)
        xs.append(x)

        x = self.jk(xs)
        return global_mean_pool(x, batch)
← Back to torch_geometric