#!/usr/bin/env python3
"""Session cost report — prices sourced from agent_notes/data/pricing.yaml."""
import sqlite3
from fnmatch import fnmatch
from pathlib import Path

BOLD = "\033[1m"
DIM = "\033[2m"
YELLOW = "\033[0;33m"
GREEN = "\033[0;32m"
CYAN = "\033[0;36m"
NC = "\033[0m"

PRICING = {{PRICING}}

DB = Path.home() / ".local/share/opencode/opencode.db"

SQL = """
WITH cs AS (SELECT id FROM session WHERE parent_id IS NULL ORDER BY time_created DESC LIMIT 1),
conv_start AS (
  SELECT COALESCE(
    (SELECT json_extract(m2.data,'$.time.created')
     FROM message m1 JOIN message m2 ON m1.session_id=m2.session_id
     WHERE m1.session_id=(SELECT id FROM cs)
       AND json_extract(m2.data,'$.time.created') > json_extract(m1.data,'$.time.created')
       AND json_extract(m2.data,'$.time.created') - json_extract(m1.data,'$.time.created') > 1800000
       AND NOT EXISTS (
         SELECT 1 FROM message mx WHERE mx.session_id=m1.session_id
           AND json_extract(mx.data,'$.time.created') > json_extract(m1.data,'$.time.created')
           AND json_extract(mx.data,'$.time.created') < json_extract(m2.data,'$.time.created'))
     ORDER BY json_extract(m1.data,'$.time.created') DESC LIMIT 1),
    0) AS start_ts
)
SELECT
  COALESCE(json_extract(m.data,'$.agent'), 'lead') AS agent,
  (SELECT json_extract(m2.data,'$.modelID') FROM message m2
   WHERE m2.session_id = s.id AND json_extract(m2.data,'$.role') = 'assistant'
   ORDER BY json_extract(m2.data,'$.time.completed') DESC LIMIT 1) AS model,
  SUM(json_extract(m.data,'$.tokens.input'))       AS inp,
  SUM(json_extract(m.data,'$.tokens.output'))      AS outp,
  SUM(json_extract(m.data,'$.tokens.cache.read'))  AS cache,
  ROUND(SUM(
    CASE WHEN json_extract(m.data,'$.time.completed') IS NOT NULL
              AND json_extract(m.data,'$.time.created') IS NOT NULL
    THEN (json_extract(m.data,'$.time.completed') - json_extract(m.data,'$.time.created')) / 1000.0
    ELSE 0 END
  ), 1) AS sec
FROM session s
JOIN message m ON m.session_id = s.id
CROSS JOIN cs
CROSS JOIN conv_start
WHERE (s.parent_id = cs.id OR s.id = cs.id)
  AND json_extract(m.data,'$.role') = 'assistant'
  AND json_extract(m.data,'$.time.created') >= conv_start.start_ts
  AND (s.time_created >= conv_start.start_ts OR s.id = (SELECT id FROM cs))
GROUP BY s.id
"""


def _build_price_table() -> list[tuple[list[str], dict]]:
    rows = []
    for provider in PRICING.get("providers", []):
        for model in provider.get("models", []):
            patterns = model["match"] if isinstance(model["match"], list) else [model["match"]]
            rows.append((patterns, model["price"]))
    return rows


_PRICE_TABLE = _build_price_table()
_BASELINE = PRICING["baseline"]["price"]
_BASELINE_LABEL = PRICING["baseline"]["label"]


def get_price(model_id: str) -> dict:
    for patterns, price in _PRICE_TABLE:
        if any(fnmatch(model_id, p) for p in patterns):
            return price
    return {"in": 3.00, "out": 15.00, "cache": 0.30}


def calculate_cost(model_id: str, inp: int, outp: int, cache: int) -> float:
    p = get_price(model_id)
    return (inp * p["in"] + outp * p["out"] + cache * p["cache"]) / 1_000_000


def baseline_cost(inp: int, outp: int, cache: int) -> float:
    p = _BASELINE
    return (inp * p["in"] + outp * p["out"] + cache * p["cache"]) / 1_000_000


def tier_color(model_id: str) -> str:
    if "opus" in model_id:
        return YELLOW
    if "sonnet" in model_id:
        return CYAN
    return DIM


