import argparse
import glob
from pathlib import Path

import cv2
import numpy as np

# # Load the color image
# names = ["SmTIAS001", "SmTIAS002", "SmTIAS003"]
# imgs = np.array([cv2.imread(f"{name}_white01.jpg") for name in names])
# names.append("mean")

# for i in range(len(names)):
#     if i < len(names) - 1:
#         img = imgs[i]
#     else:
#         img = np.mean(imgs, axis=0).astype("uint8")
#     # print("img shape", img.shape)

sigma_default = 10.0


def calc_illum_dist(args, img, imgname):
    # Convert the image to grayscale (luminance)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # ROI生成
    roi_x, roi_y, roi_w, roi_h = 0, 0, gray.shape[1], gray.shape[0]
    if args.roi:
        roi_x = args.roi[0]
        roi_y = args.roi[1]
        roi_w = args.roi[2]
        roi_h = args.roi[3]
    gray_roi = gray[roi_y : roi_y + roi_h, roi_x : roi_x + roi_w]
    vmin = gray_roi.min()
    vmax = gray_roi.max()
    vmean = gray_roi.mean()
    print(f"{imgname} value min,mean,max", vmin, vmean, vmax)

    # 相対照度分布の算出
    ratio_roi = gray_roi / vmean
    print(
        f"{imgname} ratio min,mean,max",
        ratio_roi.min(),
        ratio_roi.mean(),
        ratio_roi.max(),
    )
    ratio = np.full(gray.shape, -1.0, dtype=np.float32)
    ratio[roi_y : roi_y + roi_h, roi_x : roi_x + roi_w] = ratio_roi

    # Normalize the grayscale image
    normalized = 255.0 * (gray_roi - vmin) / (vmax - vmin)
    normalized = cv2.GaussianBlur(
        normalized, ksize=(0, 0), sigmaX=args.sigma, sigmaY=args.sigma
    )
    normalized = normalized.astype("uint8")

    # if i == len(names) - 1:
    #
    #
    #
    #     # print(ratio[:10, :10])
    #     ratio_whole = np.full(gray_full.shape, -1.0, dtype=np.float32)
    #     ratio_whole[crop_y : crop_y + crop_wh, crop_x : crop_x + crop_wh] = ratio
    #     # print("ratio_whole shape", ratio_whole.shape)
    #     # print("ratio_whole min", ratio_whole.min())
    #     # print("ratio_whole max", ratio_whole.max())
    #     np.savez_compressed("illum_dist.npz", ratio_whole)

    # Create a heatmap from the grayscale image
    heatmap = cv2.applyColorMap(normalized, cv2.COLORMAP_JET)
    heatmap = cv2.resize(heatmap, None, fx=0.5, fy=0.5)

    cv2.imshow(f"{imgname} heatmap", heatmap)


def parse_arguments():
    parser = argparse.ArgumentParser(description="Calculate illumination distribution")
    parser.add_argument("input_files", nargs="+", help="Input image files")
    parser.add_argument(
        "--roi", type=int, nargs=4, help="region of interest (x, y, w, h)"
    )
    parser.add_argument(
        "--sigma",
        type=float,
        # nargs=1,
        help=f"sigma of gaussian blur (default: {sigma_default})",
        default=sigma_default,
    )
    return parser.parse_args()


def main():
    # 実行時引数からデータ生成
    args = parse_arguments()
    file_names = [f for file in args.input_files for f in glob.glob(file)]
    names = [Path(f).stem for f in file_names]
    imgs = [cv2.imread(file) for file in file_names]

    # 平均画像の生成
    if len(imgs) > 1:
        img_mean = np.mean(np.array(imgs), axis=0).astype("uint8")
        imgs.append(img_mean)
        names.append("mean")

    for i in range(len(imgs)):
        calc_illum_dist(args, imgs[i], names[i])

    cv2.waitKey(0)
    cv2.destroyAllWindows()


if __name__ == "__main__":
    main()
