scripts/visualize_graph.py

#!/usr/bin/env python3
"""
Visualize PyTorch Geometric graph structures using networkx and matplotlib.

This script provides utilities to visualize Data objects, including:
- Graph structure (nodes and edges)
- Node features (as colors)
- Edge attributes (as edge colors/widths)
- Community/cluster assignments

Usage:
    python visualize_graph.py --dataset Cora --output graph.png

Or import and use:
    from scripts.visualize_graph import visualize_data
    visualize_data(data, title="My Graph", show_labels=True)
"""

import argparse
import matplotlib.pyplot as plt
import networkx as nx
import torch
from typing import Optional, Union
import numpy as np


def visualize_data(
    data,
    title: str = "Graph Visualization",
    node_color_attr: Optional[str] = None,
    edge_color_attr: Optional[str] = None,
    show_labels: bool = False,
    node_size: int = 300,
    figsize: tuple = (12, 10),
    layout: str = "spring",
    output_path: Optional[str] = None,
    max_nodes: Optional[int] = None,
):
    """
    Visualize a PyTorch Geometric Data object.

    Args:
        data: PyTorch Geometric Data object
        title: Plot title
        node_color_attr: Data attribute to use for node colors (e.g., 'y', 'train_mask')
        edge_color_attr: Data attribute to use for edge colors
        show_labels: Whether to show node labels
        node_size: Size of nodes in visualization
        figsize: Figure size (width, height)
        layout: Graph layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral')
        output_path: Path to save figure (if None, displays interactively)
        max_nodes: Maximum number of nodes to visualize (samples if exceeded)
    """
    # Sample nodes if graph is too large
    if max_nodes and data.num_nodes > max_nodes:
        print(f"Graph has {data.num_nodes} nodes. Sampling {max_nodes} nodes for visualization.")
        node_indices = torch.randperm(data.num_nodes)[:max_nodes]
        data = data.subgraph(node_indices)

    # Convert to networkx graph
    G = nx.Graph() if is_undirected(data.edge_index) else nx.DiGraph()

    # Add nodes
    G.add_nodes_from(range(data.num_nodes))

    # Add edges
    edge_index = data.edge_index.cpu().numpy()
    edges = list(zip(edge_index[0], edge_index[1]))
    G.add_edges_from(edges)

    # Setup figure
    fig, ax = plt.subplots(figsize=figsize)

    # Choose layout
    if layout == "spring":
        pos = nx.spring_layout(G, k=0.5, iterations=50)
    elif layout == "circular":
        pos = nx.circular_layout(G)
    elif layout == "kamada_kawai":
        pos = nx.kamada_kawai_layout(G)
    elif layout == "spectral":
        pos = nx.spectral_layout(G)
    else:
        raise ValueError(f"Unknown layout: {layout}")

    # Determine node colors
    if node_color_attr and hasattr(data, node_color_attr):
        node_colors = getattr(data, node_color_attr).cpu().numpy()
        if node_colors.dtype == bool:
            node_colors = node_colors.astype(int)
        if len(node_colors.shape) > 1:
            # Multi-dimensional features - use first dimension
            node_colors = node_colors[:, 0]
    else:
        node_colors = 'skyblue'

    # Determine edge colors
    if edge_color_attr and hasattr(data, edge_color_attr):
        edge_colors = getattr(data, edge_color_attr).cpu().numpy()
        if len(edge_colors.shape) > 1:
            edge_colors = edge_colors[:, 0]
    else:
        edge_colors = 'gray'

    # Draw graph
    nx.draw_networkx_nodes(
        G, pos,
        node_color=node_colors,
        node_size=node_size,
        cmap=plt.cm.viridis,
        ax=ax
    )

    nx.draw_networkx_edges(
        G, pos,
        edge_color=edge_colors,
        alpha=0.3,
        arrows=isinstance(G, nx.DiGraph),
        arrowsize=10,
        ax=ax
    )

    if show_labels:
        nx.draw_networkx_labels(G, pos, font_size=8, ax=ax)

    ax.set_title(title, fontsize=16, fontweight='bold')
    ax.axis('off')

    # Add colorbar if using numeric node colors
    if node_color_attr and isinstance(node_colors, np.ndarray):
        sm = plt.cm.ScalarMappable(
            cmap=plt.cm.viridis,
            norm=plt.Normalize(vmin=node_colors.min(), vmax=node_colors.max())
        )
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label(node_color_attr, rotation=270, labelpad=20)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {output_path}")
    else:
        plt.show()

    plt.close()


def is_undirected(edge_index):
    """Check if graph is undirected."""
    row, col = edge_index
    num_edges = edge_index.size(1)

    # Create a set of edges and reverse edges
    edges = set(zip(row.tolist(), col.tolist()))
    reverse_edges = set(zip(col.tolist(), row.tolist()))

    # Check if all edges have their reverse
    return edges == reverse_edges


