"""
Applied ML 2026 · Chapter 1 rig
Evaluating pass@k on a synthetic grade-school math benchmark.

What this script does:
  1. Generates a small benchmark of grade-school math problems with known
     ground-truth answers across three difficulty tiers.
  2. Runs four solvers against each problem many times:
       - regex_baseline: pattern-match numbers + first operation keyword
       - parser_solver: parse + execute the problem structure
       - noisy_oracle: synthetic LLM substitute with controllable accuracy
                       and controllable sample-level correlation
  3. Computes pass@k for k in {1, 2, 3, 5, 10} using the unbiased HumanEval
     estimator, per difficulty tier, with bootstrap 95% CIs.
  4. Writes metrics.json and headline.json.

No external dependencies. Standard library only. Deterministic: the same
invocation produces the same numbers.

Usage:
    python3 run.py --out results
    python3 run.py --seed 42 --n-problems 120 --n-samples 20 --out results
"""

from __future__ import annotations

import argparse
import json
import math
import random
import re
import sys
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path


SOLVER_SEEDS = {
    "regex_baseline": 101,
    "parser_solver": 202,
    "noisy_oracle_p70_corr0": 303,
    "noisy_oracle_p70_corr50": 404,
}


@dataclass(frozen=True)
class Problem:
    """A grade-school math word problem with a known ground-truth answer."""

    pid: int
    difficulty: str
    text: str
    answer: int
    program: list[tuple[str, int]]


def generate_problem(pid: int, difficulty: str, rng: random.Random) -> Problem:
    """Generate a single problem deterministically from (difficulty, rng)."""
    if difficulty == "easy":
        a = rng.randint(2, 20)
        b = rng.randint(1, a - 1)
        op = rng.choice(["add", "sub"])
        if op == "add":
            text = (
                f"Anna has {a} apples and picks {b} more from the tree. "
                f"How many apples does she have now?"
            )
            return Problem(pid, difficulty, text, a + b, [("start", a), ("+", b)])
        text = (
            f"Anna has {a} apples and gives {b} to her friend. "
            f"How many apples does she have left?"
        )
        return Problem(pid, difficulty, text, a - b, [("start", a), ("-", b)])

    if difficulty == "medium":
        a = rng.randint(5, 20)
        b = rng.randint(1, a - 2)
        c = rng.randint(1, 10)
        label, op1, op2 = rng.choice([("give-buy", "-", "+"), ("buy-give", "+", "-")])
        if label == "give-buy":
            text = (
                f"Anna has {a} apples. She gives {b} to Ben, then picks "
                f"{c} more from the tree. How many apples does she have now?"
            )
        else:
            text = (
                f"Anna has {a} apples. She picks {c} more from the tree, "
                f"then gives {b} to Ben. How many apples does she have now?"
            )
        program = [
            ("start", a),
            (op1, b if label == "give-buy" else c),
            (op2, c if label == "give-buy" else b),
        ]
        val = a
        for op, operand in program[1:]:
            val = val + operand if op == "+" else val - operand
        return Problem(pid, difficulty, text, val, program)

    per_tree = rng.randint(2, 6)
    trees = rng.randint(2, 5)
    give_away = rng.randint(1, 3)
    friends = rng.randint(2, 4)
    total = per_tree * trees
    remaining = total - give_away
    while remaining % friends != 0:
        give_away += 1
        remaining = total - give_away
    answer = remaining // friends
    text = (
        f"Anna picks {per_tree} apples from each of {trees} trees. She "
        f"gives {give_away} to Ben, then splits the rest equally with "
        f"{friends - 1} friends (so {friends} people share). How many "
        f"apples does each person get?"
    )
    program = [
        ("start", per_tree),
        ("*", trees),
        ("-", give_away),
        ("//", friends),
    ]
    return Problem(pid, "hard", text, answer, program)


def build_benchmark(n_problems: int, seed: int) -> list[Problem]:
    """Generate a balanced benchmark across difficulty tiers."""
    rng = random.Random(seed)
    tiers = ["easy", "medium", "hard"]
    per_tier = n_problems // 3
    problems: list[Problem] = []
    pid = 0
    for tier in tiers:
        for _ in range(per_tier):
            problems.append(generate_problem(pid, tier, rng))
            pid += 1
    while len(problems) < n_problems:
        problems.append(generate_problem(pid, "easy", rng))
        pid += 1
    return problems


