#!/usr/bin/env python3
"""Publication-quality block diagram of the EEG transformer encoder.

Layout: three-panel figure
  - Left  panel: end-to-end flow (input → patch embed → prefix → transformer → heads)
  - Top right: zoom into one transformer block
  - Bot right: zoom into the variable-channel spatial patch embedder
                + three example electrode layouts (4ch / 19ch / 64ch)

No dataset/parameter statistics. Architecture only.
"""

from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, FancyBboxPatch, Rectangle, Circle, Polygon
import matplotlib as mpl


# ─── PALETTE (muted, print-safe) ─────────────────────────────────────────────
C_INPUT    = '#cfd8dc'   # neutral grey-blue – data
C_EMBED    = '#b5d4c2'   # sage green        – patch embedder
C_PREFIX_E = '#f3c7a0'   # warm tan          – experiment / eeg / register
C_PREFIX_C = '#e89e6b'   # darker tan        – condition (LLM-derived)
C_PREFIX_R = '#f8e1c8'   # pale tan          – register tokens
C_PATCH    = '#a8c2e6'   # light blue        – per-time-patch tokens
C_BLOCK    = '#c9b6dd'   # lavender          – transformer block
C_ATTN     = '#9d7fc4'   # darker lavender   – attention sub-block
C_FFN      = '#bea4d6'   # mid lavender      – FFN sub-block
C_NORM     = '#e6dcef'   # very pale lavender– RMSNorm
C_HEAD_R   = '#e6a4a4'   # dusty red         – next-patch head
C_HEAD_F   = '#dfa9bf'   # dusty pink        – feature head
C_LINE     = '#37474f'
C_TEXT     = '#1c2126'
C_MUTE     = '#5b6770'

mpl.rcParams['font.family'] = 'DejaVu Sans'
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype']  = 42


# ─── PRIMITIVES ──────────────────────────────────────────────────────────────
def rbox(ax, x, y, w, h, *, fc, ec=C_LINE, lw=1.0, r=0.04, z=2):
    p = FancyBboxPatch((x, y), w, h,
                       boxstyle=f'round,pad=0,rounding_size={r}',
                       linewidth=lw, edgecolor=ec, facecolor=fc, zorder=z)
    ax.add_patch(p)
    return p


def label(ax, x, y, s, *, size=10, weight='normal', color=C_TEXT,
          ha='center', va='center', style='normal', z=4):
    ax.text(x, y, s, ha=ha, va=va, fontsize=size, weight=weight,
            color=color, style=style, zorder=z)


def arr(ax, x1, y1, x2, y2, *, color=C_LINE, lw=1.4, ms=14, z=3,
        style='-|>'):
    a = FancyArrowPatch((x1, y1), (x2, y2), arrowstyle=style,
                        mutation_scale=ms, color=color,
                        linewidth=lw, shrinkA=0, shrinkB=0, zorder=z)
    ax.add_patch(a)


