import cv2
import os
import csv
import re
import numpy as np
import pandas as pd
import joblib
import math
import lightgbm
import xgboost
from util.ears_ai import EarsAI
from util.calc_ste_position import CalcStethoscopePosition
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import argparse
from modules.EARSForDL.model import RegressionResNet  # モデル定義をインポート
import pickle
import matplotlib.pyplot as plt


def normalize_quadrilateral_with_point(points, extra_point):
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    left_shoulder = centered_points[0]
    right_shoulder = centered_points[1]
    shoulder_vector = right_shoulder - left_shoulder
    angle = np.arctan2(shoulder_vector[1], shoulder_vector[0])

    rotation_matrix = np.array([[np.cos(-angle), -np.sin(-angle)], [np.sin(-angle), np.cos(-angle)]])

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(np.linalg.norm(np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1))
    normalized_points = rotated_points / max_edge_length

    return normalized_points


def normalize_quadrilateral_with_point_average_rotation(points, extra_point):
    all_points = np.vstack([points.reshape(-1, 2), extra_point])
    center = np.mean(points.reshape(-1, 2), axis=0)
    centered_points = all_points - center

    left_shoulder, right_shoulder, left_hip, right_hip = centered_points[:4]

    shoulder_angle = calculate_rotation_angle(left_shoulder, right_shoulder)
    hip_angle = calculate_rotation_angle(left_hip, right_hip)

    average_angle = (shoulder_angle + hip_angle) / 2

    rotation_matrix = np.array(
        [[np.cos(-average_angle), -np.sin(-average_angle)], [np.sin(-average_angle), np.cos(-average_angle)]]
    )

    rotated_points = np.dot(centered_points, rotation_matrix.T)
    max_edge_length = np.max(np.linalg.norm(np.roll(rotated_points[:4], -1, axis=0) - rotated_points[:4], axis=1))
    normalized_points = rotated_points / max_edge_length

    return normalized_points


def calculate_rotation_angle(point1, point2):
    vector = point2 - point1
    return np.arctan2(vector[1], vector[0])


