import numpy as np
import os.path as osp
import json
import cv2
from DicomProcessor import DicomProcessor

GUI_LOAD_NEXT = 1
GUI_LOAD_PREV = 0


class GUIController:

    def __init__(self, args, setting_json):
        self.args = args
        self.setting_json = setting_json
        self.Is_continue = True
        self.Is_zoomed = False
        self.dicom_processor = DicomProcessor(args.dicom_dir, args)
        self.dicom_size = len(self.dicom_processor.dicom_file_list)
        self.CT_img_for_show = self.dicom_processor.CT_for_imshow
        self.cur_src_CT_img = self.CT_img_for_show
        self.src_CT_size = self.cur_src_CT_img.shape
        self.zoomed_CT_size = self.preprocess_zoomed_CT_size()
        self.zoom_point = [0, 0]

        # menuボタン系
        self.target_color = (24, 185, 237)
        self.non_target_color = (220, 220, 220)

        self.window_button_state = {
            "pushed_wc": True,
            "pushed_ww": False
        }
        self.wc_width = None
        self.ww_width = None
        self.state_board = None
        self.button_height, self.button_width = 0, 0
        self.show_wc_str = "WC: {}".format(round(self.dicom_processor.dcm_wc))
        self.show_ww_str = "WW: {}".format(round(self.dicom_processor.dcm_ww))
        self.wc_str_color = self.target_color
        self.ww_str_color = self.non_target_color
        self.button_img = self.create_button_img()

        # まだ画像のピクセル幅がx, yで同じときのみに対応している．
        self.set_init_window()

        #出力用変数
        self.dicom_index = 0
        self.ijk_eso_centers = [None] * self.dicom_size
        self.start_point = None

        self.GUI_imshow()

    def set_init_window(self):
        cv2.namedWindow('menu')
        cv2.namedWindow('CT image')

    def preprocess_zoomed_CT_size(self):
        zoomed_CT_size = [int(x / self.args.magnification_ratio) for x in self.CT_img_for_show.shape][:2]
        if zoomed_CT_size[0] % 2 == 0:
            zoomed_CT_size[0] -= 1
        if zoomed_CT_size[1] % 2 == 0:
            zoomed_CT_size[1] -= 1
        return zoomed_CT_size

    def create_button_img(self):
        buttons = ['back', 'next', 'export']
        button_size = (self.args.resized_button_width, self.args.resized_button_height)
        for i, button_name in enumerate(buttons):
            if i == 0:
                button_img = cv2.imread('./button_img/' + button_name + '.png')
                button_img = cv2.resize(button_img, button_size)
                self.button_height, self.button_width, _ = button_img.shape
            else:
                tmp = cv2.imread('./button_img/' + button_name + '.png')
                tmp = cv2.resize(tmp, button_size)
                button_img = cv2.hconcat([button_img, tmp])

        # 2021/1/12追加のウィンドウ処理処理用のボタン
        wc_pushed = cv2.imread("./button_img/wc_pushed.png")
        aspect_rate = self.args.resized_button_height / wc_pushed.shape[0]
        wc_pushed = cv2.resize(wc_pushed, (round(aspect_rate * wc_pushed.shape[1]), round(aspect_rate * wc_pushed.shape[0])))

        ww = cv2.imread("./button_img/ww.png")
        aspect_rate = self.args.resized_button_height / ww.shape[0]
        ww = cv2.resize(ww, (round(aspect_rate * ww.shape[1]), round(aspect_rate * ww.shape[0])))

        width = button_img.shape[1] - wc_pushed.shape[1] - ww.shape[1]
        self.state_board = np.zeros((ww.shape[0], width, 3)).astype(np.uint8)
        cv2.putText(self.state_board, self.show_wc_str, (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, self.wc_str_color,
                    1, cv2.LINE_AA)
        cv2.putText(self.state_board, self.show_ww_str, (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, self.ww_str_color,
                    1, cv2.LINE_AA)

        window_buttons = np.hstack([wc_pushed, ww, self.state_board])

        button_img = np.vstack([button_img, window_buttons])
        self.wc_width = wc_pushed.shape[1]
        self.ww_width = ww.shape[1]

        return button_img

    def change_cur_CT(self, mode):
        if mode is GUI_LOAD_NEXT:
            self.dicom_index += 1
        if mode is GUI_LOAD_PREV:
            self.dicom_index -= 1
        print('index : {}'.format(self.dicom_index))
        self.CT_img_for_show = self.dicom_processor.get_CT_by_index(self.dicom_index)
        self.cur_src_CT_img = self.CT_img_for_show
        self.Is_zoomed = False
        self.GUI_imshow()

    def GUI_imshow(self):
        cv2.imshow("menu", self.button_img)
        cv2.imshow("CT image", self.CT_img_for_show)
        cv2.waitKey(1)

    def encode_array2str(self, array):
        encoded = ""
        for item in array:
            encoded += str(item) + ","
        encoded = encoded[:-1]
        return encoded

    def menu_callbacks(self, event, x, y, flags, param):
        # click back
        if event == cv2.EVENT_LBUTTONUP and (0 <= x and x <= self.button_width) and y <= self.button_height and self.dicom_index != 0:
            self.change_cur_CT(GUI_LOAD_PREV)

        # click next
        elif event == cv2.EVENT_LBUTTONUP and (self.button_width < x and x < (2 * self.button_width)) and y <= self.button_height and self.dicom_index != (self.dicom_size - 1):
            self.change_cur_CT(GUI_LOAD_NEXT)

        # click export
        # TODO startpointがNoneの時に警告を出す機能を入れる
        elif event == cv2.EVENT_LBUTTONUP and ((2 * self.button_width) < x and x < (3 * self.button_width)) and y <= self.button_height:
            print("click export")
            if self.start_point is not None:
                # 1.中心からの距離からを計算
                self.Is_continue = False
                self.setting_json["start_point_vec"] = self.encode_array2str(self.dicom_processor.ijk2relative_vec(self.start_point))

                self.setting_json["center_points_vec"] = []
                pad_ijk_centers = [x for x in self.dicom_processor.linear_pad_list(self.ijk_eso_centers) if x is not None]
                for points in pad_ijk_centers:
                    self.setting_json["center_points_vec"].append(self.encode_array2str(self.dicom_processor.ijk2relative_vec(points)))
                with open(osp.join(self.args.output_dir, "setting.json"), "w") as f:
                    json.dump(self.setting_json, f, indent=4)
                print('exit')
            else:
                print("start pointが選ばれていません")

        # click wc
        elif event == cv2.EVENT_LBUTTONUP and x < self.wc_width and self.button_height < y and not self.window_button_state["pushed_wc"]:
            # この辺は関数か出来る．
            replaced_button = cv2.resize(cv2.imread("./button_img/wc_pushed.png"), (self.wc_width, self.button_height))
            self.button_img[self.button_height:, :self.wc_width, :] = replaced_button
            replaced_button = cv2.resize(cv2.imread("./button_img/ww.png"), (self.ww_width, self.button_height))
            self.button_img[self.button_height:, self.wc_width:(self.wc_width + self.ww_width), :] = replaced_button
            self.window_button_state["pushed_wc"], self.wc_str_color = True, self.target_color
            self.window_button_state["pushed_ww"], self.ww_str_color = False, self.non_target_color
            self.update_state_board()

        # click ww
        elif event == cv2.EVENT_LBUTTONUP and self.wc_width < self.wc_width + self.ww_width and self.button_height < y and not self.window_button_state["pushed_ww"]:
            # この辺は関数か出来る．
            replaced_button = cv2.resize(cv2.imread("./button_img/wc.png"), (self.wc_width, self.button_height))
            self.button_img[self.button_height:, :self.wc_width, :] = replaced_button
            replaced_button = cv2.resize(cv2.imread("./button_img/ww_pushed.png"), (self.ww_width, self.button_height))
            self.button_img[self.button_height:, self.wc_width:(self.wc_width + self.ww_width), :] = replaced_button
            self.window_button_state["pushed_wc"], self.wc_str_color = False, self.non_target_color
            self.window_button_state["pushed_ww"], self.ww_str_color = True, self.target_color
            self.update_state_board()

    def CT_image_callbacks(self, event, x, y, flags, param):
        if event is cv2.EVENT_LBUTTONDBLCLK and self.Is_zoomed is True:
            zoom_i_ratio, zoom_j_ratio = ((self.src_CT_size[0] / self.zoomed_CT_size[0]),
                                          (self.src_CT_size[1] / self.zoomed_CT_size[1]))
            cv2.circle(self.CT_img_for_show, (x, y), 5, (0, 255, 0), -1, cv2.LINE_AA)
            non_zoomed_i, non_zoomed_j = self.calc_non_zoomed_ij(x, y, zoom_i_ratio, zoom_j_ratio)
            print('start point i:{} j:{} k:{}'.format(non_zoomed_i, non_zoomed_j,
                                                  self.dicom_processor.calc_k_on_ijk_coordinates(self.dicom_index)))
            self.start_point = np.array([non_zoomed_i, non_zoomed_j,
                                                               self.dicom_processor.calc_k_on_ijk_coordinates(
                                                                   self.dicom_index)])

        elif event is cv2.EVENT_LBUTTONDBLCLK and self.Is_zoomed is False:
            cv2.circle(self.CT_img_for_show, (x, y), 5, (0, 255, 0), -1, lineType=cv2.LINE_AA)
            print('start point i:{} j:{} k:{}'.format(x, y, self.dicom_processor.calc_k_on_ijk_coordinates(self.dicom_index)))
            self.start_point = np.array([x, y, self.dicom_processor.calc_k_on_ijk_coordinates(self.dicom_index)])

        elif event is cv2.EVENT_LBUTTONDOWN and self.Is_zoomed is True:
            zoom_i_ratio, zoom_j_ratio = ((self.src_CT_size[0] / self.zoomed_CT_size[0]),
                                          (self.src_CT_size[1] / self.zoomed_CT_size[1]))
            cv2.circle(self.CT_img_for_show, (x, y), 5, (0, 0, 255), -1, cv2.LINE_AA)
            non_zoomed_i, non_zoomed_j = self.calc_non_zoomed_ij(x, y, zoom_i_ratio, zoom_j_ratio)
            print('clicked i:{} j:{} k:{}'.format(non_zoomed_i, non_zoomed_j,
                                                  self.dicom_processor.calc_k_on_ijk_coordinates(self.dicom_index)))
            self.ijk_eso_centers[self.dicom_index] = np.array([non_zoomed_i, non_zoomed_j,
                                                      self.dicom_processor.calc_k_on_ijk_coordinates(self.dicom_index)])
        elif event is cv2.EVENT_LBUTTONDOWN and self.Is_zoomed is False:
            cv2.circle(self.CT_img_for_show, (x, y), 5, (0, 0, 255), -1, lineType=cv2.LINE_AA)
            print('clicked i:{} j:{} k:{}'.format(x, y, self.dicom_processor.calc_k_on_ijk_coordinates(self.dicom_index)))
            self.ijk_eso_centers[self.dicom_index] = np.array([x, y, self.dicom_processor.calc_k_on_ijk_coordinates(self.dicom_index)])
        elif event is cv2.EVENT_RBUTTONDOWN and self.Is_zoomed is False:
            self.zoom_function(x, y)
        elif event is cv2.EVENT_RBUTTONDOWN and self.Is_zoomed is True:
            self.CT_img_for_show = self.cur_src_CT_img
            self.Is_zoomed = False


    def calc_non_zoomed_ij(self, i, j, zoom_i_ratio, zoom_j_ratio):
        non_zoomed_i, non_zoomed_j = (self.zoom_point[0] + (i / zoom_i_ratio),
                                      self.zoom_point[1] + (j / zoom_j_ratio))
        return non_zoomed_i, non_zoomed_j

    def zoom_function(self, x, y):
        self.zoom_point = [int(x - ((self.zoomed_CT_size[0] - 1) / 2)), int(y - ((self.zoomed_CT_size[1] - 1) / 2))]
        if self.zoom_point[0] < 0:
            self.zoom_point[0] = 0
        if self.src_CT_size[1] <= 0:
            self.zoom_point[1] = 0
        tmp = self.cur_src_CT_img[self.zoom_point[1]:(self.zoom_point[1] + self.zoomed_CT_size[1]),
                                  self.zoom_point[0]:(self.zoom_point[0] + self.zoomed_CT_size[0]), :]
        self.CT_img_for_show = cv2.resize(tmp, self.src_CT_size[:2], interpolation=cv2.INTER_LANCZOS4)
        self.Is_zoomed = True

    def key_function(self):
        key = cv2.waitKey(50)
        if key is ord('d') and self.dicom_index != (self.dicom_size - 1):
            self.change_cur_CT(GUI_LOAD_NEXT)
        elif key is ord('a') and self.dicom_index != 0:
            self.change_cur_CT(GUI_LOAD_PREV)

        # 2021/1/12追加の
        elif key is ord('w'):
            if self.window_button_state["pushed_wc"]:
                self.dicom_processor.dcm_wc += 1
                print("changed => wc: ", self.dicom_processor.dcm_wc)
                self.CT_img_for_show = self.dicom_processor.get_CT_by_index(self.dicom_index)
                self.update_state_board()

            elif self.window_button_state["pushed_ww"]:
                self.dicom_processor.dcm_ww += 1
                print("changed => ww: ", self.dicom_processor.dcm_ww)
                self.CT_img_for_show = self.dicom_processor.get_CT_by_index(self.dicom_index)
                self.update_state_board()

        elif key is ord('x'):
            if self.window_button_state["pushed_wc"]:
                self.dicom_processor.dcm_wc -= 1
                print("changed => wc: ", self.dicom_processor.dcm_wc)
                self.CT_img_for_show = self.dicom_processor.get_CT_by_index(self.dicom_index)
                self.update_state_board()

            elif self.window_button_state["pushed_ww"]:
                self.dicom_processor.dcm_ww -= 1
                print("changed => ww: ", self.dicom_processor.dcm_ww)
                self.CT_img_for_show = self.dicom_processor.get_CT_by_index(self.dicom_index)
                self.update_state_board()

    def update_state_board(self):
        self.state_board = np.zeros_like(self.state_board).astype(np.uint8)
        self.show_wc_str = "WC: {}".format(round(self.dicom_processor.dcm_wc))
        self.show_ww_str = "WW: {}".format(round(self.dicom_processor.dcm_ww))
        cv2.putText(self.state_board, self.show_wc_str, (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 1,
                    self.wc_str_color,
                    1, cv2.LINE_AA)
        cv2.putText(self.state_board, self.show_ww_str, (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 1,
                    self.ww_str_color,
                    1, cv2.LINE_AA)
        self.button_img[self.button_height:, (self.wc_width + self.ww_width):, :] = self.state_board

    def run(self):
        cv2.setMouseCallback('menu', self.menu_callbacks)
        cv2.setMouseCallback('CT image', self.CT_image_callbacks)

        while True:
            self.key_function()
            self.GUI_imshow()
            if not self.Is_continue:
                break