# ─── PANEL A: main flow ──────────────────────────────────────────────────────
def draw_main_flow(ax):
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 24)
    ax.set_aspect('equal')
    ax.axis('off')

    cx = 5.0
    bw = 7.0
    bx = cx - bw/2

    # ── Panel title ──
    label(ax, cx, 23.4, 'A.  Encoder forward pass',
          size=13, weight='bold', ha='center')

    # ── 1. Input EEG ──
    rbox(ax, bx, 21.7, bw, 1.1, fc=C_INPUT)
    label(ax, cx, 22.45, 'Input EEG signal',
          size=11, weight='bold')
    label(ax, cx, 21.97,
          'channels  $\\times$  time     (variable channel count)',
          size=9, color=C_MUTE)

    arr(ax, cx, 21.7, cx, 20.95)
    label(ax, cx + 0.20, 21.32, 'split time into patches',
          size=8.5, color=C_MUTE, ha='left', style='italic')

    # ── 2. Patches ──
    px, py, pw, ph = bx + 0.7, 19.85, bw - 1.4, 1.05
    rbox(ax, px, py, pw, ph, fc=C_INPUT)
    label(ax, cx, 20.65, 'Patches', size=10.5, weight='bold')
    label(ax, cx, 20.20,
          'patches  $\\times$  channels  $\\times$  patch length',
          size=9, color=C_MUTE)

    arr(ax, cx, 19.85, cx, 19.10)
    label(ax, cx + 0.20, 19.47,
          'spatial attention across electrodes (per patch)',
          size=8.5, color=C_MUTE, ha='left', style='italic')

    # ── 3. Patch embedder ──
    rbox(ax, bx, 17.80, bw, 1.30, fc=C_EMBED)
    label(ax, cx, 18.65,
          'Variable-Channel Spatial Patch Embedder',
          size=11, weight='bold')
    label(ax, cx, 18.25,
          'electrode tokens  →  spatial self-attention  →  masked-mean pool',
          size=9, color=C_MUTE)
    label(ax, cx, 17.97,
          '2D-RoPE on scalp coordinates;  inactive channels masked',
          size=8.5, color=C_MUTE, style='italic')

    arr(ax, cx, 17.80, cx, 17.05)
    label(ax, cx + 0.20, 17.42,
          'one patch token per time window',
          size=8.5, color=C_MUTE, ha='left', style='italic')

    # ── 4. Prefix tokens ──
    label(ax, cx, 16.65,
          'Prepend learned prefix tokens',
          size=11, weight='bold')

    seq_y, seq_h = 15.5, 0.9
    n_reg = 8
    n_total_shown = 3 + n_reg + 4
    tw = (bw - 0.4) / n_total_shown
    sx = bx + 0.2
    token_specs = [
        ('exp',  C_PREFIX_E),
        ('cond', C_PREFIX_C),
        ('eeg',  C_PREFIX_E),
    ]
    for i in range(n_reg):
        token_specs.append((f'r{i+1}', C_PREFIX_R))
    for name in ['p$_1$', 'p$_2$', '...', 'p$_N$']:
        token_specs.append((name, C_PATCH))

    for i, (name, fc) in enumerate(token_specs):
        x0 = sx + i * tw
        rbox(ax, x0 + 0.02, seq_y, tw - 0.04, seq_h, fc=fc, lw=0.8, r=0.025)
        label(ax, x0 + tw/2, seq_y + seq_h/2, name,
              size=7.8, weight='bold')

    # bracket annotations under the row
    by_top = seq_y - 0.08
    by_bot = seq_y - 0.32

    def bracket(x0, x1, text, dy=0.0):
        ax.plot([x0, x0, x1, x1],
                [by_top, by_bot, by_bot, by_top],
                color=C_LINE, lw=0.9)
        label(ax, (x0+x1)/2, by_bot - 0.25 + dy, text,
              size=8.2, color=C_MUTE)

    bracket(sx + 0.02, sx + 3*tw - 0.02, 'modality / context')
    bracket(sx + 3*tw + 0.02, sx + (3+n_reg)*tw - 0.02,
            'register tokens')
    bracket(sx + (3+n_reg)*tw + 0.02, sx + (3+n_reg+4)*tw - 0.02,
            'patch tokens')

    # condition annotation — centered under the prefix row
    label(ax, cx, by_bot - 0.55,
          'condition cell = mean-pool of frozen LLM encoding of (dataset, recording state)',
          size=7.8, color=C_MUTE, ha='center', va='top', style='italic')

    arr(ax, cx, 14.55, cx, 13.85)

    # ── 5. Transformer stack ──
    stack_y = 10.7
    stack_h = 3.1
    rbox(ax, bx, stack_y, bw, stack_h, fc=C_BLOCK, lw=1.2)
    label(ax, cx, stack_y + stack_h - 0.40,
          'Transformer Encoder Stack',
          size=11.5, weight='bold')
    label(ax, cx, stack_y + stack_h - 0.80,
          'N identical pre-norm blocks   •   causal self-attention',
          size=9, color=C_MUTE, style='italic')

    layer_w = bw - 2.4
    layer_h = 0.32
    layer_x = bx + 0.6
    base_y  = stack_y + 0.40
    n_layers_drawn = 4
    for i in range(n_layers_drawn):
        y = base_y + i * (layer_h + 0.06)
        rbox(ax, layer_x, y, layer_w, layer_h, fc=C_NORM, lw=0.7, r=0.02)
        label(ax, layer_x + layer_w/2, y + layer_h/2,
              'block', size=8, color=C_TEXT)
    label(ax, layer_x + layer_w/2,
          base_y + n_layers_drawn * (layer_h + 0.06) + 0.08,
          '$\\vdots$', size=12, color=C_TEXT)

    # causal mask inset on the right
    cm_s = 0.95
    cm_x = bx + bw - cm_s - 0.30
    cm_y = stack_y + 0.50
    _draw_causal_mask(ax, cm_x, cm_y, cm_s)
    label(ax, cm_x + cm_s/2, cm_y - 0.22, 'causal mask',
          size=7.8, color=C_MUTE)

    arr(ax, cx, stack_y, cx, stack_y - 0.65)
    label(ax, cx + 0.20, stack_y - 0.32,
          'strip prefix tokens',
          size=8.5, color=C_MUTE, ha='left', style='italic')

    # ── 6. Per-patch hidden states ──
    rbox(ax, bx + 1.3, 8.85, bw - 2.6, 1.0, fc=C_PATCH)
    label(ax, cx, 9.50, 'Per-patch hidden states',
          size=10.5, weight='bold')
    label(ax, cx, 9.10, 'one vector per patch token',
          size=8.5, color=C_MUTE)

    # branch arrows to two heads
    arr(ax, bx + 2.3, 8.85, bx + 1.3, 7.55)
    arr(ax, bx + bw - 2.3, 8.85, bx + bw - 1.3, 7.55)

    # ── 7. Two prediction heads ──
    h_y, h_h = 5.85, 1.70
    half_w = bw/2 - 0.3

    # next patch head
    rbox(ax, bx, h_y, half_w, h_h, fc=C_HEAD_R)
    label(ax, bx + half_w/2, h_y + h_h - 0.35,
          'Next-patch head', size=10.5, weight='bold')
    label(ax, bx + half_w/2, h_y + h_h - 0.85,
          'SwiGLU residual block',
          size=8.5, color=C_MUTE)
    label(ax, bx + half_w/2, h_y + h_h - 1.15,
          '→ Linear projection',
          size=8.5, color=C_MUTE)
    label(ax, bx + half_w/2, h_y + 0.27,
          'predicts the next raw patch',
          size=8.2, color=C_TEXT, style='italic')

    # feature head
    fhx = bx + bw/2 + 0.3
    rbox(ax, fhx, h_y, half_w, h_h, fc=C_HEAD_F)
    label(ax, fhx + half_w/2, h_y + h_h - 0.35,
          'Feature head', size=10.5, weight='bold')
    label(ax, fhx + half_w/2, h_y + h_h - 0.85,
          'SwiGLU residual block',
          size=8.5, color=C_MUTE)
    label(ax, fhx + half_w/2, h_y + h_h - 1.15,
          '→ Linear projection',
          size=8.5, color=C_MUTE)
    label(ax, fhx + half_w/2, h_y + 0.27,
          'predicts spectral features',
          size=8.2, color=C_TEXT, style='italic')

    # losses
    label(ax, bx + half_w/2, h_y - 0.40,
          'mean-squared error',
          size=9.5, color=C_TEXT, weight='bold')
    label(ax, fhx + half_w/2, h_y - 0.40,
          'smooth $\\ell_1$ error',
          size=9.5, color=C_TEXT, weight='bold')

    label(ax, cx, h_y - 0.85,
          'self-supervised pre-training objective  (training only)',
          size=8.8, color=C_MUTE, style='italic')

    # legend strip at bottom
    _draw_legend(ax, 0.0, 2.55, 10.0, 1.55)