def video_to_frames(video_path, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    video = cv2.VideoCapture(video_path)

    if not video.isOpened():
        raise IOError(f"動画ファイルを開けませんでした: {video_path}")

    frame_num = 0

    while True:
        success, frame = video.read()
        if not success:
            break

        frame_num += 1
        output_filename = f"{frame_num}-frame.png"
        cv2.imwrite(os.path.join(output_dir, output_filename), frame)

    video.release()
    print(f"全てのフレームを {output_dir} に保存しました。")


def lgb_load_model(model_path):
    with open(model_path, 'rb') as model_file:
        loaded_model = pickle.load(model_file)
    return loaded_model


def xg_load_model(model_path):
    with open(model_path, 'rb') as model_file:
        loaded_model = pickle.load(model_file)
    return loaded_model


def CNN_load_model(model_path, device, resnet_depth=18):
    model = RegressionResNet(resnet_depth)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model


def predict(model, data):
    return model.predict(data)


def calculate_distance(point1, point2):
    return math.sqrt((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2)


def preprocess_image(image_path):
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)


def cnn_predict(model, image_tensor, device):
    with torch.no_grad():
        output = model(image_tensor.to(device))
    return output.cpu().numpy()[0]


def process_images(base_dir, draw_trajectory=True):
    ears_ai = EarsAI()
    calc_position = CalcStethoscopePosition()
    images_dir = base_dir
    results_dir = os.path.join(os.path.dirname(base_dir), "results")
    csv_path = os.path.join(results_dir, "results.csv")
    pose_overlay_dir = os.path.join(results_dir, "pose_overlay_image")
    stethoscope_overlay_dir = os.path.join(results_dir, "stethoscope_overlay_image")

    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(pose_overlay_dir, exist_ok=True)
    os.makedirs(stethoscope_overlay_dir, exist_ok=True)

    png_files = [f for f in os.listdir(images_dir) if f.lower().endswith(".png")]
    png_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group(1)))

    rows = []

    for image_file_name in png_files:
        image_path = os.path.join(images_dir, image_file_name)
        print(f"Processing image: {image_path}")

        frame = cv2.imread(image_path)

        if frame is None:
            print(f"Failed to load image: {image_path}")
            continue

        pose_overlay_img, left_shoulder, right_shoulder, left_hip, right_hip = ears_ai.pose_detect(frame, None)
        stethoscope_overlay_img, stethoscope_x, stethoscope_y = ears_ai.ssd_detect(frame, None)

        pose_overlay_path = os.path.join(pose_overlay_dir, image_file_name)
        cv2.imwrite(pose_overlay_path, cv2.cvtColor(pose_overlay_img, cv2.COLOR_RGB2BGR))
        print(f"Saved pose overlay image: {pose_overlay_path}")

        stethoscope_overlay_path = os.path.join(stethoscope_overlay_dir, image_file_name)
        cv2.imwrite(stethoscope_overlay_path, cv2.cvtColor(stethoscope_overlay_img, cv2.COLOR_RGB2BGR))
        print(f"Saved stethoscope overlay image: {stethoscope_overlay_path}")

        # S5とS6の計算
        S5 = calculate_distance(right_shoulder, left_hip)
        S6 = calculate_distance(left_shoulder, right_hip)

        # S5とS6の比率計算
        S5_standard = 215
        S6_standard = 204
        S5_ratio = S5 / S5_standard
        S6_ratio = S6 / S6_standard
        theta_1 = ((-4.5 * 100 * S5_ratio + 440) + (5.0 * 100 * S6_ratio - 500)) / 2
        theta_2 = ((53 * S5_ratio - 53) + (57 * S6_ratio - 57)) / 2
        x_e = 1.01 * theta_1 + 0.58
        y_e = 0.79 * theta_2 - 0.45

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # モデルの読み込み
        """ cnn_model = CNN_load_model("./models/best_model.pth", device, 18)
        image_tensor = preprocess_image(image_path)
        cnn_prediction = cnn_predict(cnn_model, image_tensor, device)
        cnn_stethoscope_x = int(round(cnn_prediction[0]))
        cnn_stethoscope_y = int(round(cnn_prediction[1])) """

        row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": left_shoulder[1],
            "left_shoulder_y": left_shoulder[0],
            "right_shoulder_x": right_shoulder[1],
            "right_shoulder_y": right_shoulder[0],
            "left_hip_x": left_hip[1],
            "left_hip_y": left_hip[0],
            "right_hip_x": right_hip[1],
            "right_hip_y": right_hip[0],
            "stethoscope_x": stethoscope_x,
            "stethoscope_y": stethoscope_y,
        }
        """ row = {
            "image_file_name": image_file_name,
            "left_shoulder_x": left_shoulder[1],
            "left_shoulder_y": left_shoulder[0],
            "right_shoulder_x": right_shoulder[1],
            "right_shoulder_y": right_shoulder[0],
            "left_hip_x": left_hip[1],
            "left_hip_y": left_hip[0],
            "right_hip_x": right_hip[1],
            "right_hip_y": right_hip[0],
            "stethoscope_x": stethoscope_x,
            "stethoscope_y": stethoscope_y,
            "cnn_stethoscope_x": cnn_stethoscope_x,
            "cnn_stethoscope_y": cnn_stethoscope_y,
        } """
        rows.append(row)

    if rows:
        """ fieldnames = list(rows[0].keys()) + [
            "conv_stethoscope_x",
            "conv_stethoscope_y",
            "formula_stethoscope_x",
            "formula_stethoscope_y",
            "lightGBM_stethoscope_x",
            "lightGBM_stethoscope_y",
        ] """
        fieldnames = list(rows[0].keys()) + [
            "conv_stethoscope_x",
            "conv_stethoscope_y",
            "Xgboost_stethoscope_x",
            "Xgboost_stethoscope_y",
            "lightGBM_stethoscope_x",
            "lightGBM_stethoscope_y",
        ]

        # Load models

        # LightGBMモデルをロード

        lgb_model_x = lgb_load_model("./models/lgb_stethoscope_calc_x_best_model.pkl")
        lgb_model_y = lgb_load_model("./models/lgb_stethoscope_calc_y_best_model.pkl")

        # XGBoostをロード
        xg_model_x = xg_load_model("./models/xg_stethoscope_calc_x_best_model.pkl")
        xg_model_y = xg_load_model("./models/xg_stethoscope_calc_y_best_model.pkl")

        with open(csv_path, "w", newline="") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()

            # 変数の初期化
            pre_conv_stethoscope_x = pre_conv_stethoscope_y = 180
            # pre_formula_stethoscope_x = pre_formula_stethoscope_y = 180
            pre_lightGBM_stethoscope_x = pre_lightGBM_stethoscope_y = 180
            pre_xgboost_stethoscope_x = pre_xgboost_stethoscope_y = 180
            if stethoscope_x == 0 and stethoscope_y == 0:
                    row["stethoscope_x"] = 320
                    row["stethoscope_x"] = 240

            for row in rows:
                source_points = np.array(
                    [
                        [float(row["left_shoulder_x"]), float(row["left_shoulder_y"])],
                        [float(row["right_shoulder_x"]), float(row["right_shoulder_y"])],
                        [float(row["left_hip_x"]), float(row["left_hip_y"])],
                        [float(row["right_hip_x"]), float(row["right_hip_y"])],
                    ],
                    dtype=np.float32,
                )
                stethoscope_x = float(row["stethoscope_x"])
                stethoscope_y = float(row["stethoscope_y"])
                quadrilateral_points = np.array(
                    [
                        float(row["left_shoulder_x"]),
                        float(row["left_shoulder_y"]),
                        float(row["right_shoulder_x"]),
                        float(row["right_shoulder_y"]),
                        float(row["left_hip_x"]),
                        float(row["left_hip_y"]),
                        float(row["right_hip_x"]),
                        float(row["right_hip_y"]),
                    ]
                )
                stethoscope_point = np.array([float(row["stethoscope_x"]), float(row["stethoscope_y"])])

                normalized_points = normalize_quadrilateral_with_point_average_rotation(
                    quadrilateral_points, stethoscope_point
                )

                if stethoscope_x == 0 and stethoscope_y == 0:
                    row["conv_stethoscope_x"] = pre_conv_stethoscope_x
                    row["conv_stethoscope_y"] = pre_conv_stethoscope_y
                    # row["formula_stethoscope_x"] = pre_formula_stethoscope_x
                    # row["formula_stethoscope_y"] = pre_formula_stethoscope_y
                    row["lightGBM_stethoscope_x"] = pre_lightGBM_stethoscope_x
                    row["lightGBM_stethoscope_y"] = pre_lightGBM_stethoscope_y
                    row["Xgboost_stethoscope_x"] = pre_xgboost_stethoscope_x
                    row["Xgboost_stethoscope_y"] = pre_xgboost_stethoscope_y
                else:
                    conv_stethoscope = calc_position.calc_affine(source_points, stethoscope_x, stethoscope_y)
                    row["conv_stethoscope_x"], row["conv_stethoscope_y"] = conv_stethoscope

                    """ row["formula_stethoscope_x"], row["formula_stethoscope_y"] = int(conv_stethoscope[0] - x_e), int(
                        conv_stethoscope[1] - y_e
                    ) """
                    row_convert = {
                        "left_shoulder_x": normalized_points[0, 0],
                        "left_shoulder_y": normalized_points[0, 1],
                        "right_shoulder_x": normalized_points[1, 0],
                        "right_shoulder_y": normalized_points[1, 1],
                        "left_hip_x": normalized_points[2, 0],
                        "left_hip_y": normalized_points[2, 1],
                        "right_hip_x": normalized_points[3, 0],
                        "right_hip_y": normalized_points[3, 1],
                        "stethoscope_x": normalized_points[4, 0],
                        "stethoscope_y": normalized_points[4, 1],
                    }
                    # 各点をプロット
                    plt.scatter(row_convert['left_shoulder_x'], row_convert['left_shoulder_y'], color='blue', s=100, label='Left Shoulder')
                    plt.scatter(row_convert['right_shoulder_x'], row_convert['right_shoulder_y'], color='blue', s=100, label='Right Shoulder')
                    plt.scatter(row_convert['left_hip_x'], row_convert['left_hip_y'], color='green', s=100, label='Left Hip')
                    plt.scatter(row_convert['right_hip_x'], row_convert['right_hip_y'], color='green', s=100, label='Right Hip')
                    plt.scatter(row_convert['stethoscope_x'], row_convert['stethoscope_y'], color='red', s=100, label='Stethoscope')
                    # Machine Learning prediction
                    input_columns = [
                        "left_shoulder_x",
                        "left_shoulder_y",
                        "right_shoulder_x",
                        "right_shoulder_y",
                        "left_hip_x",
                        "left_hip_y",
                        "right_hip_x",
                        "right_hip_y",
                        "stethoscope_x",
                        "stethoscope_y",
                    ]
                    input_data = pd.DataFrame([row_convert])
                    row["lightGBM_stethoscope_x"] = int(predict(lgb_model_x, input_data[input_columns])[0])
                    row["lightGBM_stethoscope_y"] = int(predict(lgb_model_y, input_data[input_columns])[0])
                    row["Xgboost_stethoscope_x"] = int(predict(xg_model_x, input_data[input_columns])[0])
                    row["Xgboost_stethoscope_y"] = int(predict(xg_model_y, input_data[input_columns])[0])

                    pre_conv_stethoscope_x = row["conv_stethoscope_x"]
                    pre_conv_stethoscope_y = row["conv_stethoscope_y"]
                    # pre_formula_stethoscope_x = row["formula_stethoscope_x"]
                    # pre_formula_stethoscope_y = row["formula_stethoscope_y"]
                    pre_lightGBM_stethoscope_x = row["lightGBM_stethoscope_x"]
                    pre_lightGBM_stethoscope_y = row["lightGBM_stethoscope_y"]
                    pre_xgboost_stethoscope_x = row["Xgboost_stethoscope_x"]
                    pre_xgboost_stethoscope_y = row["Xgboost_stethoscope_y"]

                writer.writerow(row)

        print(f"Processed and saved results to: {csv_path}")

        # Add these new variables at the beginning of your script
        conv_points = []
        cnn_points = []
        lgbm_points = []

        # 新しい処理を追加
        df = pd.read_csv(csv_path)
        original_images_dir = base_dir
        body_image = cv2.imread("images/body/BodyF.png")
        results_dir = "images/body/results"
        os.makedirs(results_dir, exist_ok=True)
        marked_images_dir = os.path.join(results_dir, "marked_images")
        os.makedirs(marked_images_dir, exist_ok=True)

        # 新しいディレクトリを作成
        conv_dir = os.path.join(results_dir, "conv")
        cnn_dir = os.path.join(results_dir, "cnn")
        lgbm_dir = os.path.join(results_dir, "lgbm")
        os.makedirs(conv_dir, exist_ok=True)
        os.makedirs(cnn_dir, exist_ok=True)
        os.makedirs(lgbm_dir, exist_ok=True)

        for _, row in df.iterrows():
            image_file_name = row["image_file_name"]
            conv_x, conv_y = int(row["conv_stethoscope_x"]), int(row["conv_stethoscope_y"])
            cnn_x, cnn_y = int(row["Xgboost_stethoscope_x"]), int(row["Xgboost_stethoscope_x"])
            lgbm_x, lgbm_y = int(row["lightGBM_stethoscope_x"]), int(row["lightGBM_stethoscope_y"])

            original_image_path = os.path.join(original_images_dir, image_file_name)
            original_image = cv2.imread(original_image_path)

            if original_image is None:
                print(f"Failed to load image: {original_image_path}")
                continue

            # Draw markers in cyan color (BGR: 255, 255, 0)
            cv2.circle(
                original_image, (int(row["left_shoulder_x"]), int(row["left_shoulder_y"])), 10, (255, 255, 0), -1
            )
            cv2.circle(
                original_image, (int(row["right_shoulder_x"]), int(row["right_shoulder_y"])), 10, (255, 255, 0), -1
            )
            cv2.circle(original_image, (int(row["left_hip_x"]), int(row["left_hip_y"])), 10, (255, 255, 0), -1)
            cv2.circle(original_image, (int(row["right_hip_x"]), int(row["right_hip_y"])), 10, (255, 255, 0), -1)
            cv2.circle(original_image, (int(row["stethoscope_x"]), int(row["stethoscope_y"])), 10, (255, 255, 0), -1)

            # Save marked image
            marked_image_path = os.path.join(marked_images_dir, image_file_name)
            cv2.imwrite(marked_image_path, original_image)

            conv_points.append((conv_x, conv_y))
            cnn_points.append((cnn_x, cnn_y))
            lgbm_points.append((lgbm_x, lgbm_y))

            # Conv画像を生成
            conv_image = body_image.copy()
            if draw_trajectory and len(conv_points) > 1:
                cv2.polylines(conv_image, [np.array(conv_points)], False, (0, 255, 0), 2)
            cv2.circle(conv_image, (conv_x, conv_y), 10, (0, 255, 0), -1)
            cv2.imwrite(os.path.join(conv_dir, image_file_name), conv_image)

            # CNN画像を生成
            cnn_image = body_image.copy()
            if draw_trajectory and len(cnn_points) > 1:
                cv2.polylines(cnn_image, [np.array(cnn_points)], False, (255, 0, 0), 2)
            cv2.circle(cnn_image, (cnn_x, cnn_y), 10, (255, 0, 0), -1)
            cv2.imwrite(os.path.join(cnn_dir, image_file_name), cnn_image)

            # LightGBM画像を生成
            lgbm_image = body_image.copy()
            if draw_trajectory and len(lgbm_points) > 1:
                cv2.polylines(lgbm_image, [np.array(lgbm_points)], False, (0, 0, 255), 2)
            cv2.circle(lgbm_image, (lgbm_x, lgbm_y), 10, (0, 0, 255), -1)
            cv2.imwrite(os.path.join(lgbm_dir, image_file_name), lgbm_image)

        # 動画を生成
        create_video_from_images(conv_dir, os.path.join(results_dir, "conv_video_with_trajectory.mp4"))
        create_video_from_images(cnn_dir, os.path.join(results_dir, "cnn_video_with_trajectory.mp4"))
        create_video_from_images(lgbm_dir, os.path.join(results_dir, "lgbm_video_with_trajectory.mp4"))

        # 軌跡なしの動画を生成
        create_video_from_images(conv_dir, os.path.join(results_dir, "conv_video_without_trajectory.mp4"), False)
        create_video_from_images(cnn_dir, os.path.join(results_dir, "cnn_video_without_trajectory.mp4"), False)
        create_video_from_images(lgbm_dir, os.path.join(results_dir, "lgbm_video_without_trajectory.mp4"), False)

        # Create video from marked images
        create_video_from_images(marked_images_dir, os.path.join(results_dir, "marked_video.mp4"))

    else:
        print("No data to write to CSV.")


