import argparse
import glob
from pathlib import Path

import cv2
import numpy as np

import config


def calc_illum_dist(img, roi_pos):
    # Convert the image to grayscale (luminance)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # ROI生成
    roi_x1, roi_y1, roi_x2, roi_y2 = (
        roi_pos[0][0],
        roi_pos[0][1],
        roi_pos[1][0],
        roi_pos[1][1],
    )
    gray_roi = gray[roi_y1:roi_y2, roi_x1:roi_x2]

    # ROI描画
    anno = img.copy()
    cv2.rectangle(anno, (roi_x1, roi_y1), (roi_x2, roi_y2), (0, 0, 255), 2)

    # 相対照度分布の算出
    vmean = gray_roi.mean()
    ratio_roi = gray_roi / vmean
    # print(
    #     "ratio min,mean,max",
    #     ratio_roi.min(),
    #     ratio_roi.mean(),
    #     ratio_roi.max(),
    # )
    # ratio = np.full(gray.shape, -1.0, dtype=np.float32)
    # ratio[roi_y1:roi_y2, roi_x1:roi_x2] = ratio_roi

    # if i == len(names) - 1:
    #
    #
    #
    #     # print(ratio[:10, :10])
    #     ratio_whole = np.full(gray_full.shape, -1.0, dtype=np.float32)
    #     ratio_whole[crop_y : crop_y + crop_wh, crop_x : crop_x + crop_wh] = ratio
    #     # print("ratio_whole shape", ratio_whole.shape)
    #     # print("ratio_whole min", ratio_whole.min())
    #     # print("ratio_whole max", ratio_whole.max())
    #     np.savez_compressed("illum_dist.npz", ratio_whole)

    return ratio_roi, anno


def detect_landmarks(img):
    # ランドマーク検出
    landmarks = []
    annotation = img.copy()
    for i in range(len(config.corner_imgs_file)):
        template = cv2.imread(config.corner_imgs_file[i])
        roi_w = int(config.corner_search_size[0] * img.shape[1])
        roi_h = int(config.corner_search_size[1] * img.shape[0])
        roi_x = int(img.shape[1] - roi_w) if i % 2 == 1 else 0
        roi_y = int(config.corner_search_pos[i] * img.shape[0])
        roi = img[roi_y : roi_y + roi_h, roi_x : roi_x + roi_w]
        res = cv2.matchTemplate(roi, template, cv2.TM_CCOEFF_NORMED)
        _, _, _, max_loc = cv2.minMaxLoc(res)
        cv2.rectangle(
            annotation,
            (roi_x, roi_y),
            (roi_x + roi_w, roi_y + roi_h),
            (0, 255, 0),
            2,
        )
        cv2.rectangle(
            annotation,
            (max_loc[0] + roi_x, max_loc[1] + roi_y),
            (
                max_loc[0] + roi_x + template.shape[1],
                max_loc[1] + roi_y + template.shape[0],
            ),
            (0, 0, 255),
            2,
        )
        landmarks.append(
            (
                max_loc[0] + roi_x + (template.shape[1] if i < 2 else 0),
                max_loc[1] + roi_y + template.shape[0] // 2,
            )
        )
        # print(max_loc)
        # disp = cv2.resize(img, None, fx=0.5, fy=0.5)
        # cv2.imshow("img", disp)
    return landmarks, annotation


def registration(img, target, src):
    # ホモグラフィ変換
    homograpy, _ = cv2.findHomography(np.array(src), np.array(target), cv2.RANSAC, 5.0)
    # print(homograpy)
    return cv2.warpPerspective(img, homograpy, (img.shape[1], img.shape[0]))


def resize_show(name, img):
    fx = config.display_size[0] / img.shape[1]
    fy = config.display_size[1] / img.shape[0]
    scale = min(fx, fy)
    scale = 1.0 if scale > 1.0 else scale
    disp = cv2.resize(img, None, fx=scale, fy=scale)
    cv2.imshow(name, disp)


def show_heatmap(name, dist):
    # Normalize the grayscale image
    vmin = dist.min()
    vmax = dist.max()
    normalized = 255.0 * (dist - vmin) / (vmax - vmin)
    normalized = cv2.GaussianBlur(
        normalized, ksize=(0, 0), sigmaX=config.sigma, sigmaY=config.sigma
    )
    normalized = normalized.astype("uint8")

    # Create a heatmap from the grayscale image
    heatmap = cv2.applyColorMap(normalized, cv2.COLORMAP_JET)
    resize_show(name, heatmap)
    # heatmap = cv2.resize(heatmap, None, fx=0.5, fy=0.5)


def main():
    # メイン関数

    # 実行時引数からデータ生成
    # args = parse_arguments()
    file_names = [f for file in config.input_files for f in glob.glob(file)]
    names = [Path(f).stem for f in file_names]
    imgs = [cv2.imread(file) for file in file_names]

    # 位置合わせランドマーク検出
    landmarks = []
    annotations = []
    for i in range(len(imgs)):
        lm, anno = detect_landmarks(imgs[i])
        landmarks.append(lm)
        annotations.append(anno)

    if len(imgs) > 1:
        # 位置合わせ
        for i in range(1, len(imgs)):
            imgs[i] = registration(imgs[i], landmarks[0], landmarks[i])
            annotations[i] = registration(annotations[i], landmarks[0], landmarks[i])

    # 平均画像の生成
    img_mean = np.mean(np.array(imgs), axis=0).astype("uint8")
    imgs.append(img_mean)
    names.append("mean")
    # resize_show("mean", img_mean)

    # ランドマークアノテーションの重ね合わせ
    anno_mean = np.mean(np.array(annotations), axis=0).astype("uint8")
    resize_show("mean annotation", anno_mean)

    # 照度分布算出
    roi_pos = [landmarks[0][0], landmarks[0][3]]
    for i in range(0, len(imgs)):
        # print(names[i])
        ratio_roi, anno = calc_illum_dist(imgs[i], roi_pos)
        show_heatmap(f"{names[i]} heatmap", ratio_roi)
        resize_show(f"{names[i]} anno_roi", anno)
        if names[i] == "mean":
            vmin, vmax, _, _ = cv2.minMaxLoc(ratio_roi)
            print(f"ratio min,max={vmin:.3f},{vmax:.3f}")
            np.savez_compressed(config.output_file, ratio_roi)
            print(f"'{config.output_file}' saved.")

    cv2.waitKey(0)
    cv2.destroyAllWindows()


if __name__ == "__main__":
    # 起動
    main()
