"""操舵量計算モジュールのテスト"""

import numpy as np
import pytest

from common import config
from common.steering.base import SteeringOutput
from common.steering.pd_control import PdControl, PdParams
from common.steering.pursuit_control import (
    PursuitControl,
    PursuitParams,
)
from common.steering.ts_pd_control import (
    TsPdControl,
    TsPdParams,
)
from common.vision.line_detector import ImageParams


@pytest.fixture()
def _center_line_params() -> ImageParams:
    """テスト画像用に調整した検出パラメータ"""
    return ImageParams(
        method="current",
        clahe_grid=2, blur_size=3,
        open_size=1, close_width=3,
        close_height=1,
    )


class TestSteeringBase:
    """SteeringBase の共通ロジックのテスト"""

    def test_compute_returns_steering_output(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """compute は SteeringOutput を返す"""
        ctrl = PdControl(
            image_params=_center_line_params,
        )
        output = ctrl.compute(straight_line_image)
        assert isinstance(output, SteeringOutput)

    def test_steer_within_range(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """steer は -1.0 ~ +1.0 の範囲内"""
        ctrl = PdControl(
            image_params=_center_line_params,
        )
        for _ in range(10):
            output = ctrl.compute(straight_line_image)
            assert -1.0 <= output.steer <= 1.0

    def test_throttle_non_negative(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """throttle は 0 以上"""
        ctrl = PdControl(
            image_params=_center_line_params,
        )
        output = ctrl.compute(straight_line_image)
        assert output.throttle >= 0.0

    def test_rate_limiter_clamps_steer_change(
        self,
        _center_line_params: ImageParams,
    ) -> None:
        """レートリミッターが操舵変化量を制限する"""
        max_rate = 0.05
        params = PdParams(
            kp=5.0, kh=5.0, kd=0.0,
            max_steer_rate=max_rate,
        )
        ctrl = PdControl(
            params=params,
            image_params=_center_line_params,
        )

        # 大きく左にオフセットした画像を作成
        h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH
        img = np.full((h, w), 200, dtype=np.uint8)
        img[:, 2:5] = 30  # 左端に線

        output = ctrl.compute(img)
        # 初回は prev_steer=0 なので変化量が制限される
        assert abs(output.steer) <= max_rate + 1e-9

    def test_no_line_returns_zero(
        self,
        blank_image: np.ndarray,
    ) -> None:
        """線が検出できない場合は throttle=0, steer=0"""
        ctrl = PdControl()
        output = ctrl.compute(blank_image)
        assert output.throttle == 0.0
        assert output.steer == 0.0

    def test_last_detect_result_updated(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """compute 後に last_detect_result が更新される"""
        ctrl = PdControl(
            image_params=_center_line_params,
        )
        assert ctrl.last_detect_result is None
        ctrl.compute(straight_line_image)
        assert ctrl.last_detect_result is not None

    def test_reset_clears_state(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """reset 後に内部状態がクリアされる"""
        ctrl = PdControl(
            image_params=_center_line_params,
        )
        ctrl.compute(straight_line_image)
        ctrl.reset()
        assert ctrl.last_detect_result is None
        assert ctrl._prev_steer == 0.0


class TestPdControl:
    """PD 制御のテスト"""

    def test_center_line_small_steer(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """中央の線に対して操舵量が小さい"""
        ctrl = PdControl(
            image_params=_center_line_params,
        )
        output = ctrl.compute(straight_line_image)
        assert abs(output.steer) < 0.3

    def test_left_line_steers_positive(
        self,
        _center_line_params: ImageParams,
    ) -> None:
        """左にある線に対して正方向に操舵する"""
        h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH
        img = np.full((h, w), 200, dtype=np.uint8)
        img[:, 5:8] = 30  # 左寄りの線

        ctrl = PdControl(
            params=PdParams(
                kp=1.0, kh=0.0, kd=0.0,
                max_steer_rate=1.0,
            ),
            image_params=_center_line_params,
        )
        output = ctrl.compute(img)
        assert output.steer > 0.0

    def test_speed_decreases_with_curvature(
        self,
        _center_line_params: ImageParams,
    ) -> None:
        """曲率が大きいほど速度が下がる"""
        params = PdParams(max_throttle=0.5, speed_k=0.3)
        ctrl = PdControl(
            params=params,
            image_params=_center_line_params,
        )
        # 曲がった線を作成
        h, w = config.FRAME_HEIGHT, config.FRAME_WIDTH
        img = np.full((h, w), 200, dtype=np.uint8)
        for y in range(h):
            x = int(w / 2 + 5 * (y / h - 0.5) ** 2 * w)
            x = max(0, min(w - 1, x))
            img[y, max(0, x - 1):min(w, x + 2)] = 30

        output = ctrl.compute(img)
        assert output.throttle <= params.max_throttle


class TestPursuitControl:
    """2点パシュート制御のテスト"""

    def test_center_line_small_steer(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """中央の線に対して操舵量が小さい"""
        ctrl = PursuitControl(
            image_params=_center_line_params,
        )
        output = ctrl.compute(straight_line_image)
        assert abs(output.steer) < 0.3

    def test_no_line_returns_zero(
        self,
        blank_image: np.ndarray,
    ) -> None:
        """線が検出できない場合は停止"""
        ctrl = PursuitControl()
        output = ctrl.compute(blank_image)
        assert output.throttle == 0.0
        assert output.steer == 0.0


class TestTsPdControl:
    """Theil-Sen PD 制御のテスト"""

    def test_center_line_small_steer(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """中央の線に対して操舵量が小さい"""
        ctrl = TsPdControl(
            image_params=_center_line_params,
        )
        output = ctrl.compute(straight_line_image)
        assert abs(output.steer) < 0.3

    def test_no_line_returns_zero(
        self,
        blank_image: np.ndarray,
    ) -> None:
        """線が検出できない場合は停止"""
        ctrl = TsPdControl()
        output = ctrl.compute(blank_image)
        assert output.throttle == 0.0
        assert output.steer == 0.0

    def test_reset_clears_derivative_state(
        self,
        straight_line_image: np.ndarray,
        _center_line_params: ImageParams,
    ) -> None:
        """reset で微分項の状態がクリアされる"""
        ctrl = TsPdControl(
            image_params=_center_line_params,
        )
        ctrl.compute(straight_line_image)
        ctrl.reset()
        assert ctrl._prev_error == 0.0
        assert ctrl._prev_time == 0.0


class TestSafeUpdateDataclass:
    """_safe_update_dataclass のテスト

    pi.main は picamera2 に依存するため直接 import できない．
    同等のロジックを common モジュールの dataclass で検証する
    """

    @staticmethod
    def _safe_update_dataclass(
        target: object,
        updates: dict,
    ) -> None:
        """pi.main._safe_update_dataclass と同等のロジック"""
        import dataclasses
        field_names = {
            f.name for f in dataclasses.fields(target)
        }
        for key, value in updates.items():
            if key not in field_names:
                continue
            current = getattr(target, key)
            expected = type(current)
            if expected is int and isinstance(
                value, float,
            ):
                value = int(value)
            elif expected is float and isinstance(
                value, int,
            ):
                value = float(value)
            elif not isinstance(value, expected):
                continue
            setattr(target, key, value)

    def test_updates_valid_fields(self) -> None:
        """正しい型のフィールドを更新できる"""
        params = PdParams()
        self._safe_update_dataclass(
            params, {"kp": 2.0},
        )
        assert params.kp == 2.0

    def test_ignores_unknown_fields(self) -> None:
        """存在しないフィールドは無視する"""
        params = PdParams()
        original_kp = params.kp
        self._safe_update_dataclass(
            params, {"unknown_field": 999},
        )
        assert params.kp == original_kp

    def test_rejects_wrong_type(self) -> None:
        """型が一致しない場合は更新しない"""
        params = PdParams()
        original_kp = params.kp
        self._safe_update_dataclass(
            params, {"kp": "not_a_number"},
        )
        assert params.kp == original_kp

    def test_int_to_float_conversion(self) -> None:
        """int を float フィールドに渡すと変換される"""
        params = PdParams()
        self._safe_update_dataclass(
            params, {"kp": 2},
        )
        assert params.kp == 2.0
        assert isinstance(params.kp, float)

    def test_float_to_int_conversion(self) -> None:
        """float を int フィールドに渡すと変換される"""
        params = ImageParams()
        self._safe_update_dataclass(
            params, {"binary_thresh": 100.0},
        )
        assert params.binary_thresh == 100
        assert isinstance(params.binary_thresh, int)