def _draw_causal_mask(ax, x, y, s):
    """Mini triangular attention mask icon."""
    n = 6
    cell = s / n
    for i in range(n):
        for j in range(n):
            xc = x + j * cell
            yc = y + (n - 1 - i) * cell
            fc = C_ATTN if j <= i else '#f5f0fa'
            ax.add_patch(Rectangle((xc, yc), cell*0.92, cell*0.92,
                                   facecolor=fc, edgecolor='none', zorder=4))
    ax.add_patch(Rectangle((x, y), s, s, facecolor='none',
                           edgecolor=C_LINE, lw=0.6, zorder=5))


def _draw_legend(ax, x, y, w, h):
    """Color legend strip — 4 rows × 2 cols (avoids text overlap)."""
    rbox(ax, x, y, w, h, fc='#fafafa', ec='#cccccc', lw=0.7, r=0.04)
    label(ax, x + w/2, y + h - 0.18, 'Legend',
          size=9, weight='bold', color=C_TEXT)
    items = [
        (C_INPUT,    'data tensor'),
        (C_EMBED,    'patch embedder'),
        (C_PATCH,    'patch / hidden token'),
        (C_BLOCK,    'transformer stack'),
        (C_PREFIX_E, 'prefix token (learned)'),
        (C_PREFIX_C, 'prefix token (LLM-derived)'),
        (C_HEAD_R,   'next-patch head'),
        (C_HEAD_F,   'feature head'),
    ]
    cols = 2
    rows = 4
    inner_top = y + h - 0.40
    inner_h = h - 0.50
    cell_w = w / cols
    cell_h = inner_h / rows
    for idx, (col, name) in enumerate(items):
        r = idx // cols
        c = idx % cols
        cx0 = x + c * cell_w + 0.30
        cy0 = inner_top - (r + 1) * cell_h + cell_h/2 - 0.13
        ax.add_patch(Rectangle((cx0, cy0), 0.28, 0.26,
                               facecolor=col, edgecolor=C_LINE, lw=0.6, zorder=4))
        ax.text(cx0 + 0.40, cy0 + 0.13, name,
                fontsize=8.7, va='center', ha='left', color=C_TEXT, zorder=5)


