import numpy as np
import cv2
import pydicom

xy_s = []
center_xy = []
dicom_num = 16

def mouse_callbacks(event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONUP:
        xy = (x, y)
        print("x: {}  y: {}".format(x, y))
        xy_s.append(xy)

    if event == cv2.EVENT_RBUTTONUP:
        print("RBUTTONUP")
        center_xy.append((x, y))


def write_circle(xy_s, mat):
    for xy in xy_s:
        cv2.circle(mat, xy, 4, (0, 0, 200), -1)
    return mat

a = "{:03}".format(dicom_num)
ds = pydicom.dcmread(r'C:\Users\Planck\Desktop\3Dsyokudo\Image{:03}'.format(dicom_num))
CT_row = ds.pixel_array
CT_row = np.where(CT_row == 0, 0, CT_row - np.min(CT_row[CT_row != 0]))


tmp = np.array(255 * (CT_row / np.max(CT_row)), dtype=np.uint8)
CT_img = cv2.cvtColor(tmp, cv2.COLOR_GRAY2BGR)
show_CT = CT_img.copy()
cv2.imshow("a", show_CT)
cv2.setMouseCallback('a', mouse_callbacks)

while True:
    cv2.imshow("a", show_CT)
    key = cv2.waitKey(1)
    if key == ord('a'):
        write_circle(xy_s, show_CT)
        spacing_dx, spacing_dy = ds.PixelSpacing
        radius15_int = np.int(15.0 / spacing_dx)
        dicom_row_num = 512
        dicom_col_num = 512
        half_col_num = np.int(dicom_col_num / 2)
        half_row_num = np.int(dicom_row_num / 2)
        fixed_points = np.array([[half_col_num, dicom_row_num], [dicom_col_num, dicom_row_num],
                                 [dicom_col_num, half_row_num], [dicom_col_num, 0],
                                 [half_col_num, 0], [0, 0],
                                 [0, half_row_num], [0, dicom_row_num]], dtype=np.float32)
        center = center_xy[0]
        after_points = [(center[0], center[1] + radius15_int), (center[0] + radius15_int, center[1]),
                        (center[0], center[1] - radius15_int), (center[0] - radius15_int, center[1])]
        for p in after_points:
            cv2.circle(show_CT, p, 4, (0, 200, 0), thickness=-1)
        cv2.imwrite("before{}.png".format(dicom_num), show_CT)
        xy_s = np.array(xy_s, dtype=np.float32)
        after_points = np.array(after_points, dtype=np.float32)
        src = np.vstack((xy_s, fixed_points)).reshape((1, -1, 2))
        target = np.vstack((after_points, fixed_points)).reshape((1, -1, 2))

        matches = list()
        for i in range(len(src[0, :, :])):
            matches.append(cv2.DMatch(i, i, 0))

        tps = cv2.createThinPlateSplineShapeTransformer()
        tps.estimateTransformation(target, src, matches)
        out = tps.warpImage(CT_img)
        cv2.imshow("dist", out)
        cv2.imwrite("after{}.png".format(dicom_num), out)


    if key == ord('b'):
        center = center_xy[0]
        spacing_dx, spacing_dy = ds.PixelSpacing
        src_radius = np.int(5 / spacing_dx)
        xy_s = [(center[0], center[1] + src_radius), (center[0] + src_radius, center[1]),
                (center[0], center[1] - src_radius), (center[0] - src_radius, center[1])]
        write_circle(xy_s, show_CT)
        radius15_int = np.int(15.0 / spacing_dx)
        dicom_row_num = 512
        dicom_col_num = 512
        half_col_num = np.int(dicom_col_num / 2)
        half_row_num = np.int(dicom_row_num / 2)
        fixed_points = np.array([[half_col_num, dicom_row_num], [dicom_col_num, dicom_row_num],
                                 [dicom_col_num, half_row_num], [dicom_col_num, 0],
                                 [half_col_num, 0], [0, 0],
                                 [0, half_row_num], [0, dicom_row_num]], dtype=np.float32)

        after_points = [(center[0], center[1] + radius15_int), (center[0] + radius15_int, center[1]),
                        (center[0], center[1] - radius15_int), (center[0] - radius15_int, center[1])]
        for p in after_points:
            cv2.circle(show_CT, p, 4, (0, 200, 0), thickness=-1)
        cv2.imwrite("before{}.png".format(dicom_num), show_CT)
        xy_s = np.array(xy_s, dtype=np.float32)
        after_points = np.array(after_points, dtype=np.float32)
        src = np.vstack((xy_s, fixed_points)).reshape((1, -1, 2))
        target = np.vstack((after_points, fixed_points)).reshape((1, -1, 2))

        matches = list()
        for i in range(len(src[0, :, :])):
            matches.append(cv2.DMatch(i, i, 0))

        tps = cv2.createThinPlateSplineShapeTransformer()
        tps.estimateTransformation(target, src, matches)
        out = tps.warpImage(CT_img)
        cv2.imshow("dist", out)
        cv2.imwrite("after{}.png".format(dicom_num), out)
