import sys
sys.path.append('./Vnet/')
from promise2012.Vnet.layer import (conv3d, deconv3d, normalizationlayer, crop_and_concat, resnet_Add,
                        weight_xavier_init, bias_variable, save_images, conv2d)

import tensorflow as tf
import numpy as np
import cv2
import os
from tensorflow.contrib import slim
from math import e
import gc
from keras import backend as K
# from promise2012.Vnet.attention import PAM_Module
# import torch.nn as nn


def conv_bn_relu_drop(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None):
    with tf.name_scope(scope):
        W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
                               n_outputs=kernal[-1], activefunction='relu', variable_name=scope + 'conv_W')
        B = bias_variable([kernal[-1]], variable_name=scope + 'conv_B')
        conv = conv3d(x, W) + B
        conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group',
                                  scope=scope)
        conv = tf.nn.dropout(tf.nn.relu(conv), drop)
        # conv = tf.nn.dropout(tf.nn.elu(conv), drop)  # 激活函数
        return conv


def conv_bn_relu(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None):
    with tf.name_scope(scope):
        W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
                               n_outputs=kernal[-1], activefunction='relu', variable_name=scope + 'conv_W')
        B = bias_variable([kernal[-1]], variable_name=scope + 'conv_B')
        conv = conv3d(x, W) + B
        conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group',
                                  scope=scope)
        conv = tf.nn.dropout(tf.nn.relu(conv), drop)
        # conv = tf.nn.dropout(tf.nn.elu(conv), drop)  # 激活函数
        return conv


def down_sampling(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None):
    with tf.name_scope(scope):
        W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
                               n_outputs=kernal[-1],
                               activefunction='relu', variable_name=scope + 'W')
        B = bias_variable([kernal[-1]], variable_name=scope + 'B')
        conv = conv3d(x, W, 2) + B
        conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group',
                                  scope=scope)
        conv = tf.nn.dropout(tf.nn.relu(conv), drop)
        return conv


def deconv_relu(x, kernal, scope=None):
    with tf.name_scope(scope):
        W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[-1],
                               n_outputs=kernal[-2], activefunction='relu', variable_name=scope + 'W')
        B = bias_variable([kernal[-2]], variable_name=scope + 'B')
        conv = deconv3d(x, W, True) + B
        conv = tf.nn.relu(conv)
        conv = tf.nn.relu(conv)
        return conv


def conv_sigmod(x, kernal, scope=None):
    with tf.name_scope(scope):
        W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
                               n_outputs=kernal[-1], activefunction='sigmoid', variable_name=scope + 'W')
        B = bias_variable([kernal[-1]], variable_name=scope + 'B')
        conv = conv3d(x, W) + B
        conv = tf.nn.sigmoid(conv)
        return conv


def _create_conv_net(X, image_z, image_width, image_height, image_channel, phase, drop, n_class=1):
    end_points_collections = 'GNet_End_Points'
    print(image_z, image_width, image_height, image_channel, phase, drop)
    print(X, '--------------X')
    inputX = tf.reshape(X, [1, image_z, image_width, image_height, image_channel])
    # Vnet model