# ─── PANEL B: transformer block zoom ─────────────────────────────────────────
def draw_block_zoom(ax):
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 13)
    ax.set_aspect('equal')
    ax.axis('off')

    label(ax, 5.0, 12.45,
          'B.  One transformer block  (pre-norm)',
          size=13, weight='bold')

    cx = 5.5
    w  = 4.4
    bx = cx - w/2

    # input
    in_y = 11.35
    rbox(ax, bx + 1.2, in_y, 2.0, 0.55, fc=C_PATCH)
    label(ax, cx, in_y + 0.275, 'input  $x$', size=10, weight='bold')

    # split for skip 1
    split1_y = 11.00
    arr(ax, cx, in_y, cx, split1_y, ms=10)
    ax.add_patch(Circle((cx, split1_y - 0.07), 0.07,
                        facecolor=C_LINE, zorder=5))
    skip1_x = bx + w + 0.55
    ax.plot([cx, skip1_x, skip1_x], [split1_y - 0.07, split1_y - 0.07, 8.65],
            color=C_LINE, lw=1.0, ls='--', zorder=2)

    # RMSNorm 1
    rbox(ax, bx + 1.0, 10.10, 2.4, 0.55, fc=C_NORM)
    label(ax, cx, 10.375, 'RMSNorm', size=9.5)
    arr(ax, cx, split1_y - 0.15, cx, 10.65, ms=10)
    arr(ax, cx, 10.10, cx, 9.75, ms=10)

    # GQA box (label outside, cleaner)
    gqa_y, gqa_h = 8.95, 0.80
    rbox(ax, bx, gqa_y, w, gqa_h, fc=C_ATTN)
    label(ax, cx, gqa_y + gqa_h - 0.25,
          'Grouped-Query Attention',
          size=10.5, weight='bold', color='white')
    label(ax, cx, gqa_y + 0.22,
          'QK-norm  •  RoPE  •  causal mask',
          size=8.0, color='white', style='italic')
    arr(ax, cx, 9.75, cx, gqa_y + gqa_h, ms=10)

    # plus 1
    plus1_y = 8.55
    arr(ax, cx, gqa_y, cx, plus1_y + 0.20, ms=10)
    _draw_plus(ax, cx, plus1_y, r=0.20)
    arr(ax, skip1_x, plus1_y, cx + 0.20, plus1_y, ms=10)

    # split for skip 2
    split2_y = 8.10
    arr(ax, cx, plus1_y - 0.20, cx, split2_y, ms=10)
    ax.add_patch(Circle((cx, split2_y - 0.07), 0.07,
                        facecolor=C_LINE, zorder=5))
    skip2_x = bx + w + 0.55
    ax.plot([cx, skip2_x, skip2_x], [split2_y - 0.07, split2_y - 0.07, 5.65],
            color=C_LINE, lw=1.0, ls='--', zorder=2)

    # RMSNorm 2
    rbox(ax, bx + 1.0, 7.20, 2.4, 0.55, fc=C_NORM)
    label(ax, cx, 7.475, 'RMSNorm', size=9.5)
    arr(ax, cx, split2_y - 0.15, cx, 7.75, ms=10)
    arr(ax, cx, 7.20, cx, 6.85, ms=10)

    # SwiGLU FFN
    ffn_y, ffn_h = 6.05, 0.80
    rbox(ax, bx, ffn_y, w, ffn_h, fc=C_FFN)
    label(ax, cx, ffn_y + ffn_h - 0.25,
          'SwiGLU Feed-Forward',
          size=10.5, weight='bold', color='white')
    label(ax, cx, ffn_y + 0.22,
          'gated linear unit  •  SiLU activation',
          size=8.0, color='white', style='italic')
    arr(ax, cx, 6.85, cx, ffn_y + ffn_h, ms=10)

    # plus 2
    plus2_y = 5.55
    arr(ax, cx, ffn_y, cx, plus2_y + 0.20, ms=10)
    _draw_plus(ax, cx, plus2_y, r=0.20)
    arr(ax, skip2_x, plus2_y, cx + 0.20, plus2_y, ms=10)

    # output
    out_y = 4.65
    arr(ax, cx, plus2_y - 0.20, cx, out_y + 0.55, ms=10)
    rbox(ax, bx + 1.2, out_y, 2.0, 0.55, fc=C_PATCH)
    label(ax, cx, out_y + 0.275, 'output', size=10, weight='bold')

    # repeated × N annotation on far left
    ax.annotate('', xy=(0.65, out_y + 0.28),
                xytext=(0.65, in_y + 0.28),
                arrowprops=dict(arrowstyle='-|>', color=C_MUTE,
                                lw=1.0, connectionstyle='arc3'),
                zorder=3)
    label(ax, 0.40, (out_y + in_y) / 2,
          'repeated\n× N layers',
          size=9, color=C_MUTE, ha='right', style='italic')

    # legend keys for skip / data flow
    label(ax, 5.0, 4.10,
          'solid arrow  =  forward path        dashed arrow  =  residual skip',
          size=8.5, color=C_MUTE, style='italic')


