scripts/template_datamodule.py

"""
Template for creating a PyTorch Lightning DataModule.

This template provides a complete boilerplate for building a LightningDataModule
with all essential methods and best practices for data handling.
"""

import lightning as L
from torch.utils.data import Dataset, DataLoader, random_split
import torch


class CustomDataset(Dataset):
    """
    Custom Dataset implementation.

    Replace this with your actual dataset implementation.
    """

    def __init__(self, data_path, transform=None):
        """
        Initialize the dataset.

        Args:
            data_path: Path to data directory
            transform: Optional transforms to apply
        """
        self.data_path = data_path
        self.transform = transform

        # Load your data here
        # self.data = load_data(data_path)
        # self.labels = load_labels(data_path)

        # Placeholder data
        self.data = torch.randn(1000, 3, 224, 224)
        self.labels = torch.randint(0, 10, (1000,))

    def __len__(self):
        """Return the size of the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Get a single item from the dataset.

        Args:
            idx: Index of the item

        Returns:
            Tuple of (data, label)
        """
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample, label


class TemplateDataModule(L.LightningDataModule):
    """
    Template LightningDataModule for data handling.

    This class encapsulates all data processing steps:
    1. Download/prepare data (prepare_data)
    2. Create datasets (setup)
    3. Create dataloaders (train/val/test/predict_dataloader)

    Args:
        data_dir: Directory containing the data
        batch_size: Batch size for dataloaders
        num_workers: Number of workers for data loading
        train_val_split: Train/validation split ratio
        pin_memory: Whether to pin memory for faster GPU transfer
    """

    def __init__(
        self,
        data_dir: str = "./data",
        batch_size: int = 32,
        num_workers: int = 4,
        train_val_split: float = 0.8,
        pin_memory: bool = True,
    ):
        super().__init__()

        # Save hyperparameters
        self.save_hyperparameters()

        # Initialize as None (will be set in setup)
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.predict_dataset = None

    def prepare_data(self):
        """
        Download and prepare data.

        This method is called only once and on a single process.
        Do not set state here (e.g., self.x = y) because it's not
        transferred to other processes.

        Use this for:
        - Downloading datasets
        - Tokenizing text
        - Saving processed data to disk
        """
        # Example: Download data if not exists
        # if not os.path.exists(self.hparams.data_dir):
        #     download_dataset(self.hparams.data_dir)

        # Example: Process and save data
        # process_and_save(self.hparams.data_dir)

        pass

    def setup(self, stage: str = None):
        """
        Create datasets for each stage.

        This method is called on every process in distributed training.
        Set state here (e.g., self.train_dataset = ...).

        Args:
            stage: Current stage ('fit', 'validate', 'test', or 'predict')
        """
        # Define transforms
        train_transform = self._get_train_transforms()
        test_transform = self._get_test_transforms()

        # Setup for training and validation
        if stage == "fit" or stage is None:
            # Load full dataset
            full_dataset = CustomDataset(
                self.hparams.data_dir, transform=train_transform
            )

            # Split into train and validation
            train_size = int(self.hparams.train_val_split * len(full_dataset))
            val_size = len(full_dataset) - train_size

            self.train_dataset, self.val_dataset = random_split(
                full_dataset,
                [train_size, val_size],
                generator=torch.Generator().manual_seed(42),
            )

            # Apply test transforms to validation set
            # (Note: random_split doesn't support different transforms,
            # you may need to implement a custom wrapper)

        # Setup for testing
        if stage == "test" or stage is None:
            self.test_dataset = CustomDataset(
                self.hparams.data_dir, transform=test_transform
            )

        # Setup for prediction
        if stage == "predict":
            self.predict_dataset = CustomDataset(
                self.hparams.data_dir, transform=test_transform
            )

    def _get_train_transforms(self):
        """
        Define training transforms/augmentations.

        Returns:
            Training transforms
        """
        # Example with torchvision:
        # from torchvision import transforms
        # return transforms.Compose([
        #     transforms.RandomHorizontalFlip(),
        #     transforms.RandomRotation(10),
        #     transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                         std=[0.229, 0.224, 0.225])
        # ])

        return None

    def _get_test_transforms(self):
        """
        Define test/validation transforms (no augmentation).

        Returns:
            Test/validation transforms
        """
        # Example with torchvision:
        # from torchvision import transforms
        # return transforms.Compose([
        #     transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                         std=[0.229, 0.224, 0.225])
        # ])

        return None

    def train_dataloader(self):
        """
        Create training dataloader.

        Returns:
            Training DataLoader
        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=True if self.hparams.num_workers > 0 else False,
            drop_last=True,  # Drop last incomplete batch
        )

    def val_dataloader(self):
        """
        Create validation dataloader.

        Returns:
            Validation DataLoader
        """
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=True if self.hparams.num_workers > 0 else False,
        )

    def test_dataloader(self):
        """
        Create test dataloader.

        Returns:
            Test DataLoader
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
        )

    def predict_dataloader(self):
        """
        Create prediction dataloader.

        Returns:
            Prediction DataLoader
        """
        return DataLoader(
            self.predict_dataset,
            batch_size=self.hparams.batch_size,
            shuffle=False,
            num_workers=self.hparams.num_workers,
        )

    # Optional: State management for checkpointing

    def state_dict(self):
        """
        Save DataModule state for checkpointing.

        Returns:
            State dictionary
        """
        return {"train_val_split": self.hparams.train_val_split}

    def load_state_dict(self, state_dict):
        """
        Restore DataModule state from checkpoint.

        Args:
            state_dict: State dictionary
        """
        self.hparams.train_val_split = state_dict["train_val_split"]

    # Optional: Teardown for cleanup

    def teardown(self, stage: str = None):
        """
        Cleanup after training/testing/prediction.

        Args:
            stage: Current stage ('fit', 'validate', 'test', or 'predict')
        """
        # Clean up resources
        if stage == "fit":
            self.train_dataset = None
            self.val_dataset = None
        elif stage == "test":
            self.test_dataset = None
        elif stage == "predict":
            self.predict_dataset = None


# Example usage
if __name__ == "__main__":
    # Create DataModule
    dm = TemplateDataModule(
        data_dir="./data",
        batch_size=64,
        num_workers=8,
        train_val_split=0.8,
    )

    # Setup for training
    dm.prepare_data()
    dm.setup(stage="fit")

    # Get dataloaders
    train_loader = dm.train_dataloader()
    val_loader = dm.val_dataloader()

    print(f"Train dataset size: {len(dm.train_dataset)}")
    print(f"Validation dataset size: {len(dm.val_dataset)}")
    print(f"Train batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")

    # Example: Use with Trainer
    # from template_lightning_module import TemplateLightningModule
    # model = TemplateLightningModule()
    # trainer = L.Trainer(max_epochs=10)
    # trainer.fit(model, datamodule=dm)
← Back to pytorch-lightning