def plot_degree_distribution(data, output_path: Optional[str] = None):
    """Plot the degree distribution of the graph."""
    from torch_geometric.utils import degree

    row, col = data.edge_index
    deg = degree(col, data.num_nodes).cpu().numpy()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Histogram
    ax1.hist(deg, bins=50, edgecolor='black', alpha=0.7)
    ax1.set_xlabel('Degree', fontsize=12)
    ax1.set_ylabel('Frequency', fontsize=12)
    ax1.set_title('Degree Distribution', fontsize=14, fontweight='bold')
    ax1.grid(alpha=0.3)

    # Log-log plot
    unique_degrees, counts = np.unique(deg, return_counts=True)
    ax2.loglog(unique_degrees, counts, 'o-', alpha=0.7)
    ax2.set_xlabel('Degree (log scale)', fontsize=12)
    ax2.set_ylabel('Frequency (log scale)', fontsize=12)
    ax2.set_title('Degree Distribution (Log-Log)', fontsize=14, fontweight='bold')
    ax2.grid(alpha=0.3)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Degree distribution saved to {output_path}")
    else:
        plt.show()

    plt.close()


def plot_graph_statistics(data, output_path: Optional[str] = None):
    """Plot various graph statistics."""
    from torch_geometric.utils import degree, contains_self_loops, is_undirected as check_undirected

    # Compute statistics
    row, col = data.edge_index
    deg = degree(col, data.num_nodes).cpu().numpy()

    stats = {
        'Nodes': data.num_nodes,
        'Edges': data.num_edges,
        'Avg Degree': deg.mean(),
        'Max Degree': deg.max(),
        'Self-loops': contains_self_loops(data.edge_index),
        'Undirected': check_undirected(data.edge_index),
    }

    if hasattr(data, 'num_node_features'):
        stats['Node Features'] = data.num_node_features
    if hasattr(data, 'num_edge_features') and data.edge_attr is not None:
        stats['Edge Features'] = data.num_edge_features
    if hasattr(data, 'y'):
        if data.y.dim() == 1:
            stats['Classes'] = int(data.y.max().item()) + 1

    # Create text plot
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.axis('off')

    text = "Graph Statistics\n" + "=" * 40 + "\n\n"
    for key, value in stats.items():
        text += f"{key:20s}: {value}\n"

    ax.text(0.1, 0.5, text, fontsize=14, family='monospace',
            verticalalignment='center', transform=ax.transAxes)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Statistics saved to {output_path}")
    else:
        plt.show()

    plt.close()

    # Print to console as well
    print("\n" + text)


def main():
    parser = argparse.ArgumentParser(description="Visualize PyTorch Geometric graphs")
    parser.add_argument('--dataset', type=str, default='Cora',
                        help='Dataset name (e.g., Cora, CiteSeer, ENZYMES)')
    parser.add_argument('--output', type=str, default=None,
                        help='Output file path for visualization')
    parser.add_argument('--node-color', type=str, default='y',
                        help='Attribute to use for node colors')
    parser.add_argument('--layout', type=str, default='spring',
                        choices=['spring', 'circular', 'kamada_kawai', 'spectral'],
                        help='Graph layout algorithm')
    parser.add_argument('--show-labels', action='store_true',
                        help='Show node labels')
    parser.add_argument('--max-nodes', type=int, default=500,
                        help='Maximum nodes to visualize')
    parser.add_argument('--stats', action='store_true',
                        help='Show graph statistics')
    parser.add_argument('--degree', action='store_true',
                        help='Show degree distribution')

    args = parser.parse_args()

    # Load dataset
    print(f"Loading dataset: {args.dataset}")

    try:
        # Try Planetoid datasets
        from torch_geometric.datasets import Planetoid
        dataset = Planetoid(root=f'/tmp/{args.dataset}', name=args.dataset)
        data = dataset[0]
    except:
        try:
            # Try TUDataset
            from torch_geometric.datasets import TUDataset
            dataset = TUDataset(root=f'/tmp/{args.dataset}', name=args.dataset)
            data = dataset[0]
        except Exception as e:
            print(f"Error loading dataset: {e}")
            print("Supported datasets: Cora, CiteSeer, PubMed, ENZYMES, PROTEINS, etc.")
            return

    print(f"Loaded {args.dataset}: {data.num_nodes} nodes, {data.num_edges} edges")

    # Generate visualizations
    if args.stats:
        stats_output = args.output.replace('.png', '_stats.png') if args.output else None
        plot_graph_statistics(data, stats_output)

    if args.degree:
        degree_output = args.output.replace('.png', '_degree.png') if args.output else None
        plot_degree_distribution(data, degree_output)

    # Main visualization
    visualize_data(
        data,
        title=f"{args.dataset} Graph",
        node_color_attr=args.node_color,
        show_labels=args.show_labels,
        layout=args.layout,
        output_path=args.output,
        max_nodes=args.max_nodes
    )


if __name__ == '__main__':
    main()
← Back to torch_geometric