def _draw_plus(ax, x, y, r=0.18):
    ax.add_patch(Circle((x, y), r, facecolor='white',
                        edgecolor=C_LINE, lw=1.0, zorder=5))
    ax.plot([x - r*0.55, x + r*0.55], [y, y],
            color=C_LINE, lw=1.4, zorder=6)
    ax.plot([x, x], [y - r*0.55, y + r*0.55],
            color=C_LINE, lw=1.4, zorder=6)


# ─── PANEL C: patch embedder zoom ───────────────────────────────────────────
def draw_embedder_zoom(ax):
    # NOTE: do NOT set_aspect('equal') here — we want the subplot to fill its
    # allotted width. Scalp drawing manually computes its own aspect.
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    ax.axis('off')
    # Pixel aspect for this panel: bottom-right cell is wider than tall.
    # 1 data-unit in x ≈ 0.86 in, in y ≈ 0.46 in.  px_aspect = inches_y / inches_x.
    # To draw a round circle, ry = rx / px_aspect.
    px_aspect = 0.46 / 0.86  # ≈ 0.535

    label(ax, 5.0, 9.50,
          'C.  Inside the spatial patch embedder',
          size=13, weight='bold')
    label(ax, 5.0, 9.05,
          'one time-patch processed independently;  attention runs across electrodes',
          size=9, color=C_MUTE, style='italic')

    # divider between left (process) and right (variable channel sets)
    ax.plot([6.50, 6.50], [0.4, 8.40], color='#cccccc', lw=0.8, zorder=1)

    # ── LEFT half: scalp → masked-mean → patch token ──
    sx, sy, sr = 1.55, 5.20, 1.30
    _draw_scalp_aniso(ax, sx, sy, sr, px_aspect, layout='19', detailed=True)
    label(ax, sx, sy - sr/px_aspect - 0.45,
          'electrodes  →  query / key / value',
          size=9, color=C_TEXT, ha='center', weight='bold')
    label(ax, sx, sy - sr/px_aspect - 0.85,
          '2D-RoPE injects scalp coordinates',
          size=8.2, color=C_MUTE, ha='center', style='italic')

    # arrow → patch token
    arr_x0 = sx + sr + 0.15
    arr_x1 = 4.50
    ax.annotate('', xy=(arr_x1, sy), xytext=(arr_x0, sy),
                arrowprops=dict(arrowstyle='-|>', color=C_LINE, lw=1.6,
                                shrinkA=0, shrinkB=2),
                zorder=3)
    label(ax, (arr_x0 + arr_x1)/2, sy + 0.50,
          'masked-mean pool',
          size=9, color=C_TEXT, weight='bold', ha='center')
    label(ax, (arr_x0 + arr_x1)/2, sy + 0.20,
          '(inactive channels excluded)',
          size=7.8, color=C_MUTE, ha='center', style='italic')

    # patch token (output)
    rbox(ax, 4.65, sy - 0.55, 1.50, 1.10, fc=C_PATCH)
    label(ax, 4.65 + 1.50/2, sy + 0.18, 'patch', size=10, weight='bold')
    label(ax, 4.65 + 1.50/2, sy - 0.18, 'token', size=10, weight='bold')

    # ── RIGHT half: three example layouts stacked vertically ──
    label(ax, 8.30, 8.10, 'Same module handles',
          size=10, weight='bold', ha='center')
    label(ax, 8.30, 7.70, 'any electrode set',
          size=10, weight='bold', ha='center')

    layouts = [
        ('few electrodes',     '4'),
        ('standard 10-20 set', '19'),
        ('high-density set',   '64'),
    ]
    # 3 small scalps stacked vertically with sufficient spacing
    small_r = 0.42
    label_gap = small_r / px_aspect + 0.30
    spacing_y = 2 * (small_r / px_aspect) + 0.85
    by_top = 7.10
    for i, (name, layout) in enumerate(layouts):
        cx_l = 8.30
        cy_l = by_top - i * spacing_y
        _draw_scalp_aniso(ax, cx_l, cy_l, small_r, px_aspect,
                          layout=layout, detailed=False)
        label(ax, cx_l, cy_l - label_gap, name,
              size=8.5, color=C_TEXT, ha='center')


