"""
train
十字路分類モデルの学習・評価スクリプト

複数モデルを 5-fold CV で比較し，結果をドキュメントに出力する
最良モデルを全データで再学習して params/ に保存する

使い方:
    $ cd src && python -m pc.data.train
"""

import sys
from datetime import datetime
from pathlib import Path

import joblib
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report,
    f1_score,
)
from sklearn.model_selection import StratifiedKFold
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC

from common.json_utils import PARAMS_DIR
from pc.data.dataset import load_dataset

# ── 定数 ──────────────────────────────────────────────

N_FOLDS: int = 5
RANDOM_STATE: int = 42

# モデル保存先
MODEL_PATH: Path = PARAMS_DIR / "intersection_model.pkl"
SCALER_PATH: Path = PARAMS_DIR / "intersection_scaler.pkl"

# ドキュメント出力先
_PROJECT_ROOT: Path = (
    Path(__file__).resolve().parent.parent.parent.parent
)
REPORT_PATH: Path = (
    _PROJECT_ROOT / "docs" / "03_TECH"
    / "TECH_06_十字路分類モデル評価.txt"
)


# ── モデル定義 ────────────────────────────────────────

def _build_models() -> list[tuple[str, object]]:
    """比較対象のモデル一覧を返す

    Returns:
        (名前, モデルインスタンス) のリスト
    """
    return [
        (
            "LogisticRegression",
            LogisticRegression(
                max_iter=1000,
                random_state=RANDOM_STATE,
            ),
        ),
        (
            "SVM_RBF",
            SVC(
                kernel="rbf",
                random_state=RANDOM_STATE,
            ),
        ),
        (
            "SVM_Linear",
            SVC(
                kernel="linear",
                random_state=RANDOM_STATE,
            ),
        ),
        (
            "RandomForest",
            RandomForestClassifier(
                n_estimators=100,
                random_state=RANDOM_STATE,
            ),
        ),
        (
            "MLP_1layer",
            MLPClassifier(
                hidden_layer_sizes=(64,),
                max_iter=500,
                random_state=RANDOM_STATE,
            ),
        ),
        (
            "MLP_2layer",
            MLPClassifier(
                hidden_layer_sizes=(128, 64),
                max_iter=500,
                random_state=RANDOM_STATE,
            ),
        ),
    ]


# ── 評価 ──────────────────────────────────────────────

def _evaluate_models(
    x: np.ndarray,
    y: np.ndarray,
) -> list[dict]:
    """全モデルを StratifiedKFold CV で評価する

    Args:
        x: 特徴量行列
        y: ラベル配列

    Returns:
        各モデルの評価結果辞書のリスト
    """
    skf = StratifiedKFold(
        n_splits=N_FOLDS,
        shuffle=True,
        random_state=RANDOM_STATE,
    )
    results: list[dict] = []

    for name, model in _build_models():
        f1_scores: list[float] = []

        for train_idx, test_idx in skf.split(x, y):
            x_train, x_test = x[train_idx], x[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]

            # スケーリング
            scaler = StandardScaler()
            x_train_s = scaler.fit_transform(x_train)
            x_test_s = scaler.transform(x_test)

            model_clone = _clone_model(name)
            model_clone.fit(x_train_s, y_train)
            y_pred = model_clone.predict(x_test_s)

            f1 = f1_score(y_test, y_pred, average="macro")
            f1_scores.append(f1)

        results.append({
            "name": name,
            "f1_mean": float(np.mean(f1_scores)),
            "f1_std": float(np.std(f1_scores)),
            "f1_scores": f1_scores,
        })
        print(
            f"  {name:25s}  "
            f"F1={np.mean(f1_scores):.4f} "
            f"(±{np.std(f1_scores):.4f})"
        )

    return results


def _clone_model(name: str) -> object:
    """モデル名から新しいインスタンスを生成する

    Args:
        name: モデル名

    Returns:
        モデルインスタンス
    """
    for n, m in _build_models():
        if n == name:
            return m
    raise ValueError(f"不明なモデル: {name}")


# ── 最良モデルの再学習・保存 ──────────────────────────

def _train_best_and_save(
    best_name: str,
    x: np.ndarray,
    y: np.ndarray,
) -> str:
    """最良モデルを全データで再学習して保存する

    Args:
        best_name: 最良モデルの名前
        x: 全特徴量行列
        y: 全ラベル配列

    Returns:
        全データでの classification_report 文字列
    """
    scaler = StandardScaler()
    x_scaled = scaler.fit_transform(x)

    model = _clone_model(best_name)
    model.fit(x_scaled, y)
    y_pred = model.predict(x_scaled)
    report = classification_report(
        y, y_pred,
        target_names=["normal", "intersection"],
    )

    # 保存
    PARAMS_DIR.mkdir(parents=True, exist_ok=True)
    joblib.dump(model, MODEL_PATH)
    joblib.dump(scaler, SCALER_PATH)

    return report


