Newer
Older
RobotCar / src / pc / vision / line_detector.py
"""
line_detector
カメラ画像から黒線の位置を検出するモジュール
複数の検出手法を切り替えて使用できる
"""

from dataclasses import dataclass

import cv2
import numpy as np

from common import config

# 検出領域の y 範囲(画像全体)
DETECT_Y_START: int = 0
DETECT_Y_END: int = config.FRAME_HEIGHT

# フィッティングに必要な最小数
MIN_FIT_PIXELS: int = 50
MIN_FIT_ROWS: int = 10

# 検出手法の定義(キー: 識別子,値: 表示名)
DETECT_METHODS: dict[str, str] = {
    "current": "現行(CLAHE + 固定閾値)",
    "blackhat": "案A(Black-hat 中心)",
    "dual_norm": "案B(二重正規化)",
    "robust": "案C(最高ロバスト)",
}


@dataclass
class ImageParams:
    """画像処理パラメータ

    Attributes:
        method: 検出手法の識別子
        clahe_clip: CLAHE のコントラスト増幅上限
        clahe_grid: CLAHE の局所領域分割数
        blur_size: ガウシアンブラーのカーネルサイズ(奇数)
        binary_thresh: 二値化の閾値
        open_size: オープニングのカーネルサイズ
        close_width: クロージングの横幅
        close_height: クロージングの高さ
        blackhat_ksize: Black-hat のカーネルサイズ
        bg_blur_ksize: 背景除算のブラーカーネルサイズ
        adaptive_block: 適応的閾値のブロックサイズ
        adaptive_c: 適応的閾値の定数 C
        iso_close_size: 等方クロージングのカーネルサイズ
        dist_thresh: 距離変換の閾値
        min_line_width: 行ごと中心抽出の最小線幅
        ransac_thresh: RANSAC の外れ値判定閾値
        ransac_iter: RANSAC の反復回数
    """

    # 検出手法
    method: str = "current"

    # 現行手法パラメータ
    clahe_clip: float = 2.0
    clahe_grid: int = 8
    blur_size: int = 5
    binary_thresh: int = 80
    open_size: int = 5
    close_width: int = 25
    close_height: int = 3

    # 案A/C: Black-hat
    blackhat_ksize: int = 45

    # 案B: 背景除算
    bg_blur_ksize: int = 101

    # 案B/C: 適応的閾値
    adaptive_block: int = 51
    adaptive_c: int = 10

    # 案A/B/C: 後処理
    iso_close_size: int = 15
    dist_thresh: float = 3.0
    min_line_width: int = 3

    # 案C: RANSAC
    ransac_thresh: float = 5.0
    ransac_iter: int = 50


@dataclass
class LineDetectResult:
    """線検出の結果を格納するデータクラス

    Attributes:
        detected: 線が検出できたか
        position_error: 画像下端での位置偏差(-1.0~+1.0)
        heading: 線の傾き(dx/dy,画像下端での値)
        curvature: 線の曲率(d²x/dy²)
        poly_coeffs: 多項式の係数(描画用,未検出時は None)
        binary_image: 二値化後の画像(デバッグ用)
    """

    detected: bool
    position_error: float
    heading: float
    curvature: float
    poly_coeffs: np.ndarray | None
    binary_image: np.ndarray | None


def detect_line(
    frame: np.ndarray,
    params: ImageParams | None = None,
) -> LineDetectResult:
    """画像から黒線の位置を検出する

    params.method に応じて検出手法を切り替える

    Args:
        frame: BGR 形式のカメラ画像
        params: 画像処理パラメータ(None でデフォルト)

    Returns:
        線検出の結果
    """
    if params is None:
        params = ImageParams()

    method = params.method
    if method == "blackhat":
        return _detect_blackhat(frame, params)
    if method == "dual_norm":
        return _detect_dual_norm(frame, params)
    if method == "robust":
        return _detect_robust(frame, params)
    return _detect_current(frame, params)


# ── 検出手法の実装 ─────────────────────────────


def _detect_current(
    frame: np.ndarray, params: ImageParams,
) -> LineDetectResult:
    """現行手法: CLAHE + 固定閾値 + 全ピクセルフィッティング"""
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # CLAHE でコントラスト強調
    clahe = cv2.createCLAHE(
        clipLimit=params.clahe_clip,
        tileGridSize=(
            params.clahe_grid,
            params.clahe_grid,
        ),
    )
    enhanced = clahe.apply(gray)

    # ガウシアンブラー
    blur_k = params.blur_size | 1
    blurred = cv2.GaussianBlur(
        enhanced, (blur_k, blur_k), 0,
    )

    # 固定閾値で二値化(黒線を白に反転)
    _, binary = cv2.threshold(
        blurred, params.binary_thresh, 255,
        cv2.THRESH_BINARY_INV,
    )

    # オープニング(孤立ノイズ除去)
    if params.open_size >= 3:
        open_k = params.open_size | 1
        open_kernel = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE, (open_k, open_k),
        )
        binary = cv2.morphologyEx(
            binary, cv2.MORPH_OPEN, open_kernel,
        )

    # 横方向クロージング(途切れ補間)
    if params.close_width >= 3:
        close_h = max(params.close_height | 1, 1)
        close_kernel = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE,
            (params.close_width, close_h),
        )
        binary = cv2.morphologyEx(
            binary, cv2.MORPH_CLOSE, close_kernel,
        )

    # 全ピクセルフィッティング(従来方式)
    return _fit_all_pixels(binary)