def _draw_scalp_aniso(ax, cx, cy, r, px_aspect, layout='19', detailed=True):
    """Draw a scalp on an axis WITHOUT equal aspect — uses Ellipse to keep round."""
    from matplotlib.patches import Ellipse, Polygon

    ry = r / px_aspect   # data-units in y so it looks round on screen

    # head
    ax.add_patch(Ellipse((cx, cy), 2*r, 2*ry,
                         facecolor='#f6f1e7', edgecolor=C_LINE, lw=1.2, zorder=2))
    # nose
    nose = Polygon([(cx - 0.12*r, cy + ry),
                    (cx + 0.12*r, cy + ry),
                    (cx, cy + ry * 1.18)],
                   facecolor='#f6f1e7', edgecolor=C_LINE, lw=1.0, zorder=2)
    ax.add_patch(nose)
    # ears
    ax.add_patch(Ellipse((cx - r, cy), 2*r*0.10, 2*ry*0.18,
                         facecolor='#f6f1e7', edgecolor=C_LINE, lw=1.0, zorder=1))
    ax.add_patch(Ellipse((cx + r, cy), 2*r*0.10, 2*ry*0.18,
                         facecolor='#f6f1e7', edgecolor=C_LINE, lw=1.0, zorder=1))

    # electrode positions normalized (-1, 1)
    if layout == '4':
        positions = [(-0.50, 0.55), (0.50, 0.55),
                     (-0.50, -0.55), (0.50, -0.55)]
    elif layout == '19':
        positions = [
            (0, 0.85), (0, -0.85), (-0.85, 0), (0.85, 0),
            (-0.45, 0.65), (0.45, 0.65),
            (-0.45, -0.65), (0.45, -0.65),
            (-0.7, 0.35), (0.7, 0.35), (-0.7, -0.35), (0.7, -0.35),
            (0, 0.42), (0, 0), (0, -0.42),
            (-0.42, 0.0), (0.42, 0.0),
            (-0.42, 0.30), (0.42, 0.30),
        ]
    elif layout == '64':
        positions = []
        rng = np.random.default_rng(7)
        for ring_r, n in [(0.20, 4), (0.42, 10), (0.62, 14), (0.80, 18), (0.93, 18)]:
            phase = rng.uniform(0, 2*np.pi)
            for k in range(n):
                a = 2*np.pi * k / n + phase
                positions.append((ring_r * np.cos(a), ring_r * np.sin(a)))
        positions = positions[:64]
    else:
        positions = []

    el_r_data = 0.075 * r if detailed else 0.085 * r
    for (px, py) in positions:
        ax.add_patch(Ellipse((cx + px*r, cy + py*ry),
                             2*el_r_data, 2*el_r_data/px_aspect,
                             facecolor='#263238', edgecolor='white',
                             lw=0.4, zorder=4))

    # detailed: faint attention lines
    if detailed and layout == '19':
        chosen = [0, 4, 8, 13, 11, 7, 1]
        pts = [positions[i] for i in chosen]
        for i in range(len(pts)):
            for j in range(i+1, len(pts)):
                x1, y1 = cx + pts[i][0]*r, cy + pts[i][1]*ry
                x2, y2 = cx + pts[j][0]*r, cy + pts[j][1]*ry
                ax.plot([x1, x2], [y1, y2],
                        color=C_ATTN, lw=0.7, alpha=0.55, zorder=3)


