#!/usr/bin/env python3
"""Editable PPTX version of the EEG transformer encoder architecture diagram.

Mirrors the three-panel layout of eval/architecture_figure.py using native
PPTX shapes (rounded rectangles, connectors, text boxes) so every element is
editable in PowerPoint / Keynote / Google Slides.

Output: eval/architecture_v6.pptx
"""

from pathlib import Path

from pptx import Presentation
from pptx.dml.color import RGBColor
from pptx.enum.shapes import MSO_CONNECTOR, MSO_SHAPE
from pptx.enum.text import MSO_ANCHOR, PP_ALIGN
from pptx.util import Emu, Inches, Pt


# ─── PALETTE (matches eval/architecture_figure.py) ───────────────────────────
def rgb(hex_str):
    h = hex_str.lstrip('#')
    return RGBColor(int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))


C_INPUT    = rgb('#cfd8dc')
C_EMBED    = rgb('#b5d4c2')
C_PREFIX_E = rgb('#f3c7a0')
C_PREFIX_C = rgb('#e89e6b')
C_PREFIX_R = rgb('#f8e1c8')
C_PATCH    = rgb('#a8c2e6')
C_BLOCK    = rgb('#c9b6dd')
C_ATTN     = rgb('#9d7fc4')
C_FFN      = rgb('#bea4d6')
C_NORM     = rgb('#e6dcef')
C_HEAD_R   = rgb('#e6a4a4')
C_HEAD_F   = rgb('#dfa9bf')
C_LINE     = rgb('#37474f')
C_TEXT     = rgb('#1c2126')
C_MUTE     = rgb('#5b6770')
C_WHITE    = rgb('#ffffff')


# ─── DRAWING HELPERS ─────────────────────────────────────────────────────────
def add_rbox(slide, x, y, w, h, fc, ec=C_LINE, line_w=0.75):
    """Rounded rectangle. x/y/w/h in inches."""
    shp = slide.shapes.add_shape(
        MSO_SHAPE.ROUNDED_RECTANGLE, Inches(x), Inches(y), Inches(w), Inches(h)
    )
    shp.adjustments[0] = 0.10
    shp.fill.solid()
    shp.fill.fore_color.rgb = fc
    shp.line.color.rgb = ec
    shp.line.width = Pt(line_w)
    shp.shadow.inherit = False
    shp.text_frame.text = ''
    shp.text_frame.margin_left = Emu(50000)
    shp.text_frame.margin_right = Emu(50000)
    shp.text_frame.margin_top = Emu(20000)
    shp.text_frame.margin_bottom = Emu(20000)
    return shp


def set_text(shp, text, *, size=10, bold=False, italic=False, color=C_TEXT,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.MIDDLE):
    tf = shp.text_frame
    tf.word_wrap = True
    tf.vertical_anchor = anchor
    tf.text = ''
    lines = text.split('\n')
    for i, line in enumerate(lines):
        p = tf.paragraphs[0] if i == 0 else tf.add_paragraph()
        p.alignment = align
        run = p.add_run()
        run.text = line
        run.font.size = Pt(size)
        run.font.bold = bold
        run.font.italic = italic
        run.font.color.rgb = color
        run.font.name = 'Calibri'


def add_text(slide, x, y, w, h, text, *, size=10, bold=False, italic=False,
             color=C_TEXT, align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.MIDDLE):
    tb = slide.shapes.add_textbox(Inches(x), Inches(y), Inches(w), Inches(h))
    set_text(tb, text, size=size, bold=bold, italic=italic, color=color,
             align=align, anchor=anchor)
    return tb


