#!/usr/bin/env python3
"""Reproduce the Oregon BA+-residual / school poverty companion check.

The diagnostic is intentionally simple:

1. Aggregate Total Population assessment rows to school-level results.
2. Fit a participant-weighted line:

       Percent Proficient ~ ACS adult BA+

3. Compute each school's residual from that BA+-only expectation.
4. Test whether those residuals are still related to the school-sourced
   Students Experiencing Poverty field.

If high-poverty schools are systematically below the BA+-only line, ACS adult
education is missing part of the enrolled-student hardship picture. This
complements the two-factor joint regression in the sibling Oregon BA+/poverty
artifact.

The script uses repository-relative paths only and writes the published
summary CSV and companion figure.
"""

from __future__ import annotations

import csv
import math
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

import numpy as np


# Matplotlib sometimes wants a writable cache. Keep it outside the published
# artifacts tree so test runs do not create stray site files.
LOCAL_CACHE_DIR = Path(tempfile.gettempdir()) / "orschool_evidence_lab_mplcache"
LOCAL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("MPLCONFIGDIR", str(LOCAL_CACHE_DIR))
os.environ.setdefault("XDG_CACHE_HOME", str(LOCAL_CACHE_DIR))

import matplotlib.pyplot as plt


def find_data_root(start: Path) -> Path:
    """Find the nearest processed data directory without hard-coded paths."""

    required = "SchoolDataMathWithAddressesAndCensusSES.csv"
    for parent in [start.parent, *start.parents]:
        candidate = parent / "data" / "processed"
        if (candidate / required).exists():
            return candidate
        candidate = parent / "site" / "data" / "processed"
        if (candidate / required).exists():
            return candidate
    raise FileNotFoundError(
        "Could not find processed dashboard data. Run from a repository checkout "
        "that contains data/processed or site/data/processed."
    )


SCRIPT_PATH = Path(__file__).resolve()
DATA_ROOT = find_data_root(SCRIPT_PATH)
ARTIFACT_ROOT = SCRIPT_PATH.parents[1]
OUT_DIR = ARTIFACT_ROOT / "reports"
RUN_DATE = "2026-05-19"

SUMMARY_CSV = OUT_DIR / f"oregon_ba_school_poverty_residual_summary_{RUN_DATE}.csv"
DETAIL_CSV = OUT_DIR / f"oregon_ba_school_poverty_residual_detail_{RUN_DATE}.csv"
FIGURE_PNG = OUT_DIR / f"oregon_ba_residual_vs_school_poverty_{RUN_DATE}.png"
MEMO_MD = OUT_DIR / f"oregon_ba_school_poverty_residual_memo_{RUN_DATE}.md"

DATASETS = {
    "ELA": DATA_ROOT / "SchoolDataWithAddressesAndCensusSES.csv",
    "Math": DATA_ROOT / "SchoolDataMathWithAddressesAndCensusSES.csv",
    "Science": DATA_ROOT / "SchoolDataScienceWithAddressesAndCensusSES.csv",
}

GRADE_BANDS = {
    "all": None,
    "elementary": {"3", "4", "5"},
    "middle": {"6", "7", "8"},
    "high": {"11"},
}


@dataclass
class SchoolRecord:
    subject: str
    scope: str
    exclusion: str
    district_id: str
    district_name: str
    school_id: str
    school_name: str
    participants: float
    pct_proficient: float
    ba_pct: float
    school_poverty_pct: float
    predicted_from_ba: float
    ba_residual: float
    ba_residual_z: float


@dataclass
class SummaryRow:
    subject: str
    scope: str
    exclusion: str
    schools: int
    participants: float
    ba_r: float
    ba_slope_pp_per_10_ba: float
    poverty_r: float
    poverty_slope_pp_per_10_poverty: float
    residual_vs_poverty_r: float
    residual_vs_poverty_slope_pp_per_10_poverty: float
    residual_sd: float
    mean_residual_low_poverty_lt20: float
    mean_residual_high_poverty_ge40: float
    high_minus_low_residual: float


def parse_num(value: object) -> float | None:
    try:
        text = str(value).strip().replace(",", "")
        if text == "":
            return None
        return float(text)
    except Exception:
        return None