def _draw_scalp(ax, cx, cy, r, layout='19', simple=False):
    """Draw a top-down scalp with electrodes."""
    # head outline
    ax.add_patch(Circle((cx, cy), r, facecolor='#f6f1e7',
                        edgecolor=C_LINE, lw=1.2, zorder=2))
    # nose triangle
    nose = Polygon([(cx - 0.12*r, cy + r),
                    (cx + 0.12*r, cy + r),
                    (cx, cy + r * 1.18)],
                   facecolor='#f6f1e7', edgecolor=C_LINE, lw=1.0, zorder=2)
    ax.add_patch(nose)
    # ears
    ax.add_patch(Circle((cx - r, cy), r*0.10, facecolor='#f6f1e7',
                        edgecolor=C_LINE, lw=1.0, zorder=1))
    ax.add_patch(Circle((cx + r, cy), r*0.10, facecolor='#f6f1e7',
                        edgecolor=C_LINE, lw=1.0, zorder=1))

    # electrode positions
    if layout == '4':
        positions = [(-0.50, 0.55), (0.50, 0.55),
                     (-0.50, -0.55), (0.50, -0.55)]
    elif layout == '19':
        positions = [
            (0, 0.85), (0, -0.85), (-0.85, 0), (0.85, 0),
            (-0.45, 0.65), (0.45, 0.65),
            (-0.45, -0.65), (0.45, -0.65),
            (-0.7, 0.35), (0.7, 0.35), (-0.7, -0.35), (0.7, -0.35),
            (0, 0.42), (0, 0), (0, -0.42),
            (-0.42, 0.0), (0.42, 0.0),
            (-0.42, 0.30), (0.42, 0.30),
        ]
    elif layout == '64':
        positions = []
        rng = np.random.default_rng(7)
        for ring_r, n in [(0.20, 4), (0.42, 10), (0.62, 14), (0.80, 18), (0.93, 18)]:
            phase = rng.uniform(0, 2*np.pi)
            for k in range(n):
                a = 2*np.pi * k / n + phase
                positions.append((ring_r * np.cos(a), ring_r * np.sin(a)))
        positions = positions[:64]
    else:
        positions = []

    el_r = 0.07 * r if simple else 0.075 * r
    el_color = '#263238'
    for (px, py) in positions:
        ax.add_patch(Circle((cx + px*r, cy + py*r), el_r,
                            facecolor=el_color, edgecolor='white',
                            lw=0.4, zorder=4))

    # detailed scalp: faint attention lines
    if not simple and layout == '19':
        chosen = [0, 4, 8, 13, 11, 7, 1]
        pts = [positions[i] for i in chosen]
        for i in range(len(pts)):
            for j in range(i+1, len(pts)):
                x1, y1 = cx + pts[i][0]*r, cy + pts[i][1]*r
                x2, y2 = cx + pts[j][0]*r, cy + pts[j][1]*r
                ax.plot([x1, x2], [y1, y2],
                        color=C_ATTN, lw=0.7, alpha=0.55, zorder=3)


