import cv2
import numpy as np
import os
from dotenv import load_dotenv
from pathlib import Path
import csv
import re

# .envファイルから環境変数を読み込む
load_dotenv()

# 画像ディレクトリのパスを取得
IMAGE_DIR = os.getenv("IMAGE_DIR")

# グローバル変数
clicked_points = []
detected_centers = []
detected_areas = []
current_image = None
current_image_name = ""
all_results = []  # すべての画像の結果を保存するリスト


def region_growing(image, seed_point, threshold=60, max_distance=50):
    height, width = image.shape[:2]
    segmented = np.zeros((height, width), dtype=np.uint8)

    # 8近傍
    neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0), (1, 1), (-1, -1), (1, -1), (-1, 1)]

    stack = [seed_point]
    segmented[seed_point[1], seed_point[0]] = 255
    seed_color = image[seed_point[1], seed_point[0]]

    while stack:
        x, y = stack.pop()
        for dx, dy in neighbors:
            nx, ny = x + dx, y + dy
            if 0 <= nx < width and 0 <= ny < height:
                if segmented[ny, nx] == 0:
                    # シード点からの距離を計算
                    distance = np.sqrt((nx - seed_point[0]) ** 2 + (ny - seed_point[1]) ** 2)
                    if distance <= max_distance:
                        pixel_color = image[ny, nx]
                        if np.all(np.abs(seed_color.astype(int) - pixel_color.astype(int)) <= threshold):
                            segmented[ny, nx] = 255
                            stack.append((nx, ny))

    return segmented


def detect_ping_pong_ball(image, click_point, max_distance):
    x, y = click_point
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # 領域拡張法で似た色の領域を抽出（範囲制限付き）
    segmented = region_growing(hsv, (x, y), threshold=50, max_distance=max_distance)

    # ノイズ除去
    kernel = np.ones((5, 5), np.uint8)
    segmented = cv2.morphologyEx(segmented, cv2.MORPH_OPEN, kernel)
    segmented = cv2.morphologyEx(segmented, cv2.MORPH_CLOSE, kernel)

    # 重心計算
    M = cv2.moments(segmented)
    if M["m00"] > 0:
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])
        area = cv2.countNonZero(segmented)
        print(f"Region detected: center=({cX}, {cY}), area={area}")
        return (cX, cY), area

    print(f"No region detected. Returning clicked point: {click_point}")
    return click_point, 0


def on_mouse_click(event, x, y, flags, param):
    global current_image, clicked_points, detected_centers, detected_areas
    if event == cv2.EVENT_LBUTTONDOWN:
        clicked_points.append((x, y))
        print(f"Clicked point: ({x}, {y})")
        center, area = detect_ping_pong_ball(current_image, (x, y), max_distance=10)
        detected_centers.append(center)
        detected_areas.append(area)

        if center == (x, y):
            print("Warning: Detected center is the same as clicked point.")
        else:
            print(f"Detected center differs from clicked point. Offset: ({center[0]-x}, {center[1]-y})")

        # 検出結果を描画
        draw_detection()
    elif event == cv2.EVENT_RBUTTONDOWN:
        if clicked_points:
            clicked_points.pop()
            detected_centers.pop()
            detected_areas.pop()
            print("Last click undone")
            draw_detection()


def draw_detection():
    global current_image, clicked_points, detected_centers, detected_areas
    image_copy = current_image.copy()
    for click, center, area in zip(clicked_points, detected_centers, detected_areas):
        cv2.circle(image_copy, center, int(np.sqrt(area / np.pi)), (0, 255, 0), 2)
        cv2.circle(image_copy, center, 2, (0, 0, 255), 3)
        cv2.circle(image_copy, click, 2, (255, 0, 0), -1)  # クリックした点を青で表示
    cv2.imshow("Image", image_copy)


def start_detection():
    global clicked_points, detected_centers, detected_areas
    clicked_points = []
    detected_centers = []
    detected_areas = []
    cv2.setMouseCallback("Image", on_mouse_click)
    print("Detection started. Click on ping pong balls. Right-click to undo last click.")


def save_results():
    global clicked_points, detected_centers, current_image_name, detected_areas, all_results
    result = [current_image_name]
    for click, center in zip(clicked_points, detected_centers):
        result.extend([click[0], click[1], center[0], center[1]])
    all_results.append(result)
    print(f"Results for {current_image_name} added to all_results")
    print("Detected centers:")
    for center in detected_centers:
        print(center)


def save_all_results():
    global all_results
    with open("ping-pong-detection.csv", "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        header = ["image_file_name"]
        for i in range(1, (len(max(all_results, key=len)) - 1) // 4 + 1):
            header.extend([f"Pos{i}_click_x", f"Pos{i}_click_y", f"Pos{i}_center_x", f"Pos{i}_center_y"])
        writer.writerow(header)
        for result in all_results:
            writer.writerow(result)
    print("All results saved to ping-pong-detection.csv")


def sort_key(filename):
    # ファイル名からangleの前の数値を抽出
    match = re.search(r"_(-?\d+)_angle", filename)
    if match:
        return int(match.group(1))
    return 0  # マッチしない場合は0を返す


def main():
    global current_image, current_image_name
    if not IMAGE_DIR:
        print("Error: IMAGE_DIR is not set in .env file")
        return

    image_dir = Path(IMAGE_DIR)
    if not image_dir.is_dir():
        print(f"Error: {IMAGE_DIR} is not a valid directory")
        return

    # 画像ファイルを取得し、angleの前の数値でソート
    image_files = list(image_dir.glob("*.png")) + list(image_dir.glob("*.jpg"))
    image_files.sort(key=lambda x: sort_key(x.name))

    cv2.namedWindow("Image")

    for image_file in image_files:
        current_image = cv2.imread(str(image_file))
        current_image_name = image_file.stem
        if current_image is None:
            print(f"Error: Unable to read image {image_file}")
            continue

        cv2.imshow("Image", current_image)
        print(f"Current image: {image_file.name}")

        start_detection()

        while True:
            key = cv2.waitKey(1) & 0xFF
            if key == ord("f"):  # 'f' key to finish and save results
                save_results()
                break
            elif key == 27:  # ESC key to exit
                cv2.destroyAllWindows()
                save_all_results()
                return

    cv2.destroyAllWindows()
    save_all_results()


if __name__ == "__main__":
    main()