def truthy(value: object) -> bool:
    return str(value).strip().lower() in {"1", "true", "t", "yes", "y"}


def is_charter(row: dict[str, str]) -> bool:
    raw = row.get("Is Charter School", "")
    if str(raw).strip() != "":
        return truthy(raw)
    return str(row.get("ODE School Type", "")).strip().lower() == "charter"


def is_virtual(row: dict[str, str]) -> bool:
    raw = row.get("Is Virtual School", "")
    if str(raw).strip() != "":
        return truthy(raw)
    return str(row.get("ODE Virtual Status", "")).strip().lower() in {"focus virtual", "full virtual"}


def grade_code(value: str) -> str:
    text = (value or "").strip().lower()
    if text in {"all grades", "all grade", "all"}:
        return "all"
    return text.replace("grade", "").strip()


def normalize_rate(value: float) -> float:
    return value * 100.0 if abs(value) <= 1.5 else value


def weighted_mean(values: np.ndarray, weights: np.ndarray) -> float:
    mask = np.isfinite(values) & np.isfinite(weights) & (weights > 0)
    if int(mask.sum()) == 0:
        return math.nan
    return float(np.average(values[mask], weights=weights[mask]))


def weighted_corr(x: np.ndarray, y: np.ndarray, weights: np.ndarray) -> float:
    mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(weights) & (weights > 0)
    if int(mask.sum()) < 3:
        return math.nan
    x = x[mask]
    y = y[mask]
    weights = weights[mask]
    mx = float(np.average(x, weights=weights))
    my = float(np.average(y, weights=weights))
    cov = float(np.average((x - mx) * (y - my), weights=weights))
    vx = float(np.average((x - mx) ** 2, weights=weights))
    vy = float(np.average((y - my) ** 2, weights=weights))
    if vx <= 0 or vy <= 0:
        return math.nan
    return cov / math.sqrt(vx * vy)


def weighted_slope(x: np.ndarray, y: np.ndarray, weights: np.ndarray) -> float:
    mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(weights) & (weights > 0)
    if int(mask.sum()) < 3:
        return math.nan
    x = x[mask]
    y = y[mask]
    weights = weights[mask]
    mx = float(np.average(x, weights=weights))
    my = float(np.average(y, weights=weights))
    vx = float(np.average((x - mx) ** 2, weights=weights))
    if vx <= 0:
        return math.nan
    cov = float(np.average((x - mx) * (y - my), weights=weights))
    return cov / vx


def fit_ba_line(ba: np.ndarray, y: np.ndarray, weights: np.ndarray) -> tuple[float, float, float, float]:
    slope = weighted_slope(ba, y, weights)
    intercept = weighted_mean(y, weights) - slope * weighted_mean(ba, weights)
    predicted = slope * ba + intercept
    residual = y - predicted
    residual_sd = math.sqrt(float(np.average(residual**2, weights=weights)))
    r = weighted_corr(ba, y, weights)
    return slope, intercept, residual_sd, r


