import math
from typing import List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import FancyArrowPatch
from matplotlib.widgets import Button, Slider
from mpl_toolkits.mplot3d import proj3d

connection = [
    (0, 1),
    (0, 2),
    (1, 3),
    (0, 4),
    (0, 5),
    (4, 5),
    (2, 6),
    (3, 7),
    (6, 10),
    (10, 11),
    (11, 12),
    (11, 13),
    (12, 13),
    (7, 8),
    (8, 9),
    (7, 9),
]
POINT_NAMES: List[str] = []

# Configure vectors to draw (list of (start_index, end_index) using 0-based indices)
# Edit this list to add/remove vectors to display
VECTOR_PAIRS = [(4, 5), (9, 8)]
# Optional colors (will cycle if fewer colors than VECTOR_PAIRS)
VECTOR_COLORS = ["r", "g"]
# Line width for vector arrows
VECTOR_LW = 4
# Optional labels for vectors (defaults to v1, v2, ... if None)
VECTOR_LABELS = None
# Label offset factors: fraction of vector (x/y) and fraction of z-range to offset label position
# Increase these to place labels further from the vector
LABEL_OFFSET_XY = 0.12
LABEL_OFFSET_Z_FACTOR = 0.05


# 3D arrow helper to draw arrows between 3D points
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, _ = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        # If projected endpoints are effectively identical, skip drawing to avoid path errors
        if abs(xs[0] - xs[1]) < 1e-6 and abs(ys[0] - ys[1]) < 1e-6:
            return
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        try:
            super().draw(renderer)
        except Exception:
            # If patch drawing fails (e.g., degenerate path), skip drawing the arrow
            return

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return max(zs)


# Data class to hold thumb frame data
class ThumbFrameData:
    def __init__(
        self,
        frame_number: int,
        points: List[Tuple[float, float, float]],
    ):
        self.frame_number = frame_number
        self.points = points  # List of (x, y, z) tuples for 14 points


# Function to read point names from the CSV header
def get_point_names(file_path: str) -> List[str]:
    global POINT_NAMES
    df_header = pd.read_csv(file_path, skiprows=2, nrows=1, header=None)
    point_names = []
    for i in range(2, 59, 3):
        cell = df_header.iloc[0, i]
        if pd.notna(cell) and "New Subject:" in cell:
            name = cell.split("New Subject:")[1].strip()
            point_names.append(name)
    POINT_NAMES = point_names
    return point_names


# Function to load thumb data from CSV
def load_thumb_data(file_path: str) -> List[ThumbFrameData]:
    get_point_names(file_path)
    # Read CSV, skip first 5 rows (0-4), use columns 0-43
    df = pd.read_csv(file_path, skiprows=5, usecols=range(44), header=None)

    data_list = []
    for _, row in df.iterrows():
        frame_number = int(row[0])
        points = []
        for i in range(14):
            x = row[2 + i * 3]
            y = row[3 + i * 3]
            z = row[4 + i * 3]
            points.append((x, y, z))
        data_list.append(ThumbFrameData(frame_number, points))

    return data_list


def compute_and_format_angle(pts: List[Tuple[float, float, float]]) -> str:
    if len(VECTOR_PAIRS) < 2:
        return "Angle v1-v2: N/A"
    (a1, b1) = VECTOR_PAIRS[0]
    (a2, b2) = VECTOR_PAIRS[1]
    p1s = pts[a1]
    p1e = pts[b1]
    p2s = pts[a2]
    p2e = pts[b2]
    v1 = (p1e[0] - p1s[0], p1e[1] - p1s[1], p1e[2] - p1s[2])
    v2 = (p2e[0] - p2s[0], p2e[1] - p2s[1], p2e[2] - p2s[2])
    norm1 = math.sqrt(v1[0] ** 2 + v1[1] ** 2 + v1[2] ** 2)
    norm2 = math.sqrt(v2[0] ** 2 + v2[1] ** 2 + v2[2] ** 2)
    if norm1 < 1e-9 or norm2 < 1e-9:
        return "Angle v1-v2: N/A"
    dot = v1[0] * v2[0] + v1[1] * v2[1] + v1[2] * v2[2]
    cos_val = max(-1.0, min(1.0, dot / (norm1 * norm2)))
    angle_rad = math.acos(cos_val)
    angle_deg = math.degrees(angle_rad)
    return f"Angle v1-v2: {angle_deg:.2f}°"


