#!/usr/bin/env python3
"""
State-space trajectory figure: PCA of per-time-patch encoder hidden states
across full EEG recordings, AD vs healthy, Cunningham/Churchland-style.

For each test patient in alzheimer_adeeg (10 AD + 10 healthy):
  1. Load full recording, push through frozen v6 encoder in chunks.
  2. Encoder is causal + already pools across electrodes internally,
     so model.encode() returns (1, T_patches, hidden_dim).
  3. Trim causal warm-up patches.
  4. Pool all patient embeddings, fit PCA(2), project per-patient.
  5. Smooth each PC with moving average.
  6. Plot trajectories: reds for AD, blues for healthy, light->dark
     gradient along time, black start dot, end arrow.
"""

import argparse
import json
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.cross_decomposition import PLSRegression
from sklearn.decomposition import PCA
from tqdm import tqdm

import matplotlib.animation as animation
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import to_rgb
from scipy.interpolate import CubicSpline

sys.path.insert(0, str(Path(__file__).parent.parent))

from eval.evaluate_full_recording import (  # noqa: E402
    embed_full_recording,
    extract_recording_state,
    load_full_recording,
    load_model,
    resolve_h5_path,
)
from data.eeg_pretrain_dataset_safe import (  # noqa: E402
    get_dataset_channel_mask,
    get_reference_type_id,
)


@torch.no_grad()
def embed_full_recording_per_patch(model, h5_path, dataset, device, chunk_patches=1500,
                                    embedding: str = 'hidden'):
    """Returns (T_patches, D).
    embedding='hidden': D=hidden_dim (512). Final-layer per-patch hidden states.
    embedding='features': D=n_channels*9 (e.g. 576). Per-patch feature predictions
                          via the model's feature_head (5 band powers + 3 Hjorth + 1 var per channel).
    """
    max_ch = model.max_channels
    eeg, file_mask, coords = load_full_recording(h5_path, max_channels=max_ch)
    if eeg.shape[1] < 32:
        return None

    if file_mask is not None:
        channel_mask = torch.from_numpy(file_mask).unsqueeze(0).to(device)
    else:
        channel_mask = torch.from_numpy(
            get_dataset_channel_mask(dataset, max_channels=max_ch)
        ).unsqueeze(0).to(device)

    channel_coords = torch.from_numpy(coords).unsqueeze(0).to(device)
    ref_ids = torch.tensor([get_reference_type_id(dataset)], dtype=torch.long, device=device)
    rec_state = extract_recording_state(h5_path, dataset)
    condition_keys = [(dataset, rec_state)]

    chunk_samples = chunk_patches * 32
    all_out = []
    for start in range(0, eeg.shape[1], chunk_samples):
        chunk = eeg[:, start:start + chunk_samples].unsqueeze(0).to(device)
        chunk = chunk * channel_mask.unsqueeze(-1)
        out = model.encode(
            chunk, condition_keys=condition_keys, channel_mask=channel_mask,
            reference_type_ids=ref_ids, channel_coords=channel_coords,
        )
        h = out['hidden']
        if embedding == 'features':
            h = model.feature_head(h)  # (1, T, n_ch*9)
        all_out.append(h.cpu())

    all_out = torch.cat(all_out, dim=1)  # (1, T, D)
    return all_out.squeeze(0).numpy()  # (T, D)


def load_test_patients(task: str, include_train: bool = False, max_patients: int = None):
    """Return list of (h5_path, label, dataset, patient_id).

    If include_train, pulls from train split too (test first, then train) up to max_patients.
    Balances classes when truncating.
    """
    sp = Path(f'data/splits/classification_v3/{task}/splits.json')
    with open(sp) as f:
        data = json.load(f)
    datasets = data.get('datasets', ['unknown'])

    def _items_from_split(split_name):
        out = []
        for label_name, files in data[split_name].items():
            label = 1 if label_name == 'positive' else 0
            for fpath in files:
                ds = next((d for d in datasets if d in fpath), datasets[0])
                resolved = resolve_h5_path(fpath, ds)
                if resolved:
                    out.append((resolved, label, ds, Path(resolved).stem))
        return out

    test_items = _items_from_split('test')
    if not include_train:
        return test_items

    train_items = _items_from_split('train')
    if max_patients is None:
        return test_items + train_items

    # Cap each class proportionally: prefer test patients, fill from train
    target_per_class = max_patients // 2
    pos = [x for x in test_items if x[1] == 1] + [x for x in train_items if x[1] == 1]
    neg = [x for x in test_items if x[1] == 0] + [x for x in train_items if x[1] == 0]
    return pos[:target_per_class] + neg[:target_per_class]


