scripts/transfer_learning.py

#!/usr/bin/env python3
"""
Transfer Learning Script for DeepChem

Use pretrained models (ChemBERTa, GROVER, MolFormer) for molecular property prediction
with transfer learning. Particularly useful for small datasets.

Usage:
    python transfer_learning.py --model chemberta --data my_data.csv --target activity
    python transfer_learning.py --model grover --dataset bbbp
"""

import argparse
import deepchem as dc
import sys


PRETRAINED_MODELS = {
    'chemberta': {
        'name': 'ChemBERTa',
        'description': 'BERT pretrained on 77M molecules from ZINC15',
        'model_id': 'seyonec/ChemBERTa-zinc-base-v1'
    },
    'grover': {
        'name': 'GROVER',
        'description': 'Graph transformer pretrained on 10M molecules',
        'model_id': None  # GROVER uses its own loading mechanism
    },
    'molformer': {
        'name': 'MolFormer',
        'description': 'Transformer pretrained on molecular structures',
        'model_id': 'ibm/MoLFormer-XL-both-10pct'
    }
}


def train_chemberta(train_dataset, valid_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=10):
    """
    Fine-tune ChemBERTa on a dataset.

    Args:
        train_dataset: Training dataset
        valid_dataset: Validation dataset
        test_dataset: Test dataset
        task_type: 'classification' or 'regression'
        n_tasks: Number of prediction tasks
        n_epochs: Number of fine-tuning epochs

    Returns:
        Trained model and evaluation results
    """
    print("=" * 70)
    print("Fine-tuning ChemBERTa")
    print("=" * 70)
    print("\nChemBERTa is a BERT model pretrained on 77M molecules from ZINC15.")
    print("It uses SMILES strings as input and has learned rich molecular")
    print("representations that transfer well to downstream tasks.")

    print(f"\nLoading pretrained ChemBERTa model...")
    model = dc.models.HuggingFaceModel(
        model=PRETRAINED_MODELS['chemberta']['model_id'],
        task=task_type,
        n_tasks=n_tasks,
        batch_size=32,
        learning_rate=2e-5  # Lower LR for fine-tuning
    )

    print(f"\nFine-tuning for {n_epochs} epochs...")
    print("(This may take a while on the first run as the model is downloaded)")
    model.fit(train_dataset, nb_epoch=n_epochs)
    print("Fine-tuning complete!")

    # Evaluate
    print("\n" + "=" * 70)
    print("Model Evaluation")
    print("=" * 70)

    if task_type == 'classification':
        metrics = [
            dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
            dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
        ]
    else:
        metrics = [
            dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
            dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
        ]

    results = {}
    for name, dataset in [('Train', train_dataset), ('Valid', valid_dataset), ('Test', test_dataset)]:
        print(f"\n{name} Set:")
        scores = model.evaluate(dataset, metrics)
        results[name] = scores
        for metric_name, score in scores.items():
            print(f"  {metric_name}: {score:.4f}")

    return model, results


def train_grover(train_dataset, test_dataset, task_type='classification', n_tasks=1, n_epochs=20):
    """
    Fine-tune GROVER on a dataset.

    Args:
        train_dataset: Training dataset
        test_dataset: Test dataset
        task_type: 'classification' or 'regression'
        n_tasks: Number of prediction tasks
        n_epochs: Number of fine-tuning epochs

    Returns:
        Trained model and evaluation results
    """
    print("=" * 70)
    print("Fine-tuning GROVER")
    print("=" * 70)
    print("\nGROVER is a graph transformer pretrained on 10M molecules using")
    print("self-supervised learning. It learns both node and graph-level")
    print("representations through masked atom/bond prediction tasks.")

    print(f"\nCreating GROVER model...")
    model = dc.models.GroverModel(
        task=task_type,
        n_tasks=n_tasks,
        model_dir='./grover_pretrained'
    )

    print(f"\nFine-tuning for {n_epochs} epochs...")
    model.fit(train_dataset, nb_epoch=n_epochs)
    print("Fine-tuning complete!")

    # Evaluate
    print("\n" + "=" * 70)
    print("Model Evaluation")
    print("=" * 70)

    if task_type == 'classification':
        metrics = [
            dc.metrics.Metric(dc.metrics.roc_auc_score, name='ROC-AUC'),
            dc.metrics.Metric(dc.metrics.accuracy_score, name='Accuracy'),
        ]
    else:
        metrics = [
            dc.metrics.Metric(dc.metrics.r2_score, name='R²'),
            dc.metrics.Metric(dc.metrics.mean_absolute_error, name='MAE'),
        ]

    results = {}
    for name, dataset in [('Train', train_dataset), ('Test', test_dataset)]:
        print(f"\n{name} Set:")
        scores = model.evaluate(dataset, metrics)
        results[name] = scores
        for metric_name, score in scores.items():
            print(f"  {metric_name}: {score:.4f}")

    return model, results


