#!/usr/bin/env python3
"""Reproduce the Oregon BA+ / school poverty two-factor model.

This script is the source for the Evidence Lab two-factor check:

    Percent Proficient ~ ACS BA+ + Students Experiencing Poverty

It intentionally uses only public repository-relative paths. It reads the
processed dashboard CSVs, aggregates assessment records to one school result
within each grade scope, and writes the model summary CSV and short run memo
used by the report.
"""

from __future__ import annotations

import csv
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

import numpy as np


def find_data_root(start: Path) -> Path:
    """Find the nearest public data/processed directory.

    The script is published under site/evidence-lab/artifacts/scripts. In the
    full repository, the nearest matching data directory is usually
    site/data/processed; in an analysis checkout it may be data/processed.
    Either location is acceptable because these are mirrored processed files.
    """

    required = "SchoolDataMathWithAddressesAndCensusSES.csv"
    for parent in [start.parent, *start.parents]:
        candidate = parent / "data" / "processed"
        if (candidate / required).exists():
            return candidate
        candidate = parent / "site" / "data" / "processed"
        if (candidate / required).exists():
            return candidate
    raise FileNotFoundError(
        "Could not find processed dashboard data. Run from a repository checkout "
        "that contains data/processed or site/data/processed."
    )


SCRIPT_PATH = Path(__file__).resolve()
DATA_ROOT = find_data_root(SCRIPT_PATH)
ARTIFACT_ROOT = SCRIPT_PATH.parents[1]
OUT_DIR = ARTIFACT_ROOT / "reports"
RUN_DATE = "2026-05-19"

SUMMARY_CSV = OUT_DIR / f"oregon_ba_school_poverty_joint_model_{RUN_DATE}.csv"
MEMO_MD = OUT_DIR / f"oregon_ba_school_poverty_joint_model_memo_{RUN_DATE}.md"

DATASETS = {
    "ELA": DATA_ROOT / "SchoolDataWithAddressesAndCensusSES.csv",
    "Math": DATA_ROOT / "SchoolDataMathWithAddressesAndCensusSES.csv",
    "Science": DATA_ROOT / "SchoolDataScienceWithAddressesAndCensusSES.csv",
}

GRADE_BANDS = {
    "all": None,
    "elementary": {"3", "4", "5"},
    "middle": {"6", "7", "8"},
    "high": {"11"},
}


@dataclass
class ModelRow:
    subject: str
    scope: str
    exclusion: str
    schools: int
    participants: float
    ba_r: float
    poverty_r: float
    r2_ba_only: float
    r2_poverty_only: float
    r2_joint: float
    delta_r2_ba_after_poverty: float
    delta_r2_poverty_after_ba: float
    joint_beta_ba_std: float
    joint_beta_poverty_std: float
    joint_slope_ba_per_10: float
    joint_slope_poverty_per_10: float
    ba_poverty_r: float


def parse_num(value: object) -> float | None:
    try:
        text = str(value).strip().replace(",", "")
        if text == "":
            return None
        return float(text)
    except Exception:
        return None


def truthy(value: object) -> bool:
    return str(value).strip().lower() in {"1", "true", "t", "yes", "y"}


def is_charter(row: dict[str, str]) -> bool:
    raw = row.get("Is Charter School", "")
    if str(raw).strip() != "":
        return truthy(raw)
    return str(row.get("ODE School Type", "")).strip().lower() == "charter"


def is_virtual(row: dict[str, str]) -> bool:
    raw = row.get("Is Virtual School", "")
    if str(raw).strip() != "":
        return truthy(raw)
    return str(row.get("ODE Virtual Status", "")).strip().lower() in {"focus virtual", "full virtual"}


def grade_code(value: str) -> str:
    text = (value or "").strip().lower()
    if text in {"all grades", "all grade", "all"}:
        return "all"
    return text.replace("grade", "").strip()


def normalize_rate(value: float) -> float:
    return value * 100.0 if abs(value) <= 1.5 else value


def weighted_mean(values: np.ndarray, weights: np.ndarray) -> float:
    return float(np.sum(values * weights) / np.sum(weights))


