diff --git a/main.py b/main.py index fd8d832..529420f 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ from pathlib import Path import csv import re +import argparse # .envファイルから環境変数を読み込む load_dotenv() @@ -18,8 +19,79 @@ detected_areas = [] current_image = None current_image_name = "" -all_results = [] # すべての画像の結果を保存するリスト +all_results = [] +expected_position = 1 +view_mode = 1 # 1: 正面, 2: 背面 +def validate_relative_position(center, pos_number, centers): + """ + 相対的な位置関係に基づいて点の妥当性をチェック + """ + x, y = center + global view_mode + + if pos_number == 1: + return True + + if pos_number == 2: + # Pos2 は Pos1 より左(背面の場合は右) + if view_mode == 1: # 正面 + return x < centers[0][0] + else: # 背面 + return x > centers[0][0] + + if pos_number == 3: + # Pos3 は Pos1,2 より下 + return y > centers[0][1] and y > centers[1][1] + + if pos_number == 4: + # Pos4 は Pos3 より左(背面の場合は右) + if view_mode == 1: # 正面 + return x < centers[2][0] + else: # 背面 + return x > centers[2][0] + + if pos_number == 5: + # Pos5 は Pos2 の近くで、より上 + return abs(x - centers[1][0]) < 50 and y < centers[1][1] + + if pos_number == 6: + # Pos6 は Pos1 の近く + return abs(x - centers[0][0]) < 50 + + if pos_number == 7: + # Pos7 は Pos6 の下 + return abs(x - centers[5][0]) < 50 and y > centers[5][1] + + if pos_number == 8: + # Pos8 は Pos2 の近く + return abs(x - centers[1][0]) < 50 + + if pos_number == 9: + # Pos9 は Pos8 より下 + return abs(x - centers[7][0]) < 50 and y > centers[7][1] + + if pos_number == 10: + # Pos10 は中央付近(制約を緩和) + return 330 < x < 380 # x座標の範囲のみをチェック + + if pos_number == 11: + # Pos11 は Pos10 より下 + return abs(x - centers[9][0]) < 50 and y > centers[9][1] + + if pos_number == 12: + # Pos12 は Pos9 の近くで下 + return abs(x - centers[8][0]) < 50 and y > centers[8][1] + + if pos_number == 13: + # Pos13 は Pos12 より下 + return abs(x - centers[11][0]) < 50 and y > centers[11][1] + + if pos_number == 14: + # Pos14 は Pos13 より右で Pos11 より下 + return x > centers[12][0] and y > centers[10][1] + + return False def region_growing(image, seed_point, threshold=60, max_distance=50): height, width = image.shape[:2] @@ -48,7 +120,6 @@ return segmented - def detect_ping_pong_ball(image, click_point, max_distance): x, y = click_point hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) @@ -73,50 +144,102 @@ print(f"No region detected. Returning clicked point: {click_point}") return click_point, 0 - def on_mouse_click(event, x, y, flags, param): - global current_image, clicked_points, detected_centers, detected_areas + global current_image, clicked_points, detected_centers, detected_areas, expected_position + if event == cv2.EVENT_LBUTTONDOWN: - clicked_points.append((x, y)) - print(f"Clicked point: ({x}, {y})") center, area = detect_ping_pong_ball(current_image, (x, y), max_distance=10) - detected_centers.append(center) - detected_areas.append(area) - - if center == (x, y): - print("Warning: Detected center is the same as clicked point.") + + # 位置の妥当性をチェック + if validate_relative_position(center, expected_position, detected_centers): + clicked_points.append((x, y)) + detected_centers.append(center) + detected_areas.append(area) + print(f"Position {expected_position} recorded at {center}") + expected_position += 1 + draw_detection() else: - print(f"Detected center differs from clicked point. Offset: ({center[0]-x}, {center[1]-y})") - - # 検出結果を描画 - draw_detection() + print(f"Invalid relative position for Pos{expected_position}. Please click the correct position.") + show_error_message(f"Invalid relative position for Pos{expected_position}") + elif event == cv2.EVENT_RBUTTONDOWN: if clicked_points: clicked_points.pop() detected_centers.pop() detected_areas.pop() + expected_position -= 1 print("Last click undone") draw_detection() +def show_error_message(message): + """エラーメッセージを画像に表示""" + img_copy = current_image.copy() + draw_detection() # 既存の点を描画 + + # エラーメッセージを画像上部に表示 + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 1 + thickness = 2 + color = (0, 0, 255) # 赤色 + + # テキストサイズを取得して中央に配置 + text_size = cv2.getTextSize(message, font, font_scale, thickness)[0] + text_x = (img_copy.shape[1] - text_size[0]) // 2 + text_y = 50 + + # 背景を描画 + cv2.rectangle(img_copy, + (text_x - 10, text_y - text_size[1] - 10), + (text_x + text_size[0] + 10, text_y + 10), + (255, 255, 255), + -1) + + # テキストを描画 + cv2.putText(img_copy, message, (text_x, text_y), font, font_scale, color, thickness) + cv2.imshow("Image", img_copy) def draw_detection(): + """検出結果を描画""" global current_image, clicked_points, detected_centers, detected_areas image_copy = current_image.copy() - for click, center, area in zip(clicked_points, detected_centers, detected_areas): - cv2.circle(image_copy, center, int(np.sqrt(area / np.pi)), (0, 255, 0), 2) + for i, center in enumerate(detected_centers): + # 点を描画 cv2.circle(image_copy, center, 2, (0, 0, 255), 3) - cv2.circle(image_copy, click, 2, (255, 0, 0), -1) # クリックした点を青で表示 + + # ラベルを描画 + label = str(i + 1) + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + font_thickness = 1 + text_offset = (center[0] + 10, center[1] - 10) + + # 白い背景を付けてラベルを見やすくする + (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, font_thickness) + cv2.rectangle(image_copy, + (text_offset[0] - 2, text_offset[1] - text_height - 2), + (text_offset[0] + text_width + 2, text_offset[1] + 2), + (255, 255, 255), + -1) + + # ラベルのテキストを描画 + cv2.putText(image_copy, + label, + text_offset, + font, + font_scale, + (0, 0, 0), + font_thickness) + cv2.imshow("Image", image_copy) - def start_detection(): - global clicked_points, detected_centers, detected_areas + global clicked_points, detected_centers, detected_areas, expected_position clicked_points = [] detected_centers = [] detected_areas = [] + expected_position = 1 cv2.setMouseCallback("Image", on_mouse_click) - print("Detection started. Click on ping pong balls. Right-click to undo last click.") - + print("Detection started. Click on ping pong balls in order. Right-click to undo last click.") def save_results(): global clicked_points, detected_centers, current_image_name, detected_areas, all_results @@ -129,7 +252,6 @@ for center in detected_centers: print(center) - def save_all_results(): global all_results with open("ping-pong-detection.csv", "w", newline="") as csvfile: @@ -142,17 +264,24 @@ writer.writerow(result) print("All results saved to ping-pong-detection.csv") - def sort_key(filename): - # ファイル名からangleの前の数値を抽出 match = re.search(r"_(-?\d+)_angle", filename) if match: return int(match.group(1)) - return 0 # マッチしない場合は0を返す - + return 0 def main(): - global current_image, current_image_name + global current_image, current_image_name, view_mode + + # コマンドライン引数の解析 + parser = argparse.ArgumentParser(description='Ping pong ball detection') + parser.add_argument('view', type=int, choices=[1, 2], help='View mode (1: front, 2: back)') + args = parser.parse_args() + + view_mode = args.view + view_name = "front" if view_mode == 1 else "back" + print(f"Running in {view_name} view mode") + if not IMAGE_DIR: print("Error: IMAGE_DIR is not set in .env file") return @@ -193,6 +322,5 @@ cv2.destroyAllWindows() save_all_results() - if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/visual.py b/visual.py new file mode 100644 index 0000000..9ebc292 --- /dev/null +++ b/visual.py @@ -0,0 +1,133 @@ +import cv2 +import numpy as np +import os +from dotenv import load_dotenv +from pathlib import Path +import csv +import re + +# .envファイルから環境変数を読み込む +load_dotenv() + +# 画像ディレクトリのパスを取得 +IMAGE_DIR = os.getenv("IMAGE_DIR") + +def read_csv_data(csv_file): + """CSVファイルから座標データを読み取る""" + data = {} + with open(csv_file, 'r') as f: + reader = csv.DictReader(f) + for row in reader: + image_name = row['image_file_name'] + points = [] + i = 1 + while f'Pos{i}_center_x' in row: + center = ( + int(float(row[f'Pos{i}_center_x'])), + int(float(row[f'Pos{i}_center_y'])) + ) + points.append((i, center)) # インデックスと中心座標を保存 + i += 1 + data[image_name] = points + return data + +def sort_key(filename): + """ファイル名からangleの前の数値を抽出してソートキーとする""" + match = re.search(r"_(-?\d+)_angle", filename) + if match: + return int(match.group(1)) + return 0 + +def draw_points(image, points): + """画像に点とラベルを描画する""" + image_copy = image.copy() + # フォントの設定 + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + font_thickness = 1 + + for idx, center in points: + # 中心位置を赤で描画 + cv2.circle(image_copy, center, 2, (0, 0, 255), 3) + + # ラベルを描画(少し右上にオフセット) + label = str(idx) + text_offset = (center[0] + 10, center[1] - 10) # 右上にオフセット + + # 白い背景を付けてラベルを見やすくする + (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, font_thickness) + cv2.rectangle(image_copy, + (text_offset[0] - 2, text_offset[1] - text_height - 2), + (text_offset[0] + text_width + 2, text_offset[1] + 2), + (255, 255, 255), + -1) + + # ラベルのテキストを描画 + cv2.putText(image_copy, + label, + text_offset, + font, + font_scale, + (0, 0, 0), + font_thickness) + + return image_copy + +def main(): + if not IMAGE_DIR: + print("Error: IMAGE_DIR is not set in .env file") + return + + image_dir = Path(IMAGE_DIR) + if not image_dir.is_dir(): + print(f"Error: {IMAGE_DIR} is not a valid directory") + return + + # CSVデータを読み込む + csv_data = read_csv_data('ping-pong-detection.csv') + + # 画像ファイルを取得し、angleの前の数値でソート + image_files = list(image_dir.glob("*.png")) + list(image_dir.glob("*.jpg")) + image_files.sort(key=lambda x: sort_key(x.name)) + + cv2.namedWindow("Image") + + current_image_index = 0 + while True: + if 0 <= current_image_index < len(image_files): + image_file = image_files[current_image_index] + image = cv2.imread(str(image_file)) + + if image is None: + print(f"Error: Unable to read image {image_file}") + current_image_index += 1 + continue + + image_name = image_file.stem + if image_name in csv_data: + # 点とラベルを描画 + image_with_points = draw_points(image, csv_data[image_name]) + cv2.imshow("Image", image_with_points) + print(f"Showing {image_file.name} ({current_image_index + 1}/{len(image_files)})") + else: + print(f"No data found for {image_name}") + cv2.imshow("Image", image) + + key = cv2.waitKey(0) & 0xFF + if key == ord('n'): # 次の画像 + current_image_index += 1 + elif key == ord('p'): # 前の画像 + current_image_index -= 1 + elif key == 27: # ESCで終了 + break + + # インデックスが範囲外になった場合の処理 + if current_image_index >= len(image_files): + current_image_index = 0 + elif current_image_index < 0: + current_image_index = len(image_files) - 1 + + cv2.destroyAllWindows() + +if __name__ == "__main__": + main() \ No newline at end of file