##################################################################################################################

    layer0 = conv_bn_relu_drop(x=inputX, kernal=(3, 3, 3, image_channel, 16), phase=phase, drop=drop,
                               scope='layer0')
    layer1 = conv_bn_relu_drop(x=layer0, kernal=(3, 3, 3, 16, 16), phase=phase, drop=drop,
                               scope='layer1')
    layer1 = resnet_Add(x1=layer0, x2=layer1)
    print(layer1, '==============layer1')
    # down sampling1
    down1 = down_sampling(x=layer1, kernal=(3, 3, 3, 16, 32), phase=phase, drop=drop, scope='down1')
    # layer_pa = SAM_module(layer1)
    # layer2->convolution
    layer2 = conv_bn_relu_drop(x=down1, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop,
                               scope='layer2_1')
    layer2 = conv_bn_relu_drop(x=layer2, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop,
                               scope='layer2_2')
    layer2 = resnet_Add(x1=down1, x2=layer2)
    print(layer2, '==============layer2')
    # down sampling2
    down2 = down_sampling(x=layer2, kernal=(3, 3, 3, 32, 64), phase=phase, drop=drop, scope='down2')
    layer_pa2 = SAM_module(layer2)
    # layer3->convolution
    layer3 = conv_bn_relu_drop(x=down2, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
                               scope='layer3_1')
    layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
                               scope='layer3_2')
    layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
                               scope='layer3_3')
    layer3 = resnet_Add(x1=down2, x2=layer3)
    print(layer3, '==============layer3')
    # down sampling3
    down3 = down_sampling(x=layer3, kernal=(3, 3, 3, 64, 128), phase=phase, drop=drop, scope='down3')
    layer_pa3 = SAM_module(layer3)
    # layer4->convolution
    layer4 = conv_bn_relu_drop(x=down3, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
                               scope='layer4_1')
    layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
                               scope='layer4_2')
    layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
                               scope='layer4_3')
    layer4 = resnet_Add(x1=down3, x2=layer4)
    print(layer4, '==============layer4')
    # down sampling4
    down4 = down_sampling(x=layer4, kernal=(3, 3, 3, 128, 256), phase=phase, drop=drop, scope='down4')
    layer_pa4 = SAM_module(layer4)
    # layer5->convolution
    layer5 = conv_bn_relu_drop(x=down4, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
                               scope='layer5_1')
    layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
                               scope='layer5_2')
    layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
                               scope='layer5_3')
    layer5 = resnet_Add(x1=down4, x2=layer5)
    print(layer5, '==============layer5')

    layer_ca = CAM_module(layer5)
    # print(layer_pa, '==============layer_pa')
    # layer9->deconvolution
    # deconv1 = deconv_relu(x=layer5, kernal=(3, 3, 3, 128, 256), scope='deconv1')
    # deconv1 = deconv_relu(x=layer_pa, kernal=(3, 3, 3, 128, 256), scope='deconv1')
    deconv1 = deconv_relu(x=layer_ca, kernal=(3, 3, 3, 128, 256), scope='deconv1')

    # layer8->convolution
    # layer6 = crop_and_concat(layer4, deconv1)
    layer6 = crop_and_concat(py_conv4(layer_pa4), deconv1)
    layer6 = SAM_module(layer6)
    _, Z, H, W, _ = layer4.get_shape().as_list()
    layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 256, 128), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer6_1')
    layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer6_2')
    layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer6_3')
    layer6 = resnet_Add(x1=deconv1, x2=layer6)
    # layer9->deconvolution
    deconv2 = deconv_relu(x=layer6, kernal=(3, 3, 3, 64, 128), scope='deconv2')
    # layer8->convolution
    layer7 = crop_and_concat(py_conv3(layer_pa3), deconv2)
    # layer7 = crop_and_concat(layer_pa3, deconv2)
    print(layer7, '=========================layer7')
    layer7 = SAM_module(layer7)
    _, Z, H, W, _ = layer3.get_shape().as_list()
    layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 128, 64), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer7_1')
    layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 64, 64), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer7_2')
    layer7 = resnet_Add(x1=deconv2, x2=layer7)
    # layer9->deconvolution
    deconv3 = deconv_relu(x=layer7, kernal=(3, 3, 3, 32, 64), scope='deconv3')
    # layer8->convolution
    # layer8 = crop_and_concat(layer2, deconv3)
    layer8 = crop_and_concat(py_conv2(layer_pa2), deconv3)
    layer8 = SAM_module(layer8)
    _, Z, H, W, _ = layer2.get_shape().as_list()
    layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 64, 32), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer10_1')
    layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer10_2')
    layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer10_3')
    layer8 = resnet_Add(x1=deconv3, x2=layer8)
    # layer9->deconvolution
    deconv4 = deconv_relu(x=layer8, kernal=(3, 3, 3, 16, 32), scope='deconv4')
    # layer8->convolution
    layer9 = crop_and_concat(py_conv1(layer1), deconv4)
    # layer9 = crop_and_concat(layer_pa, deconv4)
    _, Z, H, W, _ = layer1.get_shape().as_list()
    layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer11_1')
    layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer11_2')
    layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
                               drop=drop, scope='layer11_3')
    layer9 = resnet_Add(x1=deconv4, x2=layer9)
    print(layer9, Z, '=========================layer9')
    # layer14->output
    # output_map = conv_sigmod(x=layer9, kernal=(1, 1, 1, 32, n_class), scope='output')
    output_map = conv_sigmod(x=layer9, kernal=(1, 1, 1, 32, 1), scope='output')
    print(output_map, '=========================output')
    #################################################################################################################

    return output_map


