"""
line_detector
カメラ画像から黒線の位置を検出するモジュール
複数の検出手法を切り替えて使用できる

公開 API:
    ImageParams, LineDetectResult, detect_line,
    reset_valley_tracker, DETECT_METHODS
"""

from dataclasses import dataclass

import cv2
import numpy as np

from common import config
from pc.vision.fitting import clean_and_fit

# 検出領域の y 範囲（画像全体）
DETECT_Y_START: int = 0
DETECT_Y_END: int = config.FRAME_HEIGHT

# フィッティングに必要な最小数
MIN_FIT_PIXELS: int = 50
MIN_FIT_ROWS: int = 10

# 検出手法の定義（キー: 識別子，値: 表示名）
DETECT_METHODS: dict[str, str] = {
    "current": "現行（CLAHE + 固定閾値）",
    "blackhat": "案A（Black-hat 中心）",
    "dual_norm": "案B（二重正規化）",
    "robust": "案C（最高ロバスト）",
    "valley": "案D（谷検出+追跡）",
}


@dataclass
class ImageParams:
    """画像処理パラメータ

    Attributes:
        method: 検出手法の識別子
        clahe_clip: CLAHE のコントラスト増幅上限
        clahe_grid: CLAHE の局所領域分割数
        blur_size: ガウシアンブラーのカーネルサイズ（奇数）
        binary_thresh: 二値化の閾値
        open_size: オープニングのカーネルサイズ
        close_width: クロージングの横幅
        close_height: クロージングの高さ
        blackhat_ksize: Black-hat のカーネルサイズ
        bg_blur_ksize: 背景除算のブラーカーネルサイズ
        global_thresh: 固定閾値（0 で無効，適応的閾値との AND）
        adaptive_block: 適応的閾値のブロックサイズ
        adaptive_c: 適応的閾値の定数 C
        iso_close_size: 等方クロージングのカーネルサイズ
        dist_thresh: 距離変換の閾値
        min_line_width: 行ごと中心抽出の最小線幅
        stage_close_small: 段階クロージング第1段のサイズ
        stage_min_area: 孤立除去の最小面積（0 で無効）
        stage_close_large: 段階クロージング第2段のサイズ（0 で無効）
        ransac_thresh: RANSAC の外れ値判定閾値
        ransac_iter: RANSAC の反復回数
        width_near: 画像下端での期待線幅（px，0 で無効）
        width_far: 画像上端での期待線幅（px，0 で無効）
        width_tolerance: 幅フィルタの上限倍率
        median_ksize: 中心点列の移動メディアンフィルタサイズ（0 で無効）
        neighbor_thresh: 近傍外れ値除去の閾値（px，0 で無効）
        residual_thresh: 残差反復除去の閾値（px，0 で無効）
        valley_gauss_ksize: 谷検出の行ごとガウシアンカーネルサイズ
        valley_min_depth: 谷として認識する最小深度
        valley_max_deviation: 追跡予測からの最大許容偏差（px）
        valley_coast_frames: 検出失敗時の予測継続フレーム数
        valley_ema_alpha: 多項式係数の指数移動平均係数
    """

    # 検出手法
    method: str = "current"

    # 現行手法パラメータ
    clahe_clip: float = 2.0
    clahe_grid: int = 8
    blur_size: int = 5
    binary_thresh: int = 80
    open_size: int = 5
    close_width: int = 25
    close_height: int = 3

    # 案A/C: Black-hat
    blackhat_ksize: int = 45

    # 案B: 背景除算
    bg_blur_ksize: int = 101
    global_thresh: int = 0  # 固定閾値（0 で無効）

    # 案B/C: 適応的閾値
    adaptive_block: int = 51
    adaptive_c: int = 10

    # 案A/B/C: 後処理
    iso_close_size: int = 15
    dist_thresh: float = 3.0
    min_line_width: int = 3

    # 案B: 段階クロージング
    stage_close_small: int = 5   # 第1段: 小クロージングサイズ
    stage_min_area: int = 0      # 孤立除去の最小面積（0 で無効）
    stage_close_large: int = 0   # 第2段: 大クロージングサイズ（0 で無効）

    # 案C: RANSAC
    ransac_thresh: float = 5.0
    ransac_iter: int = 50

    # ロバストフィッティング（全手法共通）
    median_ksize: int = 7
    neighbor_thresh: float = 10.0
    residual_thresh: float = 8.0

    # 透視補正付き幅フィルタ（0 で無効）
    width_near: int = 0
    width_far: int = 0
    width_tolerance: float = 1.8

    # 案D: 谷検出+追跡
    valley_gauss_ksize: int = 15
    valley_min_depth: int = 15
    valley_max_deviation: int = 40
    valley_coast_frames: int = 3
    valley_ema_alpha: float = 0.7


@dataclass
class LineDetectResult:
    """線検出の結果を格納するデータクラス

    Attributes:
        detected: 線が検出できたか
        position_error: 画像下端での位置偏差（-1.0～+1.0）
        heading: 線の傾き（dx/dy，画像下端での値）
        curvature: 線の曲率（d²x/dy²）
        poly_coeffs: 多項式の係数（描画用，未検出時は None）
        row_centers: 各行の線中心 x 座標（index=行番号，
            NaN=その行に線なし，未検出時は None）
        binary_image: 二値化後の画像（デバッグ用）
    """

    detected: bool
    position_error: float
    heading: float
    curvature: float
    poly_coeffs: np.ndarray | None
    row_centers: np.ndarray | None
    binary_image: np.ndarray | None


# ── 公開 API ──────────────────────────────────────