def _detect_blackhat(
    frame: np.ndarray, params: ImageParams,
) -> LineDetectResult:
    """案A: Black-hat 中心型

    Black-hat 変換で背景より暗い構造を直接抽出し,
    固定閾値 + 距離変換 + 行ごと中心抽出で検出する
    """
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # Black-hat 変換(暗い構造の抽出)
    bh_k = params.blackhat_ksize | 1
    bh_kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (bh_k, bh_k),
    )
    blackhat = cv2.morphologyEx(
        gray, cv2.MORPH_BLACKHAT, bh_kernel,
    )

    # ガウシアンブラー
    blur_k = params.blur_size | 1
    blurred = cv2.GaussianBlur(
        blackhat, (blur_k, blur_k), 0,
    )

    # 固定閾値(Black-hat 後は線が白)
    _, binary = cv2.threshold(
        blurred, params.binary_thresh, 255,
        cv2.THRESH_BINARY,
    )

    # 等方クロージング + 距離変換マスク
    binary = _apply_iso_closing(
        binary, params.iso_close_size,
    )
    binary = _apply_dist_mask(
        binary, params.dist_thresh,
    )

    # 行ごと中心抽出 + フィッティング
    return _fit_row_centers(
        binary, params.min_line_width,
    )


def _detect_dual_norm(
    frame: np.ndarray, params: ImageParams,
) -> LineDetectResult:
    """案B: 二重正規化型

    背景除算で照明勾配を除去し,
    適応的閾値で局所ムラにも対応する二重防壁構成
    """
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # 背景除算正規化
    bg_k = params.bg_blur_ksize | 1
    bg = cv2.GaussianBlur(
        gray, (bg_k, bg_k), 0,
    )
    normalized = (
        gray.astype(np.float32) * 255.0
        / (bg.astype(np.float32) + 1.0)
    )
    normalized = np.clip(
        normalized, 0, 255,
    ).astype(np.uint8)

    # 適応的閾値(ガウシアン,BINARY_INV)
    block = max(params.adaptive_block | 1, 3)
    binary = cv2.adaptiveThreshold(
        normalized, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        block, params.adaptive_c,
    )

    # 等方クロージング + 距離変換マスク
    binary = _apply_iso_closing(
        binary, params.iso_close_size,
    )
    binary = _apply_dist_mask(
        binary, params.dist_thresh,
    )

    # 行ごと中心抽出 + フィッティング
    return _fit_row_centers(
        binary, params.min_line_width,
    )


def _detect_robust(
    frame: np.ndarray, params: ImageParams,
) -> LineDetectResult:
    """案C: 最高ロバスト型

    Black-hat + 適応的閾値の二重正規化に加え,
    RANSAC で外れ値を除去する最もロバストな構成
    """
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # Black-hat 変換
    bh_k = params.blackhat_ksize | 1
    bh_kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (bh_k, bh_k),
    )
    blackhat = cv2.morphologyEx(
        gray, cv2.MORPH_BLACKHAT, bh_kernel,
    )

    # 適応的閾値(BINARY: Black-hat 後は線が白)
    block = max(params.adaptive_block | 1, 3)
    binary = cv2.adaptiveThreshold(
        blackhat, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY,
        block, -params.adaptive_c,
    )

    # 等方クロージング + 距離変換マスク
    binary = _apply_iso_closing(
        binary, params.iso_close_size,
    )
    binary = _apply_dist_mask(
        binary, params.dist_thresh,
    )

    # 行ごと中央値抽出 + RANSAC フィッティング
    return _fit_row_centers(
        binary, params.min_line_width,
        use_median=True,
        ransac_thresh=params.ransac_thresh,
        ransac_iter=params.ransac_iter,
    )


# ── 共通処理 ───────────────────────────────────


def _apply_iso_closing(
    binary: np.ndarray, size: int,
) -> np.ndarray:
    """等方クロージングで穴を埋める

    Args:
        binary: 二値画像
        size: カーネルサイズ

    Returns:
        クロージング後の二値画像
    """
    if size < 3:
        return binary
    k = size | 1
    kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (k, k),
    )
    return cv2.morphologyEx(
        binary, cv2.MORPH_CLOSE, kernel,
    )


def _apply_dist_mask(
    binary: np.ndarray, thresh: float,
) -> np.ndarray:
    """距離変換で中心部のみを残す

    Args:
        binary: 二値画像
        thresh: 距離の閾値(ピクセル)

    Returns:
        中心部のみの二値画像
    """
    if thresh <= 0:
        return binary
    dist = cv2.distanceTransform(
        binary, cv2.DIST_L2, 5,
    )
    _, mask = cv2.threshold(
        dist, thresh, 255, cv2.THRESH_BINARY,
    )
    return mask.astype(np.uint8)


