import math
from typing import List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import FancyArrowPatch
from matplotlib.widgets import Button, Slider
from mpl_toolkits.mplot3d import proj3d

VERSION = "1.0"
# Configure connections between points (list of (start_index, end_index) using 0-based indices)
LINES = [
    (0, 1),
    (0, 2),
    (1, 3),
    (0, 4),
    (0, 5),
    (4, 5),
    (2, 6),
    (3, 7),
    (6, 10),
    (10, 11),
    (11, 12),
    (11, 13),
    (12, 13),
    (7, 8),
    (8, 9),
    (7, 9),
]
# Configure vectors to draw (list of (start_index, end_index, color, label) using 0-based indices)
# Edit this list to add/remove vectors to display
VECTORS = [
    (9, 8, "g", "v1"),
    (4, 5, "r", "v2"),
    (4, 14, "b", "v3"),
    (0, 4, "m", "v4"),
    (9, 15, "c", "v1'"),
]
# Line width for vector arrows
VECTOR_LW = 4
VECTOR_LEN = 20.0  # Length of vector arrows (can be adjusted as needed)
# Label offset factors: fraction of vector (x/y) and fraction of z-range to offset label position
# Increase these to place labels further from the vector
LABEL_OFFSET_XY = 0.12
LABEL_OFFSET_Z_FACTOR = 0.05


# 3D arrow helper to draw arrows between 3D points
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        self.do_3d_projection(renderer)
        try:
            super().draw(renderer)
        except Exception:
            # If patch drawing fails (e.g., degenerate path), skip drawing the arrow
            return

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        if abs(xs[0] - xs[1]) < 1e-6 and abs(ys[0] - ys[1]) < 1e-6:
            print("Degenerate arrow, skipping drawing")
            return 0
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return max(zs)


# Data class to hold thumb frame data
class ThumbFrameData:
    def __init__(
        self,
        frame_number: int,
        points: List[Tuple[float, float, float]],
    ):
        self.frame_number = frame_number
        self.points = points  # List of (x, y, z) tuples for 14 points


# Function to compute vector from two points
def vector(
    p1: Tuple[float, float, float], p2: Tuple[float, float, float]
) -> Tuple[float, float, float]:
    return (p2[0] - p1[0], p2[1] - p1[1], p2[2] - p1[2])


# Function to normalize a vector
def normalize(v: Tuple[float, float, float]) -> Tuple[float, float, float]:
    len = math.sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])
    if len < 1e-9:
        return (0, 0, 0)
    return (v[0] / len, v[1] / len, v[2] / len)


# Function to compute normalized outer product of two vectors
def outer_product(
    v1: Tuple[float, float, float], v2: Tuple[float, float, float]
) -> Tuple[float, float, float]:
    x = v1[1] * v2[2] - v1[2] * v2[1]
    y = v1[2] * v2[0] - v1[0] * v2[2]
    z = v1[0] * v2[1] - v1[1] * v2[0]
    return (x, y, z)


# Function to compute inner product of two vectors
def dot(v1: Tuple[float, float, float], v2: Tuple[float, float, float]) -> float:
    return v1[0] * v2[0] + v1[1] * v2[1] + v1[2] * v2[2]


# Function to load thumb data from CSV
def load_thumb_data(file_path: str, n_points=14) -> List[ThumbFrameData]:
    # Read CSV, skip first 5 rows (0-4), use columns 0-43
    df = pd.read_csv(
        file_path, skiprows=2, usecols=range(n_points * 3 + 2), header=None
    )

    data_list = []
    point_names = []
    for col, row in df.iterrows():
        if col == 0:
            for i in range(n_points):
                cell = row[2 + i * 3]
                if pd.notna(cell) and "New Subject:" in cell:
                    name = cell.split("New Subject:")[1].strip()
                    point_names.append(name)
            point_names.append("P15")  # Add name for the computed point
            point_names.append("P16")  # Add name for the computed point

        if col > 2:
            frame_number = int(row[0])
            points = []
            for i in range(n_points):
                x = float(row[2 + i * 3])
                y = float(row[3 + i * 3])
                z = float(row[4 + i * 3])
                points.append((x, y, z))

            # Compute the 15th point based on points 4 and 5 using the outer product
            outer_p = normalize(
                outer_product(
                    vector(points[0], points[4]), vector(points[5], points[4])
                )
            )
            p15 = (
                points[4][0] + outer_p[0] * VECTOR_LEN,
                points[4][1] + outer_p[1] * VECTOR_LEN,
                points[4][2] + outer_p[2] * VECTOR_LEN,
            )
            points.append(p15)  # Add the computed point as the 15th point

            # Compute the 16th point based on projection of vec1 to the plane defined by vec2 and vec3
            v1 = vector(points[VECTORS[0][0]], points[VECTORS[0][1]])
            v4 = vector(points[VECTORS[3][0]], points[VECTORS[3][1]])
            proj_l = dot(v1, v4) / dot(v4, v4) if dot(v4, v4) > 1e-9 else 0
            v1_proj_v4 = (proj_l * v4[0], proj_l * v4[1], proj_l * v4[2])
            v5 = (v1[0] - v1_proj_v4[0], v1[1] - v1_proj_v4[1], v1[2] - v1_proj_v4[2])
            p16 = (points[9][0] + v5[0], points[9][1] + v5[1], points[9][2] + v5[2])
            points.append(p16)  # Add the computed point as the 16th point

            data_list.append(ThumbFrameData(frame_number, points))

    return point_names, data_list


