Newer
Older
RobotCar / src / common / vision / intersection.py
"""
intersection
十字路分類モデルの読み込みと推論を行うモジュール

学習済みモデルとスケーラを読み込み,
二値画像から十字路かどうかを判定する
"""

from pathlib import Path

import numpy as np

from common.json_utils import PARAMS_DIR

# モデル・スケーラの保存先
_MODEL_PATH: Path = PARAMS_DIR / "intersection_model.pkl"
_SCALER_PATH: Path = PARAMS_DIR / "intersection_scaler.pkl"


class IntersectionClassifier:
    """十字路分類器

    学習済みモデルを読み込み,二値画像から
    十字路かどうかを判定する
    scikit-learn を遅延インポートして起動時間を短縮する
    """

    def __init__(self) -> None:
        self._model: object | None = None
        self._scaler: object | None = None
        self._available: bool = False

    def load(self) -> None:
        """モデルとスケーラを読み込む(遅延呼び出し用)"""
        if not _MODEL_PATH.exists():
            print(
                f"  モデルが見つかりません: {_MODEL_PATH}"
            )
            return
        if not _SCALER_PATH.exists():
            print(
                f"  スケーラが見つかりません: {_SCALER_PATH}"
            )
            return
        try:
            import joblib

            self._model = joblib.load(_MODEL_PATH)
            self._scaler = joblib.load(_SCALER_PATH)
            self._available = True
        except Exception as e:
            print(f"  モデル読み込みエラー: {e}")

    @property
    def available(self) -> bool:
        """モデルが利用可能かどうか"""
        return self._available

    def predict(self, binary_image: np.ndarray) -> bool:
        """二値画像が十字路かどうかを判定する

        Args:
            binary_image: 40×30 の二値画像(0/255)

        Returns:
            十字路なら True
        """
        if not self._available:
            return False
        flat = (binary_image.flatten() / 255.0).astype(
            np.float32,
        )
        x = self._scaler.transform(flat.reshape(1, -1))
        pred = self._model.predict(x)
        return bool(pred[0] == 1)