from __future__ import annotations
import math
from database.manager import DatabaseManager

def _one_hot_outcome(m) -> tuple[int,int,int]:
    if m.home_score is None or m.away_score is None:
        return (0,0,0)
    if m.home_score > m.away_score: return (1,0,0)
    if m.home_score == m.away_score: return (0,1,0)
    return (0,0,1)

def _safe_log(x: float) -> float:
    return math.log(max(1e-12, min(1.0, x)))

async def compute_metrics(db: DatabaseManager, model: str = "poisson+form") -> dict:
    # iterate through finished matches with a prediction
    matches = await db.list_matches(status="finished", limit=5000)
    n=0
    acc=0
    brier_sum=0.0
    ll_sum=0.0

    for m in matches:
        p = await db.get_prediction_for_match(m.id, model=model)
        if not p:
            continue
        y = _one_hot_outcome(m)
        if y==(0,0,0):
            continue
        probs = (p.p_home_win, p.p_draw, p.p_away_win)
        pred = max(range(3), key=lambda i: probs[i])
        true = 0 if y[0] else 1 if y[1] else 2
        if pred == true:
            acc += 1
        # Brier (multi-class)
        brier_sum += (probs[0]-y[0])**2 + (probs[1]-y[1])**2 + (probs[2]-y[2])**2
        # LogLoss
        ll_sum += -(_safe_log(probs[true]))
        n += 1

    if n == 0:
        return {"model": model, "n_matches": 0, "accuracy": 0.0, "brier": 0.0, "logloss": 0.0}
    return {
        "model": model,
        "n_matches": n,
        "accuracy": float(acc/n),
        "brier": float(brier_sum/n),
        "logloss": float(ll_sum/n),
    }
