import numpy as np
from glob import glob
import os
import cv2

class Movie2imgConverter:
    def __init__(self, args):
        self.args = args
        self.dist_coeffs = np.load(os.path.join(self.args.param_dir, 'dist_coeffs.npy'))
        self.intrinsics_scaled = np.load(os.path.join(self.args.param_dir, 'intrinsics_scaled.npy'))
        self.resized_size = (self.args.resized_width, self.args.resized_height)

    def convert_endo_movie2img(self):
        src_dir = self.args.movie_dir
        if os.path.exists(src_dir) is False:
            print("動画のディレクトリパスがミスってます")
            raise

        movie_name_list = glob(os.path.join(src_dir, "*" + self.args.endo_movie_extend))

        if len(movie_name_list) == 0:
            print("対象のディレクトリに動画が見つかりません")
            raise

        sequence_num = 0
        file_name_num = 0
        for movie_name in movie_name_list:
            cap = cv2.VideoCapture(movie_name)
            if not cap.isOpened():
                print("ビデオキャプチャエラー")
                raise

            print("start processing {}".format(os.path.basename(movie_name)))
            cur_frame = 0
            sum_frame = cap.get(cv2.CAP_PROP_FRAME_COUNT)

            out_subdir = os.path.join(self.args.out_dir, '{}seq'.format(sequence_num))
            sequence_num += 1
            os.makedirs(out_subdir, exist_ok=True)

            while True:
                ret, frame = cap.read()
                if not ret:
                    print('{} / {}'.format(sum_frame, int(sum_frame)))
                    break

                if (cur_frame % self.args.pass_num) == 0:
                    frame = self.trim_endo_movie(frame)
                    # cv2.imshow("a", frame)
                    # cv2.waitKey()
                    frame = cv2.undistort(frame, self.intrinsics_scaled, self.dist_coeffs)
                    # cv2.imshow("a", frame)
                    # cv2.waitKey()
                    frame = self.inpaint_endo_img(frame)

                    cv2.imwrite(os.path.join(out_subdir, '{:08}.png'.format(file_name_num)), frame)
                    file_name_num += 1

                if (cur_frame % 300) == 0:
                    print('{} / {}'.format(cur_frame, int(sum_frame)))

                cur_frame += 1

    def trim_endo_movie(self, frame):
        frame = frame[32:989, 323:1599, :]
        frame = cv2.resize(frame, self.resized_size, interpolation=cv2.INTER_LANCZOS4)
        return frame

    def inpaint_endo_img(self, frame):
        YUV = cv2.cvtColor(frame, cv2.COLOR_BGR2YUV)
        Y = YUV[:, :, 0]
        _, mask = cv2.threshold(Y, self.args.thresh_num, 255, cv2.THRESH_BINARY)
        kernel = np.ones((9, 9), np.uint8)
        dilation = cv2.dilate(mask, kernel)
        out = cv2.inpaint(frame, dilation, 30, cv2.INPAINT_NS)
        return out