# Function to compute angle between first two vectors in VECTORS
def compute_and_format_angle(
    pts: List[Tuple[float, float, float]], vec1: int, vec2: int
) -> str:

    hstr = f"Angle {VECTORS[vec1][3]}-{VECTORS[vec2][3]} "
    if vec1 >= len(VECTORS) or vec2 >= len(VECTORS):
        return hstr + ": N/A"
    (a1, b1, _, _) = VECTORS[vec1]
    (a2, b2, _, _) = VECTORS[vec2]
    p1s, p1e, p2s, p2e = pts[a1], pts[b1], pts[a2], pts[b2]
    v1 = (p1e[0] - p1s[0], p1e[1] - p1s[1], p1e[2] - p1s[2])
    v2 = (p2e[0] - p2s[0], p2e[1] - p2s[1], p2e[2] - p2s[2])
    norm1 = math.sqrt(v1[0] ** 2 + v1[1] ** 2 + v1[2] ** 2)
    norm2 = math.sqrt(v2[0] ** 2 + v2[1] ** 2 + v2[2] ** 2)
    if norm1 < 1e-9 or norm2 < 1e-9:
        return hstr + "N/A"
    dot = v1[0] * v2[0] + v1[1] * v2[1] + v1[2] * v2[2]
    cos_val = max(-1.0, min(1.0, dot / (norm1 * norm2)))
    angle_rad = math.acos(cos_val)
    angle_deg = math.degrees(angle_rad)
    return hstr + f"= {angle_deg:.2f} deg"


