scripts/train_template.py

#!/usr/bin/env python3
"""
PufferLib Training Template

This template provides a complete training script for reinforcement learning
with PufferLib. Customize the environment, policy, and training configuration
as needed for your use case.
"""

import argparse
import torch
import torch.nn as nn
import pufferlib
from pufferlib import PuffeRL
from pufferlib.pytorch import layer_init


class Policy(nn.Module):
    """Example policy network."""

    def __init__(self, observation_space, action_space, hidden_size=256):
        super().__init__()

        self.observation_space = observation_space
        self.action_space = action_space

        # Encoder network
        self.encoder = nn.Sequential(
            layer_init(nn.Linear(observation_space.shape[0], hidden_size)),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_size, hidden_size)),
            nn.ReLU()
        )

        # Policy head (actor)
        self.actor = layer_init(nn.Linear(hidden_size, action_space.n), std=0.01)

        # Value head (critic)
        self.critic = layer_init(nn.Linear(hidden_size, 1), std=1.0)

    def forward(self, observations):
        """Forward pass through policy."""
        features = self.encoder(observations)
        logits = self.actor(features)
        value = self.critic(features)
        return logits, value


def make_env():
    """Create environment. Customize this for your task."""
    # Option 1: Use Ocean environment
    return pufferlib.make('procgen-coinrun', num_envs=256)

    # Option 2: Use Gymnasium environment
    # return pufferlib.make('gym-CartPole-v1', num_envs=256)

    # Option 3: Use custom environment
    # from my_envs import MyEnvironment
    # return pufferlib.emulate(MyEnvironment, num_envs=256)


def create_policy(env):
    """Create policy network."""
    return Policy(
        observation_space=env.observation_space,
        action_space=env.action_space,
        hidden_size=256
    )


def train(args):
    """Main training function."""
    # Set random seeds
    torch.manual_seed(args.seed)

    # Create environment
    print(f"Creating environment with {args.num_envs} parallel environments...")
    env = pufferlib.make(
        args.env_name,
        num_envs=args.num_envs,
        num_workers=args.num_workers
    )

    # Create policy
    print("Initializing policy...")
    policy = create_policy(env)

    if args.device == 'cuda' and torch.cuda.is_available():
        policy = policy.cuda()
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("Using CPU")

    # Create logger
    if args.logger == 'wandb':
        from pufferlib import WandbLogger
        logger = WandbLogger(
            project=args.project,
            name=args.exp_name,
            config=vars(args)
        )
    elif args.logger == 'neptune':
        from pufferlib import NeptuneLogger
        logger = NeptuneLogger(
            project=args.project,
            name=args.exp_name,
            api_token=args.neptune_token
        )
    else:
        from pufferlib import NoLogger
        logger = NoLogger()

    # Create trainer
    print("Creating trainer...")
    trainer = PuffeRL(
        env=env,
        policy=policy,
        device=args.device,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        n_epochs=args.n_epochs,
        gamma=args.gamma,
        gae_lambda=args.gae_lambda,
        clip_coef=args.clip_coef,
        ent_coef=args.ent_coef,
        vf_coef=args.vf_coef,
        max_grad_norm=args.max_grad_norm,
        logger=logger,
        compile=args.compile
    )

    # Training loop
    print(f"Starting training for {args.num_iterations} iterations...")
    for iteration in range(1, args.num_iterations + 1):
        # Collect rollouts
        rollout_data = trainer.evaluate()

        # Train on batch
        train_metrics = trainer.train()

        # Log results
        trainer.mean_and_log()

        # Save checkpoint
        if iteration % args.save_freq == 0:
            checkpoint_path = f"{args.checkpoint_dir}/checkpoint_{iteration}.pt"
            trainer.save_checkpoint(checkpoint_path)
            print(f"Saved checkpoint to {checkpoint_path}")

        # Print progress
        if iteration % args.log_freq == 0:
            mean_reward = rollout_data.get('mean_reward', 0)
            sps = rollout_data.get('sps', 0)
            print(f"Iteration {iteration}/{args.num_iterations} | "
                  f"Mean Reward: {mean_reward:.2f} | "
                  f"SPS: {sps:,.0f}")

    print("Training complete!")

    # Save final model
    final_path = f"{args.checkpoint_dir}/final_model.pt"
    trainer.save_checkpoint(final_path)
    print(f"Saved final model to {final_path}")


def main():
    parser = argparse.ArgumentParser(description='PufferLib Training')

    # Environment
    parser.add_argument('--env-name', type=str, default='procgen-coinrun',
                        help='Environment name')
    parser.add_argument('--num-envs', type=int, default=256,
                        help='Number of parallel environments')
    parser.add_argument('--num-workers', type=int, default=8,
                        help='Number of vectorization workers')

    # Training
    parser.add_argument('--num-iterations', type=int, default=10000,
                        help='Number of training iterations')
    parser.add_argument('--learning-rate', type=float, default=3e-4,
                        help='Learning rate')
    parser.add_argument('--batch-size', type=int, default=32768,
                        help='Batch size for training')
    parser.add_argument('--n-epochs', type=int, default=4,
                        help='Number of training epochs per batch')
    parser.add_argument('--device', type=str, default='cuda',
                        choices=['cuda', 'cpu'], help='Device to use')

    # PPO Parameters
    parser.add_argument('--gamma', type=float, default=0.99,
                        help='Discount factor')
    parser.add_argument('--gae-lambda', type=float, default=0.95,
                        help='GAE lambda')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                        help='PPO clipping coefficient')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                        help='Entropy coefficient')
    parser.add_argument('--vf-coef', type=float, default=0.5,
                        help='Value function coefficient')
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                        help='Maximum gradient norm')

    # Logging
    parser.add_argument('--logger', type=str, default='none',
                        choices=['wandb', 'neptune', 'none'],
                        help='Logger to use')
    parser.add_argument('--project', type=str, default='pufferlib-training',
                        help='Project name for logging')
    parser.add_argument('--exp-name', type=str, default='experiment',
                        help='Experiment name')
    parser.add_argument('--neptune-token', type=str, default=None,
                        help='Neptune API token')
    parser.add_argument('--log-freq', type=int, default=10,
                        help='Logging frequency (iterations)')

    # Checkpointing
    parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
                        help='Directory to save checkpoints')
    parser.add_argument('--save-freq', type=int, default=100,
                        help='Checkpoint save frequency (iterations)')

    # Misc
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--compile', action='store_true',
                        help='Use torch.compile for faster training')

    args = parser.parse_args()

    # Create checkpoint directory
    import os
    os.makedirs(args.checkpoint_dir, exist_ok=True)

    # Run training
    train(args)


if __name__ == '__main__':
    main()
← Back to pufferlib