"""
intersection
十字路分類モデルの読み込みと推論を行うモジュール
学習済みモデルとスケーラを読み込み,
二値画像から十字路かどうかを判定する
"""
from pathlib import Path
import joblib
import numpy as np
from common.json_utils import PARAMS_DIR
# モデル・スケーラの保存先
_MODEL_PATH: Path = PARAMS_DIR / "intersection_model.pkl"
_SCALER_PATH: Path = PARAMS_DIR / "intersection_scaler.pkl"
class IntersectionClassifier:
"""十字路分類器
学習済みモデルを読み込み,二値画像から
十字路かどうかを判定する
"""
def __init__(self) -> None:
self._model: object | None = None
self._scaler: object | None = None
self._available: bool = False
self._load()
def _load(self) -> None:
"""モデルとスケーラを読み込む"""
if not _MODEL_PATH.exists():
return
if not _SCALER_PATH.exists():
return
self._model = joblib.load(_MODEL_PATH)
self._scaler = joblib.load(_SCALER_PATH)
self._available = True
@property
def available(self) -> bool:
"""モデルが利用可能かどうか"""
return self._available
def predict(self, binary_image: np.ndarray) -> bool:
"""二値画像が十字路かどうかを判定する
Args:
binary_image: 40×30 の二値画像(0/255)
Returns:
十字路なら True
"""
if not self._available:
return False
flat = (binary_image.flatten() / 255.0).astype(
np.float32,
)
x = self._scaler.transform(flat.reshape(1, -1))
pred = self._model.predict(x)
return bool(pred[0] == 1)