import csv
import sys

import cv2
import numpy as np


class LumenProfiler:
    def __init__(self):
        self.progress_count = 0
        pass

    def load_movie(self, filename):
        self.frames = []
        self.progress_count = 0
        cap = cv2.VideoCapture(filename)
        while True:
            # 画像読み込み
            ret, frame = cap.read()
            if not ret:
                break
            self.frames.append(frame)
            self.progress_count += 1
        self.frame_count = len(self.frames)
        cap.release()

    def profiling(self, area_ratio, sigma, step=1):
        self.results = []
        self.progress_count = 0
        for idx in range(0, len(self.frames), step):
            frame = self.frames[idx]
            mask = self.lumen_mask(frame, area_ratio, sigma)
            circle_level, contour = self.calc_circle_level(mask)
            result = {
                "idx": idx,
                "frame": frame,
                "mask": mask,
                "circle_level": circle_level,
                "contour": contour,
                "ratio": area_ratio,
                "sigma": sigma,
            }
            self.results.append(result)
            self.progress_count += 1

    def draw(self, rid):
        disp = self.results[rid]["frame"].copy()
        etime = self.results[rid]["idx"] * 1.0 / 30.0
        ratio = self.results[rid]["ratio"]
        cv2.putText(
            disp,
            "frame%4d  time %.3fs" % (self.results[rid]["idx"], etime),
            (10, 25),
            cv2.FONT_HERSHEY_TRIPLEX,
            0.7,
            (255, 0, 0),
            1,
        )
        cv2.putText(
            disp,
            "area ratio=%.1f %%" % (ratio * 100),
            (10, 50),
            cv2.FONT_HERSHEY_TRIPLEX,
            0.7,
            (255, 0, 0),
            1,
        )
        cv2.putText(
            disp,
            "circle level=%.1f %%" % (self.results[rid]["circle_level"] * 100),
            (10, 75),
            cv2.FONT_HERSHEY_TRIPLEX,
            0.7,
            (255, 0, 0),
            1,
        )
        # disp[self.results[rid]["mask"] > 0] = (0, 0, 255)
        cv2.drawContours(disp, [self.results[rid]["contour"]], 0, (0, 255, 255), 3)

        return disp

    def lumen_mask(self, frame, area_ratio, sigma=5.0, min_area=200):
        # 輝度画像生成
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        val_img = hsv[:, :, 2]

        # ヒストグラムから閾値決定
        hist = cv2.calcHist([val_img], [0], None, [256], [0, 256])
        thres = -1
        sum = 0
        for i in range(0, 256):
            sum += hist[i]
            if thres < 0 and sum > (val_img.size * area_ratio):
                thres = i
                break
        # 気道のマスク生成
        val_img = cv2.GaussianBlur(val_img, (25, 25), sigma)
        mask = cv2.threshold(val_img, thres, 255, cv2.THRESH_BINARY_INV)[1]

        # 連結部の解析
        retval, labels, stats, centroids = cv2.connectedComponentsWithStats(mask)
        # 一定面積以上で中心に近い部分を選択
        target_label = 0
        min_dist = -1
        center = [frame.shape[1] / 2, frame.shape[0] / 2]
        if retval > 1:
            for i in range(1, retval):
                dist = np.linalg.norm(centroids[i] - center, 2)
                if stats[i, cv2.CC_STAT_AREA] > min_area and (
                    dist < min_dist or min_dist < 0
                ):
                    min_dist = dist
                    target_label = i
        # 選択部分のマスク生成
        selected_mask = np.zeros(mask.shape, np.uint8)
        selected_mask[labels == target_label] = 255

        return selected_mask

    def calc_circle_level(self, mask):
        contours, hierarchy = cv2.findContours(
            mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        circle_level = 0
        contour = None
        if len(contours) > 0:
            contour = contours[0]
            area = cv2.contourArea(contour)
            perimeter = cv2.arcLength(contour, True)
            if perimeter > 0:
                circle_level = 4.0 * np.pi * area / (perimeter * perimeter)

        return circle_level, contour

    def csv_output(self, filename):
        with open(filename, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(
                ["frame", "time(s)", "area ratio(%)", "threshold", "circle level(%)"]
            )
            for result in self.results:
                etime = result["idx"] * 1.0 / 30.0
                row_data = [
                    result["idx"],
                    etime,
                    result["ratio"],
                    0,
                    result["circle_level"] * 100,
                ]
                writer.writerow(row_data)

    def bgr2rgb(self, img, size=None):
        if size is not None:
            img = cv2.resize(img, size)
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


# def func():
#             cv2.putText(
#                 hsv,
#                 "threshold=%d" % (thres),
#                 (10, 30),
#                 cv2.FONT_HERSHEY_TRIPLEX,
#                 0.7,
#                 (0, 0, 0),
#                 1,
#             )


if __name__ == "__main__":
    args = sys.argv
    movie_file = "bs_sample_20250212.mp4"
    if len(args) > 1:
        movie_file = args[1]

    # 解析
    lp = LumenProfiler()
    lp.load_movie(movie_file)
    lp.profiling(0.08, 2)
    lp.csv_output("output/analysis.csv")

    # 表示
    rid = 0
    while True:
        disp = lp.draw(rid)
        cv2.imshow("frame", disp)
        if cv2.waitKey(30) & 0xFF == 27:
            break
        rid += 1
        if rid >= len(lp.results):
            rid = 0
    cv2.destroyAllWindows()
