scripts/quick_trainer_setup.py

"""
Quick Trainer Setup Examples for PyTorch Lightning.

This script provides ready-to-use Trainer configurations for common use cases.
Copy and modify these configurations for your specific needs.
"""

import lightning as L
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
    DeviceStatsMonitor,
    RichProgressBar,
)
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy


# =============================================================================
# 1. BASIC TRAINING (Single GPU/CPU)
# =============================================================================

def basic_trainer():
    """
    Simple trainer for quick prototyping.
    Use for: Small models, debugging, single GPU training
    """
    trainer = L.Trainer(
        max_epochs=10,
        accelerator="auto",  # Automatically select GPU/CPU
        devices="auto",      # Use all available devices
        enable_progress_bar=True,
        logger=True,
    )
    return trainer


# =============================================================================
# 2. DEBUGGING CONFIGURATION
# =============================================================================

def debug_trainer():
    """
    Trainer for debugging with fast dev run and anomaly detection.
    Use for: Finding bugs, testing code quickly
    """
    trainer = L.Trainer(
        fast_dev_run=True,           # Run 1 batch through train/val/test
        accelerator="cpu",            # Use CPU for easier debugging
        detect_anomaly=True,          # Detect NaN/Inf in gradients
        log_every_n_steps=1,         # Log every step
        enable_progress_bar=True,
    )
    return trainer


# =============================================================================
# 3. PRODUCTION TRAINING (Single GPU)
# =============================================================================

def production_single_gpu_trainer(
    max_epochs=100,
    log_dir="logs",
    checkpoint_dir="checkpoints"
):
    """
    Production-ready trainer for single GPU with checkpointing and logging.
    Use for: Final training runs on single GPU
    """
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename="{epoch:02d}-{val_loss:.2f}",
        monitor="val_loss",
        mode="min",
        save_top_k=3,
        save_last=True,
        verbose=True,
    )

    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        patience=10,
        mode="min",
        verbose=True,
    )

    lr_monitor = LearningRateMonitor(logging_interval="step")

    # Logger
    tb_logger = pl_loggers.TensorBoardLogger(
        save_dir=log_dir,
        name="my_model",
    )

    # Trainer
    trainer = L.Trainer(
        max_epochs=max_epochs,
        accelerator="gpu",
        devices=1,
        precision="16-mixed",        # Mixed precision for speed
        callbacks=[
            checkpoint_callback,
            early_stop_callback,
            lr_monitor,
        ],
        logger=tb_logger,
        log_every_n_steps=50,
        gradient_clip_val=1.0,       # Clip gradients
        enable_progress_bar=True,
    )

    return trainer


# =============================================================================
# 4. MULTI-GPU TRAINING (DDP)
# =============================================================================

def multi_gpu_ddp_trainer(
    max_epochs=100,
    num_gpus=4,
    log_dir="logs",
    checkpoint_dir="checkpoints"
):
    """
    Multi-GPU training with Distributed Data Parallel.
    Use for: Models <500M parameters, standard deep learning models
    """
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename="{epoch:02d}-{val_loss:.2f}",
        monitor="val_loss",
        mode="min",
        save_top_k=3,
        save_last=True,
    )

    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        patience=10,
        mode="min",
    )

    lr_monitor = LearningRateMonitor(logging_interval="step")

    # Logger
    wandb_logger = pl_loggers.WandbLogger(
        project="my-project",
        save_dir=log_dir,
    )

    # Trainer
    trainer = L.Trainer(
        max_epochs=max_epochs,
        accelerator="gpu",
        devices=num_gpus,
        strategy=DDPStrategy(
            find_unused_parameters=False,
            gradient_as_bucket_view=True,
        ),
        precision="16-mixed",
        callbacks=[
            checkpoint_callback,
            early_stop_callback,
            lr_monitor,
        ],
        logger=wandb_logger,
        log_every_n_steps=50,
        gradient_clip_val=1.0,
        sync_batchnorm=True,         # Sync batch norm across GPUs
    )

    return trainer


# =============================================================================
# 5. LARGE MODEL TRAINING (FSDP)
# =============================================================================

def large_model_fsdp_trainer(
    max_epochs=100,
    num_gpus=8,
    log_dir="logs",
    checkpoint_dir="checkpoints"
):
    """
    Training for large models (500M+ parameters) with FSDP.
    Use for: Large transformers, models that don't fit in single GPU
    """
    import torch.nn as nn

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename="{epoch:02d}-{val_loss:.2f}",
        monitor="val_loss",
        mode="min",
        save_top_k=3,
        save_last=True,
    )

    lr_monitor = LearningRateMonitor(logging_interval="step")

    # Logger
    wandb_logger = pl_loggers.WandbLogger(
        project="large-model",
        save_dir=log_dir,
    )

    # Trainer with FSDP
    trainer = L.Trainer(
        max_epochs=max_epochs,
        accelerator="gpu",
        devices=num_gpus,
        strategy=FSDPStrategy(
            sharding_strategy="FULL_SHARD",
            activation_checkpointing_policy={
                nn.TransformerEncoderLayer,
                nn.TransformerDecoderLayer,
            },
            cpu_offload=False,       # Set True if GPU memory insufficient
        ),
        precision="bf16-mixed",      # BFloat16 for A100/H100
        callbacks=[
            checkpoint_callback,
            lr_monitor,
        ],
        logger=wandb_logger,
        log_every_n_steps=10,
        gradient_clip_val=1.0,
        accumulate_grad_batches=4,   # Gradient accumulation
    )

    return trainer


# =============================================================================
# 6. VERY LARGE MODEL TRAINING (DeepSpeed)
# =============================================================================

