scripts/filter_molecules.py

#!/usr/bin/env python3
"""
Batch molecular filtering using medchem library.

This script provides a production-ready workflow for filtering compound libraries
using medchem rules, structural alerts, and custom constraints.

Usage:
    python filter_molecules.py input.csv --rules rule_of_five,rule_of_cns --alerts nibr --output filtered.csv
    python filter_molecules.py input.sdf --rules rule_of_drug --lilly --complexity 400 --output results.csv
    python filter_molecules.py smiles.txt --nibr --pains --n-jobs -1 --output clean.csv
"""

import argparse
import sys
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import json

try:
    import pandas as pd
    import datamol as dm
    import medchem as mc
    from rdkit import Chem
    from tqdm import tqdm
except ImportError as e:
    print(f"Error: Missing required package: {e}")
    print("Install dependencies: pip install medchem datamol pandas tqdm")
    sys.exit(1)


def load_molecules(input_file: Path, smiles_column: str = "smiles") -> Tuple[pd.DataFrame, List[Chem.Mol]]:
    """
    Load molecules from various file formats.

    Supports:
    - CSV/TSV with SMILES column
    - SDF files
    - Plain text files with one SMILES per line

    Returns:
        Tuple of (DataFrame with metadata, list of RDKit molecules)
    """
    suffix = input_file.suffix.lower()

    if suffix == ".sdf":
        print(f"Loading SDF file: {input_file}")
        supplier = Chem.SDMolSupplier(str(input_file))
        mols = [mol for mol in supplier if mol is not None]

        # Create DataFrame from SDF properties
        data = []
        for mol in mols:
            props = mol.GetPropsAsDict()
            props["smiles"] = Chem.MolToSmiles(mol)
            data.append(props)
        df = pd.DataFrame(data)

    elif suffix in [".csv", ".tsv"]:
        print(f"Loading CSV/TSV file: {input_file}")
        sep = "\t" if suffix == ".tsv" else ","
        df = pd.read_csv(input_file, sep=sep)

        if smiles_column not in df.columns:
            print(f"Error: Column '{smiles_column}' not found in file")
            print(f"Available columns: {', '.join(df.columns)}")
            sys.exit(1)

        print(f"Converting SMILES to molecules...")
        mols = [dm.to_mol(smi) for smi in tqdm(df[smiles_column], desc="Parsing")]

    elif suffix == ".txt":
        print(f"Loading text file: {input_file}")
        with open(input_file) as f:
            smiles_list = [line.strip() for line in f if line.strip()]

        df = pd.DataFrame({"smiles": smiles_list})
        print(f"Converting SMILES to molecules...")
        mols = [dm.to_mol(smi) for smi in tqdm(smiles_list, desc="Parsing")]

    else:
        print(f"Error: Unsupported file format: {suffix}")
        print("Supported formats: .csv, .tsv, .sdf, .txt")
        sys.exit(1)

    # Filter out invalid molecules
    valid_indices = [i for i, mol in enumerate(mols) if mol is not None]
    if len(valid_indices) < len(mols):
        n_invalid = len(mols) - len(valid_indices)
        print(f"Warning: {n_invalid} invalid molecules removed")
        df = df.iloc[valid_indices].reset_index(drop=True)
        mols = [mols[i] for i in valid_indices]

    print(f"Loaded {len(mols)} valid molecules")
    return df, mols


def apply_rule_filters(mols: List[Chem.Mol], rules: List[str], n_jobs: int) -> pd.DataFrame:
    """Apply medicinal chemistry rule filters."""
    print(f"\nApplying rule filters: {', '.join(rules)}")

    rfilter = mc.rules.RuleFilters(rule_list=rules)
    results = rfilter(mols=mols, n_jobs=n_jobs, progress=True)

    # Convert to DataFrame
    df_results = pd.DataFrame(results)

    # Add summary column
    df_results["passes_all_rules"] = df_results.all(axis=1)

    return df_results