# Function to visualize thumb data with a slider
def visualize_thumb_data(data: List[ThumbFrameData]):
    if not data:
        return

    # Collect all points for axis limits
    all_points = [p for frame in data for p in frame.points]
    x_all = [p[0] for p in all_points]
    y_all = [p[1] for p in all_points]
    z_all = [p[2] for p in all_points]
    x_min, x_max = min(x_all), max(x_all)
    y_min, y_max = min(y_all), max(y_all)
    z_min, z_max = min(z_all), max(z_all)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection="3d")

    # Initial plot
    frame_idx = 0
    points = data[frame_idx].points
    x = [p[0] for p in points]
    y = [p[1] for p in points]
    z = [p[2] for p in points]
    scat = ax.scatter(x, y, z)

    # Draw connections
    lines = []
    for conn in connection:
        idx1, idx2 = conn
        line = ax.plot(
            [points[idx1][0], points[idx2][0]],
            [points[idx1][1], points[idx2][1]],
            [points[idx1][2], points[idx2][2]],
            "b-",
        )[0]
        lines.append(line)

    # Add point names
    texts = []
    for p, name in zip(points, POINT_NAMES):
        text = ax.text(p[0], p[1], p[2], name, fontsize=8)
        texts.append(text)

    # Draw vector arrows as specified in VECTOR_PAIRS
    arrows = []
    for i_pair, (idx1, idx2) in enumerate(VECTOR_PAIRS):
        color = VECTOR_COLORS[i_pair % len(VECTOR_COLORS)] if VECTOR_COLORS else "r"
        a = Arrow3D(
            [points[idx1][0], points[idx2][0]],
            [points[idx1][1], points[idx2][1]],
            [points[idx1][2], points[idx2][2]],
            mutation_scale=20,
            lw=VECTOR_LW,
            arrowstyle="-|>",
            color=color,
        )
        ax.add_artist(a)
        arrows.append(a)

    # Create labels for each configured vector (v1, v2, ... by default)
    arrow_labels = []
    labels = (
        VECTOR_LABELS
        if VECTOR_LABELS is not None
        else [f"v{i + 1}" for i in range(len(VECTOR_PAIRS))]
    )
    z_range = z_max - z_min if z_max > z_min else 1.0
    for i_pair, (idx1, idx2) in enumerate(VECTOR_PAIRS):
        x1, y1, z1 = points[idx1]
        x2, y2, z2 = points[idx2]
        mid_x = 0.5 * (x1 + x2) + LABEL_OFFSET_XY * (x2 - x1)
        mid_y = 0.5 * (y1 + y2) + LABEL_OFFSET_XY * (y2 - y1)
        mid_z = 0.5 * (z1 + z2) + LABEL_OFFSET_Z_FACTOR * z_range
        label = labels[i_pair]
        lbl = ax.text(
            mid_x,
            mid_y,
            mid_z,
            label,
            color=VECTOR_COLORS[i_pair % len(VECTOR_COLORS)] if VECTOR_COLORS else "k",
            fontsize=10,
            weight="bold",
        )
        arrow_labels.append(lbl)

    # Angle display between v1 and v2 (figure-relative text)
    angle_text = fig.text(0.02, 0.95, "", fontsize=12, weight="bold")
    angle_text.set_text(compute_and_format_angle(points))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_zlim(z_min, z_max)
    ax.set_title(f"Frame {data[frame_idx].frame_number}")

    # Slider
    ax_slider = plt.axes([0.2, 0.02, 0.5, 0.03])
    slider = Slider(ax_slider, "Frame", 0, len(data) - 1, valinit=0, valstep=1)

    def update(val):
        frame_idx = int(slider.val)
        points = data[frame_idx].points
        x = [p[0] for p in points]
        y = [p[1] for p in points]
        z = [p[2] for p in points]
        scat._offsets3d = (x, y, z)

        # Update lines
        for i, conn in enumerate(connection):
            idx1, idx2 = conn
            lines[i].set_data_3d(
                [points[idx1][0], points[idx2][0]],
                [points[idx1][1], points[idx2][1]],
                [points[idx1][2], points[idx2][2]],
            )

        # Update texts
        for i, p in enumerate(points):
            texts[i].set_position((p[0], p[1], p[2]))

        # Update all configured arrows
        for i_pair, (idx1, idx2) in enumerate(VECTOR_PAIRS):
            arrows[i_pair]._verts3d = (
                [points[idx1][0], points[idx2][0]],
                [points[idx1][1], points[idx2][1]],
                [points[idx1][2], points[idx2][2]],
            )

        # Update arrow labels positions
        for i_pair, (idx1, idx2) in enumerate(VECTOR_PAIRS):
            x1, y1, z1 = points[idx1]
            x2, y2, z2 = points[idx2]
            mid_x = 0.5 * (x1 + x2) + LABEL_OFFSET_XY * (x2 - x1)
            mid_y = 0.5 * (y1 + y2) + LABEL_OFFSET_XY * (y2 - y1)
            mid_z = 0.5 * (z1 + z2) + LABEL_OFFSET_Z_FACTOR * (z_max - z_min)
            arrow_labels[i_pair].set_position((mid_x, mid_y, mid_z))

        # Update angle text between v1 and v2
        angle_text.set_text(compute_and_format_angle(points))

        ax.set_title(f"Frame {data[frame_idx].frame_number}")
        fig.canvas.draw_idle()

    slider.on_changed(update)

    # Button: switch to top-down (Z軸方向＝真上) の視点に切り替える
    ax_button = plt.axes([0.02, 0.02, 0.1, 0.04])
    btn_top = Button(ax_button, "X-Y plane")

    def set_top_view(event):
        ax.view_init(elev=90, azim=0)
        fig.canvas.draw_idle()

    btn_top.on_clicked(set_top_view)

    plt.show()


# Main function to test loading and visualization
def main():
    # Test the load_thumb_data function
    data = load_thumb_data("IwasakiThumbRightKapandji03.csv")
    print(f"Loaded {len(data)} frames")
    if data:
        # Visualize the data
        visualize_thumb_data(data)


if __name__ == "__main__":
    main()