def _fit_all_pixels(
    binary: np.ndarray,
) -> LineDetectResult:
    """全白ピクセルに多項式をフィッティングする

    従来方式.全ピクセルを等しく扱うため,
    陰で幅が広がった行がフィッティングを支配する弱点がある

    Args:
        binary: 二値画像

    Returns:
        線検出の結果
    """
    region = binary[DETECT_Y_START:DETECT_Y_END, :]
    ys_local, xs = np.where(region > 0)

    if len(xs) < MIN_FIT_PIXELS:
        return _no_detection(binary)

    ys = ys_local + DETECT_Y_START
    coeffs = np.polyfit(ys, xs, 2)
    return _build_result(coeffs, binary)


def _fit_row_centers(
    binary: np.ndarray,
    min_width: int,
    use_median: bool = False,
    ransac_thresh: float = 0.0,
    ransac_iter: int = 0,
) -> LineDetectResult:
    """行ごとの中心点に多項式をフィッティングする

    各行の白ピクセルの中心(平均または中央値)を1点抽出し,
    中心点列に対してフィッティングする.
    幅の変動に強く,各行が等しく寄与する

    Args:
        binary: 二値画像
        min_width: 線として認識する最小ピクセル数
        use_median: True の場合は中央値を使用
        ransac_thresh: RANSAC 閾値(0 以下で無効)
        ransac_iter: RANSAC 反復回数

    Returns:
        線検出の結果
    """
    region = binary[DETECT_Y_START:DETECT_Y_END, :]
    centers_y: list[float] = []
    centers_x: list[float] = []

    for y_local in range(region.shape[0]):
        xs = np.where(region[y_local] > 0)[0]
        if len(xs) < min_width:
            continue
        y = float(y_local + DETECT_Y_START)
        centers_y.append(y)
        if use_median:
            centers_x.append(float(np.median(xs)))
        else:
            centers_x.append(float(np.mean(xs)))

    if len(centers_y) < MIN_FIT_ROWS:
        return _no_detection(binary)

    cy = np.array(centers_y)
    cx = np.array(centers_x)

    if ransac_thresh > 0 and ransac_iter > 0:
        coeffs = _ransac_polyfit(
            cy, cx, 2, ransac_iter, ransac_thresh,
        )
        if coeffs is None:
            return _no_detection(binary)
    else:
        coeffs = np.polyfit(cy, cx, 2)

    return _build_result(coeffs, binary)


def _ransac_polyfit(
    ys: np.ndarray, xs: np.ndarray,
    degree: int, n_iter: int, thresh: float,
) -> np.ndarray | None:
    """RANSAC で外れ値を除去して多項式フィッティング

    Args:
        ys: y 座標配列
        xs: x 座標配列
        degree: 多項式の次数
        n_iter: 反復回数
        thresh: 外れ値判定閾値(ピクセル)

    Returns:
        多項式係数(フィッティング失敗時は None)
    """
    n = len(ys)
    sample_size = degree + 1
    if n < sample_size:
        return None

    best_coeffs: np.ndarray | None = None
    best_inliers = 0
    rng = np.random.default_rng()

    for _ in range(n_iter):
        idx = rng.choice(n, sample_size, replace=False)
        coeffs = np.polyfit(ys[idx], xs[idx], degree)
        poly = np.poly1d(coeffs)
        residuals = np.abs(xs - poly(ys))
        n_inliers = int(np.sum(residuals < thresh))
        if n_inliers > best_inliers:
            best_inliers = n_inliers
            best_coeffs = coeffs

    # インライアで再フィッティング
    if best_coeffs is not None:
        poly = np.poly1d(best_coeffs)
        inlier_mask = np.abs(xs - poly(ys)) < thresh
        if np.sum(inlier_mask) >= sample_size:
            best_coeffs = np.polyfit(
                ys[inlier_mask],
                xs[inlier_mask],
                degree,
            )

    return best_coeffs


def _no_detection(
    binary: np.ndarray,
) -> LineDetectResult:
    """未検出の結果を返す"""
    return LineDetectResult(
        detected=False,
        position_error=0.0,
        heading=0.0,
        curvature=0.0,
        poly_coeffs=None,
        binary_image=binary,
    )


def _build_result(
    coeffs: np.ndarray,
    binary: np.ndarray,
) -> LineDetectResult:
    """多項式係数から LineDetectResult を構築する"""
    poly = np.poly1d(coeffs)
    center_x = config.FRAME_WIDTH / 2.0

    # 画像下端での位置偏差
    x_bottom = poly(DETECT_Y_END)
    position_error = (center_x - x_bottom) / center_x

    # 傾き: dx/dy(画像下端での値)
    poly_deriv = poly.deriv()
    heading = float(poly_deriv(DETECT_Y_END))

    # 曲率: d²x/dy²
    poly_deriv2 = poly_deriv.deriv()
    curvature = float(poly_deriv2(DETECT_Y_END))

    return LineDetectResult(
        detected=True,
        position_error=position_error,
        heading=heading,
        curvature=curvature,
        poly_coeffs=coeffs,
        binary_image=binary,
    )