def create_video_from_images(image_dir, output_path, with_trajectory=True):
    images = [img for img in os.listdir(image_dir) if img.endswith(".png")]
    images.sort(key=lambda x: int(re.search(r"(\d+)", x).group()))

    if images:
        frame = cv2.imread(os.path.join(image_dir, images[0]))
        height, width, layers = frame.shape

        video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), 30, (width, height))

        for image in images:
            img = cv2.imread(os.path.join(image_dir, image))
            if not with_trajectory:
                # 軌跡を消去（背景画像で上書き）
                background = cv2.imread("images/body/BodyF.png")
                mask = cv2.threshold(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), 10, 255, cv2.THRESH_BINARY)[1]
                img = cv2.bitwise_and(img, img, mask=mask)
                background = cv2.bitwise_and(background, background, mask=cv2.bitwise_not(mask))
                img = cv2.add(img, background)
            video.write(img)

        cv2.destroyAllWindows()
        video.release()

        print(f"Created video: {output_path}")
    else:
        print(f"No images found in {image_dir}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process video and generate results.")
    parser.add_argument("--video_path", default="./video/Test3-1.mp4", help="Path to the input video file")
    parser.add_argument("--output_dir", default="images", help="Directory to save output images and results")
    parser.add_argument("--draw_trajectory", action="store_true", help="Draw trajectory in the output video")

    args = parser.parse_args()

    # Step 1: Convert video to frames
    video_to_frames(args.video_path, args.output_dir)

    # Step 2: Process the generated images
    process_images(args.output_dir, args.draw_trajectory)
