diff --git a/src/common/steering/base.py b/src/common/steering/base.py index 40eff1a..cbb365f 100644 --- a/src/common/steering/base.py +++ b/src/common/steering/base.py @@ -9,6 +9,13 @@ import numpy as np +from common.vision.line_detector import ( + ImageParams, + LineDetectResult, + detect_line, + reset_valley_tracker, +) + @dataclass class SteeringOutput: @@ -25,26 +32,78 @@ class SteeringBase(ABC): """操舵量計算の基底クラス - 全ての操舵量計算クラスはこのクラスを継承し, - compute メソッドを実装する + 線検出・レートリミッター・状態管理の共通ロジックを提供し, + サブクラスは _compute_from_result で操舵計算のみ実装する """ - @abstractmethod + def __init__( + self, + image_params: ImageParams | None = None, + ) -> None: + self.image_params: ImageParams = ( + image_params or ImageParams() + ) + self._prev_steer: float = 0.0 + self._last_result: LineDetectResult | None = None + def compute( self, frame: np.ndarray, ) -> SteeringOutput: """カメラ画像から操舵量を計算する + 線検出 → サブクラスの操舵計算 → レートリミッター + の共通フローを実行する + Args: - frame: BGR 形式のカメラ画像 + frame: グレースケールのカメラ画像 Returns: 計算された操舵量 """ + result = detect_line(frame, self.image_params) + self._last_result = result + + output = self._compute_from_result(result) + + # レートリミッター + max_rate = self._max_steer_rate() + delta = output.steer - self._prev_steer + delta = max(-max_rate, min(max_rate, delta)) + output.steer = self._prev_steer + delta + self._prev_steer = output.steer + + return output @abstractmethod - def reset(self) -> None: - """内部状態をリセットする + def _compute_from_result( + self, result: LineDetectResult, + ) -> SteeringOutput: + """線検出結果から操舵量を計算する - 自動操縦の開始時に呼び出される + サブクラスで操舵アルゴリズムを実装する. + レートリミッターは基底クラスが適用するため, + ここでは素の操舵量を返せばよい + + Args: + result: 線検出の結果 + + Returns: + 計算された操舵量(レートリミッター適用前) """ + + @abstractmethod + def _max_steer_rate(self) -> float: + """1フレームあたりの最大操舵変化量を返す""" + + def reset(self) -> None: + """内部状態をリセットする""" + self._prev_steer = 0.0 + self._last_result = None + reset_valley_tracker() + + @property + def last_detect_result( + self, + ) -> LineDetectResult | None: + """直近の線検出結果を取得する""" + return self._last_result diff --git a/src/common/steering/pd_control.py b/src/common/steering/pd_control.py index ec31ada..2579b30 100644 --- a/src/common/steering/pd_control.py +++ b/src/common/steering/pd_control.py @@ -13,8 +13,6 @@ from common.vision.line_detector import ( ImageParams, LineDetectResult, - detect_line, - reset_valley_tracker, ) @@ -46,38 +44,29 @@ params: PdParams | None = None, image_params: ImageParams | None = None, ) -> None: + super().__init__(image_params) self.params: PdParams = params or PdParams() - self.image_params: ImageParams = ( - image_params or ImageParams() - ) self._prev_error: float = 0.0 self._prev_time: float = 0.0 - self._prev_steer: float = 0.0 - self._last_result: LineDetectResult | None = None - def compute( - self, frame: np.ndarray, + def _compute_from_result( + self, result: LineDetectResult, ) -> SteeringOutput: - """カメラ画像から PD 制御で操舵量を計算する + """PD 制御で操舵量を計算する Args: - frame: グレースケールのカメラ画像 + result: 線検出の結果 Returns: 計算された操舵量 """ - p = self.params - - # 線検出 - result = detect_line(frame, self.image_params) - self._last_result = result - - # 線が検出できなかった場合は停止 if not result.detected: return SteeringOutput( throttle=0.0, steer=0.0, ) + p = self.params + # 位置偏差 + 傾きによる操舵量 error = ( p.kp * result.position_error @@ -100,12 +89,6 @@ # 操舵量のクランプ steer = max(-1.0, min(1.0, steer)) - # レートリミッター(急な操舵変化を制限) - delta = steer - self._prev_steer - max_delta = p.max_steer_rate - delta = max(-max_delta, min(max_delta, delta)) - steer = self._prev_steer + delta - # 速度制御(曲率連動) throttle = ( p.max_throttle @@ -116,23 +99,16 @@ # 状態の更新 self._prev_error = error self._prev_time = now - self._prev_steer = steer return SteeringOutput( throttle=throttle, steer=steer, ) + def _max_steer_rate(self) -> float: + return self.params.max_steer_rate + def reset(self) -> None: """内部状態をリセットする""" + super().reset() self._prev_error = 0.0 self._prev_time = 0.0 - self._prev_steer = 0.0 - self._last_result = None - reset_valley_tracker() - - @property - def last_detect_result( - self, - ) -> LineDetectResult | None: - """直近の線検出結果を取得する""" - return self._last_result diff --git a/src/common/steering/pursuit_control.py b/src/common/steering/pursuit_control.py index d576cf3..4e00ac9 100644 --- a/src/common/steering/pursuit_control.py +++ b/src/common/steering/pursuit_control.py @@ -14,8 +14,6 @@ from common.vision.line_detector import ( ImageParams, LineDetectResult, - detect_line, - reset_valley_tracker, ) @@ -53,37 +51,28 @@ params: PursuitParams | None = None, image_params: ImageParams | None = None, ) -> None: + super().__init__(image_params) self.params: PursuitParams = ( params or PursuitParams() ) - self.image_params: ImageParams = ( - image_params or ImageParams() - ) - self._prev_steer: float = 0.0 - self._last_result: LineDetectResult | None = None - def compute( - self, frame: np.ndarray, + def _compute_from_result( + self, result: LineDetectResult, ) -> SteeringOutput: - """カメラ画像から2点パシュートで操舵量を計算する + """2点パシュートで操舵量を計算する Args: - frame: グレースケールのカメラ画像 + result: 線検出の結果 Returns: 計算された操舵量 """ - p = self.params - - # 線検出 - result = detect_line(frame, self.image_params) - self._last_result = result - if not result.detected or result.row_centers is None: return SteeringOutput( throttle=0.0, steer=0.0, ) + p = self.params centers = result.row_centers # 有効な点(NaN でない行)を抽出 @@ -116,32 +105,14 @@ steer = p.k_near * near_err + p.k_far * far_err steer = max(-1.0, min(1.0, steer)) - # レートリミッター - delta = steer - self._prev_steer - max_delta = p.max_steer_rate - delta = max(-max_delta, min(max_delta, delta)) - steer = self._prev_steer + delta - # 速度制御(2点の x 差でカーブ度合いを判定) curve = abs(near_x - far_x) / center_x throttle = p.max_throttle - p.speed_k * curve throttle = max(0.0, throttle) - self._prev_steer = steer - return SteeringOutput( throttle=throttle, steer=steer, ) - def reset(self) -> None: - """内部状態をリセットする""" - self._prev_steer = 0.0 - self._last_result = None - reset_valley_tracker() - - @property - def last_detect_result( - self, - ) -> LineDetectResult | None: - """直近の線検出結果を取得する""" - return self._last_result + def _max_steer_rate(self) -> float: + return self.params.max_steer_rate diff --git a/src/common/steering/ts_pd_control.py b/src/common/steering/ts_pd_control.py index 30944c3..e4f2f41 100644 --- a/src/common/steering/ts_pd_control.py +++ b/src/common/steering/ts_pd_control.py @@ -16,8 +16,6 @@ from common.vision.line_detector import ( ImageParams, LineDetectResult, - detect_line, - reset_valley_tracker, ) @@ -53,39 +51,30 @@ params: TsPdParams | None = None, image_params: ImageParams | None = None, ) -> None: + super().__init__(image_params) self.params: TsPdParams = ( params or TsPdParams() ) - self.image_params: ImageParams = ( - image_params or ImageParams() - ) self._prev_error: float = 0.0 self._prev_time: float = 0.0 - self._prev_steer: float = 0.0 - self._last_result: LineDetectResult | None = None - def compute( - self, frame: np.ndarray, + def _compute_from_result( + self, result: LineDetectResult, ) -> SteeringOutput: - """カメラ画像から Theil-Sen PD 制御で操舵量を計算する + """Theil-Sen PD 制御で操舵量を計算する Args: - frame: グレースケールのカメラ画像 + result: 線検出の結果 Returns: 計算された操舵量 """ - p = self.params - - # 線検出 - result = detect_line(frame, self.image_params) - self._last_result = result - if not result.detected or result.row_centers is None: return SteeringOutput( throttle=0.0, steer=0.0, ) + p = self.params centers = result.row_centers # 有効な点(NaN でない行)を抽出 @@ -127,12 +116,6 @@ # 操舵量のクランプ steer = max(-1.0, min(1.0, steer)) - # レートリミッター - delta = steer - self._prev_steer - max_delta = p.max_steer_rate - delta = max(-max_delta, min(max_delta, delta)) - steer = self._prev_steer + delta - # 速度制御(傾きベース) throttle = p.max_throttle - p.speed_k * abs(slope) throttle = max(0.0, throttle) @@ -140,23 +123,16 @@ # 状態の更新 self._prev_error = error self._prev_time = now - self._prev_steer = steer return SteeringOutput( throttle=throttle, steer=steer, ) + def _max_steer_rate(self) -> float: + return self.params.max_steer_rate + def reset(self) -> None: """内部状態をリセットする""" + super().reset() self._prev_error = 0.0 self._prev_time = 0.0 - self._prev_steer = 0.0 - self._last_result = None - reset_valley_tracker() - - @property - def last_detect_result( - self, - ) -> LineDetectResult | None: - """直近の線検出結果を取得する""" - return self._last_result diff --git a/src/common/vision/detectors/valley.py b/src/common/vision/detectors/valley.py index 0368c5b..e29aa60 100644 --- a/src/common/vision/detectors/valley.py +++ b/src/common/vision/detectors/valley.py @@ -99,12 +99,27 @@ self._frames_lost = 0 -_valley_tracker = ValleyTracker() +_valley_tracker: ValleyTracker | None = None + + +def get_valley_tracker() -> ValleyTracker: + """モジュール内のデフォルト ValleyTracker を取得する + + 初回呼び出し時にインスタンスを生成する + + Returns: + ValleyTracker インスタンス + """ + global _valley_tracker + if _valley_tracker is None: + _valley_tracker = ValleyTracker() + return _valley_tracker def reset_valley_tracker() -> None: """谷検出の追跡状態をリセットする""" - _valley_tracker.reset() + if _valley_tracker is not None: + _valley_tracker.reset() def _find_row_valley( @@ -226,9 +241,19 @@ def detect_valley( - frame: np.ndarray, params: ImageParams, + frame: np.ndarray, + params: ImageParams, + tracker: ValleyTracker | None = None, ) -> LineDetectResult: - """案D: 谷検出+追跡型""" + """案D: 谷検出+追跡型 + + Args: + frame: グレースケールのカメラ画像 + params: 二値化パラメータ + tracker: 追跡状態(None でモジュール内デフォルトを使用) + """ + if tracker is None: + tracker = get_valley_tracker() h, w = frame.shape[:2] # 行ごとにガウシアン平滑化するため画像全体をブラー @@ -262,7 +287,7 @@ expected_w = 0.0 # 予測 x 座標 - predicted_x = _valley_tracker.predict_x( + predicted_x = tracker.predict_x( float(y), ) @@ -285,7 +310,7 @@ ) if len(centers_y) < MIN_FIT_ROWS: - coasted = _valley_tracker.coast( + coasted = tracker.coast( params.valley_coast_frames, ) if coasted is not None: @@ -308,7 +333,7 @@ ransac_iter=params.ransac_iter, ) if coeffs is None: - coasted = _valley_tracker.coast( + coasted = tracker.coast( params.valley_coast_frames, ) if coasted is not None: @@ -317,7 +342,7 @@ return no_detection(debug_binary) # EMA で平滑化 - smoothed = _valley_tracker.update( + smoothed = tracker.update( coeffs, params.valley_ema_alpha, ) diff --git a/src/common/vision/line_detector.py b/src/common/vision/line_detector.py index dbc9ec6..018c10d 100644 --- a/src/common/vision/line_detector.py +++ b/src/common/vision/line_detector.py @@ -8,11 +8,17 @@ reset_valley_tracker, DETECT_METHODS """ +from __future__ import annotations + from dataclasses import dataclass +from typing import TYPE_CHECKING import cv2 import numpy as np +if TYPE_CHECKING: + from collections.abc import Callable + from common import config from common.vision.fitting import clean_and_fit @@ -155,6 +161,42 @@ # ── 公開 API ────────────────────────────────────── +def _get_detector_registry() -> dict[ + str, + "Callable[[np.ndarray, ImageParams], LineDetectResult]", +]: + """検出手法の辞書を遅延構築して返す + + Returns: + 手法識別子と検出関数の辞書 + """ + from common.vision.detectors.blackhat import ( + detect_blackhat, + ) + from common.vision.detectors.current import ( + detect_current, + ) + from common.vision.detectors.dual_norm import ( + detect_dual_norm, + ) + from common.vision.detectors.robust import ( + detect_robust, + ) + from common.vision.detectors.valley import ( + detect_valley, + ) + return { + "current": detect_current, + "blackhat": detect_blackhat, + "dual_norm": detect_dual_norm, + "robust": detect_robust, + "valley": detect_valley, + } + + +_detector_registry: dict | None = None + + def detect_line( frame: np.ndarray, params: ImageParams | None = None, @@ -170,35 +212,18 @@ Returns: 線検出の結果 """ + global _detector_registry + if _detector_registry is None: + _detector_registry = _get_detector_registry() + if params is None: params = ImageParams() - method = params.method - if method == "blackhat": - from common.vision.detectors.blackhat import ( - detect_blackhat, - ) - return detect_blackhat(frame, params) - if method == "dual_norm": - from common.vision.detectors.dual_norm import ( - detect_dual_norm, - ) - return detect_dual_norm(frame, params) - if method == "robust": - from common.vision.detectors.robust import ( - detect_robust, - ) - return detect_robust(frame, params) - if method == "valley": - from common.vision.detectors.valley import ( - detect_valley, - ) - return detect_valley(frame, params) - - from common.vision.detectors.current import ( - detect_current, + detector = _detector_registry.get( + params.method, + _detector_registry["current"], ) - return detect_current(frame, params) + return detector(frame, params) def reset_valley_tracker() -> None: diff --git a/src/pi/main.py b/src/pi/main.py index 2c61ecc..06892c2 100644 --- a/src/pi/main.py +++ b/src/pi/main.py @@ -9,6 +9,7 @@ import dataclasses import time +from typing import Any from pi.camera.capture import CameraCapture from pi.comm.zmq_client import PiZmqClient @@ -229,6 +230,39 @@ zmq_client.stop() +def _safe_update_dataclass( + target: Any, + updates: dict[str, Any], +) -> None: + """dataclass のフィールドを型チェック付きで更新する + + Args: + target: 更新対象の dataclass インスタンス + updates: フィールド名と値の辞書 + """ + fields = { + f.name: f.type + for f in dataclasses.fields(target) + } + for key, value in updates.items(): + if key not in fields: + continue + expected = fields[key] + # int フィールドに float が来た場合は変換を許容 + if expected is int and isinstance(value, float): + value = int(value) + elif expected is float and isinstance(value, int): + value = float(value) + elif not isinstance(value, expected): + print( + f"Pi: パラメータ型エラー " + f"{key}: 期待={expected.__name__}, " + f"実際={type(value).__name__}" + ) + continue + setattr(target, key, value) + + def _apply_command( cmd: dict, pd_control: PdControl, @@ -253,42 +287,36 @@ for ctrl in [ pd_control, pursuit_control, ts_pd_control, ]: - current = ctrl.image_params - for key, value in ip.items(): - if hasattr(current, key): - setattr(current, key, value) + _safe_update_dataclass( + ctrl.image_params, ip, + ) # PD パラメータの更新 if "pd_params" in cmd: - sp = cmd["pd_params"] - current = pd_control.params - for key, value in sp.items(): - if hasattr(current, key): - setattr(current, key, value) + _safe_update_dataclass( + pd_control.params, cmd["pd_params"], + ) # Pursuit パラメータの更新 if "pursuit_params" in cmd: - pp = cmd["pursuit_params"] - current = pursuit_control.params - for key, value in pp.items(): - if hasattr(current, key): - setattr(current, key, value) + _safe_update_dataclass( + pursuit_control.params, + cmd["pursuit_params"], + ) # Theil-Sen PD パラメータの更新 if "steering_params" in cmd: - sp = cmd["steering_params"] - current = ts_pd_control.params - for key, value in sp.items(): - if hasattr(current, key): - setattr(current, key, value) + _safe_update_dataclass( + ts_pd_control.params, + cmd["steering_params"], + ) # 復帰パラメータの更新 if "recovery_params" in cmd: - rp = cmd["recovery_params"] - current = recovery.params - for key, value in rp.items(): - if hasattr(current, key): - setattr(current, key, value) + _safe_update_dataclass( + recovery.params, + cmd["recovery_params"], + ) # 十字路分類器の遅延読み込み if (