"""
intersection
十字路分類モデルの読み込みと推論を行うモジュール

学習済みモデルとスケーラを読み込み，
二値画像から十字路かどうかを判定する
"""

from pathlib import Path

import numpy as np

from common.json_utils import PARAMS_DIR

# モデル・スケーラの保存先
_MODEL_PATH: Path = PARAMS_DIR / "intersection_model.pkl"
_SCALER_PATH: Path = PARAMS_DIR / "intersection_scaler.pkl"


class IntersectionClassifier:
    """十字路分類器

    学習済みモデルを読み込み，二値画像から
    十字路かどうかを判定する
    scikit-learn を遅延インポートして起動時間を短縮する
    """

    def __init__(self) -> None:
        self._model: object | None = None
        self._scaler: object | None = None
        self._available: bool = False

    def load(self) -> None:
        """モデルとスケーラを読み込む（遅延呼び出し用）"""
        if not _MODEL_PATH.exists():
            return
        if not _SCALER_PATH.exists():
            return
        import joblib

        self._model = joblib.load(_MODEL_PATH)
        self._scaler = joblib.load(_SCALER_PATH)
        self._available = True

    @property
    def available(self) -> bool:
        """モデルが利用可能かどうか"""
        return self._available

    def predict(self, binary_image: np.ndarray) -> bool:
        """二値画像が十字路かどうかを判定する

        Args:
            binary_image: 40×30 の二値画像（0/255）

        Returns:
            十字路なら True
        """
        if not self._available:
            return False
        flat = (binary_image.flatten() / 255.0).astype(
            np.float32,
        )
        x = self._scaler.transform(flat.reshape(1, -1))
        pred = self._model.predict(x)
        return bool(pred[0] == 1)
