from typing import List, Tuple
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.widgets import Slider
connection = [
(0, 1),
(0, 2),
(1, 3),
(0, 4),
(0, 5),
(4, 5),
(2, 6),
(3, 7),
(6, 10),
(10, 11),
(11, 12),
(11, 13),
(12, 13),
(7, 8),
(8, 9),
(7, 9),
]
POINT_NAMES: List[str] = []
# Data class to hold thumb frame data
class ThumbFrameData:
def __init__(
self,
frame_number: int,
points: List[Tuple[float, float, float]],
):
self.frame_number = frame_number
self.points = points # List of (x, y, z) tuples for 14 points
# Function to read point names from the CSV header
def get_point_names(file_path: str) -> List[str]:
global POINT_NAMES
df_header = pd.read_csv(file_path, skiprows=2, nrows=1, header=None)
point_names = []
for i in range(2, 59, 3):
cell = df_header.iloc[0, i]
if pd.notna(cell) and "New Subject:" in cell:
name = cell.split("New Subject:")[1].strip()
point_names.append(name)
POINT_NAMES = point_names
return point_names
# Function to load thumb data from CSV
def load_thumb_data(file_path: str) -> List[ThumbFrameData]:
get_point_names(file_path)
# Read CSV, skip first 5 rows (0-4), use columns 0-43
df = pd.read_csv(file_path, skiprows=5, usecols=range(44), header=None)
data_list = []
for _, row in df.iterrows():
frame_number = int(row[0])
points = []
for i in range(14):
x = row[2 + i * 3]
y = row[3 + i * 3]
z = row[4 + i * 3]
points.append((x, y, z))
data_list.append(ThumbFrameData(frame_number, points))
return data_list
# Function to visualize thumb data with a slider
def visualize_thumb_data(data: List[ThumbFrameData]):
if not data:
return
# Collect all points for axis limits
all_points = [p for frame in data for p in frame.points]
x_all = [p[0] for p in all_points]
y_all = [p[1] for p in all_points]
z_all = [p[2] for p in all_points]
x_min, x_max = min(x_all), max(x_all)
y_min, y_max = min(y_all), max(y_all)
z_min, z_max = min(z_all), max(z_all)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
# Initial plot
frame_idx = 0
points = data[frame_idx].points
x = [p[0] for p in points]
y = [p[1] for p in points]
z = [p[2] for p in points]
scat = ax.scatter(x, y, z)
# Draw connections
lines = []
for conn in connection:
idx1, idx2 = conn
line = ax.plot(
[points[idx1][0], points[idx2][0]],
[points[idx1][1], points[idx2][1]],
[points[idx1][2], points[idx2][2]],
"b-",
)[0]
lines.append(line)
# Add point names
texts = []
for p, name in zip(points, POINT_NAMES):
text = ax.text(p[0], p[1], p[2], name, fontsize=8)
texts.append(text)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_zlim(z_min, z_max)
ax.set_title(f"Frame {data[frame_idx].frame_number}")
# Slider
ax_slider = plt.axes([0.2, 0.02, 0.5, 0.03])
slider = Slider(ax_slider, "Frame", 0, len(data) - 1, valinit=0, valstep=1)
def update(val):
frame_idx = int(slider.val)
points = data[frame_idx].points
x = [p[0] for p in points]
y = [p[1] for p in points]
z = [p[2] for p in points]
scat._offsets3d = (x, y, z)
# Update lines
for i, conn in enumerate(connection):
idx1, idx2 = conn
lines[i].set_data_3d(
[points[idx1][0], points[idx2][0]],
[points[idx1][1], points[idx2][1]],
[points[idx1][2], points[idx2][2]],
)
# Update texts
for i, p in enumerate(points):
texts[i].set_position((p[0], p[1], p[2]))
ax.set_title(f"Frame {data[frame_idx].frame_number}")
fig.canvas.draw_idle()
slider.on_changed(update)
plt.show()
# Main function to test loading and visualization
def main():
# Test the load_thumb_data function
data = load_thumb_data("IwasakiThumbRightKapandji03.csv")
print(f"Loaded {len(data)} frames")
if data:
# Visualize the data
visualize_thumb_data(data)
if __name__ == "__main__":
main()