Newer
Older
RobotCar / tests / test_fitting.py
"""fitting モジュールのテスト"""

import numpy as np
import pytest

from pc.vision.fitting import (
    MIN_FIT_ROWS,
    clean_and_fit,
    ransac_polyfit,
    theil_sen_fit,
)


class TestTheilSenFit:
    """theil_sen_fit のテスト"""

    def test_linear_data(self) -> None:
        """直線データから正しい slope と intercept を復元できる"""
        y = np.arange(10, dtype=float)
        # x = 2.0 * y + 5.0
        x = 2.0 * y + 5.0
        slope, intercept = theil_sen_fit(y, x)
        assert slope == pytest.approx(2.0, abs=1e-6)
        assert intercept == pytest.approx(5.0, abs=1e-6)

    def test_with_outlier(self) -> None:
        """外れ値が1つあっても正しい傾きを推定できる"""
        y = np.arange(11, dtype=float)
        x = 1.0 * y + 3.0
        # 1点を大きく外す
        x[5] = 100.0
        slope, intercept = theil_sen_fit(y, x)
        assert slope == pytest.approx(1.0, abs=0.2)
        assert intercept == pytest.approx(3.0, abs=1.0)

    def test_two_points(self) -> None:
        """2点でも傾きを計算できる"""
        y = np.array([0.0, 10.0])
        x = np.array([5.0, 15.0])
        slope, intercept = theil_sen_fit(y, x)
        assert slope == pytest.approx(1.0, abs=1e-6)
        assert intercept == pytest.approx(5.0, abs=1e-6)

    def test_single_point(self) -> None:
        """1点しかない場合は slope=0, intercept=median(x)"""
        y = np.array([5.0])
        x = np.array([10.0])
        slope, intercept = theil_sen_fit(y, x)
        assert slope == 0.0
        assert intercept == pytest.approx(10.0)

    def test_horizontal_line(self) -> None:
        """水平な線(slope=0)を正しく推定できる"""
        y = np.arange(10, dtype=float)
        x = np.full(10, 7.0)
        slope, intercept = theil_sen_fit(y, x)
        assert slope == pytest.approx(0.0, abs=1e-6)
        assert intercept == pytest.approx(7.0, abs=1e-6)


class TestRansacPolyfit:
    """ransac_polyfit のテスト"""

    def test_clean_quadratic(self) -> None:
        """ノイズなしの2次曲線を正しくフィットできる"""
        ys = np.arange(20, dtype=float)
        # x = 0.1 * y^2 - 2.0 * y + 10.0
        xs = 0.1 * ys**2 - 2.0 * ys + 10.0
        coeffs = ransac_polyfit(ys, xs, 2, 50, 5.0)
        assert coeffs is not None
        assert coeffs[0] == pytest.approx(0.1, abs=0.01)
        assert coeffs[1] == pytest.approx(-2.0, abs=0.1)

    def test_with_outliers(self) -> None:
        """30% の外れ値があっても正しくフィットできる"""
        rng = np.random.default_rng(42)
        ys = np.arange(30, dtype=float)
        xs = 0.05 * ys**2 + 3.0
        # 30% を大きく外す
        outlier_idx = rng.choice(30, 9, replace=False)
        xs[outlier_idx] += rng.uniform(50, 100, 9)
        coeffs = ransac_polyfit(ys, xs, 2, 100, 5.0)
        assert coeffs is not None
        assert coeffs[0] == pytest.approx(0.05, abs=0.02)

    def test_too_few_points(self) -> None:
        """点が不足している場合は None を返す"""
        ys = np.array([1.0, 2.0])
        xs = np.array([3.0, 4.0])
        assert ransac_polyfit(ys, xs, 2, 50, 5.0) is None


class TestCleanAndFit:
    """clean_and_fit のテスト"""

    def test_basic_fit(self) -> None:
        """正常なデータでフィッティングできる"""
        ys = np.arange(15, dtype=float)
        xs = 0.5 * ys + 10.0
        coeffs = clean_and_fit(
            ys, xs,
            median_ksize=0,
            neighbor_thresh=0.0,
        )
        assert coeffs is not None
        # 2次の係数はほぼ 0,1次はほぼ 0.5
        assert coeffs[-2] == pytest.approx(0.5, abs=0.1)

    def test_too_few_points(self) -> None:
        """MIN_FIT_ROWS 未満のデータは None を返す"""
        n = MIN_FIT_ROWS - 1
        ys = np.arange(n, dtype=float)
        xs = np.arange(n, dtype=float)
        assert clean_and_fit(
            ys, xs, median_ksize=0, neighbor_thresh=0.0,
        ) is None

    def test_neighbor_filter_removes_outlier(self) -> None:
        """近傍フィルタが外れ値を除去できる"""
        ys = np.arange(20, dtype=float)
        xs = np.full(20, 15.0)
        xs[10] = 100.0  # 大きな外れ値
        coeffs = clean_and_fit(
            ys, xs,
            median_ksize=0,
            neighbor_thresh=5.0,
        )
        assert coeffs is not None
        # 外れ値除去後,x ≈ 15.0 の直線になる
        poly = np.poly1d(coeffs)
        assert poly(10) == pytest.approx(15.0, abs=2.0)

    def test_residual_removal(self) -> None:
        """残差除去が外れ値を取り除ける"""
        ys = np.arange(20, dtype=float)
        xs = 1.0 * ys + 5.0
        xs[3] = 80.0
        xs[17] = -50.0
        coeffs = clean_and_fit(
            ys, xs,
            median_ksize=0,
            neighbor_thresh=0.0,
            residual_thresh=10.0,
        )
        assert coeffs is not None
        poly = np.poly1d(coeffs)
        assert poly(10) == pytest.approx(15.0, abs=3.0)