"""
zmq_client
Pi 側の ZMQ 通信を担当するモジュール
画像の送信と操舵量の受信を行う
"""

import json
import struct
import time

import cv2
import numpy as np
import zmq

from common import config


class PiZmqClient:
    """Pi 側の ZMQ 通信クライアント

    画像送信（PUB）と操舵量受信（SUB）の2チャネルを管理する
    """

    def __init__(self) -> None:
        self._context = zmq.Context()
        self._image_socket: zmq.Socket | None = None
        self._control_socket: zmq.Socket | None = None
        self._last_receive_time: float = 0.0
        self._last_rtt: float | None = None

    def start(self) -> None:
        """通信ソケットを初期化して接続する"""

        # 画像送信ソケット（PUB，PC へ画像を送信）
        self._image_socket = self._context.socket(zmq.PUB)
        self._image_socket.setsockopt(zmq.CONFLATE, 1)
        self._image_socket.connect(config.image_connect_address())

        # 操舵量受信ソケット（SUB，PC からの操舵量を受信）
        self._control_socket = self._context.socket(zmq.SUB)
        self._control_socket.setsockopt(zmq.CONFLATE, 1)
        self._control_socket.setsockopt_string(zmq.SUBSCRIBE, "")
        self._control_socket.connect(config.control_connect_address())

        self._last_receive_time = time.time()

    def send_image(self, frame: np.ndarray) -> None:
        """画像を JPEG 圧縮してタイムスタンプ付きで送信する

        Args:
            frame: カメラから取得した画像の NumPy 配列
        """
        if self._image_socket is None:
            return
        _, encoded = cv2.imencode(
            ".jpg",
            frame,
            [cv2.IMWRITE_JPEG_QUALITY, config.JPEG_QUALITY],
        )
        ts_bytes = struct.pack("d", time.time())
        self._image_socket.send(
            ts_bytes + encoded.tobytes(), zmq.NOBLOCK,
        )

    def receive_control(
        self,
    ) -> tuple[float, float] | None:
        """操舵量を非ブロッキングで受信する

        Returns:
            (throttle, steer) のタプル，受信データがない場合は None
        """
        if self._control_socket is None:
            return None
        try:
            data = self._control_socket.recv(zmq.NOBLOCK)
            payload = json.loads(data.decode("utf-8"))
            self._last_receive_time = time.time()

            # ラウンドトリップ計測
            if "ts" in payload:
                rtt = time.time() - payload["ts"]
                self._last_rtt = rtt

            return (payload["throttle"], payload["steer"])
        except zmq.Again:
            return None

    @property
    def last_rtt(self) -> float | None:
        """最後に計測したラウンドトリップ時間（秒）を返す"""
        return self._last_rtt

    def is_timeout(self) -> bool:
        """操舵量の受信がタイムアウトしたか判定する

        Returns:
            タイムアウトしていれば True
        """
        elapsed = time.time() - self._last_receive_time
        return elapsed > config.CONTROL_TIMEOUT_SEC

    def stop(self) -> None:
        """通信ソケットを閉じる"""
        if self._image_socket is not None:
            self._image_socket.close()
            self._image_socket = None
        if self._control_socket is not None:
            self._control_socket.close()
            self._control_socket = None
        self._context.term()
