references/tasks.md

PyHealth Clinical Prediction Tasks

Overview

PyHealth provides 20+ predefined clinical prediction tasks for common healthcare AI applications. Each task function transforms raw patient data into structured input-output pairs for model training.

Task Function Structure

All task functions inherit from BaseTask and provide:

  • input_schema: Defines input features (diagnoses, medications, labs, etc.)
  • output_schema: Defines prediction targets (labels, values)
  • pre_filter(): Optional patient/visit filtering logic

Usage Pattern:

from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn

dataset = MIMIC4Dataset(root="/path/to/data")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)

Electronic Health Record (EHR) Tasks

Mortality Prediction

Purpose: Predict patient death risk at next visit or within specified timeframe

MIMIC-III Mortality (mortality_prediction_mimic3_fn) - Predicts death at next hospital visit - Binary classification task - Input: Historical diagnoses, procedures, medications - Output: Binary label (deceased/alive)

MIMIC-IV Mortality (mortality_prediction_mimic4_fn) - Updated version for MIMIC-IV dataset - Enhanced feature set - Improved label quality

eICU Mortality (mortality_prediction_eicu_fn) - Multi-center ICU mortality prediction - Accounts for hospital-level variation

OMOP Mortality (mortality_prediction_omop_fn) - Standardized mortality prediction - Works with OMOP common data model

In-Hospital Mortality (inhospital_mortality_prediction_mimic4_fn) - Predicts death during current hospitalization - Real-time risk assessment - Earlier prediction window than next-visit mortality

StageNet Mortality (mortality_prediction_mimic4_fn_stagenet) - Specialized for StageNet model architecture - Temporal stage-aware prediction

Hospital Readmission Prediction

Purpose: Identify patients at risk of hospital readmission within specified timeframe (typically 30 days)

MIMIC-III Readmission (readmission_prediction_mimic3_fn) - 30-day readmission prediction - Binary classification - Input: Diagnosis history, medications, demographics - Output: Binary label (readmitted/not readmitted)

MIMIC-IV Readmission (readmission_prediction_mimic4_fn) - Enhanced readmission features - Improved temporal modeling

eICU Readmission (readmission_prediction_eicu_fn) - ICU-specific readmission risk - Multi-site data

OMOP Readmission (readmission_prediction_omop_fn) - Standardized readmission prediction

Length of Stay Prediction

Purpose: Estimate hospital stay duration for resource planning and patient management

MIMIC-III Length of Stay (length_of_stay_prediction_mimic3_fn) - Regression task - Input: Admission diagnoses, vitals, demographics - Output: Continuous value (days)

MIMIC-IV Length of Stay (length_of_stay_prediction_mimic4_fn) - Enhanced features for LOS prediction - Better temporal granularity

eICU Length of Stay (length_of_stay_prediction_eicu_fn) - ICU stay duration prediction - Multi-hospital data

OMOP Length of Stay (length_of_stay_prediction_omop_fn) - Standardized LOS prediction

Drug Recommendation

Purpose: Suggest appropriate medications based on patient history and current conditions

MIMIC-III Drug Recommendation (drug_recommendation_mimic3_fn) - Multi-label classification - Input: Diagnoses, previous medications, demographics - Output: Set of recommended drug codes - Considers drug-drug interactions

MIMIC-IV Drug Recommendation (drug_recommendation_mimic4_fn) - Updated medication data - Enhanced interaction modeling

eICU Drug Recommendation (drug_recommendation_eicu_fn) - Critical care medication recommendations

OMOP Drug Recommendation (drug_recommendation_omop_fn) - Standardized drug recommendation

Key Considerations: - Handles polypharmacy scenarios - Multi-label prediction (multiple drugs per patient) - Can integrate with SafeDrug/GAMENet models for safety-aware recommendations

Specialized Clinical Tasks

Medical Coding

MIMIC-III ICD-9 Coding (icd9_coding_mimic3_fn) - Assigns ICD-9 diagnosis/procedure codes to clinical notes - Multi-label text classification - Input: Clinical text/documentation - Output: Set of ICD-9 codes - Supports both diagnosis and procedure coding

Patient Linkage

MIMIC-III Patient Linking (patient_linkage_mimic3_fn) - Record matching and deduplication - Binary classification (same patient or not) - Input: Demographic and clinical features from two records - Output: Match probability

Physiological Signal Tasks

Sleep Staging

Purpose: Classify sleep stages from EEG/physiological signals for sleep disorder diagnosis

ISRUC Sleep Staging (sleep_staging_isruc_fn) - Multi-class classification (Wake, N1, N2, N3, REM) - Input: Multi-channel EEG signals - Output: Sleep stage per epoch (typically 30 seconds)

SleepEDF Sleep Staging (sleep_staging_sleepedf_fn) - Standard sleep staging task - PSG signal processing

SHHS Sleep Staging (sleep_staging_shhs_fn) - Large-scale sleep study data - Population-level sleep analysis

Standardized Labels: - Wake (W) - Non-REM Stage 1 (N1) - Non-REM Stage 2 (N2) - Non-REM Stage 3 (N3/Deep Sleep) - REM (Rapid Eye Movement)

EEG Analysis

