Newer
Older
ThumbAnalysis / main.py
from typing import List, Tuple

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import FancyArrowPatch
from matplotlib.widgets import Slider
from mpl_toolkits.mplot3d import proj3d

connection = [
    (0, 1),
    (0, 2),
    (1, 3),
    (0, 4),
    (0, 5),
    (4, 5),
    (2, 6),
    (3, 7),
    (6, 10),
    (10, 11),
    (11, 12),
    (11, 13),
    (12, 13),
    (7, 8),
    (8, 9),
    (7, 9),
]
POINT_NAMES: List[str] = []


# 3D arrow helper to draw arrows between 3D points
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, _ = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return max(zs)


# Data class to hold thumb frame data
class ThumbFrameData:
    def __init__(
        self,
        frame_number: int,
        points: List[Tuple[float, float, float]],
    ):
        self.frame_number = frame_number
        self.points = points  # List of (x, y, z) tuples for 14 points


# Function to read point names from the CSV header
def get_point_names(file_path: str) -> List[str]:
    global POINT_NAMES
    df_header = pd.read_csv(file_path, skiprows=2, nrows=1, header=None)
    point_names = []
    for i in range(2, 59, 3):
        cell = df_header.iloc[0, i]
        if pd.notna(cell) and "New Subject:" in cell:
            name = cell.split("New Subject:")[1].strip()
            point_names.append(name)
    POINT_NAMES = point_names
    return point_names


# Function to load thumb data from CSV
def load_thumb_data(file_path: str) -> List[ThumbFrameData]:
    get_point_names(file_path)
    # Read CSV, skip first 5 rows (0-4), use columns 0-43
    df = pd.read_csv(file_path, skiprows=5, usecols=range(44), header=None)

    data_list = []
    for _, row in df.iterrows():
        frame_number = int(row[0])
        points = []
        for i in range(14):
            x = row[2 + i * 3]
            y = row[3 + i * 3]
            z = row[4 + i * 3]
            points.append((x, y, z))
        data_list.append(ThumbFrameData(frame_number, points))

    return data_list


# Function to visualize thumb data with a slider
def visualize_thumb_data(data: List[ThumbFrameData]):
    if not data:
        return

    # Collect all points for axis limits
    all_points = [p for frame in data for p in frame.points]
    x_all = [p[0] for p in all_points]
    y_all = [p[1] for p in all_points]
    z_all = [p[2] for p in all_points]
    x_min, x_max = min(x_all), max(x_all)
    y_min, y_max = min(y_all), max(y_all)
    z_min, z_max = min(z_all), max(z_all)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection="3d")

    # Initial plot
    frame_idx = 0
    points = data[frame_idx].points
    x = [p[0] for p in points]
    y = [p[1] for p in points]
    z = [p[2] for p in points]
    scat = ax.scatter(x, y, z)

    # Draw connections
    lines = []
    for conn in connection:
        idx1, idx2 = conn
        line = ax.plot(
            [points[idx1][0], points[idx2][0]],
            [points[idx1][1], points[idx2][1]],
            [points[idx1][2], points[idx2][2]],
            "b-",
        )[0]
        lines.append(line)

    # Add point names
    texts = []
    for p, name in zip(points, POINT_NAMES):
        text = ax.text(p[0], p[1], p[2], name, fontsize=8)
        texts.append(text)

    # Draw vector arrow from point 4 to 5 (thick red arrow)
    arrow = Arrow3D(
        [points[4][0], points[5][0]],
        [points[4][1], points[5][1]],
        [points[4][2], points[5][2]],
        mutation_scale=20,
        lw=4,
        arrowstyle="-|>",
        color="r",
    )
    ax.add_artist(arrow)

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_zlim(z_min, z_max)
    ax.set_title(f"Frame {data[frame_idx].frame_number}")

    # Slider
    ax_slider = plt.axes([0.2, 0.02, 0.5, 0.03])
    slider = Slider(ax_slider, "Frame", 0, len(data) - 1, valinit=0, valstep=1)

    def update(val):
        frame_idx = int(slider.val)
        points = data[frame_idx].points
        x = [p[0] for p in points]
        y = [p[1] for p in points]
        z = [p[2] for p in points]
        scat._offsets3d = (x, y, z)

        # Update lines
        for i, conn in enumerate(connection):
            idx1, idx2 = conn
            lines[i].set_data_3d(
                [points[idx1][0], points[idx2][0]],
                [points[idx1][1], points[idx2][1]],
                [points[idx1][2], points[idx2][2]],
            )

        # Update texts
        for i, p in enumerate(points):
            texts[i].set_position((p[0], p[1], p[2]))

        # Update arrow from point 4 to 5
        arrow._verts3d = (
            [points[4][0], points[5][0]],
            [points[4][1], points[5][1]],
            [points[4][2], points[5][2]],
        )

        ax.set_title(f"Frame {data[frame_idx].frame_number}")
        fig.canvas.draw_idle()

    slider.on_changed(update)

    plt.show()


# Main function to test loading and visualization
def main():
    # Test the load_thumb_data function
    data = load_thumb_data("IwasakiThumbRightKapandji03.csv")
    print(f"Loaded {len(data)} frames")
    if data:
        # Visualize the data
        visualize_thumb_data(data)


if __name__ == "__main__":
    main()