def weighted_corr(x: np.ndarray, y: np.ndarray, weights: np.ndarray) -> float:
    mx = weighted_mean(x, weights)
    my = weighted_mean(y, weights)
    cov = np.sum(weights * (x - mx) * (y - my)) / np.sum(weights)
    vx = np.sum(weights * (x - mx) ** 2) / np.sum(weights)
    vy = np.sum(weights * (y - my) ** 2) / np.sum(weights)
    if vx <= 0 or vy <= 0:
        return math.nan
    return float(cov / math.sqrt(vx * vy))


def weighted_fit(x: np.ndarray, y: np.ndarray, weights: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
    design = np.column_stack([np.ones(len(x)), x])
    sw = np.sqrt(weights)
    beta, *_ = np.linalg.lstsq(design * sw[:, None], y * sw, rcond=None)
    pred = design @ beta
    ybar = weighted_mean(y, weights)
    ss_res = float(np.sum(weights * (y - pred) ** 2))
    ss_tot = float(np.sum(weights * (y - ybar) ** 2))
    r2 = math.nan if ss_tot <= 0 else 1.0 - ss_res / ss_tot
    return beta, pred, r2


def standardized_joint_beta(
    ba: np.ndarray,
    poverty: np.ndarray,
    y: np.ndarray,
    weights: np.ndarray,
) -> tuple[float, float]:
    def zscore(values: np.ndarray) -> np.ndarray:
        mean = weighted_mean(values, weights)
        var = np.sum(weights * (values - mean) ** 2) / np.sum(weights)
        return (values - mean) / math.sqrt(var)

    xz = np.column_stack([zscore(ba), zscore(poverty)])
    yz = zscore(y)
    beta, _, _ = weighted_fit(xz, yz, weights)
    return float(beta[1]), float(beta[2])


def load_scope_rows(
    subject: str,
    path: Path,
    grade_band: str,
    omit_charter_virtual: bool,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    wanted = GRADE_BANDS[grade_band]
    staged: list[tuple[str, str, str, float, float, float, float]] = []

    # Keep only Total Population rows. This avoids mixing demographic subgroup
    # rows with all-student results and matches the report's stated scope.
    with path.open(newline="", encoding="utf-8-sig") as handle:
        reader = csv.DictReader(handle)
        for row in reader:
            if row.get("Student Group") != "Total Population (All Students)":
                continue
            if omit_charter_virtual and (is_charter(row) or is_virtual(row)):
                continue

            grade = grade_code(row.get("Grade Level", ""))
            if wanted is not None and grade not in wanted:
                continue

            participants = parse_num(row.get("Number of Participants", ""))
            proficient = parse_num(row.get("Number Proficient", ""))
            ba = parse_num(row.get("ACS_Ed_BA_or_Higher_Rate", ""))
            poverty = parse_num(row.get("Students Experiencing Poverty", ""))
            if None in (participants, proficient, ba, poverty):
                continue
            if participants <= 0:
                continue

            staged.append(
                (
                    row.get("District ID", ""),
                    row.get("School ID", ""),
                    grade,
                    participants,
                    proficient,
                    normalize_rate(ba),
                    normalize_rate(poverty),
                )
            )

    # Math and ELA usually provide grade-specific records; Science includes
    # data-backed "All Grades" records. If both are present for a school, use
    # the grade-specific records and drop the pre-assembled All Grades row so
    # the same students are not counted twice.
    grade_specific_by_school: dict[tuple[str, str], bool] = {}
    for district_id, school_id, grade, *_ in staged:
        key = (district_id, school_id)
        grade_specific_by_school[key] = grade_specific_by_school.get(key, False) or grade != "all"

    # Collapse selected grade rows to one school-level observation. SES fields
    # are averaged by tested participants because the outcome is also built from
    # participant counts.
    schools: dict[tuple[str, str], dict[str, float]] = {}
    for district_id, school_id, grade, participants, proficient, ba, poverty in staged:
        key = (district_id, school_id)
        if grade_band == "all" and grade == "all" and grade_specific_by_school.get(key, False):
            continue
        if key not in schools:
            schools[key] = {
                "participants": 0.0,
                "proficient": 0.0,
                "ba_wsum": 0.0,
                "poverty_wsum": 0.0,
            }
        item = schools[key]
        item["participants"] += participants
        item["proficient"] += proficient
        item["ba_wsum"] += ba * participants
        item["poverty_wsum"] += poverty * participants

    y_values: list[float] = []
    ba_values: list[float] = []
    poverty_values: list[float] = []
    weights: list[float] = []
    for item in schools.values():
        participants = item["participants"]
        if participants <= 0:
            continue
        weights.append(participants)
        y_values.append(100.0 * item["proficient"] / participants)
        ba_values.append(item["ba_wsum"] / participants)
        poverty_values.append(item["poverty_wsum"] / participants)

    return (
        np.array(y_values, dtype=float),
        np.array(ba_values, dtype=float),
        np.array(poverty_values, dtype=float),
        np.array(weights, dtype=float),
    )


def run_one(subject: str, path: Path, grade_band: str, exclusion: str) -> ModelRow | None:
    omit = exclusion == "noncharter_nonvirtual"
    y, ba, poverty, weights = load_scope_rows(subject, path, grade_band, omit)
    if len(y) < 20 or float(np.sum(weights)) <= 0:
        return None

    # Weighted least squares: each school contributes in proportion to tested
    # participants, preserving the report's student-weighted interpretation.
    _, _, r2_ba = weighted_fit(ba.reshape(-1, 1), y, weights)
    _, _, r2_poverty = weighted_fit(poverty.reshape(-1, 1), y, weights)
    joint_beta, _, r2_joint = weighted_fit(np.column_stack([ba, poverty]), y, weights)
    std_ba, std_poverty = standardized_joint_beta(ba, poverty, y, weights)

    return ModelRow(
        subject=subject,
        scope=grade_band,
        exclusion=exclusion,
        schools=len(y),
        participants=float(np.sum(weights)),
        ba_r=weighted_corr(ba, y, weights),
        poverty_r=weighted_corr(poverty, y, weights),
        r2_ba_only=float(r2_ba),
        r2_poverty_only=float(r2_poverty),
        r2_joint=float(r2_joint),
        # Added R^2 is the extra explanatory power gained by adding one
        # predictor after the other is already in the model.
        delta_r2_ba_after_poverty=float(r2_joint - r2_poverty),
        delta_r2_poverty_after_ba=float(r2_joint - r2_ba),
        joint_beta_ba_std=std_ba,
        joint_beta_poverty_std=std_poverty,
        joint_slope_ba_per_10=float(joint_beta[1] * 10.0),
        joint_slope_poverty_per_10=float(joint_beta[2] * 10.0),
        ba_poverty_r=weighted_corr(ba, poverty, weights),
    )


def fmt(value: float, digits: int = 3) -> str:
    if value is None or not math.isfinite(value):
        return ""
    return f"{value:.{digits}f}"


def public_artifact_path(path: Path) -> str:
    """Return a stable relative path for memo text, avoiding local user paths."""

    try:
        return str(path.relative_to(ARTIFACT_ROOT.parent))
    except ValueError:
        return path.name


def write_csv(rows: Iterable[ModelRow]) -> None:
    fieldnames = list(ModelRow.__dataclass_fields__.keys())
    SUMMARY_CSV.parent.mkdir(parents=True, exist_ok=True)
    with SUMMARY_CSV.open("w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=fieldnames)
        writer.writeheader()
        for row in rows:
            writer.writerow({name: getattr(row, name) for name in fieldnames})


def write_memo(rows: list[ModelRow]) -> None:
    primary = [
        row
        for row in rows
        if row.exclusion == "noncharter_nonvirtual" and row.scope in {"all", "elementary"}
    ]
    all_primary = [row for row in primary if row.scope == "all"]
    elem_primary = [row for row in primary if row.scope == "elementary"]

    lines = [
        "# Oregon BA+ and School Poverty Joint Model Check",
        "",
        f"Date: {RUN_DATE}",
        "",
        "## Question",
        "",
        "Have we tested whether ACS adult BA+ still predicts proficiency after adding the school-sourced `Students Experiencing Poverty` measure to the same model?",
        "",
        "Short answer: earlier Oregon work did something very close in broader four-factor and residual-peer models. This artifact runs the narrower two-predictor diagnostic directly:",
        "",
        "`Percent Proficient ~ ACS BA+ + Students Experiencing Poverty`",
        "",
        "## Scope",
        "",
        "- Oregon 2024-25 ELA, Math, and Science processed school files.",
        "- Student group: `Total Population (All Students)`.",
        "- Primary filter: non-charter, non-virtual schools.",
        "- Sensitivity filter: all schools.",
        "- School-level aggregation: grade rows are aggregated to one school result within each grade band; dataset-backed `All Grades` rows are not double-counted when grade rows exist.",
        "- Weighting: `Number of Participants`.",
        "- Outcome: percent proficient.",
        "- Predictors: ACS tract/block-group adult BA+ rate and ODE school-sourced `Students Experiencing Poverty`.",
        "",
        "## Primary Results: All Tested Grades",
        "",
        "| Subject | Schools | BA+ r | Poverty r | Joint R^2 | BA+ beta | Poverty beta | BA+ delta R^2 after poverty | Poverty delta R^2 after BA+ |",
        "| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |",
    ]

    for row in all_primary:
        lines.append(
            f"| {row.subject} | {row.schools} | {fmt(row.ba_r)} | {fmt(row.poverty_r)} | "
            f"{fmt(row.r2_joint)} | {fmt(row.joint_beta_ba_std)} | {fmt(row.joint_beta_poverty_std)} | "
            f"{fmt(row.delta_r2_ba_after_poverty)} | {fmt(row.delta_r2_poverty_after_ba)} |"
        )

    lines.extend(
        [
            "",
            "## Primary Results: Elementary Scope",
            "",
            "| Subject | Schools | BA+ r | Poverty r | Joint R^2 | BA+ beta | Poverty beta | BA+ delta R^2 after poverty | Poverty delta R^2 after BA+ |",
            "| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |",
        ]
    )

    for row in elem_primary:
        lines.append(
            f"| {row.subject} | {row.schools} | {fmt(row.ba_r)} | {fmt(row.poverty_r)} | "
            f"{fmt(row.r2_joint)} | {fmt(row.joint_beta_ba_std)} | {fmt(row.joint_beta_poverty_std)} | "
            f"{fmt(row.delta_r2_ba_after_poverty)} | {fmt(row.delta_r2_poverty_after_ba)} |"
        )

    lines.extend(
        [
            "",
            "## Read",
            "",
            "School poverty is very strong, especially for ELA and Science, but BA+ generally remains positive with meaningful incremental explanatory power after school poverty is included.",
            "",
            "If the school-sourced poverty measure captured nearly all of the relevant socioeconomic context, then the BA+ standardized beta and drop-one delta R^2 would fall near zero. In these Oregon models, they do not.",
            "",
            "That does not make the ACS school-site measure perfect. It still says Oregon's tract BA+ measure is capturing context that is not fully reducible to the school poverty field. The school poverty field also captures a lot that BA+ does not, so the best current interpretation is complementary signal rather than replacement.",
            "",
            "## Caveats",
            "",
            "- This is observational and school-level. It does not identify individual family SES effects.",
            "- `Students Experiencing Poverty` is a school-sourced Oregon field and should not be read as an individual-student poverty measure.",
            "- BA+ is measured from ACS geography attached to the school location, not a true attendance-zone crosswalk.",
            "- High school and mixed-grade results remain more vulnerable to catchment-area mismatch than elementary results.",
            "- Standardized betas compare variables on common weighted standard-deviation units; delta R^2 is the added explanatory power from adding one predictor after the other.",
            "",
            "## Output",
            "",
            f"- `{public_artifact_path(SUMMARY_CSV)}`",
        ]
    )

    MEMO_MD.write_text("\n".join(lines) + "\n", encoding="utf-8")


def main() -> None:
    rows: list[ModelRow] = []
    for subject, path in DATASETS.items():
        for grade_band in GRADE_BANDS:
            for exclusion in ("noncharter_nonvirtual", "all_schools"):
                row = run_one(subject, path, grade_band, exclusion)
                if row is not None:
                    rows.append(row)

    write_csv(rows)
    write_memo(rows)
    print(f"Wrote {public_artifact_path(SUMMARY_CSV)}")
    print(f"Wrote {public_artifact_path(MEMO_MD)}")


if __name__ == "__main__":
    main()