def load_school_records(
    subject: str,
    path: Path,
    scope: str,
    exclusion: str,
) -> list[SchoolRecord]:
    omit_charter_virtual = exclusion == "noncharter_nonvirtual"
    wanted_grades = GRADE_BANDS[scope]
    staged: list[tuple[str, str, str, str, str, float, float, float, float]] = []

    # Keep only all-student rows and apply the report's school-type filter.
    with path.open(newline="", encoding="utf-8-sig") as handle:
        reader = csv.DictReader(handle)
        for row in reader:
            if row.get("Student Group") != "Total Population (All Students)":
                continue
            if omit_charter_virtual and (is_charter(row) or is_virtual(row)):
                continue

            grade = grade_code(row.get("Grade Level", ""))
            if wanted_grades is not None and grade not in wanted_grades:
                continue

            participants = parse_num(row.get("Number of Participants", ""))
            proficient = parse_num(row.get("Number Proficient", ""))
            ba = parse_num(row.get("ACS_Ed_BA_or_Higher_Rate", ""))
            poverty = parse_num(row.get("Students Experiencing Poverty", ""))
            if None in (participants, proficient, ba, poverty):
                continue
            if participants <= 0:
                continue

            staged.append(
                (
                    row.get("District ID", ""),
                    row.get("District Name", ""),
                    row.get("School ID", ""),
                    row.get("School Name", ""),
                    grade,
                    participants,
                    proficient,
                    normalize_rate(ba),
                    normalize_rate(poverty),
                )
            )

    # Avoid double-counting when a source file contains both a pre-assembled
    # All Grades row and grade-specific rows for the same school.
    has_grade_specific: dict[tuple[str, str], bool] = {}
    for district_id, _, school_id, _, grade, *_ in staged:
        key = (district_id, school_id)
        has_grade_specific[key] = has_grade_specific.get(key, False) or grade != "all"

    # Aggregate to one school result. Participant weighting keeps the residual
    # check aligned with the student-weighted interpretation of the report.
    by_school: dict[tuple[str, str], dict[str, object]] = {}
    for district_id, district_name, school_id, school_name, grade, participants, proficient, ba, poverty in staged:
        key = (district_id, school_id)
        if scope == "all" and grade == "all" and has_grade_specific.get(key, False):
            continue
        if key not in by_school:
            by_school[key] = {
                "district_id": district_id,
                "district_name": district_name,
                "school_id": school_id,
                "school_name": school_name,
                "participants": 0.0,
                "proficient": 0.0,
                "ba_wsum": 0.0,
                "poverty_wsum": 0.0,
            }
        item = by_school[key]
        item["participants"] = float(item["participants"]) + participants
        item["proficient"] = float(item["proficient"]) + proficient
        item["ba_wsum"] = float(item["ba_wsum"]) + ba * participants
        item["poverty_wsum"] = float(item["poverty_wsum"]) + poverty * participants

    if not by_school:
        return []

    rows = []
    participants_arr = np.array([float(item["participants"]) for item in by_school.values()], dtype=float)
    y_arr = np.array(
        [100.0 * float(item["proficient"]) / float(item["participants"]) for item in by_school.values()],
        dtype=float,
    )
    ba_arr = np.array(
        [float(item["ba_wsum"]) / float(item["participants"]) for item in by_school.values()],
        dtype=float,
    )
    slope, intercept, residual_sd, _ = fit_ba_line(ba_arr, y_arr, participants_arr)

    for item in by_school.values():
        participants = float(item["participants"])
        pct_proficient = 100.0 * float(item["proficient"]) / participants
        ba_pct = float(item["ba_wsum"]) / participants
        poverty_pct = float(item["poverty_wsum"]) / participants
        predicted = slope * ba_pct + intercept
        residual = pct_proficient - predicted
        rows.append(
            SchoolRecord(
                subject=subject,
                scope=scope,
                exclusion=exclusion,
                district_id=str(item["district_id"]),
                district_name=str(item["district_name"]),
                school_id=str(item["school_id"]),
                school_name=str(item["school_name"]),
                participants=participants,
                pct_proficient=pct_proficient,
                ba_pct=ba_pct,
                school_poverty_pct=poverty_pct,
                predicted_from_ba=predicted,
                ba_residual=residual,
                ba_residual_z=residual / residual_sd if residual_sd > 0 else math.nan,
            )
        )
    return rows


def summarize(records: list[SchoolRecord]) -> SummaryRow:
    ba = np.array([row.ba_pct for row in records], dtype=float)
    poverty = np.array([row.school_poverty_pct for row in records], dtype=float)
    y = np.array([row.pct_proficient for row in records], dtype=float)
    residual = np.array([row.ba_residual for row in records], dtype=float)
    weights = np.array([row.participants for row in records], dtype=float)

    low_mask = poverty < 20
    high_mask = poverty >= 40
    low_resid = weighted_mean(residual[low_mask], weights[low_mask])
    high_resid = weighted_mean(residual[high_mask], weights[high_mask])

    slope, _, residual_sd, ba_r = fit_ba_line(ba, y, weights)
    return SummaryRow(
        subject=records[0].subject,
        scope=records[0].scope,
        exclusion=records[0].exclusion,
        schools=len(records),
        participants=float(weights.sum()),
        ba_r=ba_r,
        ba_slope_pp_per_10_ba=slope * 10.0,
        poverty_r=weighted_corr(poverty, y, weights),
        poverty_slope_pp_per_10_poverty=weighted_slope(poverty, y, weights) * 10.0,
        residual_vs_poverty_r=weighted_corr(poverty, residual, weights),
        residual_vs_poverty_slope_pp_per_10_poverty=weighted_slope(poverty, residual, weights) * 10.0,
        residual_sd=residual_sd,
        mean_residual_low_poverty_lt20=low_resid,
        mean_residual_high_poverty_ge40=high_resid,
        high_minus_low_residual=high_resid - low_resid,
    )


