import cv2
import numpy as np
import os
from dotenv import load_dotenv
from pathlib import Path
import csv
import re
import argparse

# .envファイルから環境変数を読み込む
load_dotenv()

# 画像ディレクトリのパスを取得
IMAGE_DIR = os.getenv("IMAGE_DIR")

# グローバル変数
clicked_points = []
detected_centers = []
detected_areas = []
current_image = None
current_image_name = ""
all_results = []
expected_position = 1
view_mode = 1  # 1: 正面, 2: 背面

def validate_relative_position(center, pos_number, centers):
    """
    相対的な位置関係に基づいて点の妥当性をチェック
    """
    x, y = center
    global view_mode
    
    if pos_number == 1:
        return True
        
    if pos_number == 2:
        # Pos2 は Pos1 より左（背面の場合は右）
        if view_mode == 1:  # 正面
            return x < centers[0][0]
        else:  # 背面
            return x > centers[0][0]
        
    if pos_number == 3:
        # Pos3 は Pos1,2 より下
        return y > centers[0][1] and y > centers[1][1]
        
    if pos_number == 4:
        # Pos4 は Pos3 より左（背面の場合は右）
        if view_mode == 1:  # 正面
            return x < centers[2][0]
        else:  # 背面
            return x > centers[2][0]
        
    if pos_number == 5:
        # Pos5 は Pos2 の近くで、より上
        return abs(x - centers[1][0]) < 50 and y < centers[1][1]
        
    if pos_number == 6:
        # Pos6 は Pos1 の近く
        return abs(x - centers[0][0]) < 50
        
    if pos_number == 7:
        # Pos7 は Pos6 の下
        return abs(x - centers[5][0]) < 50 and y > centers[5][1]
        
    if pos_number == 8:
        # Pos8 は Pos2 の近く
        return abs(x - centers[1][0]) < 50
        
    if pos_number == 9:
        # Pos9 は Pos8 より下
        return abs(x - centers[7][0]) < 50 and y > centers[7][1]
        
    if pos_number == 10:
        # Pos10 は中央付近（制約を緩和）
        return 330 < x < 380  # x座標の範囲のみをチェック
        
    if pos_number == 11:
        # Pos11 は Pos10 より下
        return abs(x - centers[9][0]) < 50 and y > centers[9][1]
        
    if pos_number == 12:
        # Pos12 は Pos9 の近くで下
        return abs(x - centers[8][0]) < 50 and y > centers[8][1]
        
    if pos_number == 13:
        # Pos13 は Pos12 より下
        return abs(x - centers[11][0]) < 50 and y > centers[11][1]
        
    if pos_number == 14:
        # Pos14 は Pos13 より右で Pos11 より下
        return x > centers[12][0] and y > centers[10][1]

    return False

def region_growing(image, seed_point, threshold=60, max_distance=50):
    height, width = image.shape[:2]
    segmented = np.zeros((height, width), dtype=np.uint8)

    # 8近傍
    neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0), (1, 1), (-1, -1), (1, -1), (-1, 1)]

    stack = [seed_point]
    segmented[seed_point[1], seed_point[0]] = 255
    seed_color = image[seed_point[1], seed_point[0]]

    while stack:
        x, y = stack.pop()
        for dx, dy in neighbors:
            nx, ny = x + dx, y + dy
            if 0 <= nx < width and 0 <= ny < height:
                if segmented[ny, nx] == 0:
                    # シード点からの距離を計算
                    distance = np.sqrt((nx - seed_point[0]) ** 2 + (ny - seed_point[1]) ** 2)
                    if distance <= max_distance:
                        pixel_color = image[ny, nx]
                        if np.all(np.abs(seed_color.astype(int) - pixel_color.astype(int)) <= threshold):
                            segmented[ny, nx] = 255
                            stack.append((nx, ny))

    return segmented

def detect_ping_pong_ball(image, click_point, max_distance):
    x, y = click_point
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # 領域拡張法で似た色の領域を抽出（範囲制限付き）
    segmented = region_growing(hsv, (x, y), threshold=50, max_distance=max_distance)

    # ノイズ除去
    kernel = np.ones((5, 5), np.uint8)
    segmented = cv2.morphologyEx(segmented, cv2.MORPH_OPEN, kernel)
    segmented = cv2.morphologyEx(segmented, cv2.MORPH_CLOSE, kernel)

    # 重心計算
    M = cv2.moments(segmented)
    if M["m00"] > 0:
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        area = cv2.countNonZero(segmented)
        print(f"Region detected: center=({cX}, {cY}), area={area}")
        return (cX, cY), area

    print(f"No region detected. Returning clicked point: {click_point}")
    return click_point, 0

def on_mouse_click(event, x, y, flags, param):
    global current_image, clicked_points, detected_centers, detected_areas, expected_position
    
    if event == cv2.EVENT_LBUTTONDOWN:
        center, area = detect_ping_pong_ball(current_image, (x, y), max_distance=10)
        
        # 位置の妥当性をチェック
        if validate_relative_position(center, expected_position, detected_centers):
            clicked_points.append((x, y))
            detected_centers.append(center)
            detected_areas.append(area)
            print(f"Position {expected_position} recorded at {center}")
            expected_position += 1
            draw_detection()
        else:
            print(f"Invalid relative position for Pos{expected_position}. Please click the correct position.")
            show_error_message(f"Invalid relative position for Pos{expected_position}")
            
    elif event == cv2.EVENT_RBUTTONDOWN:
        if clicked_points:
            clicked_points.pop()
            detected_centers.pop()
            detected_areas.pop()
            expected_position -= 1
            print("Last click undone")
            draw_detection()

