"""
telemetry_display
テレメトリ受信・状態抽出・映像表示・オーバーレイ描画を担当するモジュール
"""

import time
from dataclasses import dataclass

import cv2
import numpy as np
from PySide6.QtCore import Qt
from PySide6.QtGui import QImage, QPixmap
from PySide6.QtWidgets import QLabel

from common import config
from common.vision.line_detector import LineDetectResult
from pc.comm.zmq_client import PcZmqClient
from pc.vision.overlay import OverlayFlags, draw_overlay

# 映像表示のスケール倍率
DISPLAY_SCALE: float = config.DISPLAY_SCALE


@dataclass
class TelemetryState:
    """テレメトリから受信した Pi 側の状態"""

    detected: bool = False
    pos_error: float = 0.0
    heading: float = 0.0
    is_intersection: bool = False
    is_recovering: bool = False
    intersection_available: bool = False
    fps: float = 0.0
    throttle: float = 0.0
    steer: float = 0.0


class TelemetryDisplay:
    """テレメトリを受信して映像・状態を表示する

    PcZmqClient からテレメトリを受信し，
    状態を抽出してオーバーレイ付き映像を QLabel に表示する
    """

    def __init__(
        self,
        zmq_client: PcZmqClient,
        video_label: QLabel,
        detect_info_label: QLabel,
        perf_label: QLabel,
    ) -> None:
        self._zmq_client = zmq_client
        self._video_label = video_label
        self._detect_info_label = detect_info_label
        self._perf_label = perf_label

        self.state = TelemetryState()
        self._latest_binary: np.ndarray | None = None

        # 受信 FPS 計測
        self._recv_frame_count: int = 0
        self._recv_fps_start: float = time.time()
        self._recv_fps: float = 0.0

    def update(
        self,
        overlay_flags: OverlayFlags,
    ) -> bool:
        """テレメトリを受信して表示を更新する

        Args:
            overlay_flags: オーバーレイ表示フラグ

        Returns:
            テレメトリを受信できた場合 True
        """
        result = self._zmq_client.receive_telemetry()
        if result is None:
            return False

        telemetry, frame, binary = result

        # テレメトリから状態を取得
        self.state.detected = telemetry.get(
            "detected", False,
        )
        self.state.pos_error = telemetry.get(
            "pos_error", 0.0,
        )
        self.state.heading = telemetry.get(
            "heading", 0.0,
        )
        self.state.is_intersection = telemetry.get(
            "is_intersection", False,
        )
        self.state.is_recovering = telemetry.get(
            "is_recovering", False,
        )
        self.state.intersection_available = telemetry.get(
            "intersection_available", False,
        )
        self.state.fps = telemetry.get("fps", 0.0)
        self.state.throttle = telemetry.get(
            "throttle", 0.0,
        )
        self.state.steer = telemetry.get("steer", 0.0)

        self._latest_binary = binary

        # 受信 FPS 計測
        self._recv_frame_count += 1
        elapsed = time.time() - self._recv_fps_start
        if elapsed >= 1.0:
            self._recv_fps = (
                self._recv_frame_count / elapsed
            )
            self._recv_frame_count = 0
            self._recv_fps_start = time.time()

        # 検出情報表示
        self._update_detect_info_label()

        # パフォーマンス表示
        self._perf_label.setText(
            f"recv FPS: {self._recv_fps:.1f}"
            f"  Pi FPS: {self.state.fps:.1f}"
        )

        self._display_frame(frame, overlay_flags)
        return True

    def _update_detect_info_label(self) -> None:
        """検出情報ラベルを更新する"""
        if not self.state.detected:
            self._detect_info_label.setText(
                "pos: ---  head: ---"
            )
            return
        self._detect_info_label.setText(
            f"pos: {self.state.pos_error:+.3f}"
            f"  head: {self.state.heading:+.4f}"
        )

    def _display_frame(
        self,
        frame: np.ndarray,
        overlay_flags: OverlayFlags,
    ) -> None:
        """NumPy 配列の画像を QLabel に表示する

        Args:
            frame: グレースケールの画像
            overlay_flags: オーバーレイ表示フラグ
        """
        # グレースケール → BGR 変換
        bgr = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)

        # テレメトリから LineDetectResult を構築
        detect_result = None
        if (
            self.state.detected
            and self._latest_binary is not None
        ):
            detect_result = LineDetectResult(
                detected=True,
                position_error=self.state.pos_error,
                heading=self.state.heading,
                curvature=0.0,
                poly_coeffs=None,
                row_centers=None,
                binary_image=self._latest_binary,
            )

        bgr = draw_overlay(
            bgr, detect_result,
            overlay_flags,
            is_intersection=(
                self.state.is_intersection
            ),
        )

        # BGR → RGB 変換
        rgb = bgr[:, :, ::-1].copy()
        h, w, ch = rgb.shape
        image = QImage(
            rgb.data, w, h, ch * w,
            QImage.Format.Format_RGB888,
        )
        disp_w = int(config.FRAME_WIDTH * DISPLAY_SCALE)
        disp_h = int(
            config.FRAME_HEIGHT * DISPLAY_SCALE
        )
        pixmap = QPixmap.fromImage(image).scaled(
            disp_w,
            disp_h,
            Qt.AspectRatioMode.KeepAspectRatio,
            Qt.TransformationMode.SmoothTransformation,
        )
        self._video_label.setPixmap(pixmap)
