PyHealth Training, Evaluation, and Interpretability
Overview
PyHealth provides comprehensive tools for training models, evaluating predictions, ensuring model reliability, and interpreting results for clinical applications.
Trainer Class
Core Functionality
The Trainer class manages the complete model training and evaluation workflow with PyTorch integration.
Initialization:
from pyhealth.trainer import Trainer
trainer = Trainer(
model=model, # PyHealth or PyTorch model
device="cuda", # or "cpu"
)
Training
train() method
Trains models with comprehensive monitoring and checkpointing.
Parameters:
- train_dataloader: Training data loader
- val_dataloader: Validation data loader (optional)
- test_dataloader: Test data loader (optional)
- epochs: Number of training epochs
- optimizer: Optimizer instance or class
- learning_rate: Learning rate (default: 1e-3)
- weight_decay: L2 regularization (default: 0)
- max_grad_norm: Gradient clipping threshold
- monitor: Metric to monitor (e.g., "pr_auc_score")
- monitor_criterion: "max" or "min"
- save_path: Checkpoint save directory
Usage:
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
epochs=50,
optimizer=torch.optim.Adam,
learning_rate=1e-3,
weight_decay=1e-5,
max_grad_norm=5.0,
monitor="pr_auc_score",
monitor_criterion="max",
save_path="./checkpoints"
)
Training Features:
-
Automatic Checkpointing: Saves best model based on monitored metric
-
Early Stopping: Stops training if no improvement
-
Gradient Clipping: Prevents exploding gradients
-
Progress Tracking: Displays training progress and metrics
-
Multi-GPU Support: Automatic device placement
Inference
inference() method
Performs predictions on datasets.
Parameters:
- dataloader: Data loader for inference
- additional_outputs: List of additional outputs to return
- return_patient_ids: Return patient identifiers
Usage:
predictions = trainer.inference(
dataloader=test_loader,
additional_outputs=["attention_weights", "embeddings"],
return_patient_ids=True
)
Returns:
- y_pred: Model predictions
- y_true: Ground truth labels
- patient_ids: Patient identifiers (if requested)
- Additional outputs (if specified)
Evaluation
evaluate() method
Computes comprehensive evaluation metrics.
Parameters:
- dataloader: Data loader for evaluation
- metrics: List of metric functions
Usage:
from pyhealth.metrics import binary_metrics_fn
results = trainer.evaluate(
dataloader=test_loader,
metrics=["accuracy", "pr_auc_score", "roc_auc_score", "f1_score"]
)
print(results)
# Output: {'accuracy': 0.85, 'pr_auc_score': 0.78, 'roc_auc_score': 0.82, 'f1_score': 0.73}
Checkpoint Management
save() method
trainer.save("./models/best_model.pt")
load() method
trainer.load("./models/best_model.pt")
Evaluation Metrics
Binary Classification Metrics
Available Metrics:
- accuracy: Overall accuracy
- precision: Positive predictive value
- recall: Sensitivity/True positive rate
- f1_score: F1 score (harmonic mean of precision and recall)
- roc_auc_score: Area under ROC curve
- pr_auc_score: Area under precision-recall curve
- cohen_kappa: Inter-rater reliability
Usage:
from pyhealth.metrics import binary_metrics_fn
# Comprehensive binary metrics
metrics = binary_metrics_fn(
y_true=labels,
y_pred=predictions,
metrics=["accuracy", "f1_score", "pr_auc_score", "roc_auc_score"]
)
Threshold Selection:
# Default threshold: 0.5
predictions_binary = (predictions > 0.5).astype(int)
# Optimal threshold by F1
from sklearn.metrics import f1_score
thresholds = np.arange(0.1, 0.9, 0.05)
f1_scores = [f1_score(y_true, (y_pred > t).astype(int)) for t in thresholds]
optimal_threshold = thresholds[np.argmax(f1_scores)]
Best Practices: - Use AUROC: Overall model discrimination - Use AUPRC: Especially for imbalanced classes - Use F1: Balance precision and recall - Report confidence intervals: Bootstrap resampling
Multi-Class Classification Metrics
Available Metrics:
- accuracy: Overall accuracy
- macro_f1: Unweighted mean F1 across classes
- micro_f1: Global F1 (total TP, FP, FN)
- weighted_f1: Weighted mean F1 by class frequency
- cohen_kappa: Multi-class kappa
Usage:
from pyhealth.metrics import multiclass_metrics_fn
metrics = multiclass_metrics_fn(
y_true=labels,
y_pred=predictions,
metrics=["accuracy", "macro_f1", "weighted_f1"]
)
Per-Class Metrics:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred,
target_names=["Wake", "N1", "N2", "N3", "REM"]))
Confusion Matrix:
from sklearn.metrics import confusion_matrix
import seaborn as sns
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d')
Multi-Label Classification Metrics
Available Metrics:
- jaccard_score: Intersection over union
- hamming_loss: Fraction of incorrect labels
- example_f1: F1 per example (micro average)
- label_f1: F1 per label (macro average)
Usage:
from pyhealth.metrics import multilabel_metrics_fn
# y_pred: [n_samples, n_labels] binary matrix
metrics = multilabel_metrics_fn(
y_true=label_matrix,
y_pred=pred_matrix,
metrics=["jaccard_score", "example_f1", "label_f1"]
)
Drug Recommendation Metrics:
# Jaccard similarity (intersection/union)
jaccard = len(set(true_drugs) & set(pred_drugs)) / len(set(true_drugs) | set(pred_drugs))
# Precision@k: Precision for top-k predictions
def precision_at_k(y_true, y_pred, k=10):
top_k_pred = y_pred.argsort()[-k:]
return len(set(y_true) & set(top_k_pred)) / k
Regression Metrics
Available Metrics:
- mean_absolute_error: Average absolute error
- mean_squared_error: Average squared error
- root_mean_squared_error: RMSE
- r2_score: Coefficient of determination
Usage:
from pyhealth.metrics import regression_metrics_fn
metrics = regression_metrics_fn(
y_true=true_values,
y_pred=predictions,
metrics=["mae", "rmse", "r2"]
)
Percentage Error Metrics:
# Mean Absolute Percentage Error
mape = np.mean(np.abs((y_true - y_pred) / y_true)) * 100
# Median Absolute Percentage Error (robust to outliers)
medape = np.median(np.abs((y_true - y_pred) / y_true)) * 100
Fairness Metrics
Purpose: Assess model bias across demographic groups
Available Metrics:
- demographic_parity: Equal positive prediction rates
- equalized_odds: Equal TPR and FPR across groups
- equal_opportunity: Equal TPR across groups
- predictive_parity: Equal PPV across groups
Usage:
from pyhealth.metrics import fairness_metrics_fn
fairness_results = fairness_metrics_fn(
y_true=labels,
y_pred=predictions,
sensitive_attributes=demographics, # e.g., race, gender
metrics=["demographic_parity", "equalized_odds"]
)
Example:
# Evaluate fairness across gender
male_mask = (demographics == "male")
female_mask = (demographics == "female")
male_tpr = recall_score(y_true[male_mask], y_pred[male_mask])
female_tpr = recall_score(y_true[female_mask], y_pred[female_mask])
tpr_disparity = abs(male_tpr - female_tpr)
print(f"TPR disparity: {tpr_disparity:.3f}")
Calibration and Uncertainty Quantification
Model Calibration
Purpose: Ensure predicted probabilities match actual frequencies
Calibration Plot:
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt
fraction_of_positives, mean_predicted_value = calibration_curve(
y_true, y_prob, n_bins=10
)
plt.plot(mean_predicted_value, fraction_of_positives, marker='o')
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect calibration')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.legend()
Expected Calibration Error (ECE):
def expected_calibration_error(y_true, y_prob, n_bins=10):
"""Compute ECE"""
bins = np.linspace(0, 1, n_bins + 1)
bin_indices = np.digitize(y_prob, bins) - 1
ece = 0
for i in range(n_bins):
mask = bin_indices == i
if mask.sum() > 0:
bin_accuracy = y_true[mask].mean()
bin_confidence = y_prob[mask].mean()
ece += mask.sum() / len(y_true) * abs(bin_accuracy - bin_confidence)
return ece
Calibration Methods:
- Platt Scaling: Logistic regression on validation predictions
from sklearn.linear_model import LogisticRegression
calibrator = LogisticRegression()
calibrator.fit(val_predictions.reshape(-1, 1), val_labels)
calibrated_probs = calibrator.predict_proba(test_predictions.reshape(-1, 1))[:, 1]
- Isotonic Regression: Non-parametric calibration
from sklearn.isotonic import IsotonicRegression
calibrator = IsotonicRegression(out_of_bounds='clip')
calibrator.fit(val_predictions, val_labels)
calibrated_probs = calibrator.predict(test_predictions)
- Temperature Scaling: Scale logits before softmax
def find_temperature(logits, labels):
"""Find optimal temperature parameter"""
from scipy.optimize import minimize
def nll(temp):
scaled_logits = logits / temp
probs = torch.softmax(scaled_logits, dim=1)
return F.cross_entropy(probs, labels).item()
result = minimize(nll, x0=1.0, method='BFGS')
return result.x[0]
temperature = find_temperature(val_logits, val_labels)
calibrated_logits = test_logits / temperature
Uncertainty Quantification
Conformal Prediction:
Provide prediction sets with guaranteed coverage.
Usage:
from pyhealth.metrics import prediction_set_metrics_fn
# Calibrate on validation set
scores = 1 - val_predictions[np.arange(len(val_labels)), val_labels]
quantile_level = np.quantile(scores, 0.9) # 90% coverage
# Generate prediction sets on test set
prediction_sets = test_predictions > (1 - quantile_level)
# Evaluate
metrics = prediction_set_metrics_fn(
y_true=test_labels,
prediction_sets=prediction_sets,
metrics=["coverage", "average_size"]
)
Monte Carlo Dropout:
Estimate uncertainty through dropout at inference.
def predict_with_uncertainty(model, dataloader, num_samples=20):
"""Predict with uncertainty using MC dropout"""
model.train() # Keep dropout active
predictions = []
for _ in range(num_samples):
batch_preds = []
for batch in dataloader:
with torch.no_grad():
output = model(batch)
batch_preds.append(output)
predictions.append(torch.cat(batch_preds))
predictions = torch.stack(predictions)
mean_pred = predictions.mean(dim=0)
std_pred = predictions.std(dim=0) # Uncertainty
return mean_pred, std_pred
Ensemble Uncertainty:
# Train multiple models
models = [train_model(seed=i) for i in range(5)]
# Predict with ensemble
ensemble_preds = []
for model in models:
pred = model.predict(test_data)
ensemble_preds.append(pred)
mean_pred = np.mean(ensemble_preds, axis=0)
std_pred = np.std(ensemble_preds, axis=0) # Uncertainty
Interpretability
Attention Visualization
For Transformer and RETAIN models:
# Get attention weights during inference
outputs = trainer.inference(
test_loader,
additional_outputs=["attention_weights"]
)
attention = outputs["attention_weights"]
# Visualize attention for sample
import matplotlib.pyplot as plt
import seaborn as sns
sample_idx = 0
sample_attention = attention[sample_idx] # [seq_length, seq_length]
sns.heatmap(sample_attention, cmap='viridis')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Attention Weights')
plt.show()
RETAIN Interpretation:
# RETAIN provides visit-level and feature-level attention
visit_attention = outputs["visit_attention"] # Which visits are important
feature_attention = outputs["feature_attention"] # Which features are important
# Find most influential visit
most_important_visit = visit_attention[sample_idx].argmax()
# Find most influential features in that visit
important_features = feature_attention[sample_idx, most_important_visit].argsort()[-10:]
Feature Importance
Permutation Importance:
from sklearn.inspection import permutation_importance
def get_predictions(model, X):
return model.predict(X)
result = permutation_importance(
model, X_test, y_test,
n_repeats=10,
scoring='roc_auc'
)
# Sort features by importance
indices = result.importances_mean.argsort()[::-1]
for i in indices[:10]:
print(f"{feature_names[i]}: {result.importances_mean[i]:.3f}")
SHAP Values:
import shap
# Create explainer
explainer = shap.DeepExplainer(model, train_data)
# Compute SHAP values
shap_values = explainer.shap_values(test_data)
# Visualize
shap.summary_plot(shap_values, test_data, feature_names=feature_names)
ChEFER (Clinical Health Event Feature Extraction and Ranking)
PyHealth's Interpretability Tool:
from pyhealth.explain import ChEFER
explainer = ChEFER(model=model, dataset=test_dataset)
# Get feature importance for prediction
importance_scores = explainer.explain(
patient_id="patient_123",
visit_id="visit_456"
)
# Visualize top features
explainer.plot_importance(importance_scores, top_k=20)
Complete Training Pipeline Example
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import Transformer
from pyhealth.trainer import Trainer
from pyhealth.metrics import binary_metrics_fn
# 1. Load and prepare data
dataset = MIMIC4Dataset(root="/path/to/mimic4")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)
# 2. Split data
train_data, val_data, test_data = split_by_patient(
sample_dataset, ratios=[0.7, 0.1, 0.2], seed=42
)
# 3. Create data loaders
train_loader = get_dataloader(train_data, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = get_dataloader(test_data, batch_size=64, shuffle=False)
# 4. Initialize model
model = Transformer(
dataset=sample_dataset,
feature_keys=["diagnoses", "procedures", "medications"],
mode="binary",
embedding_dim=128,
num_heads=8,
num_layers=3,
dropout=0.3
)
# 5. Train model
trainer = Trainer(model=model, device="cuda")
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=50,
optimizer=torch.optim.Adam,
learning_rate=1e-3,
weight_decay=1e-5,
monitor="pr_auc_score",
monitor_criterion="max",
save_path="./checkpoints/mortality_model"
)
# 6. Evaluate on test set
test_results = trainer.evaluate(
test_loader,
metrics=["accuracy", "precision", "recall", "f1_score",
"roc_auc_score", "pr_auc_score"]
)
print("Test Results:")
for metric, value in test_results.items():
print(f"{metric}: {value:.4f}")
# 7. Get predictions for analysis
predictions = trainer.inference(test_loader, return_patient_ids=True)
y_pred, y_true, patient_ids = predictions
# 8. Calibration analysis
from sklearn.calibration import calibration_curve
fraction_pos, mean_pred = calibration_curve(y_true, y_pred, n_bins=10)
ece = expected_calibration_error(y_true, y_pred)
print(f"Expected Calibration Error: {ece:.4f}")
# 9. Save final model
trainer.save("./models/mortality_transformer_final.pt")
Best Practices
Training
- Monitor multiple metrics: Track both loss and task-specific metrics
- Use validation set: Prevent overfitting with early stopping
- Gradient clipping: Stabilize training (max_grad_norm=5.0)
- Learning rate scheduling: Reduce LR on plateau
- Checkpoint best model: Save based on validation performance
Evaluation
- Use task-appropriate metrics: AUROC/AUPRC for binary, macro-F1 for imbalanced multi-class
- Report confidence intervals: Bootstrap or cross-validation
- Stratified evaluation: Report metrics by subgroups
- Clinical metrics: Include clinically relevant thresholds
- Fairness assessment: Evaluate across demographic groups
Deployment
- Calibrate predictions: Ensure probabilities are reliable
- Quantify uncertainty: Provide confidence estimates
- Monitor performance: Track metrics in production
- Handle distribution shift: Detect when data changes
- Interpretability: Provide explanations for predictions