def show_error_message(message):
    """エラーメッセージを画像に表示"""
    img_copy = current_image.copy()
    draw_detection()  # 既存の点を描画
    
    # エラーメッセージを画像上部に表示
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    thickness = 2
    color = (0, 0, 255)  # 赤色
    
    # テキストサイズを取得して中央に配置
    text_size = cv2.getTextSize(message, font, font_scale, thickness)[0]
    text_x = (img_copy.shape[1] - text_size[0]) // 2
    text_y = 50
    
    # 背景を描画
    cv2.rectangle(img_copy, 
                 (text_x - 10, text_y - text_size[1] - 10),
                 (text_x + text_size[0] + 10, text_y + 10),
                 (255, 255, 255),
                 -1)
    
    # テキストを描画
    cv2.putText(img_copy, message, (text_x, text_y), font, font_scale, color, thickness)
    cv2.imshow("Image", img_copy)

def draw_detection():
    """検出結果を描画"""
    global current_image, clicked_points, detected_centers, detected_areas
    image_copy = current_image.copy()
    for i, center in enumerate(detected_centers):
        # 点を描画
        cv2.circle(image_copy, center, 2, (0, 0, 255), 3)
        
        # ラベルを描画
        label = str(i + 1)
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.5
        font_thickness = 1
        text_offset = (center[0] + 10, center[1] - 10)
        
        # 白い背景を付けてラベルを見やすくする
        (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, font_thickness)
        cv2.rectangle(image_copy, 
                     (text_offset[0] - 2, text_offset[1] - text_height - 2),
                     (text_offset[0] + text_width + 2, text_offset[1] + 2),
                     (255, 255, 255),
                     -1)
        
        # ラベルのテキストを描画
        cv2.putText(image_copy,
                   label,
                   text_offset,
                   font,
                   font_scale,
                   (0, 0, 0),
                   font_thickness)
    
    cv2.imshow("Image", image_copy)

def start_detection():
    global clicked_points, detected_centers, detected_areas, expected_position
    clicked_points = []
    detected_centers = []
    detected_areas = []
    expected_position = 1
    cv2.setMouseCallback("Image", on_mouse_click)
    print("Detection started. Click on ping pong balls in order. Right-click to undo last click.")

def save_results():
    global clicked_points, detected_centers, current_image_name, detected_areas, all_results
    result = [current_image_name]
    for click, center in zip(clicked_points, detected_centers):
        result.extend([click[0], click[1], center[0], center[1]])
    all_results.append(result)
    print(f"Results for {current_image_name} added to all_results")
    print("Detected centers:")
    for center in detected_centers:
        print(center)

def save_all_results():
    global all_results
    with open("ping-pong-detection.csv", "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        header = ["image_file_name"]
        for i in range(1, (len(max(all_results, key=len)) - 1) // 4 + 1):
            header.extend([f"Pos{i}_click_x", f"Pos{i}_click_y", f"Pos{i}_center_x", f"Pos{i}_center_y"])
        writer.writerow(header)
        for result in all_results:
            writer.writerow(result)
    print("All results saved to ping-pong-detection.csv")

def sort_key(filename):
    match = re.search(r"_(-?\d+)_angle", filename)
    if match:
        return int(match.group(1))
    return 0

def main():
    global current_image, current_image_name, view_mode

    # コマンドライン引数の解析
    parser = argparse.ArgumentParser(description='Ping pong ball detection')
    parser.add_argument('view', type=int, choices=[1, 2], help='View mode (1: front, 2: back)')
    args = parser.parse_args()
    
    view_mode = args.view
    view_name = "front" if view_mode == 1 else "back"
    print(f"Running in {view_name} view mode")

    if not IMAGE_DIR:
        print("Error: IMAGE_DIR is not set in .env file")
        return

    image_dir = Path(IMAGE_DIR)
    if not image_dir.is_dir():
        print(f"Error: {IMAGE_DIR} is not a valid directory")
        return

    # 画像ファイルを取得し、angleの前の数値でソート
    image_files = list(image_dir.glob("*.png")) + list(image_dir.glob("*.jpg"))
    image_files.sort(key=lambda x: sort_key(x.name))

    cv2.namedWindow("Image")

    for image_file in image_files:
        current_image = cv2.imread(str(image_file))
        current_image_name = image_file.stem
        if current_image is None:
            print(f"Error: Unable to read image {image_file}")
            continue

        cv2.imshow("Image", current_image)
        print(f"Current image: {image_file.name}")

        start_detection()

        while True:
            key = cv2.waitKey(1) & 0xFF
            if key == ord("f"):  # 'f' key to finish and save results
                save_results()
                break
            elif key == 27:  # ESC key to exit
                cv2.destroyAllWindows()
                save_all_results()
                return

    cv2.destroyAllWindows()
    save_all_results()

if __name__ == "__main__":
    main()