def fit_jpca(per_patient_dict, k_pca: int = 6):
    """Churchland & Cunningham 2012 jPCA.

    1. Stack all per-patient (T, D) into (sum_T, D), PCA-reduce to k dims.
    2. Fit skew-symmetric M minimizing ||dX - X @ M.T||_F via vectorized LS
       with skew constraint M = -M.T (k(k-1)/2 free params).
    3. Top conjugate eigenvalue pair of M defines the rotational plane.
    4. Project all per-patient trajectories onto that 2D plane.

    Returns (proj_per_patient, omega_top) where omega_top = |Im(top eigval)|.
    """
    pids = list(per_patient_dict.keys())
    X_concat = np.concatenate([per_patient_dict[p] for p in pids], axis=0)
    pca = PCA(n_components=k_pca)
    pca.fit(X_concat)

    # Build (T-1, k) finite differences and (T-1, k) X_minus per patient,
    # then concatenate so dynamics are fit jointly across patients.
    X_minus_all, dX_all = [], []
    for p in pids:
        X_p = pca.transform(per_patient_dict[p])
        X_minus_all.append(X_p[:-1])
        dX_all.append(np.diff(X_p, axis=0))
    X_minus = np.concatenate(X_minus_all, axis=0)  # (Tsum, k)
    dX = np.concatenate(dX_all, axis=0)            # (Tsum, k)

    # Skew-symmetric LS: parameterize M = sum theta_p (E_ij - E_ji) for i<j.
    T, k = X_minus.shape
    pairs = [(i, j) for i in range(k) for j in range(i + 1, k)]
    n_params = len(pairs)
    A = np.zeros((T * k, n_params))
    for p_idx, (i, j) in enumerate(pairs):
        # For each pair, derivative of (M @ x)[c] w.r.t. theta_p:
        #   c==i: + x[j]    (because M[i,j] = +theta)
        #   c==j: - x[i]    (because M[j,i] = -theta)
        rows_i = i * T + np.arange(T)
        rows_j = j * T + np.arange(T)
        A[rows_i, p_idx] = X_minus[:, j]
        A[rows_j, p_idx] = -X_minus[:, i]
    y = dX.flatten(order='F')  # column-major: y[c*T + t] = dX[t, c]
    theta, *_ = np.linalg.lstsq(A, y, rcond=None)

    M = np.zeros((k, k))
    for p_idx, (i, j) in enumerate(pairs):
        M[i, j] = theta[p_idx]
        M[j, i] = -theta[p_idx]

    eigvals, eigvecs = np.linalg.eig(M)
    order = np.argsort(-np.abs(eigvals.imag))
    top = order[0]
    v = eigvecs[:, top]
    basis_re = np.real(v); basis_re /= np.linalg.norm(basis_re) + 1e-12
    basis_im = np.imag(v) - (np.imag(v) @ basis_re) * basis_re
    basis_im /= np.linalg.norm(basis_im) + 1e-12
    basis = np.stack([basis_re, basis_im], axis=1)  # (k, 2)

    # R^2 of skew fit vs unconstrained mean-predict
    dX_pred = X_minus @ M.T
    r2 = 1.0 - np.sum((dX - dX_pred) ** 2) / np.sum((dX - dX.mean(0)) ** 2)
    omega = float(np.abs(eigvals[top].imag))
    print(f'  jPCA: skew-LS fit R^2={r2:.3f}, top |Im(eig)|={omega:.4f} '
          f'(period~{2 * np.pi / max(omega, 1e-6):.0f} steps)')

    proj = {p: pca.transform(per_patient_dict[p]) @ basis for p in pids}
    return proj, omega, r2


def smooth_columns(arr: np.ndarray, window: int) -> np.ndarray:
    """Centered moving average per column with edge handling via min_periods=1."""
    if window <= 1:
        return arr
    df = pd.DataFrame(arr)
    return df.rolling(window=window, center=True, min_periods=1).mean().to_numpy()


