"""
pursuit_control
2点パシュートによる操舵量計算モジュール
行中心点に Theil-Sen 直線近似を適用し，外れ値に強い操舵量を算出する
"""

from dataclasses import dataclass

import numpy as np

from common import config
from common.steering.base import SteeringBase, SteeringOutput
from common.vision.fitting import theil_sen_fit
from common.vision.line_detector import (
    ImageParams,
    LineDetectResult,
)


@dataclass
class PursuitParams:
    """2点パシュート制御のパラメータ

    Attributes:
        near_ratio: 近い目標点の位置（0.0=上端，1.0=下端）
        far_ratio: 遠い目標点の位置（0.0=上端，1.0=下端）
        k_near: 近い目標点の操舵ゲイン
        k_far: 遠い目標点の操舵ゲイン
        max_steer_rate: 1フレームあたりの最大操舵変化量
        max_throttle: 直線での最大速度
        speed_k: カーブ減速係数（2点の差に対する係数）
    """
    near_ratio: float = 0.8
    far_ratio: float = 0.3
    k_near: float = 0.5
    k_far: float = 0.3
    max_steer_rate: float = 0.1
    max_throttle: float = 0.4
    speed_k: float = 2.0


class PursuitControl(SteeringBase):
    """2点パシュートによる操舵量計算クラス

    行中心点から Theil-Sen 直線近似を行い，
    直線上の近い点と遠い点の偏差から操舵量を計算する
    """

    def __init__(
        self,
        params: PursuitParams | None = None,
        image_params: ImageParams | None = None,
    ) -> None:
        super().__init__(image_params)
        self.params: PursuitParams = (
            params or PursuitParams()
        )

    def _compute_from_result(
        self, result: LineDetectResult,
    ) -> SteeringOutput:
        """2点パシュートで操舵量を計算する

        Args:
            result: 線検出の結果

        Returns:
            計算された操舵量
        """
        if not result.detected or result.row_centers is None:
            return SteeringOutput(
                throttle=0.0, steer=0.0,
            )

        p = self.params
        centers = result.row_centers

        # 有効な点（NaN でない行）を抽出
        valid = ~np.isnan(centers)
        ys = np.where(valid)[0].astype(float)
        xs = centers[valid]

        if len(ys) < 2:
            return SteeringOutput(
                throttle=0.0, steer=0.0,
            )

        # Theil-Sen 直線近似
        slope, intercept = theil_sen_fit(ys, xs)

        center_x = config.FRAME_WIDTH / 2.0
        h = len(centers)

        # 直線上の 2 点の x 座標を取得
        near_y = h * p.near_ratio
        far_y = h * p.far_ratio
        near_x = slope * near_y + intercept
        far_x = slope * far_y + intercept

        # 各点の偏差（正: 線が左にある → 右に曲がる）
        near_err = (center_x - near_x) / center_x
        far_err = (center_x - far_x) / center_x

        # 操舵量
        steer = p.k_near * near_err + p.k_far * far_err
        steer = max(-1.0, min(1.0, steer))

        # 速度制御（2点の x 差でカーブ度合いを判定）
        curve = abs(near_x - far_x) / center_x
        throttle = p.max_throttle - p.speed_k * curve
        throttle = max(0.0, throttle)

        return SteeringOutput(
            throttle=throttle, steer=steer,
        )

    def _max_steer_rate(self) -> float:
        return self.params.max_steer_rate