def add_box(slide, x, y, w, h, *, fc, title=None, body=None,
            title_size=10, body_size=8, ec=C_LINE):
    """Single box with bold title + smaller muted body inside."""
    shp = add_rbox(slide, x, y, w, h, fc=fc, ec=ec)
    if title and body:
        set_text(shp, f'{title}\n{body}', size=body_size)
        # Re-format: first line bold/larger, rest muted
        tf = shp.text_frame
        tf.text = ''
        tf.vertical_anchor = MSO_ANCHOR.MIDDLE
        p1 = tf.paragraphs[0]
        p1.alignment = PP_ALIGN.CENTER
        r1 = p1.add_run()
        r1.text = title
        r1.font.size = Pt(title_size)
        r1.font.bold = True
        r1.font.color.rgb = C_TEXT
        r1.font.name = 'Calibri'
        for line in body.split('\n'):
            p = tf.add_paragraph()
            p.alignment = PP_ALIGN.CENTER
            r = p.add_run()
            r.text = line
            r.font.size = Pt(body_size)
            r.font.color.rgb = C_MUTE
            r.font.name = 'Calibri'
            r.font.italic = False
    elif title:
        set_text(shp, title, size=title_size, bold=True)
    elif body:
        set_text(shp, body, size=body_size, color=C_MUTE)
    return shp


def add_arrow(slide, x1, y1, x2, y2, *, color=C_LINE, line_w=1.5,
              dashed=False):
    """Straight arrow connector from (x1,y1) → (x2,y2). Inches."""
    conn = slide.shapes.add_connector(
        MSO_CONNECTOR.STRAIGHT, Inches(x1), Inches(y1), Inches(x2), Inches(y2)
    )
    conn.line.color.rgb = color
    conn.line.width = Pt(line_w)
    if dashed:
        from pptx.oxml.ns import qn
        from lxml import etree
        ln = conn.line._get_or_add_ln()
        prstDash = etree.SubElement(ln, qn('a:prstDash'))
        prstDash.set('val', 'dash')
    # Arrowhead at end
    from pptx.oxml.ns import qn
    from lxml import etree
    ln = conn.line._get_or_add_ln()
    tail = etree.SubElement(ln, qn('a:tailEnd'))
    tail.set('type', 'triangle')
    tail.set('w', 'med')
    tail.set('h', 'med')
    return conn


def add_caption(slide, x_mid, y, text, *, size=8, italic=True, color=C_MUTE):
    """Small italic caption line, centered horizontally on x_mid."""
    w = 4.5
    return add_text(slide, x_mid - w/2, y - 0.13, w, 0.26, text,
                    size=size, italic=italic, color=color,
                    align=PP_ALIGN.CENTER)


def add_circle(slide, cx, cy, r, *, fc, ec=C_LINE, lw=0.5):
    """Filled circle centered on (cx, cy). r in inches."""
    shp = slide.shapes.add_shape(
        MSO_SHAPE.OVAL,
        Inches(cx - r), Inches(cy - r), Inches(2*r), Inches(2*r)
    )
    shp.fill.solid()
    shp.fill.fore_color.rgb = fc
    shp.line.color.rgb = ec
    shp.line.width = Pt(lw)
    shp.shadow.inherit = False
    return shp