def interpolate_trajectory(xy: np.ndarray, n_out: int) -> np.ndarray:
    """Cubic-spline interpolate (T,2) trajectory to (n_out, 2)."""
    n = len(xy)
    if n < 4 or n_out <= n:
        # Fallback to linear interpolation
        t_in = np.linspace(0, 1, n)
        t_out = np.linspace(0, 1, n_out)
        return np.stack([np.interp(t_out, t_in, xy[:, k]) for k in range(2)], axis=1)
    t_in = np.linspace(0, 1, n)
    t_out = np.linspace(0, 1, n_out)
    cs_x = CubicSpline(t_in, xy[:, 0])
    cs_y = CubicSpline(t_in, xy[:, 1])
    return np.stack([cs_x(t_out), cs_y(t_out)], axis=1)


def gradient_line(ax, xy: np.ndarray, base_color, lw: float = 1.6):
    """Draw a polyline whose color goes from light->dark along time."""
    if len(xy) < 2:
        return
    base_rgb = np.array(to_rgb(base_color))
    n_seg = len(xy) - 1
    # Light at start (mix base with white), dark at end (base unmixed)
    t = np.linspace(0.25, 1.0, n_seg)[:, None]  # avoid pure white at very start
    colors = (1 - t) * np.array([1.0, 1.0, 1.0]) + t * base_rgb
    segs = np.stack([xy[:-1], xy[1:]], axis=1)  # (n_seg, 2, 2)
    lc = LineCollection(segs, colors=colors, linewidths=lw, capstyle='round',
                        joinstyle='round', alpha=0.95)
    ax.add_collection(lc)


def animate_trajectories(traj_by_patient, labels_by_patient, axis_labels,
                         output_base, n_frames=120, interp_points=600, fps=24):
    """Build a GIF showing trajectories growing from start to end with cubic-spline interpolation."""
    pids = list(traj_by_patient.keys())
    ad_pids = [p for p in pids if labels_by_patient[p] == 1]
    hc_pids = [p for p in pids if labels_by_patient[p] == 0]

    reds = plt.cm.Reds(np.linspace(0.55, 0.95, max(len(ad_pids), 1)))
    blues = plt.cm.Blues(np.linspace(0.55, 0.95, max(len(hc_pids), 1)))

    # Per-patient: interpolate to common length so animation progress is comparable
    interp = {pid: interpolate_trajectory(traj_by_patient[pid], interp_points) for pid in pids}
    color_for = {}
    for i, pid in enumerate(ad_pids):
        color_for[pid] = reds[i]
    for i, pid in enumerate(hc_pids):
        color_for[pid] = blues[i]

    # Compute fixed plot bounds from all interpolated data
    all_pts = np.concatenate(list(interp.values()), axis=0)
    pad = 0.05 * (all_pts.max(0) - all_pts.min(0))
    xlim = (all_pts[:, 0].min() - pad[0], all_pts[:, 0].max() + pad[0])
    ylim = (all_pts[:, 1].min() - pad[1], all_pts[:, 1].max() + pad[1])

    fig, ax = plt.subplots(figsize=(8, 7), dpi=120)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_aspect('equal', adjustable='datalim')
    ax.set_xlabel(axis_labels[0])
    ax.set_ylabel(axis_labels[1])
    ax.set_title('EEG state-space trajectories — AD vs Healthy\n'
                 'Per-patch hidden state from causal encoder, full recording')
    for s in ('top', 'right'):
        ax.spines[s].set_visible(False)

    # Pre-create one LineCollection per patient and a start dot.
    line_collections = {}
    end_arrows = {}
    for pid in pids:
        lc = LineCollection([], linewidths=1.6, capstyle='round', joinstyle='round',
                            alpha=0.95)
        ax.add_collection(lc)
        line_collections[pid] = lc
        ax.scatter(interp[pid][0, 0], interp[pid][0, 1], s=18, c='black', zorder=5,
                   edgecolors='none')

    from matplotlib.lines import Line2D
    legend_handles = [
        Line2D([0], [0], color=plt.cm.Reds(0.85), lw=2,
               label=f'AD  (n={len(ad_pids)})'),
        Line2D([0], [0], color=plt.cm.Blues(0.85), lw=2,
               label=f'Healthy  (n={len(hc_pids)})'),
        Line2D([0], [0], marker='o', color='black', lw=0, markersize=6,
               label='start'),
    ]
    ax.legend(handles=legend_handles, loc='best', frameon=False, fontsize=9)

    fig.tight_layout()

    def update(frame):
        # frame in [0, n_frames-1]; reveal up to (frame+1)/n_frames of each patient's trajectory
        prog = (frame + 1) / n_frames
        end_idx = max(2, int(prog * interp_points))
        artists = []
        for pid in pids:
            xy = interp[pid][:end_idx]
            base = color_for[pid]
            base_rgb = np.array(to_rgb(base))
            n_seg = max(len(xy) - 1, 1)
            t = np.linspace(0.25, 1.0, n_seg)[:, None]
            colors = (1 - t) * np.array([1.0, 1.0, 1.0]) + t * base_rgb
            segs = np.stack([xy[:-1], xy[1:]], axis=1) if len(xy) >= 2 else np.zeros((0, 2, 2))
            line_collections[pid].set_segments(segs)
            line_collections[pid].set_color(colors)
            artists.append(line_collections[pid])

            # Add an arrow on the final frame
            if frame == n_frames - 1 and pid not in end_arrows and len(xy) >= 2:
                arr = ax.annotate(
                    '', xy=xy[-1], xytext=xy[-2],
                    arrowprops=dict(arrowstyle='->', color=base, lw=1.4,
                                    shrinkA=0, shrinkB=0),
                    zorder=6,
                )
                end_arrows[pid] = arr
        return artists

    anim = animation.FuncAnimation(fig, update, frames=n_frames, interval=1000 // fps,
                                   blit=False, repeat=False)
    gif_path = f'{output_base}.gif'
    print(f'Writing GIF: {gif_path}  ({n_frames} frames @ {fps} fps)')
    anim.save(gif_path, writer=animation.PillowWriter(fps=fps))
    plt.close(fig)
    print(f'Saved: {gif_path}')