def load_molnet_dataset(dataset_name, model_type):
    """
    Load a MoleculeNet dataset with appropriate featurization.

    Args:
        dataset_name: Name of MoleculeNet dataset
        model_type: Type of pretrained model being used

    Returns:
        tasks, train/valid/test datasets, transformers
    """
    # Map of MoleculeNet datasets
    molnet_datasets = {
        'tox21': dc.molnet.load_tox21,
        'bbbp': dc.molnet.load_bbbp,
        'bace': dc.molnet.load_bace_classification,
        'hiv': dc.molnet.load_hiv,
        'delaney': dc.molnet.load_delaney,
        'freesolv': dc.molnet.load_freesolv,
        'lipo': dc.molnet.load_lipo
    }

    if dataset_name not in molnet_datasets:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # ChemBERTa and MolFormer use raw SMILES
    if model_type in ['chemberta', 'molformer']:
        featurizer = 'Raw'
    # GROVER needs graph features
    elif model_type == 'grover':
        featurizer = 'GraphConv'
    else:
        featurizer = 'ECFP'

    print(f"\nLoading {dataset_name} dataset...")
    load_func = molnet_datasets[dataset_name]
    tasks, datasets, transformers = load_func(
        featurizer=featurizer,
        splitter='scaffold'
    )

    return tasks, datasets, transformers


def load_custom_dataset(data_path, target_cols, smiles_col, model_type):
    """
    Load a custom CSV dataset.

    Args:
        data_path: Path to CSV file
        target_cols: List of target column names
        smiles_col: Name of SMILES column
        model_type: Type of pretrained model being used

    Returns:
        train, valid, test datasets
    """
    print(f"\nLoading custom data from {data_path}...")

    # Choose featurizer based on model
    if model_type in ['chemberta', 'molformer']:
        featurizer = dc.feat.DummyFeaturizer()  # Models handle featurization
    elif model_type == 'grover':
        featurizer = dc.feat.MolGraphConvFeaturizer()
    else:
        featurizer = dc.feat.CircularFingerprint()

    loader = dc.data.CSVLoader(
        tasks=target_cols,
        feature_field=smiles_col,
        featurizer=featurizer
    )
    dataset = loader.create_dataset(data_path)

    print(f"Loaded {len(dataset)} molecules")

    # Split data
    print("Splitting data with scaffold splitter...")
    splitter = dc.splits.ScaffoldSplitter()
    train, valid, test = splitter.train_valid_test_split(
        dataset,
        frac_train=0.8,
        frac_valid=0.1,
        frac_test=0.1
    )

    print(f"  Training: {len(train)}")
    print(f"  Validation: {len(valid)}")
    print(f"  Test: {len(test)}")

    return train, valid, test


def main():
    parser = argparse.ArgumentParser(
        description='Transfer learning for molecular property prediction'
    )
    parser.add_argument(
        '--model',
        type=str,
        choices=list(PRETRAINED_MODELS.keys()),
        required=True,
        help='Pretrained model to use'
    )
    parser.add_argument(
        '--dataset',
        type=str,
        choices=['tox21', 'bbbp', 'bace', 'hiv', 'delaney', 'freesolv', 'lipo'],
        default=None,
        help='MoleculeNet dataset to use'
    )
    parser.add_argument(
        '--data',
        type=str,
        default=None,
        help='Path to custom CSV file'
    )
    parser.add_argument(
        '--target',
        nargs='+',
        default=['target'],
        help='Target column name(s) for custom data'
    )
    parser.add_argument(
        '--smiles-col',
        type=str,
        default='smiles',
        help='SMILES column name for custom data'
    )
    parser.add_argument(
        '--task-type',
        type=str,
        choices=['classification', 'regression'],
        default='classification',
        help='Type of prediction task'
    )
    parser.add_argument(
        '--epochs',
        type=int,
        default=10,
        help='Number of fine-tuning epochs'
    )

    args = parser.parse_args()

    # Validate arguments
    if args.dataset is None and args.data is None:
        print("Error: Must specify either --dataset or --data", file=sys.stderr)
        return 1

    if args.dataset and args.data:
        print("Error: Cannot specify both --dataset and --data", file=sys.stderr)
        return 1

    # Print model info
    model_info = PRETRAINED_MODELS[args.model]
    print("\n" + "=" * 70)
    print(f"Transfer Learning with {model_info['name']}")
    print("=" * 70)
    print(f"\n{model_info['description']}")

    try:
        # Load dataset
        if args.dataset:
            tasks, datasets, transformers = load_molnet_dataset(args.dataset, args.model)
            train, valid, test = datasets
            task_type = 'classification' if args.dataset in ['tox21', 'bbbp', 'bace', 'hiv'] else 'regression'
            n_tasks = len(tasks)
        else:
            train, valid, test = load_custom_dataset(
                args.data,
                args.target,
                args.smiles_col,
                args.model
            )
            task_type = args.task_type
            n_tasks = len(args.target)

        # Train model
        if args.model == 'chemberta':
            model, results = train_chemberta(
                train, valid, test,
                task_type=task_type,
                n_tasks=n_tasks,
                n_epochs=args.epochs
            )
        elif args.model == 'grover':
            model, results = train_grover(
                train, test,
                task_type=task_type,
                n_tasks=n_tasks,
                n_epochs=args.epochs
            )
        else:
            print(f"Error: Model {args.model} not yet implemented", file=sys.stderr)
            return 1

        print("\n" + "=" * 70)
        print("Transfer Learning Complete!")
        print("=" * 70)
        print("\nTip: Pretrained models often work best with:")
        print("  - Small datasets (< 1000 samples)")
        print("  - Lower learning rates (1e-5 to 5e-5)")
        print("  - Fewer epochs (5-20)")
        print("  - Avoiding overfitting through early stopping")

        return 0

    except Exception as e:
        print(f"\nError: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        return 1


if __name__ == '__main__':
    sys.exit(main())
← Back to deepchem