# ─── PANEL A: Encoder forward pass ───────────────────────────────────────────
def panel_a(slide, x0, y0):
    """Main vertical flow. Anchors top-left at (x0, y0). Width ~5 in."""
    cx = x0 + 2.5  # center column
    bw = 4.5
    bx = cx - bw/2

    # Title
    add_text(slide, x0, y0, 5.0, 0.35,
             'A.  Encoder forward pass',
             size=14, bold=True, align=PP_ALIGN.LEFT, anchor=MSO_ANCHOR.TOP)

    y = y0 + 0.45

    # 1. Input EEG
    add_box(slide, bx, y, bw, 0.65, fc=C_INPUT,
            title='Input EEG signal',
            body='channels × time   (variable channel count)',
            title_size=11, body_size=8)
    y_after_input = y + 0.65

    # arrow
    add_arrow(slide, cx, y_after_input + 0.02, cx, y_after_input + 0.45)
    add_caption(slide, cx, y_after_input + 0.23, 'split time into patches')
    y = y_after_input + 0.50

    # 2. Patches
    add_box(slide, bx + 0.5, y, bw - 1.0, 0.55, fc=C_INPUT,
            title='Patches',
            body='patches × channels × patch length',
            title_size=10, body_size=8)
    y_after_patch = y + 0.55

    add_arrow(slide, cx, y_after_patch + 0.02, cx, y_after_patch + 0.45)
    add_caption(slide, cx, y_after_patch + 0.23,
                'spatial attention across electrodes')
    y = y_after_patch + 0.50

    # 3. Patch embedder
    add_box(slide, bx, y, bw, 0.78, fc=C_EMBED,
            title='Variable-Channel Spatial Patch Embedder',
            body='electrode tokens → spatial self-attention → masked-mean pool\n'
                 '2D-RoPE on scalp coordinates;  inactive channels masked',
            title_size=10, body_size=7.5)
    y_after_emb = y + 0.78

    add_arrow(slide, cx, y_after_emb + 0.02, cx, y_after_emb + 0.45)
    add_caption(slide, cx, y_after_emb + 0.23, 'one patch token per time window')
    y = y_after_emb + 0.50

    # 4. Prefix tokens row
    add_text(slide, bx, y, bw, 0.25,
             'Prepend learned prefix tokens',
             size=10, bold=True, align=PP_ALIGN.CENTER,
             anchor=MSO_ANCHOR.MIDDLE)
    y += 0.27

    n_tokens = 3 + 8 + 4  # exp, cond, eeg, r1..r8, p1, p2, ..., pN
    tw = bw / n_tokens
    th = 0.30
    token_specs = [('exp', C_PREFIX_E), ('cond', C_PREFIX_C), ('eeg', C_PREFIX_E)]
    for i in range(8):
        token_specs.append((f'r{i+1}', C_PREFIX_R))
    for n in ['p₁', 'p₂', '…', 'pₙ']:
        token_specs.append((n, C_PATCH))

    for i, (name, fc) in enumerate(token_specs):
        shp = add_rbox(slide, bx + i*tw + 0.01, y, tw - 0.02, th,
                       fc=fc, line_w=0.5)
        set_text(shp, name, size=7, bold=True)

    # bracket annotations
    y_b = y + th + 0.05
    # group 1: modality/context (3 tokens)
    add_text(slide, bx, y_b, 3*tw, 0.18, 'modality / context',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
    add_text(slide, bx + 3*tw, y_b, 8*tw, 0.18, 'register tokens',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
    add_text(slide, bx + 11*tw, y_b, 4*tw, 0.18, 'patch tokens',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)

    add_text(slide, bx, y_b + 0.20, bw, 0.20,
             'condition cell = mean-pool of frozen LLM encoding of (dataset, recording state)',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)

    y = y_b + 0.45
    add_arrow(slide, cx, y, cx, y + 0.45)
    y += 0.50

    # 5. Transformer stack
    stack_h = 1.10
    add_rbox(slide, bx, y, bw, stack_h, fc=C_BLOCK, line_w=1.0)
    add_text(slide, bx, y + 0.04, bw, 0.30,
             'Transformer Encoder Stack',
             size=11, bold=True, align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
    add_text(slide, bx, y + 0.32, bw, 0.20,
             'N identical pre-norm blocks  •  causal self-attention',
             size=7.5, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
    # mini stack of "block" pills
    pill_w = bw - 1.6
    pill_h = 0.10
    pill_x = bx + 0.4
    pill_y_base = y + 0.55
    for i in range(4):
        py = pill_y_base + i * (pill_h + 0.04)
        shp = add_rbox(slide, pill_x, py, pill_w, pill_h, fc=C_NORM, line_w=0.4)
        set_text(shp, 'block', size=6.5)
    add_text(slide, pill_x, pill_y_base + 4*(pill_h+0.04) - 0.04, pill_w, 0.13,
             '⋮', size=10, color=C_TEXT,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)

    y_after_stack = y + stack_h
    add_arrow(slide, cx, y_after_stack + 0.02, cx, y_after_stack + 0.42)
    add_caption(slide, cx, y_after_stack + 0.22, 'strip prefix tokens')
    y = y_after_stack + 0.46

    # 6. Per-patch hidden states
    add_box(slide, bx + 0.4, y, bw - 0.8, 0.50, fc=C_PATCH,
            title='Per-patch hidden states',
            body='one vector per time window',
            title_size=10, body_size=7.5)
    y_h = y + 0.50

    # Branch arrows to two heads
    head_w = 1.95
    head_y = y_h + 0.55
    head_left_x = bx + 0.10
    head_right_x = bx + bw - head_w - 0.10
    add_arrow(slide, cx - 0.10, y_h + 0.02,
              head_left_x + head_w/2, head_y - 0.02)
    add_arrow(slide, cx + 0.10, y_h + 0.02,
              head_right_x + head_w/2, head_y - 0.02)

    # 7. Heads
    add_box(slide, head_left_x, head_y, head_w, 0.75, fc=C_HEAD_R,
            title='Next-patch head',
            body='SwiGLU + linear\npredict the next raw patch',
            title_size=9.5, body_size=7.5)
    add_box(slide, head_right_x, head_y, head_w, 0.75, fc=C_HEAD_F,
            title='Feature head',
            body='SwiGLU + linear\npredict spectral features',
            title_size=9.5, body_size=7.5)

    add_text(slide, head_left_x, head_y + 0.78, head_w, 0.20,
             'mean-squared error',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
    add_text(slide, head_right_x, head_y + 0.78, head_w, 0.20,
             'smooth-L1 error',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)

    add_text(slide, x0, head_y + 1.05, 5.0, 0.20,
             'self-supervised pre-training objective (training only)',
             size=8, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)


# ─── PANEL B: One transformer block ──────────────────────────────────────────
def panel_b(slide, x0, y0):
    add_text(slide, x0, y0, 5.0, 0.35,
             'B.  One transformer block (pre-norm)',
             size=14, bold=True, align=PP_ALIGN.LEFT, anchor=MSO_ANCHOR.TOP)

    cx = x0 + 2.5
    bw = 3.5
    bx = cx - bw/2

    y = y0 + 0.50

    # Input
    add_box(slide, bx + 1.0, y, bw - 2.0, 0.40, fc=C_INPUT,
            title='Input  x', body=None, title_size=10)
    y_in = y + 0.40

    # arrow down
    add_arrow(slide, cx, y_in + 0.02, cx, y_in + 0.30)
    y = y_in + 0.32

    # Norm
    add_box(slide, bx + 0.5, y, bw - 1.0, 0.32, fc=C_NORM,
            title='RMSNorm', body=None, title_size=9)
    y += 0.34
    add_arrow(slide, cx, y, cx, y + 0.20)
    y += 0.22

    # Attention
    add_box(slide, bx, y, bw, 0.65, fc=C_ATTN,
            title='Grouped-Query Attention',
            body='Q heads ≫ KV heads  •  causal mask  •  time-RoPE',
            title_size=10, body_size=7.5)
    y += 0.65

    # residual loop (dashed) on the right side
    add_arrow(slide, bx + bw + 0.10, y_in + 0.20,
              bx + bw + 0.10, y + 0.10, dashed=True)
    add_arrow(slide, bx + bw + 0.10, y + 0.10, cx + 0.05, y + 0.10,
              dashed=True)
    add_arrow(slide, cx, y + 0.05, cx, y + 0.40)
    add_text(slide, bx + bw, y_in + 0.22, 1.0, 0.18, 'residual',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.LEFT, anchor=MSO_ANCHOR.TOP)
    y += 0.40

    # Norm 2
    add_box(slide, bx + 0.5, y, bw - 1.0, 0.32, fc=C_NORM,
            title='RMSNorm', body=None, title_size=9)
    y += 0.34
    add_arrow(slide, cx, y, cx, y + 0.20)
    y_pre_ffn = y + 0.20
    y += 0.22

    # FFN
    add_box(slide, bx, y, bw, 0.60, fc=C_FFN,
            title='SwiGLU Feed-Forward',
            body='gated linear unit  •  expand & project',
            title_size=10, body_size=7.5)
    y += 0.60

    # residual loop 2 (dashed)
    add_arrow(slide, bx + bw + 0.10, y_pre_ffn,
              bx + bw + 0.10, y + 0.10, dashed=True)
    add_arrow(slide, bx + bw + 0.10, y + 0.10, cx + 0.05, y + 0.10,
              dashed=True)
    add_arrow(slide, cx, y + 0.05, cx, y + 0.40)
    add_text(slide, bx + bw, y_pre_ffn + 0.04, 1.0, 0.18, 'residual',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.LEFT, anchor=MSO_ANCHOR.TOP)
    y += 0.40

    # Output
    add_box(slide, bx + 1.0, y, bw - 2.0, 0.40, fc=C_INPUT,
            title='Output', body=None, title_size=10)

    # Repeated annotation
    add_text(slide, x0, y0 + 0.50, 1.0, 4.5,
             '×\nrepeated\nN times',
             size=9, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.MIDDLE)

    # legend at bottom
    add_text(slide, x0, y + 0.55, 5.0, 0.20,
             'solid arrow = forward path     dashed arrow = residual skip',
             size=7.5, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)


# ─── PANEL C: Spatial patch embedder zoom ────────────────────────────────────
def panel_c(slide, x0, y0):
    add_text(slide, x0, y0, 6.5, 0.35,
             'C.  Inside the spatial patch embedder',
             size=14, bold=True, align=PP_ALIGN.LEFT, anchor=MSO_ANCHOR.TOP)
    add_text(slide, x0, y0 + 0.30, 6.5, 0.20,
             'one time-patch processed independently  •  attention runs across electrodes',
             size=8, italic=True, color=C_MUTE,
             align=PP_ALIGN.LEFT, anchor=MSO_ANCHOR.TOP)

    # Scalp diagram with electrodes (10-20 layout, ~19 dots)
    head_cx = x0 + 1.5
    head_cy = y0 + 1.85
    head_r  = 1.05

    # head outline (large oval)
    add_circle(slide, head_cx, head_cy, head_r, fc=C_WHITE, ec=C_LINE, lw=1.0)
    # nasion notch (small triangle on top)
    notch = slide.shapes.add_shape(
        MSO_SHAPE.ISOSCELES_TRIANGLE,
        Inches(head_cx - 0.10), Inches(head_cy - head_r - 0.18),
        Inches(0.20), Inches(0.20)
    )
    notch.fill.solid()
    notch.fill.fore_color.rgb = C_WHITE
    notch.line.color.rgb = C_LINE
    notch.line.width = Pt(0.8)
    notch.shadow.inherit = False

    # 19-channel 10-20 normalized coords
    coords_19 = [
        ('Fp1', -0.30, -0.85), ('Fp2',  0.30, -0.85),
        ('F7',  -0.65, -0.55), ('F3',  -0.40, -0.45), ('Fz', 0.00, -0.45),
        ('F4',   0.40, -0.45), ('F8',   0.65, -0.55),
        ('T3',  -0.85,  0.00), ('C3',  -0.45,  0.00), ('Cz', 0.00,  0.00),
        ('C4',   0.45,  0.00), ('T4',   0.85,  0.00),
        ('T5',  -0.65,  0.55), ('P3',  -0.40,  0.45), ('Pz', 0.00,  0.45),
        ('P4',   0.40,  0.45), ('T6',   0.65,  0.55),
        ('O1',  -0.30,  0.85), ('O2',   0.30,  0.85),
    ]
    dot_r = 0.07
    for _, ex, ey in coords_19:
        cx_e = head_cx + ex * head_r * 0.85
        cy_e = head_cy + ey * head_r * 0.85
        add_circle(slide, cx_e, cy_e, dot_r, fc=C_EMBED, ec=C_LINE, lw=0.5)

    add_text(slide, head_cx - 1.1, head_cy + head_r + 0.10, 2.2, 0.20,
             'electrodes  =  query / key / value',
             size=8, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
    add_text(slide, head_cx - 1.1, head_cy + head_r + 0.30, 2.2, 0.20,
             '2D-RoPE injects scalp coordinates',
             size=7.5, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)

    # Arrow from scalp → masked-mean pool
    arrow_xs = head_cx + head_r + 0.20
    arrow_xe = arrow_xs + 0.85
    arrow_y  = head_cy
    add_arrow(slide, arrow_xs, arrow_y, arrow_xe, arrow_y)
    add_text(slide, arrow_xs - 0.05, arrow_y - 0.30, arrow_xe - arrow_xs + 0.10, 0.20,
             'masked-mean pool',
             size=8, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
    add_text(slide, arrow_xs - 0.05, arrow_y + 0.10, arrow_xe - arrow_xs + 0.10, 0.20,
             '(active channels only)',
             size=7, italic=True, color=C_MUTE,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)

    # Patch token output
    pt_x = arrow_xe + 0.05
    pt_y = head_cy - 0.30
    add_box(slide, pt_x, pt_y, 1.10, 0.60, fc=C_PATCH,
            title='patch\ntoken', body=None, title_size=10)

    # Three example montages on the right
    legend_x = pt_x + 1.40
    legend_y = y0 + 0.55
    add_text(slide, legend_x, legend_y, 1.7, 0.22,
             'Same module handles\nany electrode set',
             size=8.5, bold=True, color=C_TEXT,
             align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)

    montage_specs = [
        ('few electrodes',
         [(-0.30, -0.85), (0.30, -0.85), (-0.65, -0.55), (0.65, -0.55),
          (0.00, -0.45)]),
        ('standard 10-20 set',
         [c[1:] for c in coords_19]),
        ('high-density set',
         _hd_coords(64)),
    ]
    mini_r = 0.36
    mini_y = legend_y + 0.50
    spacing = 0.95
    for label_text, pts in montage_specs:
        cx_m = legend_x + 0.85
        add_circle(slide, cx_m, mini_y + mini_r, mini_r,
                   fc=C_WHITE, ec=C_LINE, lw=0.7)
        for ex, ey in pts:
            add_circle(slide, cx_m + ex * mini_r * 0.85,
                       mini_y + mini_r + ey * mini_r * 0.85,
                       0.025, fc=C_EMBED, ec=C_LINE, lw=0.3)
        add_text(slide, legend_x, mini_y + 2*mini_r + 0.04, 1.7, 0.18,
                 label_text,
                 size=7.5, italic=True, color=C_MUTE,
                 align=PP_ALIGN.CENTER, anchor=MSO_ANCHOR.TOP)
        mini_y += spacing


def _hd_coords(n=64):
    """Quasi-uniform points inside unit disk for HD montage decoration."""
    import math
    pts = []
    rng = [(0.15, 6), (0.35, 10), (0.55, 14), (0.75, 18), (0.95, 16)]
    for r, k in rng:
        for i in range(k):
            theta = 2 * math.pi * i / k + r
            pts.append((r * math.cos(theta), r * math.sin(theta)))
    return pts[:n]


# ─── BUILD SLIDE ─────────────────────────────────────────────────────────────
def build():
    prs = Presentation()
    prs.slide_width = Inches(13.33)
    prs.slide_height = Inches(7.5)
    blank = prs.slide_layouts[6]
    slide = prs.slides.add_slide(blank)

    # Top title bar
    add_text(slide, 0.4, 0.18, 12.5, 0.40,
             'EEG Transformer Encoder — Architecture',
             size=18, bold=True, align=PP_ALIGN.LEFT, anchor=MSO_ANCHOR.MIDDLE)

    # Panel A: left column
    panel_a(slide, x0=0.30, y0=0.65)

    # Panel B: top right
    panel_b(slide, x0=5.50, y0=0.65)

    # Panel C: bottom right
    panel_c(slide, x0=5.50, y0=4.40)

    out_dir = Path(__file__).parent
    out_path = out_dir / 'architecture_v6.pptx'
    prs.save(out_path)
    print(f'Saved: {out_path}')


if __name__ == '__main__':
    build()