def regex_baseline(problem: Problem, rng: random.Random) -> int:
    """Extract numbers + guess op from first keyword. Deliberately imperfect."""
    nums = [int(x) for x in re.findall(r"\d+", problem.text)]
    if not nums:
        return rng.randint(0, 20)
    text_lower = problem.text.lower()
    if " gives " in text_lower or "gives " in text_lower:
        if "picks" in text_lower and "then" in text_lower:
            return max(0, nums[0] - nums[1])
        return max(0, nums[0] - nums[1]) if len(nums) >= 2 else nums[0]
    if " picks " in text_lower or "picks " in text_lower:
        if "gives" in text_lower:
            return nums[0] + nums[1] if len(nums) >= 2 else nums[0]
        return nums[0] + nums[1] if len(nums) >= 2 else nums[0]
    return sum(nums)


def parser_solver(problem: Problem, rng: random.Random) -> int:
    """Execute the canonical program, with a simulated hard-tier parse failure."""
    if problem.difficulty == "hard":
        if rng.random() < 0.5:
            val = problem.program[0][1]
            for op, operand in problem.program[1:]:
                if op == "+":
                    val += operand
                elif op == "-":
                    val -= operand
                elif op == "*":
                    val *= operand
                elif op == "//":
                    val //= operand
            return val
        return max(0, problem.answer + rng.choice([-1, 1]))

    val = problem.program[0][1]
    for op, operand in problem.program[1:]:
        if op == "+":
            val += operand
        elif op == "-":
            val -= operand
        elif op == "*":
            val *= operand
        elif op == "//":
            val //= operand
    return val


def make_noisy_oracle(accuracy_by_tier: dict[str, float], correlation: float):
    """
    A synthetic LLM substitute.

    With probability accuracy_by_tier[tier], it returns the correct answer.
    Otherwise it returns a plausible wrong answer. When correlation > 0, the
    oracle may repeat its previous mistake for the same problem, simulating the
    "stuck on a wrong answer" behavior that breaks naive pass@k intuition.
    """
    state: dict[int, int | None] = {}

    def oracle(problem: Problem, rng: random.Random) -> int:
        acc = accuracy_by_tier.get(problem.difficulty, 0.5)
        prior = state.get(problem.pid)

        if prior is not None and prior != problem.answer and rng.random() < correlation:
            return prior

        if rng.random() < acc:
            state[problem.pid] = problem.answer
            return problem.answer

        wrong_offset = rng.choice([-1, 1, -2, 2])
        wrong = max(0, problem.answer + wrong_offset)
        if wrong == problem.answer:
            wrong = problem.answer + 1
        state[problem.pid] = wrong
        return wrong

    return oracle


def pass_at_k(n: int, c: int, k: int) -> float:
    """Unbiased estimator of pass@k from the HumanEval paper."""
    if n - c < k:
        return 1.0
    return 1.0 - math.prod((n - c - i) / (n - i) for i in range(k))


def bootstrap_ci(values: list[float], n_bootstrap: int, seed: int) -> tuple[float, float]:
    """Return 95% percentile bootstrap CI for the mean of values."""
    if not values:
        return (0.0, 0.0)
    rng = random.Random(seed)
    n = len(values)
    means: list[float] = []
    for _ in range(n_bootstrap):
        sample = [values[rng.randrange(n)] for _ in range(n)]
        means.append(sum(sample) / n)
    means.sort()
    lo = means[int(0.025 * n_bootstrap)]
    hi = means[int(0.975 * n_bootstrap)]
    return (lo, hi)


