"""
intersection
十字路分類モデルの読み込みと推論を行うモジュール
学習済みモデルとスケーラを読み込み,
二値画像から十字路かどうかを判定する
"""
from pathlib import Path
import numpy as np
from common.json_utils import PARAMS_DIR
# モデル・スケーラの保存先
_MODEL_PATH: Path = PARAMS_DIR / "intersection_model.pkl"
_SCALER_PATH: Path = PARAMS_DIR / "intersection_scaler.pkl"
class IntersectionClassifier:
"""十字路分類器
学習済みモデルを読み込み,二値画像から
十字路かどうかを判定する
scikit-learn を遅延インポートして起動時間を短縮する
"""
def __init__(self) -> None:
self._model: object | None = None
self._scaler: object | None = None
self._available: bool = False
def load(self) -> None:
"""モデルとスケーラを読み込む(遅延呼び出し用)"""
if not _MODEL_PATH.exists():
print(
f" モデルが見つかりません: {_MODEL_PATH}"
)
return
if not _SCALER_PATH.exists():
print(
f" スケーラが見つかりません: {_SCALER_PATH}"
)
return
try:
import joblib
self._model = joblib.load(_MODEL_PATH)
self._scaler = joblib.load(_SCALER_PATH)
self._available = True
except Exception as e:
print(f" モデル読み込みエラー: {e}")
@property
def available(self) -> bool:
"""モデルが利用可能かどうか"""
return self._available
def predict(self, binary_image: np.ndarray) -> bool:
"""二値画像が十字路かどうかを判定する
Args:
binary_image: 40×30 の二値画像(0/255)
Returns:
十字路なら True
"""
if not self._available:
return False
flat = (binary_image.flatten() / 255.0).astype(
np.float32,
)
x = self._scaler.transform(flat.reshape(1, -1))
pred = self._model.predict(x)
return bool(pred[0] == 1)