def apply_structural_alerts(mols: List[Chem.Mol], alert_type: str, n_jobs: int) -> pd.DataFrame:
    """Apply structural alert filters."""
    print(f"\nApplying {alert_type} structural alerts...")

    if alert_type == "common":
        alert_filter = mc.structural.CommonAlertsFilters()
        results = alert_filter(mols=mols, n_jobs=n_jobs, progress=True)

        df_results = pd.DataFrame({
            "has_common_alerts": [r["has_alerts"] for r in results],
            "num_common_alerts": [r["num_alerts"] for r in results],
            "common_alert_details": [", ".join(r["alert_details"]) if r["alert_details"] else "" for r in results]
        })

    elif alert_type == "nibr":
        nibr_filter = mc.structural.NIBRFilters()
        results = nibr_filter(mols=mols, n_jobs=n_jobs, progress=True)

        df_results = pd.DataFrame({
            "passes_nibr": results
        })

    elif alert_type == "lilly":
        lilly_filter = mc.structural.LillyDemeritsFilters()
        results = lilly_filter(mols=mols, n_jobs=n_jobs, progress=True)

        df_results = pd.DataFrame({
            "lilly_demerits": [r["demerits"] for r in results],
            "passes_lilly": [r["passes"] for r in results],
            "lilly_patterns": [", ".join([p["pattern"] for p in r["matched_patterns"]]) for r in results]
        })

    elif alert_type == "pains":
        results = [mc.rules.basic_rules.pains_filter(mol) for mol in tqdm(mols, desc="PAINS")]

        df_results = pd.DataFrame({
            "passes_pains": results
        })

    else:
        raise ValueError(f"Unknown alert type: {alert_type}")

    return df_results


def apply_complexity_filter(mols: List[Chem.Mol], max_complexity: float, method: str = "bertz") -> pd.DataFrame:
    """Calculate molecular complexity."""
    print(f"\nCalculating molecular complexity (method={method}, max={max_complexity})...")

    complexity_scores = [
        mc.complexity.calculate_complexity(mol, method=method)
        for mol in tqdm(mols, desc="Complexity")
    ]

    df_results = pd.DataFrame({
        "complexity_score": complexity_scores,
        "passes_complexity": [score <= max_complexity for score in complexity_scores]
    })

    return df_results


def apply_constraints(mols: List[Chem.Mol], constraints: Dict, n_jobs: int) -> pd.DataFrame:
    """Apply custom property constraints."""
    print(f"\nApplying constraints: {constraints}")

    constraint_filter = mc.constraints.Constraints(**constraints)
    results = constraint_filter(mols=mols, n_jobs=n_jobs, progress=True)

    df_results = pd.DataFrame({
        "passes_constraints": [r["passes"] for r in results],
        "constraint_violations": [", ".join(r["violations"]) if r["violations"] else "" for r in results]
    })

    return df_results


def apply_chemical_groups(mols: List[Chem.Mol], groups: List[str]) -> pd.DataFrame:
    """Detect chemical groups."""
    print(f"\nDetecting chemical groups: {', '.join(groups)}")

    group_detector = mc.groups.ChemicalGroup(groups=groups)
    results = group_detector.get_all_matches(mols)

    df_results = pd.DataFrame()
    for group in groups:
        df_results[f"has_{group}"] = [bool(r.get(group)) for r in results]

    return df_results


