
from promise2012.Vnet.CNN2 import Vnet3dModule_64 as CNN3D_2
from promise2012.Vnet.CNN3 import Vnet3dModule_64 as CNN3D_3
from promise2012.Vnet.CNN1 import Vnet3dModule_64 as CNN3D_1
# from promise2012.Vnet.util import convertMetaModelToPbModel
import numpy as np
import pandas as pd
import cv2
import os
import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
import SimpleITK as sitk
from nipype.interfaces.ants import N4BiasFieldCorrection
import dicom
import gc


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
KTF.set_session(sess)  # 设置session


def train():
    '''
    Preprocessing for dataset
    '''
    # Read  data set (Train data from CSV file)
    csvmaskdata = pd.read_csv('promise12Vnet3dMask_train_group1.csv')
    csvimagedata = pd.read_csv('promise12Vnet3dImage_train_group1.csv')
    maskdata = csvmaskdata.iloc[:, :].values
    imagedata = csvimagedata.iloc[:, :].values

    # 验证数据
    csvmaskdata_val = pd.read_csv('promise12Vnet3dMask_val.csv')
    csvimagedata_val = pd.read_csv('promise12Vnet3dImage_val.csv')
    maskdata_val = csvmaskdata_val.iloc[:, :].values
    imagedata_val = csvimagedata_val.iloc[:, :].values

    # shuffle imagedata and maskdata together
    perm = np.arange(len(csvimagedata))
    np.random.shuffle(perm)
    imagedata = imagedata[perm]
    maskdata = maskdata[perm]

    perm_val = np.arange(len(csvimagedata_val))
    np.random.shuffle(perm_val)
    imagedata_val = imagedata_val[perm_val]
    maskdata_val = maskdata_val[perm_val]

    CNN3d = CNN3D_1(96, 96, 16, channels=1, costname="dice coefficient")
    CNN3d.train(imagedata, maskdata, imagedata_val, maskdata_val,
                 "model\\cross_paper3/CNN1/1/CNN1.pd", "log\\", 0.001, 0.5, 100000, 1)