# CAM
def CAM_module(inputs):
    inputs_shape = inputs.get_shape().as_list()
    batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
                                         inputs_shape[4]

    proj_query = tf.transpose(tf.reshape(inputs, [batchsize, width*height * depth, -1]), perm=[0, 2, 1])
    proj_key = tf.reshape(inputs, [batchsize, width * height * depth, -1])
    energy = tf.matmul(proj_query, proj_key)
    energy_new = tf.maximum(energy, -1)-energy

    attention = tf.nn.softmax(energy_new)
    proj_value = tf.transpose(tf.reshape(inputs, [batchsize, width * height * depth, -1]), perm=[0, 2, 1])

    out = tf.transpose(tf.matmul(attention, proj_value), perm=[0, 2, 1])
    out = (tf.reshape(out, [batchsize, height, width, depth, C]))
    out = out + inputs
    return out


# SAM
def SAM_module(inputs):
    inputs_shape = inputs.get_shape().as_list()
    print(inputs_shape)
    # batchsize, height, width, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[4]
    batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
                                         inputs_shape[4]
    filter = tf.Variable(tf.truncated_normal([1, 1, 1, C, C//8], dtype=tf.float32, stddev=0.1), name='weights')
    filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='weights1')
    query_conv = tf.nn.conv3d(inputs, filter, strides=[1, 1, 1, 1, 1], padding='VALID')
    key_conv = tf.nn.conv3d(inputs, filter, strides=[1, 1, 1, 1, 1], padding='VALID')
    value_conv = tf.nn.conv3d(inputs, filter1, strides=[1, 1, 1, 1, 1], padding='VALID')

    # proj_query = tf.reshape(query_conv, [batchsize, width*height, -1])
    # proj_key = tf.transpose((tf.reshape(key_conv, [batchsize, width * height, -1])), perm=[0, 2, 1])
    proj_query = tf.reshape(query_conv, [batchsize, width * height * depth, -1])
    proj_key = tf.transpose((tf.reshape(key_conv, [batchsize, width * height * depth, -1])), perm=[0, 2, 1])
    energy = tf.matmul(proj_query, proj_key)

    attention = tf.nn.softmax(energy)
    # proj_value = tf.reshape(value_conv, [batchsize, width * height, -1])
    proj_value = tf.reshape(value_conv, [batchsize, width * height * depth, -1])

    out = tf.matmul(attention, proj_value)
    # out = tf.reshape(out, [batchsize, height, width, C])
    out = tf.reshape(out, [batchsize, height, width, depth, C])
    out = out + inputs
    return out


def py_conv1(input):
    inputs_shape = input.get_shape().as_list()
    batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
                                         inputs_shape[4]
    filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_1')
    filter2 = tf.Variable(tf.truncated_normal([15, 15, 7, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_1')
    filter3 = tf.Variable(tf.truncated_normal([13, 13, 5, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_2')
    filter4 = tf.Variable(tf.truncated_normal([7, 7, 3, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_3')
    filter5 = tf.Variable(tf.truncated_normal([3, 3, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_4')
    filter6 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_3')

    '''保留1x1x1卷积层'''
    py_conv1_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv1_2_1 = tf.nn.conv3d(py_conv1_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv1_2_2 = tf.nn.conv3d(py_conv1_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv1_2_3 = tf.nn.conv3d(py_conv1_1, filter4, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv1_2_4 = tf.nn.conv3d(py_conv1_1, filter5, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv1_3 = tf.nn.conv3d(py_conv1_1 + py_conv1_2_1 + py_conv1_2_2 + py_conv1_2_3 + py_conv1_2_4, filter6,
                              strides=[1, 1, 1, 1, 1], padding='SAME')

    out = tf.reshape(py_conv1_3, [batchsize, height, width, depth, C])
    return out


def py_conv2(input):
    inputs_shape = input.get_shape().as_list()
    batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
                                         inputs_shape[4]
    filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_1')
    filter2 = tf.Variable(tf.truncated_normal([13, 13, 5, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_2_1')
    filter3 = tf.Variable(tf.truncated_normal([7, 7, 3, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_2_2')
    filter4 = tf.Variable(tf.truncated_normal([3, 3, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_2_3')
    filter5 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_3')

    '''保留1x1x1卷积层'''
    py_conv2_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv2_2_1 = tf.nn.conv3d(py_conv2_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv2_2_2 = tf.nn.conv3d(py_conv2_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv2_2_3 = tf.nn.conv3d(py_conv2_1, filter4, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv2_3 = tf.nn.conv3d(py_conv2_1 + py_conv2_2_1 + py_conv2_2_2 + py_conv2_2_3, filter5,
                              strides=[1, 1, 1, 1, 1], padding='SAME')

    out = tf.reshape(py_conv2_3, [batchsize, height, width, depth, C])
    return out


def py_conv3(input):
    inputs_shape = input.get_shape().as_list()
    batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
                                         inputs_shape[4]
    filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_1')
    filter2 = tf.Variable(tf.truncated_normal([15, 15, 7, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_2_1')
    filter3 = tf.Variable(tf.truncated_normal([13, 13, 5, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_2_2')
    filter4 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_3')

    '''保留1x1x1卷积层'''
    py_conv3_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv3_2_1 = tf.nn.conv3d(py_conv3_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv3_2_2 = tf.nn.conv3d(py_conv3_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv3_3 = tf.nn.conv3d(py_conv3_1 + py_conv3_2_1 + py_conv3_2_2, filter4,
                              strides=[1, 1, 1, 1, 1], padding='SAME')

    out = tf.reshape(py_conv3_3, [batchsize, height, width, depth, C])
    return out


def py_conv4(input):
    inputs_shape = input.get_shape().as_list()
    batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
                                         inputs_shape[4]
    filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py4_1')
    filter2 = tf.Variable(tf.truncated_normal([3, 3, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py4_2_1')
    filter3 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py4_3')

    '''保留1x1x1卷积层'''
    py_conv4_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv4_2_1 = tf.nn.conv3d(py_conv4_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
    py_conv4_3 = tf.nn.conv3d(py_conv4_1 + py_conv4_2_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')

    out = tf.reshape(py_conv4_3, [batchsize, height, width, depth, C])
    return out


# Serve data by batches
def _next_batch(train_images, train_labels, batch_size, index_in_epoch):
    start = index_in_epoch
    index_in_epoch += batch_size

    num_examples = train_images.shape[0]
    # when all trainig data have been already used, it is reorder randomly
    if index_in_epoch > num_examples:
        # shuffle the data
        perm = np.arange(num_examples)
        np.random.shuffle(perm)
        train_images = train_images[perm]
        train_labels = train_labels[perm]
        # start next epoch
        start = 0
        index_in_epoch = batch_size
        # print(train_images.shape,'-----------------------------------------------------------------------------------------------------------------------------------')
        assert batch_size <= num_examples
    end = index_in_epoch
    return train_images[start:end], train_labels[start:end], index_in_epoch


class Vnet3dModule_64(object):
    """
        A unet2d implementation

        :param image_height: number of height in the input image
        :param image_width: number of width in the input image
        :param image_depth: number of depth in the input image
        :param channels: number of channels in the input image
        :param costname: name of the cost function.Default is "dice coefficient"
    """

    def __init__(self, image_height, image_width, image_depth, channels=1, costname="dice coefficient", inference=False,
                 model_path=None):
        self.image_width = image_width
        self.image_height = image_height
        self.image_depth = image_depth
        self.channels = channels

        self.X = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width,
                                                self.channels])
        self.Y_gt = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width,
                                                   self.channels])
        self.weight = 1  # 分割权重，根据GT的大小决定
        # self.X_dis_pred = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width,
        #                                                  self.channels])
        #
        # self.Y_gt_dis = tf.placeholder("float", shape=[None, 1])

        self.lr = tf.placeholder('float')
        self.phase = tf.placeholder(tf.bool)
        self.drop = tf.placeholder('float')
        self.is_training = tf.placeholder(tf.bool, name='is_training')

        self.Y_pred = _create_conv_net(self.X, self.image_depth, self.image_width, self.image_height, self.channels,
                                       self.phase, self.drop)

        self.accuracy = self.__get_accuracy("dice coefficient")
        # self.cost_d = self.__get_cost_d("cross entropy")

        global_step_D = tf.Variable(tf.constant(0))
        lr_d = 0.001

        D_score_fake, D_sigmoid_fake = build_D(tf.multiply(self.X, self.Y_pred/255))
        D_score_real, D_sigmoid_real = build_D(tf.multiply(self.X, self.Y_gt/255), reuse=True)
        Loss_D_fake = sig_loss(D_score_fake, False)
        Loss_D_real = sig_loss(D_score_real, True)
        self.Loss_D = Loss_D_fake + Loss_D_real
        tf.summary.scalar('loss_d', self.Loss_D)

        self.cost = self.__get_cost(costname) + 0.01 * sig_loss(D_score_fake, True)
        # self.cost = self.__get_cost(costname)
        all_vars = tf.trainable_variables()
        d_vars = [var for var in all_vars if 'FCDiscriminator' in var.name]
        print("d_vars:", len(d_vars))
        self.train_op_D = update_optim(self.Loss_D, lr_d, d_vars, global_step_D)

        if inference:
            init = tf.global_variables_initializer()
            saver = tf.train.Saver()
            self.sess = tf.InteractiveSession()
            self.sess.run(init)
            saver.restore(self.sess, model_path)

    def __get_cost(self, cost_name):  # DICE loss损失函数
        Z, H, W, C = self.Y_gt.get_shape().as_list()[1:]
        print('Z', Z, 'H', H, 'W', W, 'C', C, 'gt_shape----', self.Y_gt)
        if cost_name == "dice coefficient":
            smooth = 1e-5
            # print(self.Y_pred.shape, 'y_perd---------------------')
            pred_flat = tf.reshape(self.Y_pred, [-1, H * W * C * Z])
            true_flat = tf.reshape(self.Y_gt, [-1, H * W * C * Z])
            intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + smooth  # (GT and prediction)*2
            denominator = tf.reduce_sum(pred_flat, axis=1) + tf.reduce_sum(true_flat, axis=1) + smooth
            # GT or prediction
            # loss value
            loss = 1-tf.reduce_mean(intersection / denominator)
            # loss = tf.cast(e**(-tf.count_nonzero(true_flat)/(64*64*32)), dtype=tf.float32) * (1 - tf.reduce_mean(
            #     intersection / denominator))
        return loss

    def __get_accuracy(self, cost_name):  # DICE loss损失函数
        Z, H, W, C = self.Y_gt.get_shape().as_list()[1:]
        print('Z', Z, 'H', H, 'W', W, 'C', C, 'gt_shape----', self.Y_gt)
        if cost_name == "dice coefficient":
            smooth = 1e-5
            # print(self.Y_pred.shape, 'y_perd---------------------')
            pred_flat = tf.reshape(self.Y_pred, [-1, H * W * C * Z])
            true_flat = tf.reshape(self.Y_gt, [-1, H * W * C * Z])
            intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + smooth  # (GT and prediction)*2
            denominator = tf.reduce_sum(pred_flat, axis=1) + tf.reduce_sum(true_flat, axis=1) + smooth
            loss = tf.reduce_mean(intersection / denominator)
        return loss

    def train(self, train_images, train_labels, train_images_val, train_labels_val, model_path, logs_path,
              learning_rate, dropout_conv, train_steps=50000, batch_size=1):
        train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)

        len_train_data = 5000

        init = tf.global_variables_initializer()
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=30)
        tf.summary.scalar("loss", self.cost)
        tf.summary.scalar("accuracy", self.accuracy)
        # tf.summary.scalar("lossD", self.Loss_D)
        merged_summary_op = tf.summary.merge_all()
        sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
        # summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
        summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph().finalize())
        # 锁定图，确认训练过程中有没有向图中追加节点
        sess.run(init)
        sess1 = tf.Session()
        # DISPLAY_STEP = 1
        # DISPLAY_STEP = 100
        index_in_epoch = 0
        bst_acc = 0
        # decay_num = 1
        for i in range(train_steps):
            # print('training step:' + str(i) + '--------------------------')
            # get new batch
            batch_xs_path, batch_ys_path, index_in_epoch = _next_batch(train_images, train_labels, batch_size,
                                                                       index_in_epoch)
            batch_xs = np.empty((len(batch_xs_path), self.image_depth, self.image_height, self.image_width,
                                 self.channels))  # image
            batch_ys = np.empty((len(batch_ys_path), self.image_depth, self.image_height, self.image_width,
                                 self.channels))  # label

            # batch_xs_path_val, batch_ys_path_val, index_in_epoch = _next_batch(train_images_val, train_labels_val,
            #                                                                    batch_size, index_in_epoch)
            # batch_xs_val = np.empty((len(batch_xs_path_val), self.image_depth, self.image_height, self.image_width,
            #                         self.channels))  # 验证图像
            # batch_ys_val = np.empty((len(batch_ys_path_val), self.image_depth, self.image_height, self.image_width,
            #                         self.channels))  # 验证标注
            volume = 0
            for num in range(len(batch_xs_path)):  # batch内循环/batch size
                index = 0  # 图像深度
                for _ in os.listdir(batch_xs_path[num][0]):  # 按depth遍历
                    # if _ != 'Thumbs.db':
                    if _ != 'Thumbs.db' and index < 16:
                        image = cv2.imread(batch_xs_path[num][0]+'/'+str(index)+'.png', cv2.IMREAD_GRAYSCALE)  # 读取文件夹内RGB图像
                        label = cv2.imread(batch_ys_path[num][0]+'/'+str(index)+'.png', cv2.IMREAD_GRAYSCALE)  # 读取文件夹内标注图像
                        batch_xs[num, index, :, :, :] = np.reshape(image, (self.image_height, self.image_width,
                                                                           self.channels))  # 读取的图像拼合
                        batch_ys[num, index, :, :, :] = np.reshape(label, (self.image_height, self.image_width,
                                                                           self.channels))  # 读取的标注拼合

                        index += 1

            # Extracting images and labels from given data
            batch_xs = batch_xs.astype(np.float)  # 转为float
            batch_ys = batch_ys.astype(np.float)

            # Normalize from [0:255] => [0.0:1.0]
            batch_xs = np.multiply(batch_xs, 1.0 / 255.0)  # 归一化
            batch_ys = np.multiply(batch_ys, 1.0 / 255.0)
            if i > len_train_data:
                learning_rate = learning_rate * 0.9 ** (i/len_train_data)

            print('train step:', str(i), '------------------------------------------')

            _, summary = sess.run([train_op, merged_summary_op],
                                  feed_dict={self.X: batch_xs,  # 训练generator
                                             self.Y_gt: batch_ys,
                                             self.lr: learning_rate,
                                             self.phase: 1,
                                             self.drop: dropout_conv})

            _, train_loss_d = sess.run([self.train_op_D, self.Loss_D],  # 训练discriminator
                                       feed_dict={self.X: batch_xs, self.Y_gt: batch_ys, self.is_training: True,
                                                  self.lr: learning_rate, self.drop: dropout_conv})

            if i == len_train_data or i == len_train_data * 2 or i == len_train_data * 3 or i == len_train_data * 4 \
                    or i == len_train_data * 5 or i == len_train_data * 6 or i == len_train_data * 7 \
                    or i == len_train_data * 8 or i == len_train_data * 9 or i == len_train_data * 10 \
                    or i == len_train_data * 11 or i == len_train_data * 12 or i == len_train_data * 13 or i == 0 \
                    or i == len_train_data * 14 or i == len_train_data * 15 or i == len_train_data * 16 \
                    or i == len_train_data * 17 or i == len_train_data * 18:

                save_path = saver.save(sess, model_path, global_step=i)
                print("Model saved in file:" + save_path)



            summary_writer.add_summary(summary, i)
        summary_writer.close()

        save_path = saver.save(sess, model_path)
        print("Model saved in file:", save_path)

    def prediction(self, test_images):
        test_images = np.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))
        # rgb
        test_images = test_images.astype(np.float)
        test_images = np.multiply(test_images, 1.0 / 255.0)
        y_dummy = test_images
        # print(y_dummy.shape)
        pred = self.sess.run(self.Y_pred, feed_dict={self.X: [test_images],
                                                     self.Y_gt: [y_dummy],
                                                     self.phase: 1,
                                                     self.drop: 1})
        result = pred.astype(np.float32) * 255.
        result = np.clip(result, 0, 255).astype('uint8')
        result = np.reshape(result, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))
        return result


'''
    D_Net
'''


def build_D(inputs, reuse=False):
    '''
    inputs: G's output or label
    '''
    ndf = 64
    #    end_points_collections='DNet_End_Points'
    with tf.variable_scope('FCDiscriminator'):
        conv1 = slim.conv3d(inputs, ndf, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv1')
        print(conv1, '------------conv1')  # (1, 16, 32, 32, 64)
        conv2 = slim.conv3d(conv1, ndf * 2, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv2')
        print(conv2, '------------conv2')  # (1, 8, 16, 16, 128)
        conv3 = slim.conv3d(conv2, ndf * 4, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv3')
        print(conv3, '------------conv3')  # (1, 4, 8, 8, 256)
        conv4 = slim.conv3d(conv3, ndf * 8, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv4')
        print(conv4, '------------conv4')  # (1, 2, 4, 4, 512)
        classifier = slim.conv3d(conv4, 1, 3, 1, activation_fn=None, reuse=reuse, scope='classifier')
        print(classifier, '------------classifier')  # classifier shape：(1, 2, 4, 4, 1)
        score = tf.placeholder(tf.float32, [classifier.shape[0], classifier.shape[1], classifier.shape[2],
                                            classifier.shape[3], classifier.shape[4]])
        score = upsample(classifier[:, 0, :, :, :], tf.shape(inputs)[2:4])  # 16深度测试
        sigmoid = tf.nn.sigmoid(score)
        print(score, '------------score')
        print(sigmoid, '------------sigmoid')
    return score, sigmoid
# input：batch, height, width, channels


def upsample(bottom, size):  # bottom: input image  size: resized shape (height, width)
    out = tf.image.resize_bilinear(bottom, size=size)
    return out


def leaky_relu(x, alpha=0.2):
    return tf.maximum(x, alpha*x)


def update_optim(loss, decay_lr, var_list, global_step=None):
    print(loss, '-----------------update opt input loss')
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(decay_lr)
        # delete global_step,because global_step auto add 1,because we want to use G weight share and same lr
        train_op = optimizer.minimize(loss, var_list=var_list, global_step=global_step)
    return train_op


def update_lr(learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, global_step):
    '''
    globale_step automatic update +1  and update lr
    '''
    decay_lr = tf.train.exponential_decay(learning_rate, global_step, learning_rate_decay_steps,
                                          learning_rate_decay_rate, staircase=True)
    return decay_lr


def sig_loss(logits, true_label=False):
    # binary cross-entropy
    if true_label:
        labels = tf.ones_like(logits, dtype=tf.float32)  # 返回全1 tensor
        print(labels, '-----------------labels in sig loss, true')
    else:
        labels = tf.zeros_like(logits, dtype=tf.float32)  # 返回全0 tensor
        print(labels, '-----------------labels in sig loss, false')
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))