def generate_summary(df: pd.DataFrame, output_file: Path):
    """Generate filtering summary report."""
    summary_file = output_file.parent / f"{output_file.stem}_summary.txt"

    with open(summary_file, "w") as f:
        f.write("=" * 80 + "\n")
        f.write("MEDCHEM FILTERING SUMMARY\n")
        f.write("=" * 80 + "\n\n")

        f.write(f"Total molecules processed: {len(df)}\n\n")

        # Rule results
        rule_cols = [col for col in df.columns if col.startswith("rule_") or col == "passes_all_rules"]
        if rule_cols:
            f.write("RULE FILTERS:\n")
            f.write("-" * 40 + "\n")
            for col in rule_cols:
                if col in df.columns and df[col].dtype == bool:
                    n_pass = df[col].sum()
                    pct = 100 * n_pass / len(df)
                    f.write(f"  {col}: {n_pass} passed ({pct:.1f}%)\n")
            f.write("\n")

        # Structural alerts
        alert_cols = [col for col in df.columns if "alert" in col.lower() or "nibr" in col.lower() or "lilly" in col.lower() or "pains" in col.lower()]
        if alert_cols:
            f.write("STRUCTURAL ALERTS:\n")
            f.write("-" * 40 + "\n")
            if "has_common_alerts" in df.columns:
                n_clean = (~df["has_common_alerts"]).sum()
                pct = 100 * n_clean / len(df)
                f.write(f"  No common alerts: {n_clean} ({pct:.1f}%)\n")
            if "passes_nibr" in df.columns:
                n_pass = df["passes_nibr"].sum()
                pct = 100 * n_pass / len(df)
                f.write(f"  Passes NIBR: {n_pass} ({pct:.1f}%)\n")
            if "passes_lilly" in df.columns:
                n_pass = df["passes_lilly"].sum()
                pct = 100 * n_pass / len(df)
                f.write(f"  Passes Lilly: {n_pass} ({pct:.1f}%)\n")
                avg_demerits = df["lilly_demerits"].mean()
                f.write(f"  Average Lilly demerits: {avg_demerits:.1f}\n")
            if "passes_pains" in df.columns:
                n_pass = df["passes_pains"].sum()
                pct = 100 * n_pass / len(df)
                f.write(f"  Passes PAINS: {n_pass} ({pct:.1f}%)\n")
            f.write("\n")

        # Complexity
        if "complexity_score" in df.columns:
            f.write("COMPLEXITY:\n")
            f.write("-" * 40 + "\n")
            avg_complexity = df["complexity_score"].mean()
            f.write(f"  Average complexity: {avg_complexity:.1f}\n")
            if "passes_complexity" in df.columns:
                n_pass = df["passes_complexity"].sum()
                pct = 100 * n_pass / len(df)
                f.write(f"  Within threshold: {n_pass} ({pct:.1f}%)\n")
            f.write("\n")

        # Constraints
        if "passes_constraints" in df.columns:
            f.write("CONSTRAINTS:\n")
            f.write("-" * 40 + "\n")
            n_pass = df["passes_constraints"].sum()
            pct = 100 * n_pass / len(df)
            f.write(f"  Passes all constraints: {n_pass} ({pct:.1f}%)\n")
            f.write("\n")

        # Overall pass rate
        pass_cols = [col for col in df.columns if col.startswith("passes_")]
        if pass_cols:
            df["passes_all_filters"] = df[pass_cols].all(axis=1)
            n_pass = df["passes_all_filters"].sum()
            pct = 100 * n_pass / len(df)
            f.write("OVERALL:\n")
            f.write("-" * 40 + "\n")
            f.write(f"  Molecules passing all filters: {n_pass} ({pct:.1f}%)\n")

        f.write("\n" + "=" * 80 + "\n")

    print(f"\nSummary report saved to: {summary_file}")


