Newer
Older
Ping-Pone-Detection / visual.py
import cv2
import numpy as np
import os
from dotenv import load_dotenv
from pathlib import Path
import csv
import re

# .envファイルから環境変数を読み込む
load_dotenv()

# 画像ディレクトリのパスを取得
IMAGE_DIR = os.getenv("IMAGE_DIR")

def read_csv_data(csv_file):
    """CSVファイルから座標データを読み取る"""
    data = {}
    with open(csv_file, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            image_name = row['image_file_name']
            points = []
            i = 1
            while f'Pos{i}_center_x' in row:
                center = (
                    int(float(row[f'Pos{i}_center_x'])),
                    int(float(row[f'Pos{i}_center_y']))
                )
                points.append((i, center))  # インデックスと中心座標を保存
                i += 1
            data[image_name] = points
    return data

def sort_key(filename):
    """ファイル名からangleの前の数値を抽出してソートキーとする"""
    match = re.search(r"_(-?\d+)_angle", filename)
    if match:
        return int(match.group(1))
    return 0

def draw_points(image, points):
    """画像に点とラベルを描画する"""
    image_copy = image.copy()
    # フォントの設定
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
    font_thickness = 1
    
    for idx, center in points:
        # 中心位置を赤で描画
        cv2.circle(image_copy, center, 2, (0, 0, 255), 3)
        
        # ラベルを描画(少し右上にオフセット)
        label = str(idx)
        text_offset = (center[0] + 10, center[1] - 10)  # 右上にオフセット
        
        # 白い背景を付けてラベルを見やすくする
        (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, font_thickness)
        cv2.rectangle(image_copy, 
                     (text_offset[0] - 2, text_offset[1] - text_height - 2),
                     (text_offset[0] + text_width + 2, text_offset[1] + 2),
                     (255, 255, 255),
                     -1)
        
        # ラベルのテキストを描画
        cv2.putText(image_copy,
                   label,
                   text_offset,
                   font,
                   font_scale,
                   (0, 0, 0),
                   font_thickness)
    
    return image_copy

def main():
    if not IMAGE_DIR:
        print("Error: IMAGE_DIR is not set in .env file")
        return

    image_dir = Path(IMAGE_DIR)
    if not image_dir.is_dir():
        print(f"Error: {IMAGE_DIR} is not a valid directory")
        return

    # CSVデータを読み込む
    csv_data = read_csv_data('ping-pong-detection.csv')

    # 画像ファイルを取得し、angleの前の数値でソート
    image_files = list(image_dir.glob("*.png")) + list(image_dir.glob("*.jpg"))
    image_files.sort(key=lambda x: sort_key(x.name))

    cv2.namedWindow("Image")

    current_image_index = 0
    while True:
        if 0 <= current_image_index < len(image_files):
            image_file = image_files[current_image_index]
            image = cv2.imread(str(image_file))
            
            if image is None:
                print(f"Error: Unable to read image {image_file}")
                current_image_index += 1
                continue

            image_name = image_file.stem
            if image_name in csv_data:
                # 点とラベルを描画
                image_with_points = draw_points(image, csv_data[image_name])
                cv2.imshow("Image", image_with_points)
                print(f"Showing {image_file.name} ({current_image_index + 1}/{len(image_files)})")
            else:
                print(f"No data found for {image_name}")
                cv2.imshow("Image", image)

        key = cv2.waitKey(0) & 0xFF
        if key == ord('n'):  # 次の画像
            current_image_index += 1
        elif key == ord('p'):  # 前の画像
            current_image_index -= 1
        elif key == 27:  # ESCで終了
            break

        # インデックスが範囲外になった場合の処理
        if current_image_index >= len(image_files):
            current_image_index = 0
        elif current_image_index < 0:
            current_image_index = len(image_files) - 1

    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()