"""
line_detector
カメラ画像から黒線の位置を検出するモジュール
Pi 側では現行手法（current）のみ使用する
"""

from dataclasses import dataclass

import cv2
import numpy as np

from common import config
from pi.vision.fitting import clean_and_fit

# 検出領域の y 範囲（画像全体）
DETECT_Y_START: int = 0
DETECT_Y_END: int = config.FRAME_HEIGHT

# フィッティングに必要な最小数
MIN_FIT_PIXELS: int = 50
MIN_FIT_ROWS: int = 10


@dataclass
class ImageParams:
    """二値化パラメータ

    Attributes:
        clahe_clip: CLAHE のコントラスト増幅上限
        clahe_grid: CLAHE の局所領域分割数
        blur_size: ガウシアンブラーのカーネルサイズ（奇数）
        binary_thresh: 二値化の閾値
        open_size: オープニングのカーネルサイズ
        close_width: クロージングの横幅
        close_height: クロージングの高さ
        median_ksize: 中心点列の移動メディアンフィルタサイズ（0 で無効）
        neighbor_thresh: 近傍外れ値除去の閾値（px，0 で無効）
        residual_thresh: 残差反復除去の閾値（px，0 で無効）
    """

    clahe_clip: float = 2.0
    clahe_grid: int = 8
    blur_size: int = 5
    binary_thresh: int = 80
    open_size: int = 5
    close_width: int = 25
    close_height: int = 3
    median_ksize: int = 7
    neighbor_thresh: float = 10.0
    residual_thresh: float = 8.0


@dataclass
class LineDetectResult:
    """線検出の結果を格納するデータクラス

    Attributes:
        detected: 線が検出できたか
        position_error: 画像下端での位置偏差（-1.0～+1.0）
        heading: 線の傾き（dx/dy，画像下端での値）
        curvature: 線の曲率（d²x/dy²）
        poly_coeffs: 多項式の係数（描画用，未検出時は None）
        row_centers: 各行の線中心 x 座標（index=行番号，
            NaN=その行に線なし，未検出時は None）
        binary_image: 二値化後の画像（デバッグ用）
    """

    detected: bool
    position_error: float
    heading: float
    curvature: float
    poly_coeffs: np.ndarray | None
    row_centers: np.ndarray | None
    binary_image: np.ndarray | None


def no_detection(
    binary: np.ndarray,
) -> LineDetectResult:
    """未検出の結果を返す"""
    return LineDetectResult(
        detected=False,
        position_error=0.0,
        heading=0.0,
        curvature=0.0,
        poly_coeffs=None,
        row_centers=None,
        binary_image=binary,
    )


def _extract_row_centers(
    binary: np.ndarray,
) -> np.ndarray | None:
    """二値画像の最大連結領域から各行の線中心を求める

    Args:
        binary: 二値画像

    Returns:
        各行の中心 x 座標（NaN=その行に線なし），
        最大領域が見つからない場合は None
    """
    h, w = binary.shape[:2]
    num_labels, labels, stats, _ = (
        cv2.connectedComponentsWithStats(binary)
    )

    if num_labels <= 1:
        return None

    areas = stats[1:, cv2.CC_STAT_AREA]
    largest_label = int(np.argmax(areas)) + 1
    mask = (labels == largest_label).astype(np.uint8)

    centers = np.full(h, np.nan)
    for y in range(h):
        row = mask[y]
        cols = np.where(row > 0)[0]
        if len(cols) > 0:
            centers[y] = (cols[0] + cols[-1]) / 2.0

    return centers


def build_result(
    coeffs: np.ndarray,
    binary: np.ndarray,
    row_centers: np.ndarray | None = None,
) -> LineDetectResult:
    """多項式係数から LineDetectResult を構築する

    row_centers が None の場合は binary から自動抽出する
    """
    poly = np.poly1d(coeffs)
    center_x = config.FRAME_WIDTH / 2.0

    x_bottom = poly(DETECT_Y_END)
    position_error = (center_x - x_bottom) / center_x

    poly_deriv = poly.deriv()
    heading = float(poly_deriv(DETECT_Y_END))

    poly_deriv2 = poly_deriv.deriv()
    curvature = float(poly_deriv2(DETECT_Y_END))

    if row_centers is None:
        row_centers = _extract_row_centers(binary)

    return LineDetectResult(
        detected=True,
        position_error=position_error,
        heading=heading,
        curvature=curvature,
        poly_coeffs=coeffs,
        row_centers=row_centers,
        binary_image=binary,
    )


def detect_line(
    frame: np.ndarray,
    params: ImageParams | None = None,
) -> LineDetectResult:
    """画像から黒線の位置を検出する（現行手法）

    Args:
        frame: グレースケールのカメラ画像
        params: 二値化パラメータ（None でデフォルト）

    Returns:
        線検出の結果
    """
    if params is None:
        params = ImageParams()

    # CLAHE でコントラスト強調
    clahe = cv2.createCLAHE(
        clipLimit=params.clahe_clip,
        tileGridSize=(
            params.clahe_grid,
            params.clahe_grid,
        ),
    )
    enhanced = clahe.apply(frame)

    # ガウシアンブラー
    blur_k = params.blur_size | 1
    blurred = cv2.GaussianBlur(
        enhanced, (blur_k, blur_k), 0,
    )

    # 固定閾値で二値化（黒線を白に反転）
    _, binary = cv2.threshold(
        blurred, params.binary_thresh, 255,
        cv2.THRESH_BINARY_INV,
    )

    # オープニング（孤立ノイズ除去）
    if params.open_size >= 3:
        open_k = params.open_size | 1
        open_kernel = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE, (open_k, open_k),
        )
        binary = cv2.morphologyEx(
            binary, cv2.MORPH_OPEN, open_kernel,
        )

    # 横方向クロージング（途切れ補間）
    if params.close_width >= 3:
        close_h = max(params.close_height | 1, 1)
        close_kernel = cv2.getStructuringElement(
            cv2.MORPH_ELLIPSE,
            (params.close_width, close_h),
        )
        binary = cv2.morphologyEx(
            binary, cv2.MORPH_CLOSE, close_kernel,
        )

    # 全ピクセルフィッティング
    region = binary[DETECT_Y_START:DETECT_Y_END, :]
    ys_local, xs = np.where(region > 0)

    if len(xs) < MIN_FIT_PIXELS:
        return no_detection(binary)

    ys = ys_local + DETECT_Y_START
    coeffs = np.polyfit(ys, xs, 2)
    return build_result(coeffs, binary)