def predict():
    CNN3d = CNN3D_1(96, 96, 16, inference=True,
                                   model_path="model\\cross_paper3/CNN1/1"
                                              "/CNN1.pd")
    # CNN3d = CNN3D_2(96, 96, 16, inference=True,
    #                                      model_path="model\\cross_paper3/CNN2/1"
    #                                                 "/CNN2.pd")
    # CNN3d = CNN3D_3(96, 96, 16, inference=True,
    #                                model_path="model\\cross_paper3/CNN3/1"
    #                                           "/CNN3.pd")
    src_range = [50, 350]  # normalization setting
    dst_range = [0, 255]
    # arr = [23, 45]
    # arr = [22, 37, 48]
    # arr = [6, 8, 35, 43]
    # arr = [12, 20, 36, 38, 51]
    # arr = [19, 33, 40, 47, 49]
    # arr = [15, 17, 21, 24, 27, 29, 31, 32, 39, 41, 42, 44, 46, 53]
    arr = [23, 45, 46, 47, 48, 22, 37, 6, 8, 35, 43, 12, 20, 36, 38, 51, 19, 33, 40, 49, 15, 17, 21, 24,
           27, 29, 31, 32, 39, 41, 42, 44, 53]
    for i_ in range(0, len(arr)):
        if i_ == 0:
            i = arr[i_]
            input = np.fromfile('D:/Ktest/' + str(i) + '/delay_isorotopic' + str(i) + '.raw', dtype=np.int16)  # DICOM image
            bone = np.fromfile('D:/Ktest/' + str(i) + '/0228OLDBoneIsoNew' + str(i) + '.raw', dtype=np.uint8)  # bone
            input_w = np.copy(input)
            slope = (float(dst_range[1]) - float(dst_range[0])) \
                    / (float(src_range[1]) - float(src_range[0]))
            print(input_w.shape, 'case------', str(i), '---------start')

            bone[bone == 0] = 1
            bone[bone == 255] = 0
            input_w = input_w * bone
            input_w = np.reshape(input_w, (bone.shape[0]//262144, 512, 512))
            output = np.zeros(shape=(input_w.shape[0], input_w.shape[1], input_w.shape[2]), dtype=np.uint8)
            for a in range(output.shape[0]//16):
                for b in range(512 // 96):
                    for c in range(512 // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16): ((a + 1) * 16), (b * 96): ((b + 1) * 96),
                                                       (c * 96): ((c + 1) * 96)]
                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'NNN-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + index, (b * 96): ((b + 1) * 96),
                                       (c * 96): ((c + 1) * 96)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs

                for b in range(512 // 96):
                    for c in range((512-48) // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16): ((a + 1) * 16), (b * 96): ((b + 1) * 96),
                                               (c * 96+48): ((c + 1) * 96+48)]

                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'NNY-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + index, (b * 96): ((b + 1) * 96),
                                (c * 96+48): ((c + 1) * 96+48)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs

                for b in range((512-48) // 96):
                    for c in range((512-48) // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16): ((a + 1) * 16), (b * 96+48): ((b + 1) * 96+48),
                                               (c * 96+48): ((c + 1) * 96+48)]

                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'NYY-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + index, (b * 96+48): ((b + 1) * 96+48),
                                (c * 96+48): ((c + 1) * 96+48)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs

                for b in range((512-48) // 96):
                    for c in range(512 // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16): ((a + 1) * 16), (b * 96+48): ((b + 1) * 96+48),
                                               (c * 96): ((c + 1) * 96)]

                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'NYN-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + index, (b * 96+48): ((b + 1) * 96+48),
                                (c * 96): ((c + 1) * 96)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs

            for a in range((output.shape[0] - 8) // 16):
                for b in range(512 // 96):
                    for c in range(512 // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16+8): ((a + 1) * 16+8), (b * 96): ((b + 1) * 96),
                                               (c * 96): ((c + 1) * 96)]

                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'YNN-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + 8 + index, (b * 96): ((b + 1) * 96),
                                (c * 96): ((c + 1) * 96)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs

                for b in range(512 // 96):
                    for c in range((512 - 48) // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16+8): ((a + 1) * 16+8), (b * 96): ((b + 1) * 96),
                                               (c * 96 + 48): ((c + 1) * 96 + 48)]

                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'YNY-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + index+8, (b * 96): ((b + 1) * 96),
                                (c * 96 + 48): ((c + 1) * 96 + 48)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs

                for b in range((512 - 48) // 96):
                    for c in range((512 - 48) // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16+8): ((a + 1) * 16+8), (b * 96 + 48): ((b + 1) * 96 + 48),
                                               (c * 96 + 48): ((c + 1) * 96 + 48)]

                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'YYY-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + index+8, (b * 96 + 48): ((b + 1) * 96 + 48),
                                (c * 96 + 48): ((c + 1) * 96 + 48)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs

                for b in range((512 - 48) // 96):
                    for c in range(512 // 96):
                        batch_xs = np.zeros(shape=(16, 96, 96, 1))  # rgb
                        batch_xs[:, :, :, 0] = input_w[(a * 16+8): ((a + 1) * 16+8), (b * 96 + 48): ((b + 1) * 96 + 48),
                                               (c * 96): ((c + 1) * 96)]

                        batch_xs = (batch_xs.astype('float') - float(src_range[0])) * slope + float(dst_range[0])
                        batch_xs[batch_xs < dst_range[0]] = dst_range[0]
                        batch_xs[batch_xs > dst_range[1]] = dst_range[1]
                        batch_xs = batch_xs.astype('uint8')

                        if np.where(batch_xs != 0)[0].shape[0] < 30:
                            print('zero:', a, b, c, '-------', str(np.where(batch_xs != 0)[0].shape[0]))
                            # result = np.zeros(shape=(96, 96, 3), dtype=np.uint8)
                        else:
                            print(a, b, c, 'YYN-case:', str(i))
                            predictvalue = CNN3d.prediction(batch_xs)
                            for index in range(16):
                                result_1ch = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = np.zeros(shape=(96, 96), dtype=np.uint8)
                                result = predictvalue[index]
                                kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
                                result = cv2.morphologyEx(result, cv2.MORPH_CLOSE, kernel)

                                for m in range(0, 96):
                                    for n in range(0, 96):
                                        if result[m, n] <= 200:
                                            result_1ch[m, n] = 0
                                        else:
                                            result_1ch[m, n] = result[m, n]
                                result_1ch = result_1ch // 8
                                output[a * 16 + index+8, (b * 96 + 48): ((b + 1) * 96 + 48),
                                (c * 96): ((c + 1) * 96)] += result_1ch
                                del result
                                del result_1ch
                        del batch_xs
            output[output < 100] = 0
            output.tofile('D:/Ktest/' + str(i) + '/texture_test.raw')
            del output
            del bone
            del input
            del input_w
            gc.collect()


train()
predict()