def write_csv(path: Path, rows: Iterable[object]) -> None:
    rows = list(rows)
    fieldnames = list(rows[0].__dataclass_fields__.keys()) if rows else []
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=fieldnames)
        writer.writeheader()
        for row in rows:
            writer.writerow({name: getattr(row, name) for name in fieldnames})


def make_figure(detail_rows: list[SchoolRecord]) -> None:
    primary = [
        row
        for row in detail_rows
        if row.exclusion == "noncharter_nonvirtual" and row.scope in {"all", "elementary"}
    ]
    fig, axes = plt.subplots(2, 3, figsize=(15.5, 8.8), sharex=True, sharey=True)
    colors = {"ELA": "#386cb0", "Math": "#7b3294", "Science": "#008060"}
    for row_idx, scope in enumerate(["all", "elementary"]):
        for col_idx, subject in enumerate(["ELA", "Math", "Science"]):
            ax = axes[row_idx, col_idx]
            rows = [row for row in primary if row.scope == scope and row.subject == subject]
            x = np.array([row.school_poverty_pct for row in rows], dtype=float)
            y = np.array([row.ba_residual for row in rows], dtype=float)
            w = np.array([row.participants for row in rows], dtype=float)
            sizes = np.clip(np.sqrt(w) * 3.4, 12, 90)
            ax.scatter(x, y, s=sizes, color=colors[subject], alpha=0.28, linewidths=0)
            ax.axhline(0, color="#444444", lw=0.9, alpha=0.7)

            slope = weighted_slope(x, y, w)
            intercept = weighted_mean(y, w) - slope * weighted_mean(x, w)
            x_line = np.array([0, max(80, float(np.nanmax(x)))], dtype=float)
            ax.plot(x_line, slope * x_line + intercept, color="#222222", lw=1.8)

            r = weighted_corr(x, y, w)
            ax.set_title(f"{subject} - {scope.title()}", fontsize=13, weight="bold")
            ax.text(
                0.03,
                0.06,
                f"r = {r:.2f}\nslope = {slope*10:.1f} pp / +10 poverty",
                transform=ax.transAxes,
                fontsize=10.5,
                color="#243126",
                bbox={"boxstyle": "round,pad=0.25", "facecolor": "#f7f4ec", "edgecolor": "#d8d1c2", "alpha": 0.9},
            )
            ax.grid(True, color="#ddd6c7", linewidth=0.7, alpha=0.65)

    for ax in axes[:, 0]:
        ax.set_ylabel("Residual from BA+-only prediction\n(actual minus predicted proficiency points)")
    for ax in axes[-1, :]:
        ax.set_xlabel("Students Experiencing Poverty (%)")

    fig.suptitle(
        "Oregon Schools: Poverty Still Predicts Residuals After a BA+-Only Model",
        fontsize=17,
        weight="bold",
        y=0.98,
    )
    fig.text(
        0.5,
        0.025,
        "Total Population, 2024-25; primary scope excludes charter and virtual schools. Dot size follows tested participants.",
        ha="center",
        fontsize=11,
        color="#536255",
    )
    fig.tight_layout(rect=[0, 0.05, 1, 0.94])
    fig.savefig(FIGURE_PNG, dpi=170)
    plt.close(fig)


def fmt(value: float, digits: int = 3) -> str:
    if value is None or not math.isfinite(value):
        return ""
    return f"{value:.{digits}f}"