# Function to visualize thumb data with a slider
def visualize_thumb_data(data: List[ThumbFrameData], POINT_NAMES: List[str] = []):
    if not data:
        return

    # Collect all points for axis limits
    all_points = [p for frame in data for p in frame.points]
    x_all = [p[0] for p in all_points]
    y_all = [p[1] for p in all_points]
    z_all = [p[2] for p in all_points]
    x_min, x_max = min(x_all), max(x_all)
    y_min, y_max = min(y_all), max(y_all)
    z_min, z_max = min(z_all), max(z_all)

    fig = plt.figure(figsize=(10, 10))
    fig.canvas.manager.set_window_title("ThumbAnalysis ver " + VERSION)
    ax = fig.add_subplot(111, projection="3d")

    # Initial plot
    frame_idx = 0
    points = data[frame_idx].points
    x = [p[0] for p in points]
    y = [p[1] for p in points]
    z = [p[2] for p in points]
    scat = ax.scatter(x, y, z)

    # Draw connections
    lines = []
    for conn in LINES:
        idx1, idx2 = conn
        line = ax.plot(
            [points[idx1][0], points[idx2][0]],
            [points[idx1][1], points[idx2][1]],
            [points[idx1][2], points[idx2][2]],
            "b-",
        )[0]
        lines.append(line)

    # Add point names
    texts = []
    for p, name in zip(points, POINT_NAMES):
        text = ax.text(p[0], p[1], p[2], name, fontsize=8)
        texts.append(text)

    def vector_coordinates(
        p1: Tuple[float, float, float], p2: Tuple[float, float, float]
    ) -> Tuple[List[List[float]], List[float]]:
        x1, y1, z1 = p1
        x2, y2, z2 = p2
        mid_x = 0.5 * (x1 + x2) + LABEL_OFFSET_XY * (x1 - x2)
        mid_y = 0.5 * (y1 + y2) + LABEL_OFFSET_XY * (y1 - y2)
        mid_z = 0.5 * (z1 + z2) + LABEL_OFFSET_Z_FACTOR * (z_max - z_min)
        midp = [mid_x, mid_y, mid_z]
        return [[x1, x2], [y1, y2], [z1, z2]], midp

    # Draw vector arrows as specified in VECTOR_PAIRS
    arrows = []
    arrow_labels = []
    # z_range = z_max - z_min if z_max > z_min else 1.0
    for i_pair, (idx1, idx2, color, name) in enumerate(VECTORS):
        arrow_pts, mid_pts = vector_coordinates(points[idx1], points[idx2])
        a = Arrow3D(
            *arrow_pts,
            mutation_scale=20,
            lw=VECTOR_LW,
            arrowstyle="-|>",
            color=color,
        )
        ax.add_artist(a)
        arrows.append(a)

        lbl = ax.text(
            *mid_pts,
            name,
            color=color,
            fontsize=10,
            weight="bold",
        )
        arrow_labels.append(lbl)

    # Angle display
    angle_text1 = fig.text(0.02, 0.95, "", fontsize=12, weight="bold")
    angle_text1.set_text(compute_and_format_angle(points, 0, 1))
    angle_text2 = fig.text(0.02, 0.92, "", fontsize=12, weight="bold")
    angle_text2.set_text(compute_and_format_angle(points, 4, 1))

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_zlim(z_min, z_max)
    ax.set_title(f"Frame {data[frame_idx].frame_number}")

    # Slider
    ax_slider = plt.axes([0.2, 0.02, 0.5, 0.03])
    slider = Slider(ax_slider, "Frame", 1, len(data), valinit=1, valstep=1)

    def update(val):
        frame_idx = int(slider.val) - 1
        points = data[frame_idx].points
        x = [p[0] for p in points]
        y = [p[1] for p in points]
        z = [p[2] for p in points]
        scat._offsets3d = (x, y, z)

        # Update lines
        for i, conn in enumerate(LINES):
            idx1, idx2 = conn
            lines[i].set_data_3d(
                [points[idx1][0], points[idx2][0]],
                [points[idx1][1], points[idx2][1]],
                [points[idx1][2], points[idx2][2]],
            )

        # Update texts
        for i, p in enumerate(points):
            texts[i].set_position((p[0], p[1], p[2]))

        # Update all configured arrows
        for i_pair, (idx1, idx2, color, name) in enumerate(VECTORS):
            arrow_pts, mid_pts = vector_coordinates(points[idx1], points[idx2])
            arrows[i_pair]._verts3d = arrow_pts
            arrow_labels[i_pair].set_position(mid_pts)

        # Update angle text
        angle_text1.set_text(compute_and_format_angle(points, 0, 1))
        angle_text2.set_text(compute_and_format_angle(points, 4, 1))

        ax.set_title(f"Frame {data[frame_idx].frame_number}")
        fig.canvas.draw_idle()

    slider.on_changed(update)

    # Button: switch view to direction from points[0] towards points[4]
    ax_button = plt.axes([0.02, 0.02, 0.1, 0.04])
    btn_top = Button(ax_button, "v4 view")

    def set_top_view(points):
        # Calculate direction from points[0] to points[4]
        forward = vector(points[0], points[4])
        # Calculate azimuth and elevation angles
        azim = math.degrees(math.atan2(forward[1], forward[0]))
        elev = math.degrees(
            math.atan2(forward[2], math.sqrt(forward[0] ** 2 + forward[1] ** 2))
        )
        ax.view_init(elev=elev, azim=azim)

        # Calculate roll so that vector from points[4] to points[5] points left
        # Target left direction: from points[4] to points[5]
        target_left = normalize(vector(points[14], points[4]))
        # Up vector (Z-axis)
        up = (0, 0, 1)
        # Current right direction: forward × up
        right = normalize(outer_product(forward, up))
        # Current left direction: up × right
        left = outer_product(up, right)
        # Calculate roll angle between current left and target left
        cos_roll = (
            left[0] * target_left[0]
            + left[1] * target_left[1]
            + left[2] * target_left[2]
        )
        sin_roll = (
            right[0] * target_left[0]
            + right[1] * target_left[1]
            + right[2] * target_left[2]
        )
        roll = math.degrees(math.atan2(sin_roll, cos_roll))
        ax.view_init(elev=elev, azim=azim, roll=roll)

    def btn_top_on_clicked(event):
        set_top_view(data[int(slider.val) - 1].points)
        fig.canvas.draw_idle()

    btn_top.on_clicked(btn_top_on_clicked)

    plt.show()


# Main function to test loading and visualization
def main():
    # Test the load_thumb_data function
    data_file = "IwasakiThumbRightKapandji03.csv"
    names, data = load_thumb_data(data_file)
    print(f"Loaded {len(data)} frames")
    if data:
        # Visualize the data
        visualize_thumb_data(data, names)


if __name__ == "__main__":
    main()