def deepspeed_trainer(
    max_epochs=100,
    num_gpus=8,
    stage=3,
    log_dir="logs",
    checkpoint_dir="checkpoints"
):
    """
    Training for very large models with DeepSpeed.
    Use for: Models >10B parameters, maximum memory efficiency
    """
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename="{epoch:02d}-{step:06d}",
        save_top_k=3,
        save_last=True,
        every_n_train_steps=1000,    # Save every N steps
    )

    lr_monitor = LearningRateMonitor(logging_interval="step")

    # Logger
    wandb_logger = pl_loggers.WandbLogger(
        project="very-large-model",
        save_dir=log_dir,
    )

    # Select DeepSpeed stage
    strategy = f"deepspeed_stage_{stage}"

    # Trainer
    trainer = L.Trainer(
        max_epochs=max_epochs,
        accelerator="gpu",
        devices=num_gpus,
        strategy=strategy,
        precision="16-mixed",
        callbacks=[
            checkpoint_callback,
            lr_monitor,
        ],
        logger=wandb_logger,
        log_every_n_steps=10,
        gradient_clip_val=1.0,
        accumulate_grad_batches=4,
    )

    return trainer


# =============================================================================
# 7. HYPERPARAMETER TUNING
# =============================================================================

def hyperparameter_tuning_trainer(max_epochs=50):
    """
    Lightweight trainer for hyperparameter search.
    Use for: Quick experiments, hyperparameter tuning
    """
    trainer = L.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices=1,
        enable_checkpointing=False,  # Don't save checkpoints
        logger=False,                 # Disable logging
        enable_progress_bar=False,
        limit_train_batches=0.5,     # Use 50% of training data
        limit_val_batches=0.5,       # Use 50% of validation data
    )
    return trainer


# =============================================================================
# 8. OVERFITTING TEST
# =============================================================================

def overfit_test_trainer(num_batches=10):
    """
    Trainer for overfitting on small subset to verify model capacity.
    Use for: Testing if model can learn, debugging
    """
    trainer = L.Trainer(
        max_epochs=100,
        accelerator="auto",
        devices=1,
        overfit_batches=num_batches,  # Overfit on N batches
        log_every_n_steps=1,
        enable_progress_bar=True,
    )
    return trainer


# =============================================================================
# 9. TIME-LIMITED TRAINING (SLURM)
# =============================================================================

def time_limited_trainer(
    max_time_hours=23.5,
    max_epochs=1000,
    checkpoint_dir="checkpoints"
):
    """
    Training with time limit for SLURM clusters.
    Use for: Cluster jobs with time limits
    """
    from datetime import timedelta

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        save_top_k=3,
        save_last=True,              # Important for resuming
        every_n_epochs=5,
    )

    trainer = L.Trainer(
        max_epochs=max_epochs,
        max_time=timedelta(hours=max_time_hours),
        accelerator="gpu",
        devices="auto",
        callbacks=[checkpoint_callback],
        log_every_n_steps=50,
    )

    return trainer


# =============================================================================
# 10. REPRODUCIBLE RESEARCH
# =============================================================================

def reproducible_trainer(seed=42, max_epochs=100):
    """
    Fully reproducible trainer for research papers.
    Use for: Publications, reproducible results
    """
    # Set seed
    L.seed_everything(seed, workers=True)

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath="checkpoints",
        filename="{epoch:02d}-{val_loss:.2f}",
        monitor="val_loss",
        mode="min",
        save_top_k=3,
        save_last=True,
    )

    # Trainer
    trainer = L.Trainer(
        max_epochs=max_epochs,
        accelerator="gpu",
        devices=1,
        precision="32-true",         # Full precision for reproducibility
        deterministic=True,          # Use deterministic algorithms
        benchmark=False,             # Disable cudnn benchmarking
        callbacks=[checkpoint_callback],
        log_every_n_steps=50,
    )

    return trainer


# =============================================================================
# USAGE EXAMPLES
# =============================================================================

if __name__ == "__main__":
    print("PyTorch Lightning Trainer Configurations\n")

    # Example 1: Basic training
    print("1. Basic Trainer:")
    trainer = basic_trainer()
    print(f"   - Max epochs: {trainer.max_epochs}")
    print(f"   - Accelerator: {trainer.accelerator}")
    print()

    # Example 2: Debug training
    print("2. Debug Trainer:")
    trainer = debug_trainer()
    print(f"   - Fast dev run: {trainer.fast_dev_run}")
    print(f"   - Detect anomaly: {trainer.detect_anomaly}")
    print()

    # Example 3: Production single GPU
    print("3. Production Single GPU Trainer:")
    trainer = production_single_gpu_trainer(max_epochs=100)
    print(f"   - Max epochs: {trainer.max_epochs}")
    print(f"   - Precision: {trainer.precision}")
    print(f"   - Callbacks: {len(trainer.callbacks)}")
    print()

    # Example 4: Multi-GPU DDP
    print("4. Multi-GPU DDP Trainer:")
    trainer = multi_gpu_ddp_trainer(num_gpus=4)
    print(f"   - Strategy: {trainer.strategy}")
    print(f"   - Devices: {trainer.num_devices}")
    print()

    # Example 5: FSDP for large models
    print("5. FSDP Trainer for Large Models:")
    trainer = large_model_fsdp_trainer(num_gpus=8)
    print(f"   - Strategy: {trainer.strategy}")
    print(f"   - Precision: {trainer.precision}")
    print()

    print("\nTo use these configurations:")
    print("1. Import the desired function")
    print("2. Create trainer: trainer = production_single_gpu_trainer()")
    print("3. Train model: trainer.fit(model, datamodule=dm)")
← Back to pytorch-lightning