def run(
    n_problems: int,
    n_samples: int,
    seed: int,
    n_bootstrap: int,
    out_dir: Path,
) -> dict:
    t0 = time.time()
    problems = build_benchmark(n_problems, seed)

    solvers = {
        "regex_baseline": regex_baseline,
        "parser_solver": parser_solver,
        "noisy_oracle_p70_corr0": make_noisy_oracle(
            {"easy": 0.85, "medium": 0.70, "hard": 0.55}, correlation=0.0
        ),
        "noisy_oracle_p70_corr50": make_noisy_oracle(
            {"easy": 0.85, "medium": 0.70, "hard": 0.55}, correlation=0.5
        ),
    }

    ks = [1, 2, 3, 5, 10]
    report: dict = {}

    for solver_name, solver_fn in solvers.items():
        rng = random.Random(seed + SOLVER_SEEDS[solver_name])
        per_problem: dict[str, list[tuple[int, int]]] = {"easy": [], "medium": [], "hard": []}
        for problem in problems:
            correct = 0
            for _ in range(n_samples):
                guess = solver_fn(problem, rng)
                if guess == problem.answer:
                    correct += 1
            per_problem[problem.difficulty].append((n_samples, correct))

        solver_report: dict = {"by_difficulty": {}, "overall": {}}
        all_pass_at: dict[int, list[float]] = {k: [] for k in ks}
        for tier, pairs in per_problem.items():
            tier_report: dict = {"n_problems": len(pairs)}
            for k in ks:
                per_problem_pass_at_k = [pass_at_k(n, c, k) for n, c in pairs]
                mean = sum(per_problem_pass_at_k) / len(per_problem_pass_at_k)
                lo, hi = bootstrap_ci(per_problem_pass_at_k, n_bootstrap, seed=seed + k)
                tier_report[f"pass_at_{k}"] = {
                    "mean": round(mean, 4),
                    "ci95_lo": round(lo, 4),
                    "ci95_hi": round(hi, 4),
                }
                all_pass_at[k].extend(per_problem_pass_at_k)
            solver_report["by_difficulty"][tier] = tier_report

        for k in ks:
            vals = all_pass_at[k]
            if not vals:
                continue
            mean = sum(vals) / len(vals)
            lo, hi = bootstrap_ci(vals, n_bootstrap, seed=seed + 100 + k)
            solver_report["overall"][f"pass_at_{k}"] = {
                "mean": round(mean, 4),
                "ci95_lo": round(lo, 4),
                "ci95_hi": round(hi, 4),
            }

        report[solver_name] = solver_report

    out = {
        "rig_generated_at": datetime.now(timezone.utc).isoformat(),
        "wall_clock_sec": round(time.time() - t0, 3),
        "config": {
            "n_problems": n_problems,
            "n_samples_per_problem": n_samples,
            "seed": seed,
            "n_bootstrap": n_bootstrap,
            "difficulty_tiers": ["easy", "medium", "hard"],
        },
        "solvers": report,
    }

    out_dir.mkdir(parents=True, exist_ok=True)
    (out_dir / "metrics.json").write_text(json.dumps(out, indent=2))

    noisy = report["noisy_oracle_p70_corr0"]
    noisy_sticky = report["noisy_oracle_p70_corr50"]
    headline = {
        "rig_generated_at": out["rig_generated_at"],
        "wall_clock_sec": out["wall_clock_sec"],
        "n_problems": n_problems,
        "n_samples_per_problem": n_samples,
        "regex_overall_pass_at_1": report["regex_baseline"]["overall"]["pass_at_1"]["mean"],
        "parser_overall_pass_at_1": report["parser_solver"]["overall"]["pass_at_1"]["mean"],
        "oracle_iid_pass_at_1": noisy["overall"]["pass_at_1"]["mean"],
        "oracle_iid_pass_at_5": noisy["overall"]["pass_at_5"]["mean"],
        "oracle_iid_pass_at_10": noisy["overall"]["pass_at_10"]["mean"],
        "oracle_sticky_pass_at_5": noisy_sticky["overall"]["pass_at_5"]["mean"],
        "oracle_sticky_pass_at_10": noisy_sticky["overall"]["pass_at_10"]["mean"],
        "correlation_penalty_pass_at_10": round(
            noisy["overall"]["pass_at_10"]["mean"]
            - noisy_sticky["overall"]["pass_at_10"]["mean"],
            4,
        ),
    }
    (out_dir / "headline.json").write_text(json.dumps(headline, indent=2))

    return out


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--n-problems", type=int, default=120)
    parser.add_argument("--n-samples", type=int, default=20)
    parser.add_argument("--n-bootstrap", type=int, default=1000)
    parser.add_argument(
        "--out",
        default="results",
        help="output directory for metrics.json/headline.json",
    )
    args = parser.parse_args()

    out_dir = Path(args.out)
    out = run(
        n_problems=args.n_problems,
        n_samples=args.n_samples,
        seed=args.seed,
        n_bootstrap=args.n_bootstrap,
        out_dir=out_dir,
    )

    print(f"[applied_ml_2026/ch01] metrics.json + headline.json -> {out_dir}")
    print(f"[applied_ml_2026/ch01] wall_clock_sec = {out['wall_clock_sec']}")
    for solver_name, rep in out["solvers"].items():
        overall = rep["overall"]
        print(
            f"  {solver_name:32s}  "
            f"pass@1={overall['pass_at_1']['mean']:.3f}  "
            f"pass@5={overall['pass_at_5']['mean']:.3f}  "
            f"pass@10={overall['pass_at_10']['mean']:.3f}"
        )
    return 0


if __name__ == "__main__":
    sys.exit(main())
