Newer
Older
RobotCar / src / pc / steering / pursuit_control.py
"""
pursuit_control
2点パシュートによる操舵量計算モジュール
行中心点に Theil-Sen 直線近似を適用し,外れ値に強い操舵量を算出する
"""

from dataclasses import dataclass

import numpy as np

from common import config
from pc.steering.base import SteeringBase, SteeringOutput
from pc.vision.fitting import theil_sen_fit
from pc.vision.line_detector import (
    ImageParams,
    detect_line,
    reset_valley_tracker,
)


@dataclass
class PursuitParams:
    """2点パシュート制御のパラメータ

    Attributes:
        near_ratio: 近い目標点の位置(0.0=上端,1.0=下端)
        far_ratio: 遠い目標点の位置(0.0=上端,1.0=下端)
        k_near: 近い目標点の操舵ゲイン
        k_far: 遠い目標点の操舵ゲイン
        max_steer_rate: 1フレームあたりの最大操舵変化量
        max_throttle: 直線での最大速度
        speed_k: カーブ減速係数(2点の差に対する係数)
    """
    near_ratio: float = 0.8
    far_ratio: float = 0.3
    k_near: float = 0.5
    k_far: float = 0.3
    max_steer_rate: float = 0.1
    max_throttle: float = 0.4
    speed_k: float = 2.0


class PursuitControl(SteeringBase):
    """2点パシュートによる操舵量計算クラス

    行中心点から Theil-Sen 直線近似を行い,
    直線上の近い点と遠い点の偏差から操舵量を計算する
    """

    def __init__(
        self,
        params: PursuitParams | None = None,
        image_params: ImageParams | None = None,
    ) -> None:
        self.params: PursuitParams = (
            params or PursuitParams()
        )
        self.image_params: ImageParams = (
            image_params or ImageParams()
        )
        self._prev_steer: float = 0.0
        self._last_result = None
        self._last_pursuit_points: (
            tuple[tuple[float, float], tuple[float, float]]
            | None
        ) = None

    def compute(
        self, frame: np.ndarray,
    ) -> SteeringOutput:
        """カメラ画像から2点パシュートで操舵量を計算する

        Args:
            frame: グレースケールのカメラ画像

        Returns:
            計算された操舵量
        """
        p = self.params

        # 線検出
        result = detect_line(frame, self.image_params)
        self._last_result = result

        if not result.detected or result.row_centers is None:
            self._last_pursuit_points = None
            return SteeringOutput(
                throttle=0.0, steer=0.0,
            )

        centers = result.row_centers

        # 有効な点(NaN でない行)を抽出
        valid = ~np.isnan(centers)
        ys = np.where(valid)[0].astype(float)
        xs = centers[valid]

        if len(ys) < 2:
            self._last_pursuit_points = None
            return SteeringOutput(
                throttle=0.0, steer=0.0,
            )

        # Theil-Sen 直線近似
        slope, intercept = theil_sen_fit(ys, xs)

        center_x = config.FRAME_WIDTH / 2.0
        h = len(centers)

        # 直線上の 2 点の x 座標を取得
        near_y = h * p.near_ratio
        far_y = h * p.far_ratio
        near_x = slope * near_y + intercept
        far_x = slope * far_y + intercept

        # 目標点を保持(デバッグ表示用)
        self._last_pursuit_points = (
            (near_x, near_y),
            (far_x, far_y),
        )

        # 各点の偏差(正: 線が左にある → 右に曲がる)
        near_err = (center_x - near_x) / center_x
        far_err = (center_x - far_x) / center_x

        # 操舵量
        steer = p.k_near * near_err + p.k_far * far_err
        steer = max(-1.0, min(1.0, steer))

        # レートリミッター
        delta = steer - self._prev_steer
        max_delta = p.max_steer_rate
        delta = max(-max_delta, min(max_delta, delta))
        steer = self._prev_steer + delta

        # 速度制御(2点の x 差でカーブ度合いを判定)
        curve = abs(near_x - far_x) / center_x
        throttle = p.max_throttle - p.speed_k * curve
        throttle = max(0.0, throttle)

        self._prev_steer = steer

        return SteeringOutput(
            throttle=throttle, steer=steer,
        )

    def reset(self) -> None:
        """内部状態をリセットする"""
        self._prev_steer = 0.0
        self._last_result = None
        self._last_pursuit_points = None
        reset_valley_tracker()

    @property
    def last_detect_result(self):
        """直近の線検出結果を取得する"""
        return self._last_result

    @property
    def last_pursuit_points(
        self,
    ) -> (
        tuple[tuple[float, float], tuple[float, float]]
        | None
    ):
        """直近の2点パシュート目標点を取得する

        Returns:
            ((near_x, near_y), (far_x, far_y)) または None
        """
        return self._last_pursuit_points