# ─── COMPOSE ─────────────────────────────────────────────────────────────────
def main():
    fig = plt.figure(figsize=(18, 12), dpi=150)

    gs = fig.add_gridspec(
        nrows=2, ncols=2,
        width_ratios=[1.05, 1.0],
        height_ratios=[1.10, 0.78],
        wspace=0.05, hspace=0.06,
        left=0.01, right=0.99, top=0.95, bottom=0.02,
    )

    ax_main  = fig.add_subplot(gs[:, 0])   # left full height
    ax_block = fig.add_subplot(gs[0, 1])
    ax_embed = fig.add_subplot(gs[1, 1])

    draw_main_flow(ax_main)
    draw_block_zoom(ax_block)
    draw_embedder_zoom(ax_embed)

    fig.suptitle('EEG Transformer Encoder Architecture',
                 y=0.985, fontsize=17, weight='bold', color=C_TEXT)

    out_dir = Path(__file__).parent
    png = out_dir / 'architecture_v6.png'
    pdf = out_dir / 'architecture_v6.pdf'
    fig.savefig(png, dpi=300, bbox_inches='tight', facecolor='white')
    fig.savefig(pdf, bbox_inches='tight', facecolor='white')
    print(f'Saved: {png}')
    print(f'Saved: {pdf}')


if __name__ == '__main__':
    main()