def detect_line(
    frame: np.ndarray,
    params: ImageParams | None = None,
) -> LineDetectResult:
    """画像から黒線の位置を検出する

    params.method に応じて検出手法を切り替える

    Args:
        frame: グレースケールのカメラ画像
        params: 画像処理パラメータ（None でデフォルト）

    Returns:
        線検出の結果
    """
    if params is None:
        params = ImageParams()

    method = params.method
    if method == "blackhat":
        from pc.vision.detectors.blackhat import (
            detect_blackhat,
        )
        return detect_blackhat(frame, params)
    if method == "dual_norm":
        from pc.vision.detectors.dual_norm import (
            detect_dual_norm,
        )
        return detect_dual_norm(frame, params)
    if method == "robust":
        from pc.vision.detectors.robust import (
            detect_robust,
        )
        return detect_robust(frame, params)
    if method == "valley":
        from pc.vision.detectors.valley import (
            detect_valley,
        )
        return detect_valley(frame, params)

    from pc.vision.detectors.current import (
        detect_current,
    )
    return detect_current(frame, params)


def reset_valley_tracker() -> None:
    """谷検出の追跡状態をリセットする"""
    from pc.vision.detectors.valley import (
        reset_valley_tracker as _reset,
    )
    _reset()


# ── 共通結果構築（各検出器から使用） ──────────────


def no_detection(
    binary: np.ndarray,
) -> LineDetectResult:
    """未検出の結果を返す"""
    return LineDetectResult(
        detected=False,
        position_error=0.0,
        heading=0.0,
        curvature=0.0,
        poly_coeffs=None,
        row_centers=None,
        binary_image=binary,
    )


def _extract_row_centers(
    binary: np.ndarray,
) -> np.ndarray | None:
    """二値画像の最大連結領域から各行の線中心を求める

    Args:
        binary: 二値画像

    Returns:
        各行の中心 x 座標（NaN=その行に線なし），
        最大領域が見つからない場合は None
    """
    h, w = binary.shape[:2]
    num_labels, labels, stats, _ = (
        cv2.connectedComponentsWithStats(binary)
    )

    if num_labels <= 1:
        return None

    # 背景（ラベル 0）を除いた最大領域を取得
    areas = stats[1:, cv2.CC_STAT_AREA]
    largest_label = int(np.argmax(areas)) + 1

    # 最大領域のマスク
    mask = (labels == largest_label).astype(np.uint8)

    # 各行の左右端から中心を計算
    centers = np.full(h, np.nan)
    for y in range(h):
        row = mask[y]
        cols = np.where(row > 0)[0]
        if len(cols) > 0:
            centers[y] = (cols[0] + cols[-1]) / 2.0

    return centers


def build_result(
    coeffs: np.ndarray,
    binary: np.ndarray,
    row_centers: np.ndarray | None = None,
) -> LineDetectResult:
    """多項式係数から LineDetectResult を構築する

    row_centers が None の場合は binary から自動抽出する
    """
    poly = np.poly1d(coeffs)
    center_x = config.FRAME_WIDTH / 2.0

    # 画像下端での位置偏差
    x_bottom = poly(DETECT_Y_END)
    position_error = (center_x - x_bottom) / center_x

    # 傾き: dx/dy（画像下端での値）
    poly_deriv = poly.deriv()
    heading = float(poly_deriv(DETECT_Y_END))

    # 曲率: d²x/dy²
    poly_deriv2 = poly_deriv.deriv()
    curvature = float(poly_deriv2(DETECT_Y_END))

    # row_centers が未提供なら binary から抽出
    if row_centers is None:
        row_centers = _extract_row_centers(binary)

    return LineDetectResult(
        detected=True,
        position_error=position_error,
        heading=heading,
        curvature=curvature,
        poly_coeffs=coeffs,
        row_centers=row_centers,
        binary_image=binary,
    )


def fit_row_centers(
    binary: np.ndarray,
    min_width: int,
    use_median: bool = False,
    ransac_thresh: float = 0.0,
    ransac_iter: int = 0,
    median_ksize: int = 0,
    neighbor_thresh: float = 0.0,
    residual_thresh: float = 0.0,
) -> LineDetectResult:
    """行ごとの中心点に多項式をフィッティングする

    各行の白ピクセルの中心（平均または中央値）を1点抽出し，
    ロバスト前処理の後にフィッティングする．
    幅の変動に強く，各行が等しく寄与する

    Args:
        binary: 二値画像
        min_width: 線として認識する最小ピクセル数
        use_median: True の場合は中央値を使用
        ransac_thresh: RANSAC 閾値（0 以下で無効）
        ransac_iter: RANSAC 反復回数
        median_ksize: 移動メディアンのカーネルサイズ
        neighbor_thresh: 近傍外れ値除去の閾値 px
        residual_thresh: 残差反復除去の閾値 px

    Returns:
        線検出の結果
    """
    region = binary[DETECT_Y_START:DETECT_Y_END, :]
    centers_y: list[float] = []
    centers_x: list[float] = []

    for y_local in range(region.shape[0]):
        xs = np.where(region[y_local] > 0)[0]
        if len(xs) < min_width:
            continue
        y = float(y_local + DETECT_Y_START)
        centers_y.append(y)
        if use_median:
            centers_x.append(float(np.median(xs)))
        else:
            centers_x.append(float(np.mean(xs)))

    if len(centers_y) < MIN_FIT_ROWS:
        return no_detection(binary)

    cy = np.array(centers_y)
    cx = np.array(centers_x)

    coeffs = clean_and_fit(
        cy, cx,
        median_ksize=median_ksize,
        neighbor_thresh=neighbor_thresh,
        residual_thresh=residual_thresh,
        ransac_thresh=ransac_thresh,
        ransac_iter=ransac_iter,
    )
    if coeffs is None:
        return no_detection(binary)

    return build_result(coeffs, binary)