def main():
    parser = argparse.ArgumentParser(
        description="Batch molecular filtering using medchem",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__
    )

    # Input/Output
    parser.add_argument("input", type=Path, help="Input file (CSV, TSV, SDF, or TXT)")
    parser.add_argument("--output", "-o", type=Path, required=True, help="Output CSV file")
    parser.add_argument("--smiles-column", default="smiles", help="Name of SMILES column (default: smiles)")

    # Rule filters
    parser.add_argument("--rules", help="Comma-separated list of rules (e.g., rule_of_five,rule_of_cns)")

    # Structural alerts
    parser.add_argument("--common-alerts", action="store_true", help="Apply common structural alerts")
    parser.add_argument("--nibr", action="store_true", help="Apply NIBR filters")
    parser.add_argument("--lilly", action="store_true", help="Apply Lilly demerits filter")
    parser.add_argument("--pains", action="store_true", help="Apply PAINS filter")

    # Complexity
    parser.add_argument("--complexity", type=float, help="Maximum complexity threshold")
    parser.add_argument("--complexity-method", default="bertz", choices=["bertz", "whitlock", "barone"],
                       help="Complexity calculation method")

    # Constraints
    parser.add_argument("--mw-range", help="Molecular weight range (e.g., 200,500)")
    parser.add_argument("--logp-range", help="LogP range (e.g., -2,5)")
    parser.add_argument("--tpsa-max", type=float, help="Maximum TPSA")
    parser.add_argument("--hbd-max", type=int, help="Maximum H-bond donors")
    parser.add_argument("--hba-max", type=int, help="Maximum H-bond acceptors")
    parser.add_argument("--rotatable-bonds-max", type=int, help="Maximum rotatable bonds")

    # Chemical groups
    parser.add_argument("--groups", help="Comma-separated chemical groups to detect")

    # Processing options
    parser.add_argument("--n-jobs", type=int, default=-1, help="Number of parallel jobs (-1 = all cores)")
    parser.add_argument("--no-summary", action="store_true", help="Don't generate summary report")
    parser.add_argument("--filter-output", action="store_true", help="Only output molecules passing all filters")

    args = parser.parse_args()

    # Load molecules
    df, mols = load_molecules(args.input, args.smiles_column)

    # Apply filters
    result_dfs = [df]

    # Rules
    if args.rules:
        rule_list = [r.strip() for r in args.rules.split(",")]
        df_rules = apply_rule_filters(mols, rule_list, args.n_jobs)
        result_dfs.append(df_rules)

    # Structural alerts
    if args.common_alerts:
        df_alerts = apply_structural_alerts(mols, "common", args.n_jobs)
        result_dfs.append(df_alerts)

    if args.nibr:
        df_nibr = apply_structural_alerts(mols, "nibr", args.n_jobs)
        result_dfs.append(df_nibr)

    if args.lilly:
        df_lilly = apply_structural_alerts(mols, "lilly", args.n_jobs)
        result_dfs.append(df_lilly)

    if args.pains:
        df_pains = apply_structural_alerts(mols, "pains", args.n_jobs)
        result_dfs.append(df_pains)

    # Complexity
    if args.complexity:
        df_complexity = apply_complexity_filter(mols, args.complexity, args.complexity_method)
        result_dfs.append(df_complexity)

    # Constraints
    constraints = {}
    if args.mw_range:
        mw_min, mw_max = map(float, args.mw_range.split(","))
        constraints["mw_range"] = (mw_min, mw_max)
    if args.logp_range:
        logp_min, logp_max = map(float, args.logp_range.split(","))
        constraints["logp_range"] = (logp_min, logp_max)
    if args.tpsa_max:
        constraints["tpsa_max"] = args.tpsa_max
    if args.hbd_max:
        constraints["hbd_max"] = args.hbd_max
    if args.hba_max:
        constraints["hba_max"] = args.hba_max
    if args.rotatable_bonds_max:
        constraints["rotatable_bonds_max"] = args.rotatable_bonds_max

    if constraints:
        df_constraints = apply_constraints(mols, constraints, args.n_jobs)
        result_dfs.append(df_constraints)

    # Chemical groups
    if args.groups:
        group_list = [g.strip() for g in args.groups.split(",")]
        df_groups = apply_chemical_groups(mols, group_list)
        result_dfs.append(df_groups)

    # Combine results
    df_final = pd.concat(result_dfs, axis=1)

    # Filter output if requested
    if args.filter_output:
        pass_cols = [col for col in df_final.columns if col.startswith("passes_")]
        if pass_cols:
            df_final["passes_all"] = df_final[pass_cols].all(axis=1)
            df_final = df_final[df_final["passes_all"]]
            print(f"\nFiltered to {len(df_final)} molecules passing all filters")

    # Save results
    args.output.parent.mkdir(parents=True, exist_ok=True)
    df_final.to_csv(args.output, index=False)
    print(f"\nResults saved to: {args.output}")

    # Generate summary
    if not args.no_summary:
        generate_summary(df_final, args.output)

    print("\nDone!")


if __name__ == "__main__":
    main()
← Back to medchem