"""fitting モジュールのテスト"""
import numpy as np
import pytest
from common.vision.fitting import (
MIN_FIT_ROWS,
clean_and_fit,
ransac_polyfit,
theil_sen_fit,
)
class TestTheilSenFit:
"""theil_sen_fit のテスト"""
def test_linear_data(self) -> None:
"""直線データから正しい slope と intercept を復元できる"""
y = np.arange(10, dtype=float)
# x = 2.0 * y + 5.0
x = 2.0 * y + 5.0
slope, intercept = theil_sen_fit(y, x)
assert slope == pytest.approx(2.0, abs=1e-6)
assert intercept == pytest.approx(5.0, abs=1e-6)
def test_with_outlier(self) -> None:
"""外れ値が1つあっても正しい傾きを推定できる"""
y = np.arange(11, dtype=float)
x = 1.0 * y + 3.0
# 1点を大きく外す
x[5] = 100.0
slope, intercept = theil_sen_fit(y, x)
assert slope == pytest.approx(1.0, abs=0.2)
assert intercept == pytest.approx(3.0, abs=1.0)
def test_two_points(self) -> None:
"""2点でも傾きを計算できる"""
y = np.array([0.0, 10.0])
x = np.array([5.0, 15.0])
slope, intercept = theil_sen_fit(y, x)
assert slope == pytest.approx(1.0, abs=1e-6)
assert intercept == pytest.approx(5.0, abs=1e-6)
def test_single_point(self) -> None:
"""1点しかない場合は slope=0, intercept=median(x)"""
y = np.array([5.0])
x = np.array([10.0])
slope, intercept = theil_sen_fit(y, x)
assert slope == 0.0
assert intercept == pytest.approx(10.0)
def test_horizontal_line(self) -> None:
"""水平な線(slope=0)を正しく推定できる"""
y = np.arange(10, dtype=float)
x = np.full(10, 7.0)
slope, intercept = theil_sen_fit(y, x)
assert slope == pytest.approx(0.0, abs=1e-6)
assert intercept == pytest.approx(7.0, abs=1e-6)
class TestRansacPolyfit:
"""ransac_polyfit のテスト"""
def test_clean_quadratic(self) -> None:
"""ノイズなしの2次曲線を正しくフィットできる"""
ys = np.arange(20, dtype=float)
# x = 0.1 * y^2 - 2.0 * y + 10.0
xs = 0.1 * ys**2 - 2.0 * ys + 10.0
coeffs = ransac_polyfit(ys, xs, 2, 50, 5.0)
assert coeffs is not None
assert coeffs[0] == pytest.approx(0.1, abs=0.01)
assert coeffs[1] == pytest.approx(-2.0, abs=0.1)
def test_with_outliers(self) -> None:
"""30% の外れ値があっても正しくフィットできる"""
rng = np.random.default_rng(42)
ys = np.arange(30, dtype=float)
xs = 0.05 * ys**2 + 3.0
# 30% を大きく外す
outlier_idx = rng.choice(30, 9, replace=False)
xs[outlier_idx] += rng.uniform(50, 100, 9)
coeffs = ransac_polyfit(ys, xs, 2, 100, 5.0)
assert coeffs is not None
assert coeffs[0] == pytest.approx(0.05, abs=0.02)
def test_too_few_points(self) -> None:
"""点が不足している場合は None を返す"""
ys = np.array([1.0, 2.0])
xs = np.array([3.0, 4.0])
assert ransac_polyfit(ys, xs, 2, 50, 5.0) is None
class TestCleanAndFit:
"""clean_and_fit のテスト"""
def test_basic_fit(self) -> None:
"""正常なデータでフィッティングできる"""
ys = np.arange(15, dtype=float)
xs = 0.5 * ys + 10.0
coeffs = clean_and_fit(
ys, xs,
median_ksize=0,
neighbor_thresh=0.0,
)
assert coeffs is not None
# 2次の係数はほぼ 0,1次はほぼ 0.5
assert coeffs[-2] == pytest.approx(0.5, abs=0.1)
def test_too_few_points(self) -> None:
"""MIN_FIT_ROWS 未満のデータは None を返す"""
n = MIN_FIT_ROWS - 1
ys = np.arange(n, dtype=float)
xs = np.arange(n, dtype=float)
assert clean_and_fit(
ys, xs, median_ksize=0, neighbor_thresh=0.0,
) is None
def test_neighbor_filter_removes_outlier(self) -> None:
"""近傍フィルタが外れ値を除去できる"""
ys = np.arange(20, dtype=float)
xs = np.full(20, 15.0)
xs[10] = 100.0 # 大きな外れ値
coeffs = clean_and_fit(
ys, xs,
median_ksize=0,
neighbor_thresh=5.0,
)
assert coeffs is not None
# 外れ値除去後,x ≈ 15.0 の直線になる
poly = np.poly1d(coeffs)
assert poly(10) == pytest.approx(15.0, abs=2.0)
def test_residual_removal(self) -> None:
"""残差除去が外れ値を取り除ける"""
ys = np.arange(20, dtype=float)
xs = 1.0 * ys + 5.0
xs[3] = 80.0
xs[17] = -50.0
coeffs = clean_and_fit(
ys, xs,
median_ksize=0,
neighbor_thresh=0.0,
residual_thresh=10.0,
)
assert coeffs is not None
poly = np.poly1d(coeffs)
assert poly(10) == pytest.approx(15.0, abs=3.0)