def fmt_num(n: int) -> str:
    if n >= 1_000_000:
        return f"{n / 1_000_000:.2f}m"
    if n >= 1_000:
        return f"{n / 1_000:.2f}k"
    return str(n)


def fmt_tokens(inp, outp, cache) -> str:
    return f"{fmt_num(inp)}/{fmt_num(outp)}/{fmt_num(cache)}"


def fmt_time(sec: float) -> str:
    s = int(round(sec))
    if s < 60:
        return f"{s}s"
    m, s = divmod(s, 60)
    if m < 60:
        return f"{m}m {s}s" if s else f"{m}m"
    h, m = divmod(m, 60)
    return f"{h}h {m}m" if m else f"{h}h"


def fmt_cost(c: float) -> str:
    return f"${c:.4f}"


def main() -> None:
    if not DB.exists():
        print(f"Database not found: {DB}")
        return

    rows = sqlite3.connect(DB).execute(SQL).fetchall()
    if not rows:
        print("No sessions found.")
        return

    records = [
        (agent, model or "unknown", inp or 0, outp or 0, cache or 0, sec or 0)
        for agent, model, inp, outp, cache, sec in rows
    ]

    costs = [
        (agent, model, inp, outp, cache, sec,
         calculate_cost(model, inp, outp, cache),
         baseline_cost(inp, outp, cache))
        for agent, model, inp, outp, cache, sec in records
    ]

    _total_inp  = sum(i for _, _, i, *_ in costs)
    _total_outp = sum(o for _, _, _, o, *_ in costs)
    _total_cache= sum(c for _, _, _, _, c, *_ in costs)
    _max_sec    = max(s for _, _, _, _, _, s, *_ in costs)
    _total_sec  = sum(s for _, _, _, _, _, s, *_ in costs)
    _total_time = f"{fmt_time(_max_sec)} / {fmt_time(_total_sec)} seq"

    agent_col_w = max(len(f"{a}({m})") for a, m, *_ in costs) + 2
    tok_col_w   = max(
        max(len(fmt_tokens(i, o, c)) for _, _, i, o, c, *_ in costs),
        len(fmt_tokens(_total_inp, _total_outp, _total_cache))
    ) + 2
    time_col_w  = max(
        max(len(fmt_time(s)) for _, _, _, _, _, s, *_ in costs),
        len(_total_time)
    ) + 2
    W = (agent_col_w, tok_col_w, time_col_w, 12, 12)

    header = (
        f"{'agent(model)':<{W[0]}}"
        f" {'in/out/cache':<{W[1]}}"
        f" {'time':<{W[2]}}"
        f" {'actual':<{W[3]}}"
        f" {f'vs {_BASELINE_LABEL}':<{W[4]}}"
    )
    print(BOLD + header + NC)
    print(DIM + "-" * len(header) + NC)

    total_inp = total_outp = total_cache = 0
    total_actual = total_vs = max_sec = total_sec = 0.0

    for agent, model, inp, outp, cache, sec, actual, vs in costs:
        label = f"{agent}({model})"
        time_str = fmt_time(sec)
        col = tier_color(model)
        print(
            col + f"{label:<{W[0]}}" + NC
            + f" {fmt_tokens(inp, outp, cache):<{W[1]}}"
            + f" {time_str:<{W[2]}}"
            + f" {fmt_cost(actual):<{W[3]}}"
            + f" {fmt_cost(vs):<{W[4]}}"
        )
        total_inp += inp; total_outp += outp; total_cache += cache
        total_actual += actual; total_vs += vs
        max_sec = max(max_sec, sec)
        total_sec += sec

    saved_pct = round((1 - total_actual / total_vs) * 100) if total_vs else 0
    total_label = f"TOTAL (saved {saved_pct}%)"
    total_time = _total_time
    col = GREEN if total_actual <= 5 else YELLOW
    print(
        col + BOLD
        + f"{total_label:<{W[0]}}"
        + f" {fmt_tokens(total_inp, total_outp, total_cache):<{W[1]}}"
        + f" {total_time:<{W[2]}}"
        + f" {fmt_cost(total_actual):<{W[3]}}"
        + f" {fmt_cost(total_vs):<{W[4]}}"
        + NC
    )


if __name__ == "__main__":
    main()