Abnormality Detection (abnormality_detection_tuab_fn) - Binary classification (normal/abnormal EEG) - Clinical screening application - Input: Multi-channel EEG recordings - Output: Binary label

Event Detection (event_detection_tuev_fn) - Identify specific EEG events (spikes, seizures) - Multi-class classification - Input: EEG time series - Output: Event type and timing

Seizure Detection (seizure_detection_tusz_fn) - Specialized epileptic seizure detection - Critical for epilepsy monitoring - Input: Continuous EEG - Output: Seizure/non-seizure classification

Medical Imaging Tasks

COVID-19 Chest X-ray Classification

COVID-19 CXR (covid_classification_cxr_fn) - Multi-class image classification - Classes: COVID-19, bacterial pneumonia, viral pneumonia, normal - Input: Chest X-ray images - Output: Disease classification

Text-Based Tasks

Medical Transcription Classification

Medical Specialty Classification (medical_transcription_classification_fn) - Classify clinical notes by medical specialty - Multi-class text classification - Input: Clinical transcription text - Output: Medical specialty (Cardiology, Neurology, etc.)

Custom Task Creation

Creating Custom Tasks

Define custom prediction tasks by specifying input/output schemas:

from pyhealth.tasks import BaseTask

def custom_task_fn(patient):
    """Custom prediction task"""

    # Define input features
    samples = []

    for i, visit in enumerate(patient.visits):
        # Skip if not enough history
        if i < 2:
            continue

        # Create input from historical visits
        input_info = {
            "diagnoses": [],
            "medications": [],
            "procedures": []
        }

        # Collect features from previous visits
        for past_visit in patient.visits[:i]:
            for event in past_visit.events:
                if event.vocabulary == "ICD10CM":
                    input_info["diagnoses"].append(event.code)
                elif event.vocabulary == "NDC":
                    input_info["medications"].append(event.code)

        # Define prediction target
        # Example: predict specific outcome at current visit
        output_info = {
            "label": 1 if some_condition else 0
        }

        samples.append({
            "patient_id": patient.patient_id,
            "visit_id": visit.visit_id,
            "input_info": input_info,
            "output_info": output_info
        })

    return samples

# Apply custom task
sample_dataset = dataset.set_task(custom_task_fn)

Task Function Components

  1. Input Schema Definition
  2. Specify which features to extract
  3. Define feature types (codes, sequences, values)
  4. Set temporal windows

  5. Output Schema Definition

  6. Define prediction targets
  7. Set label types (binary, multi-class, multi-label, regression)
  8. Specify evaluation metrics

  9. Filtering Logic

  10. Exclude patients/visits with insufficient data
  11. Apply inclusion/exclusion criteria
  12. Handle missing data

  13. Sample Generation

  14. Create input-output pairs
  15. Maintain patient/visit identifiers
  16. Preserve temporal ordering

Task Selection Guidelines

Clinical Prediction Tasks

Use when: Working with structured EHR data (diagnoses, medications, procedures)

Datasets: MIMIC-III, MIMIC-IV, eICU, OMOP

Common tasks: - Mortality prediction for risk stratification - Readmission prediction for care transition planning - Length of stay for resource allocation - Drug recommendation for clinical decision support

Signal Processing Tasks

Use when: Working with physiological time-series data

Datasets: SleepEDF, SHHS, ISRUC, TUEV, TUAB, TUSZ

Common tasks: - Sleep staging for sleep disorder diagnosis - EEG abnormality detection for screening - Seizure detection for epilepsy monitoring

Imaging Tasks

Use when: Working with medical images

Datasets: COVID-19 CXR

Common tasks: - Disease classification from radiographs - Abnormality detection

Text Tasks

Use when: Working with clinical notes and documentation

Datasets: Medical Transcriptions, MIMIC-III (with notes)

Common tasks: - Medical coding from clinical text - Specialty classification - Clinical information extraction

Task Output Structure

All task functions return SampleDataset with:

sample = {
    "patient_id": "unique_patient_id",
    "visit_id": "unique_visit_id",  # if applicable
    "input_info": {
        # Input features (diagnoses, medications, etc.)
    },
    "output_info": {
        # Prediction targets (labels, values)
    }
}

Integration with Models

Tasks define the input/output contract for models:

from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import mortality_prediction_mimic4_fn
from pyhealth.models import Transformer

# 1. Create task-specific dataset
dataset = MIMIC4Dataset(root="/path/to/data")
sample_dataset = dataset.set_task(mortality_prediction_mimic4_fn)

# 2. Model automatically adapts to task schema
model = Transformer(
    dataset=sample_dataset,
    feature_keys=["diagnoses", "medications"],
    mode="binary",  # matches task output
)

Best Practices

  1. Match task to clinical question: Choose predefined tasks when available for standardized benchmarking

  2. Consider temporal windows: Ensure sufficient history for meaningful predictions

  3. Handle class imbalance: Many clinical outcomes are rare (mortality, readmission)

  4. Validate clinical relevance: Ensure prediction windows align with clinical decision-making timelines

  5. Use appropriate metrics: Different tasks require different evaluation metrics (AUROC for binary, macro-F1 for multi-class)

  6. Document exclusion criteria: Track which patients/visits are filtered and why

  7. Preserve patient privacy: Always use de-identified data and follow HIPAA/GDPR guidelines

← Back to pyhealth