def plot_trajectories(traj_by_patient, labels_by_patient, axis_labels, output_base):
    """traj_by_patient: dict pid -> (T, 2). labels_by_patient: dict pid -> 0/1.
    axis_labels: (xlabel, ylabel) strings for the projection."""
    pids = list(traj_by_patient.keys())
    ad_pids = [p for p in pids if labels_by_patient[p] == 1]
    hc_pids = [p for p in pids if labels_by_patient[p] == 0]

    # Per-patient base colors from Reds/Blues, mid-to-dark range so light gradient is visible
    reds = plt.cm.Reds(np.linspace(0.55, 0.95, max(len(ad_pids), 1)))
    blues = plt.cm.Blues(np.linspace(0.55, 0.95, max(len(hc_pids), 1)))

    fig, ax = plt.subplots(figsize=(8, 7), dpi=150)

    def draw_group(pid_list, palette):
        for i, pid in enumerate(pid_list):
            xy = traj_by_patient[pid]
            base = palette[i]
            gradient_line(ax, xy, base, lw=1.4)
            # Start dot
            ax.scatter(xy[0, 0], xy[0, 1], s=18, c='black', zorder=5,
                       edgecolors='none')
            # End arrow (last short segment, base color)
            if len(xy) >= 2:
                ax.annotate(
                    '',
                    xy=xy[-1], xytext=xy[-2],
                    arrowprops=dict(arrowstyle='->', color=base, lw=1.4,
                                    shrinkA=0, shrinkB=0),
                    zorder=6,
                )

    draw_group(ad_pids, reds)
    draw_group(hc_pids, blues)

    ax.set_xlabel(axis_labels[0])
    ax.set_ylabel(axis_labels[1])
    ax.set_title('EEG state-space trajectories — AD vs Healthy\n'
                 'Per-patch hidden state from causal encoder, full recording')

    # Tidy: equal aspect, no top/right spines, no grid
    ax.set_aspect('equal', adjustable='datalim')
    for s in ('top', 'right'):
        ax.spines[s].set_visible(False)
    ax.tick_params(direction='out', length=4)

    # Minimal legend
    from matplotlib.lines import Line2D
    legend_handles = [
        Line2D([0], [0], color=plt.cm.Reds(0.85), lw=2,
               label=f'AD  (n={len(ad_pids)})'),
        Line2D([0], [0], color=plt.cm.Blues(0.85), lw=2,
               label=f'Healthy  (n={len(hc_pids)})'),
        Line2D([0], [0], marker='o', color='black', lw=0, markersize=6,
               label='start'),
        Line2D([0], [0], marker=r'$\rightarrow$', color='gray', lw=0,
               markersize=12, label='end'),
    ]
    ax.legend(handles=legend_handles, loc='best', frameon=False, fontsize=9)

    fig.tight_layout()
    png = f'{output_base}.png'
    pdf = f'{output_base}.pdf'
    fig.savefig(png, dpi=300, bbox_inches='tight')
    fig.savefig(pdf, bbox_inches='tight')
    print(f'Saved: {png}')
    print(f'Saved: {pdf}')


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--checkpoint',
        default='checkpoints/S_variable64_v6_r3_20260329_160409/best_model.pt',
    )
    parser.add_argument('--task', default='alzheimer_adeeg')
    parser.add_argument('--output', default='eval/trajectory_ad_vs_healthy')
    parser.add_argument('--smooth', type=int, default=6,
                        help='Centered moving-average window applied to projected 2D points (post-projection)')
    parser.add_argument('--embed_smooth', type=int, default=1,
                        help='Centered moving-average window applied to the 512-D embeddings BEFORE projection')
    parser.add_argument('--embed_subsample', type=int, default=1,
                        help='Keep every Nth patch from each recording AFTER embed-smoothing, BEFORE projection')
    parser.add_argument('--warmup_trim', type=int, default=30,
                        help='Drop first N patches per recording (causal warm-up)')
    parser.add_argument('--center', choices=['none', 'per_patient'], default='none',
                        help='Subtract per-patient mean from each row before projection. '
                             'Removes between-recording offset, keeps only within-recording dynamics.')
    parser.add_argument('--method', choices=['pca', 'pls', 'umap', 'jpca'], default='pca',
                        help='pca = unsupervised top-2 components. '
                             'pls = supervised linear projection (fit on per-patient mean embeddings). '
                             'umap = supervised UMAP fit on per-patch embeddings with per-patch labels. '
                             'jpca = Churchland skew-symmetric rotational dynamics fit.')
    parser.add_argument('--embedding', choices=['hidden', 'features'], default='hidden',
                        help='hidden = final-layer per-patch hidden states (D=512). '
                             'features = model.feature_head output (D=576: 64 channels x 9 features).')
    parser.add_argument('--jpca_pca_dim', type=int, default=6,
                        help='jPCA only: PCA-reduce to this many dims before fitting skew-symmetric M.')
    parser.add_argument('--umap_subsample', type=int, default=5,
                        help='UMAP only: keep every Nth patch when fitting (memory/speed).')
    parser.add_argument('--umap_neighbors', type=int, default=30,
                        help='UMAP only: n_neighbors hyperparameter.')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--classifier_pkl', default=None,
                        help='Path to pickled classifier dict (e.g. server/classifiers/alzheimer_adeeg.pkl). '
                             'If set, predict on per-patient mean embeddings and filter.')
    parser.add_argument('--keep', choices=['all', 'correct', 'confident_correct'], default='all',
                        help="Filter patients by classifier prediction. 'correct' = pred matches label. "
                             "'confident_correct' = correct AND prob>=0.7 for true class.")
    parser.add_argument('--include_train', action='store_true',
                        help='Include train split patients (in addition to test).')
    parser.add_argument('--max_patients', type=int, default=None,
                        help='Cap total patients (split equally between classes).')
    parser.add_argument('--animate', action='store_true',
                        help='Save an animated GIF instead of static PNG/PDF.')
    parser.add_argument('--anim_frames', type=int, default=120,
                        help='Number of frames in the animation.')
    parser.add_argument('--anim_interp', type=int, default=600,
                        help='Per-patient cubic-spline interpolation length (smoother curves with more points).')
    parser.add_argument('--anim_fps', type=int, default=24, help='GIF frames per second.')
    args = parser.parse_args()

    print(f'Loading model: {args.checkpoint}')
    model, config = load_model(args.checkpoint, args.device)

    print(f'Loading test patients for task: {args.task} '
          f'(include_train={args.include_train}, max={args.max_patients})')
    items = load_test_patients(args.task, include_train=args.include_train,
                               max_patients=args.max_patients)
    n_pos = sum(1 for _, lab, _, _ in items if lab == 1)
    n_neg = sum(1 for _, lab, _, _ in items if lab == 0)
    print(f'  {len(items)} patients ({n_pos} AD / {n_neg} healthy)')

    per_patient = {}
    labels = {}
    for h5_path, label, dataset, pid in tqdm(items, desc='Embedding'):
        emb = embed_full_recording_per_patch(model, h5_path, dataset, args.device,
                                              embedding=args.embedding)
        if emb is None or emb.shape[0] <= args.warmup_trim + 2:
            print(f'  SKIP {pid}: too short')
            continue
        emb = emb[args.warmup_trim:]
        per_patient[pid] = emb
        labels[pid] = label

    print('\nPer-patient patch counts (post warm-up trim):')
    for pid, emb in per_patient.items():
        cls = 'AD' if labels[pid] == 1 else 'HC'
        print(f'  {cls}  {pid}: {emb.shape[0]} patches')

    if args.classifier_pkl:
        import pickle
        with open(args.classifier_pkl, 'rb') as f:
            clf_data = pickle.load(f)
        clf = clf_data['model']
        method = clf_data.get('method', 'xgboost')
        print(f"\nClassifier: {method} (saved AUC={clf_data.get('auc'):.3f})")

        means = np.stack([per_patient[pid].mean(axis=0) for pid in per_patient])
        if 'pca' in clf_data:
            means_in = clf_data['pca'].transform(means)
        else:
            means_in = means
        probs = clf.predict_proba(means_in)[:, 1]  # P(AD)

        print('\nPer-patient classifier predictions:')
        print(f'  {"patient":<35} {"true":<5} {"P(AD)":<7} {"correct":<8} {"keep"}')
        keep_pids = []
        for pid, prob in zip(list(per_patient.keys()), probs):
            true = labels[pid]
            pred = int(prob >= 0.5)
            correct = pred == true
            true_class_prob = prob if true == 1 else 1 - prob
            confident = true_class_prob >= 0.7
            if args.keep == 'all':
                keep = True
            elif args.keep == 'correct':
                keep = correct
            else:  # confident_correct
                keep = correct and confident
            tag = 'AD' if true == 1 else 'HC'
            mark = ' OK ' if correct else 'MISS'
            kp = ' YES' if keep else '  no'
            print(f'  {tag} {pid:<32} {true:<5} {prob:.3f}   {mark}     {kp}')
            if keep:
                keep_pids.append(pid)

        per_patient = {pid: per_patient[pid] for pid in keep_pids}
        labels = {pid: labels[pid] for pid in keep_pids}
        n_pos = sum(1 for v in labels.values() if v == 1)
        n_neg = len(labels) - n_pos
        print(f'\n  Kept {len(labels)} patients ({n_pos} AD / {n_neg} HC)')

    if args.embed_smooth > 1 or args.embed_subsample > 1:
        new_per_patient = {}
        for pid, emb in per_patient.items():
            sm = smooth_columns(emb, args.embed_smooth) if args.embed_smooth > 1 else emb
            if args.embed_subsample > 1:
                sm = sm[::args.embed_subsample]
            new_per_patient[pid] = sm
        per_patient = new_per_patient
        sample_pid = next(iter(per_patient))
        print(f'  Smoothed 512-D embeddings (window={args.embed_smooth}), '
              f'subsampled every {args.embed_subsample}th patch. '
              f'Example: {sample_pid} now {per_patient[sample_pid].shape[0]} patches.')

    if args.center == 'per_patient':
        for pid in list(per_patient.keys()):
            per_patient[pid] = per_patient[pid] - per_patient[pid].mean(axis=0, keepdims=True)
        print('  Subtracted per-patient mean from each recording.')

    X_all = np.concatenate(list(per_patient.values()), axis=0)

    if args.method == 'jpca':
        print(f'\nFitting jPCA (k={args.jpca_pca_dim}) on {args.embedding} embeddings')
        proj_traj, omega, r2 = fit_jpca(per_patient, k_pca=args.jpca_pca_dim)
        axis_labels = (f'jPC1 (rot, R^2={r2:.2f})', f'jPC2 (omega={omega:.3f}/step)')
        traj = {pid: smooth_columns(proj_traj[pid], args.smooth) for pid in proj_traj}
        out_dir = Path(args.output).parent
        out_dir.mkdir(parents=True, exist_ok=True)
        if args.animate:
            animate_trajectories(traj, labels, axis_labels, args.output,
                                 n_frames=args.anim_frames,
                                 interp_points=args.anim_interp,
                                 fps=args.anim_fps)
        else:
            plot_trajectories(traj, labels, axis_labels, args.output)
        return

    if args.method == 'pca':
        print(f'\nFitting PCA on pooled matrix: {X_all.shape}')
        proj_model = PCA(n_components=2)
        proj_model.fit(X_all)
        var = proj_model.explained_variance_ratio_
        print(f'  PC1 var={var[0]:.3%}  PC2 var={var[1]:.3%}  sum={var.sum():.3%}')
        axis_labels = (f'PC1  ({var[0]:.1%} var.)', f'PC2  ({var[1]:.1%} var.)')

        def project(x):
            return proj_model.transform(x)

    elif args.method == 'umap':
        import umap
        # Per-patch labels broadcast from patient labels
        labels_per_patch = np.concatenate([
            np.full(per_patient[p].shape[0], labels[p], dtype=np.int64)
            for p in per_patient.keys()
        ])
        sub = max(1, args.umap_subsample)
        Xfit = X_all[::sub]
        yfit = labels_per_patch[::sub]
        print(f'\nFitting supervised UMAP on subsampled patches: '
              f'{Xfit.shape}  (every {sub}th patch), '
              f'n_neighbors={args.umap_neighbors}, target_metric=categorical')
        reducer = umap.UMAP(
            n_components=2,
            n_neighbors=args.umap_neighbors,
            min_dist=0.1,
            metric='euclidean',
            target_metric='categorical',
            random_state=42,
            verbose=True,
        )
        reducer.fit(Xfit, y=yfit)
        proj_model = reducer
        axis_labels = ('UMAP 1  (supervised on AD/HC)', 'UMAP 2')

        def project(x):
            return reducer.transform(x)

    else:  # 'pls' — supervised on per-patient means
        pids_ord = list(per_patient.keys())
        X_means = np.stack([per_patient[p].mean(axis=0) for p in pids_ord])  # (20, 512)
        y = np.array([labels[p] for p in pids_ord], dtype=float)             # (20,)
        print(f'\nFitting PLS(2) on per-patient means: {X_means.shape}, '
              f'labels {int(y.sum())}+ / {int((1 - y).sum())}-')
        pls = PLSRegression(n_components=2, scale=False)
        pls.fit(X_means, y)
        proj_model = pls
        # Sanity: how well does the 2-component PLS separate the means?
        means_2d = pls.transform(X_means)
        ad_mean_pc1 = means_2d[y == 1, 0].mean()
        hc_mean_pc1 = means_2d[y == 0, 0].mean()
        print(f'  Patient-mean separation along PLS1: '
              f'AD={ad_mean_pc1:+.3f}, HC={hc_mean_pc1:+.3f}, '
              f'gap={abs(ad_mean_pc1 - hc_mean_pc1):.3f}')
        axis_labels = ('PLS1  (AD vs HC discriminant)', 'PLS2  (orthogonal)')

        def project(x):
            return pls.transform(x)

    traj = {}
    for pid, emb in per_patient.items():
        proj = project(emb)
        proj = smooth_columns(proj, args.smooth)
        traj[pid] = proj

    out_dir = Path(args.output).parent
    out_dir.mkdir(parents=True, exist_ok=True)
    if args.animate:
        animate_trajectories(traj, labels, axis_labels, args.output,
                             n_frames=args.anim_frames,
                             interp_points=args.anim_interp,
                             fps=args.anim_fps)
    else:
        plot_trajectories(traj, labels, axis_labels, args.output)


if __name__ == '__main__':
    main()