def public_artifact_path(path: Path) -> str:
    """Return a stable relative path for memo text, avoiding local user paths."""

    try:
        return str(path.relative_to(ARTIFACT_ROOT.parent))
    except ValueError:
        return path.name


def write_memo(summary_rows: list[SummaryRow]) -> None:
    primary = [
        row
        for row in summary_rows
        if row.exclusion == "noncharter_nonvirtual" and row.scope in {"all", "elementary"}
    ]
    lines = [
        "# Oregon BA+ Residuals and School Poverty Check",
        "",
        f"Date: {RUN_DATE}",
        "",
        "## Question",
        "",
        "After predicting school proficiency from ACS adult BA+ alone, are the remaining over/under-performance residuals still patterned by school-sourced `Students Experiencing Poverty`?",
        "",
        "This is a companion to the two-factor joint model. The joint model asks whether BA+ and poverty both add explanatory power in the same regression. This residual check asks whether high-poverty schools systematically sit below the BA+-only expectation line.",
        "",
        "## Scope",
        "",
        "- Oregon 2024-25 ELA, Math, and Science processed school files.",
        "- Student group: `Total Population (All Students)`.",
        "- Primary filter: non-charter, non-virtual schools.",
        "- School-level aggregation: grade rows are aggregated to one school result within each grade band; dataset-backed `All Grades` rows are not double-counted when grade rows exist.",
        "- Weighting: `Number of Participants`.",
        "- Residual: actual percent proficient minus predicted percent proficient from a participant-weighted BA+-only line.",
        "",
        "## Primary Results",
        "",
        "| Subject | Scope | Schools | BA+ r | Poverty r | Residual vs poverty r | Residual slope per +10 poverty | High-poverty minus low-poverty residual |",
        "| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: |",
    ]

    for row in primary:
        lines.append(
            f"| {row.subject} | {row.scope} | {row.schools} | {fmt(row.ba_r)} | {fmt(row.poverty_r)} | "
            f"{fmt(row.residual_vs_poverty_r)} | {fmt(row.residual_vs_poverty_slope_pp_per_10_poverty)} | "
            f"{fmt(row.high_minus_low_residual)} |"
        )

    lines.extend(
        [
            "",
            "## Read",
            "",
            "The residual pattern is clear: after a BA+-only model, high-poverty Oregon schools still tend to score below the BA+-based expectation. The effect is strongest in elementary ELA/Math and remains visible across all three subjects.",
            "",
            "This does not invalidate BA+. It says the simple BA+ association carries a mixture of community educational context and student hardship. School poverty captures an additional enrolled-student-composition signal that BA+ alone misses.",
            "",
            "A careful formulation is: adult BA+ is a strong community-context signal, but poverty-aware models and residual checks show that school poverty is not merely a duplicate measure. It explains important under- or over-performance relative to BA+-only expectations.",
            "",
            "## Outputs",
            "",
            f"- `{public_artifact_path(SUMMARY_CSV)}`",
            f"- `{public_artifact_path(DETAIL_CSV)}`",
            f"- `{public_artifact_path(FIGURE_PNG)}`",
        ]
    )
    MEMO_MD.write_text("\n".join(lines) + "\n", encoding="utf-8")


def main() -> None:
    all_records: list[SchoolRecord] = []
    summaries: list[SummaryRow] = []
    for subject, path in DATASETS.items():
        for scope in GRADE_BANDS:
            for exclusion in ("noncharter_nonvirtual", "all_schools"):
                records = load_school_records(subject, path, scope, exclusion)
                if len(records) < 20:
                    continue
                all_records.extend(records)
                summaries.append(summarize(records))

    OUT_DIR.mkdir(parents=True, exist_ok=True)
    write_csv(SUMMARY_CSV, summaries)
    write_csv(DETAIL_CSV, all_records)
    make_figure(all_records)
    write_memo(summaries)
    print(f"Wrote {public_artifact_path(SUMMARY_CSV)}")
    print(f"Wrote {public_artifact_path(DETAIL_CSV)}")
    print(f"Wrote {public_artifact_path(FIGURE_PNG)}")
    print(f"Wrote {public_artifact_path(MEMO_MD)}")


if __name__ == "__main__":
    main()