# ── ドキュメント出力 ──────────────────────────────────

def _write_report(
    results: list[dict],
    best_name: str,
    n_samples: int,
    n_intersection: int,
    n_normal: int,
    full_report: str,
) -> None:
    """評価結果をドキュメントとして出力する

    Args:
        results: 各モデルの評価結果
        best_name: 最良モデルの名前
        n_samples: 全サンプル数
        n_intersection: intersection のサンプル数
        n_normal: normal のサンプル数
        full_report: 全データでの classification_report
    """
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")

    # 結果テーブル
    sorted_results = sorted(
        results, key=lambda r: r["f1_mean"], reverse=True,
    )
    table_lines: list[str] = []
    table_lines.append(
        f"    {'モデル':25s}  {'F1(平均)':>10s}"
        f"  {'F1(標準偏差)':>14s}"
    )
    table_lines.append(f"    {'─' * 55}")
    for r in sorted_results:
        marker = " ← 採用" if r["name"] == best_name else ""
        table_lines.append(
            f"    {r['name']:25s}"
            f"  {r['f1_mean']:10.4f}"
            f"  {r['f1_std']:14.4f}{marker}"
        )
    table_str = "\n".join(table_lines)

    # fold 詳細
    fold_lines: list[str] = []
    for r in sorted_results:
        scores_str = ", ".join(
            f"{s:.4f}" for s in r["f1_scores"]
        )
        fold_lines.append(
            f"    ・{r['name']}: [{scores_str}]"
        )
    fold_str = "\n".join(fold_lines)

    # classification_report のインデント
    report_indented = "\n".join(
        f"        {line}" for line in full_report.splitlines()
    )

    content = f"""\
========================================================================
十字路分類モデル評価 (Intersection Classifier Evaluation)
========================================================================


1. 概要 (Overview)
------------------------------------------------------------------------

    1-0. 目的

    十字路（intersection）と通常区間（normal）を分類する
    二値画像分類モデルの比較評価結果を記録する．

    1-1. 評価日時

    ・実施日時: {timestamp}

    1-2. データセット

    ・入力: 40×30 二値画像（1200 特徴量，0.0/1.0）
    ・全サンプル数: {n_samples}
        - intersection: {n_intersection}
        - normal: {n_normal}
    ・クラス比率: intersection:normal\
 = {n_intersection}:{n_normal}


2. 評価方法 (Evaluation Method)
------------------------------------------------------------------------

    2-1. 交差検証

    ・手法: Stratified {N_FOLDS}-Fold Cross-Validation
    ・指標: マクロ平均 F1 スコア
    ・前処理: StandardScaler（fold ごとに fit）
    ・乱数シード: {RANDOM_STATE}


3. 評価結果 (Results)
------------------------------------------------------------------------

    3-1. モデル比較（F1 スコア降順）

{table_str}

    3-2. 各 Fold の F1 スコア

{fold_str}

    3-3. 採用モデル

    ・モデル: {best_name}
    ・保存先: params/intersection_model.pkl
    ・スケーラ: params/intersection_scaler.pkl

    3-4. 全データでの分類レポート（再学習後）

{report_indented}
"""

    REPORT_PATH.parent.mkdir(parents=True, exist_ok=True)
    with open(
        REPORT_PATH, "w", encoding="utf-8",
    ) as f:
        f.write(content)

    print(f"\nドキュメント出力: {REPORT_PATH}")


# ── メイン ────────────────────────────────────────────

def main() -> None:
    """学習・評価のメインフロー"""
    print("=== 十字路分類モデルの学習・評価 ===\n")

    # データ読み込み
    print("データ読み込み中...")
    try:
        x, y = load_dataset()
    except FileNotFoundError as e:
        print(f"エラー: {e}")
        sys.exit(1)

    n_samples = len(y)
    n_intersection = int(np.sum(y == 1))
    n_normal = int(np.sum(y == 0))
    print(
        f"  サンプル数: {n_samples}"
        f" (intersection={n_intersection},"
        f" normal={n_normal})\n"
    )

    # CV 評価
    print(f"{N_FOLDS}-Fold CV 評価中...")
    results = _evaluate_models(x, y)

    # 最良モデル選定
    best = max(results, key=lambda r: r["f1_mean"])
    best_name = best["name"]
    print(
        f"\n最良モデル: {best_name}"
        f" (F1={best['f1_mean']:.4f})\n"
    )

    # 全データで再学習・保存
    print("全データで再学習・保存中...")
    full_report = _train_best_and_save(best_name, x, y)
    print(f"  モデル保存: {MODEL_PATH}")
    print(f"  スケーラ保存: {SCALER_PATH}")

    # ドキュメント出力
    _write_report(
        results, best_name,
        n_samples, n_intersection, n_normal,
        full_report,
    )

    print("\n完了")


if __name__ == "__main__":
    main()
