import cv2
import numpy as np


# Mouse event
def mouseEvent(e, x, y, flags, param):
    global drag, drag_start, drag_end
    if e == cv2.EVENT_LBUTTONDOWN:
        drag = 1
        drag_start = (x, y)
    if e == cv2.EVENT_MOUSEMOVE and drag == 1:
        drag_end = (x, y)
    if e == cv2.EVENT_LBUTTONUP:
        drag = 0
        drag_end = (x, y)
        make_roi_mask()


# Threshold slider for cauterized area detection
def on_thres_bar(val):
    detect_cauterized_area()


def make_roi_mask():
    global roi_mask
    roi_mask = np.zeros(img.shape[:2], dtype=np.uint8)
    cv2.rectangle(roi_mask, drag_start, drag_end, 1, cv2.FILLED)


# Cauterized area detection
def detect_cauterized_area():
    global cauterized_mask
    val = cv2.getTrackbarPos("threshold", "image")
    rgb_range_min = np.array([0, 0, 0])
    rgb_range_max = np.array([val, val, val])
    cauterized_mask = cv2.inRange(img, rgb_range_min, rgb_range_max)


def load_image(image_no):
    global img, drag, drag_start, drag_end
    filename = f"../data/{image_no}.tiff"
    img = cv2.imread(filename)
    height, width, channels = img.shape[:3]
    proc_w = 1280
    scale = proc_w / width
    img = cv2.resize(img, None, fx=scale, fy=scale)
    # print(img.shape)
    drag = 0
    drag_start = (200, 100)
    drag_end = (img.shape[1] - 200, img.shape[0] - 150)
    make_roi_mask()
    detect_cauterized_area()


# Main routine
if __name__ == "__main__":
    # Initialize
    cv2.namedWindow("image", cv2.WINDOW_AUTOSIZE)
    cv2.createTrackbar("threshold", "image", 50, 255, on_thres_bar)
    cv2.setMouseCallback("image", mouseEvent)
    image_no = 55
    load_image(image_no)

    # Display loop
    flush_state = 1
    flush_count = 0
    while 1:
        disp = img.copy()
        masks = cv2.bitwise_and(cauterized_mask, roi_mask)
        area = cv2.countNonZero(masks)
        thres = cv2.getTrackbarPos("threshold", "image")
        if flush_state == 1:
            disp[masks > 0] = (255, 0, 0)
        cv2.rectangle(disp, drag_start, drag_end, (0, 255, 255), 1)
        cv2.putText(
            disp,
            f"{image_no}.tiff  Cauterized area={area}[px]",
            (10, 30),
            cv2.FONT_HERSHEY_COMPLEX,
            1.0,
            (100, 255, 255),
            2,
        )
        cv2.putText(
            disp,
            f"ROI{drag_start}-{drag_end}  Thres:{thres}",
            (800, 30),
            cv2.FONT_HERSHEY_COMPLEX,
            0.6,
            (100, 255, 255),
            1,
        )
        cv2.imshow("image", disp)
        key = cv2.waitKey(30)
        if key == 27:
            cv2.destroyAllWindows()
            break
        if key == ord("a") and image_no < 79:
            image_no += 1
            if image_no == 74:
                image_no = 75
            load_image(image_no)
        if key == ord("z") and image_no > 51:
            image_no -= 1
            if image_no == 74:
                image_no = 73
            load_image(image_no)
        if key == ord("s"):
            cv2.imwrite(f"../data/{image_no}_detect.jpg", disp)

        flush_count += 1
        if flush_count > 15:
            flush_count = 0
            flush_state = 1 - flush_state
