import sys
sys.path.append('./Vnet/')
from promise2012.Vnet.layer import (conv3d, deconv3d, normalizationlayer, crop_and_concat, resnet_Add,
weight_xavier_init, bias_variable, save_images, conv2d)
import tensorflow as tf
import numpy as np
import cv2
import os
from tensorflow.contrib import slim
from math import e
import gc
from keras import backend as K
def conv_bn_relu_drop(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None):
with tf.name_scope(scope):
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
n_outputs=kernal[-1], activefunction='relu', variable_name=scope + 'conv_W')
B = bias_variable([kernal[-1]], variable_name=scope + 'conv_B')
conv = conv3d(x, W) + B
conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group',
scope=scope)
conv = tf.nn.dropout(tf.nn.relu(conv), drop)
# conv = tf.nn.dropout(tf.nn.elu(conv), drop) # 激活函数
return conv
def conv_bn_relu(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None):
with tf.name_scope(scope):
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
n_outputs=kernal[-1], activefunction='relu', variable_name=scope + 'conv_W')
B = bias_variable([kernal[-1]], variable_name=scope + 'conv_B')
conv = conv3d(x, W) + B
conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group',
scope=scope)
conv = tf.nn.dropout(tf.nn.relu(conv), drop)
# conv = tf.nn.dropout(tf.nn.elu(conv), drop) # 激活函数
return conv
def down_sampling(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None):
with tf.name_scope(scope):
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
n_outputs=kernal[-1],
activefunction='relu', variable_name=scope + 'W')
B = bias_variable([kernal[-1]], variable_name=scope + 'B')
conv = conv3d(x, W, 2) + B
conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group',
scope=scope)
conv = tf.nn.dropout(tf.nn.relu(conv), drop)
return conv
def deconv_relu(x, kernal, scope=None):
with tf.name_scope(scope):
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[-1],
n_outputs=kernal[-2], activefunction='relu', variable_name=scope + 'W')
B = bias_variable([kernal[-2]], variable_name=scope + 'B')
conv = deconv3d(x, W, True) + B
conv = tf.nn.relu(conv)
conv = tf.nn.relu(conv)
return conv
def conv_sigmod(x, kernal, scope=None):
with tf.name_scope(scope):
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3],
n_outputs=kernal[-1], activefunction='sigmoid', variable_name=scope + 'W')
B = bias_variable([kernal[-1]], variable_name=scope + 'B')
conv = conv3d(x, W) + B
conv = tf.nn.sigmoid(conv)
return conv
def _create_conv_net(X, image_z, image_width, image_height, image_channel, phase, drop, n_class=1):
end_points_collections = 'GNet_End_Points'
print(image_z, image_width, image_height, image_channel, phase, drop)
print(X, '--------------X')
inputX = tf.reshape(X, [1, image_z, image_width, image_height, image_channel])
# Vnet model
##################################################################################################################
layer0 = conv_bn_relu_drop(x=inputX, kernal=(3, 3, 3, image_channel, 16), phase=phase, drop=drop,
scope='layer0')
layer1 = conv_bn_relu_drop(x=layer0, kernal=(3, 3, 3, 16, 16), phase=phase, drop=drop,
scope='layer1')
layer1 = resnet_Add(x1=layer0, x2=layer1)
print(layer1, '==============layer1')
# down sampling1
down1 = down_sampling(x=layer1, kernal=(3, 3, 3, 16, 32), phase=phase, drop=drop, scope='down1')
# layer_pa = SAM_module(layer1)
# layer2->convolution
layer2 = conv_bn_relu_drop(x=down1, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop,
scope='layer2_1')
layer2 = conv_bn_relu_drop(x=layer2, kernal=(3, 3, 3, 32, 32), phase=phase, drop=drop,
scope='layer2_2')
layer2 = resnet_Add(x1=down1, x2=layer2)
print(layer2, '==============layer2')
# down sampling2
down2 = down_sampling(x=layer2, kernal=(3, 3, 3, 32, 64), phase=phase, drop=drop, scope='down2')
# layer_pa2 = SAM_module(layer2)
# layer3->convolution
layer3 = conv_bn_relu_drop(x=down2, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
scope='layer3_1')
layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
scope='layer3_2')
layer3 = conv_bn_relu_drop(x=layer3, kernal=(3, 3, 3, 64, 64), phase=phase, drop=drop,
scope='layer3_3')
layer3 = resnet_Add(x1=down2, x2=layer3)
print(layer3, '==============layer3')
# down sampling3
down3 = down_sampling(x=layer3, kernal=(3, 3, 3, 64, 128), phase=phase, drop=drop, scope='down3')
# layer_pa3 = SAM_module(layer3)
# layer4->convolution
layer4 = conv_bn_relu_drop(x=down3, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
scope='layer4_1')
layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
scope='layer4_2')
layer4 = conv_bn_relu_drop(x=layer4, kernal=(3, 3, 3, 128, 128), phase=phase, drop=drop,
scope='layer4_3')
layer4 = resnet_Add(x1=down3, x2=layer4)
print(layer4, '==============layer4')
# down sampling4
down4 = down_sampling(x=layer4, kernal=(3, 3, 3, 128, 256), phase=phase, drop=drop, scope='down4')
# layer_pa4 = SAM_module(layer4)
# layer5->convolution
layer5 = conv_bn_relu_drop(x=down4, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
scope='layer5_1')
layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
scope='layer5_2')
layer5 = conv_bn_relu_drop(x=layer5, kernal=(3, 3, 3, 256, 256), phase=phase, drop=drop,
scope='layer5_3')
layer5 = resnet_Add(x1=down4, x2=layer5)
print(layer5, '==============layer5')
# layer_ca = CAM_module(layer5)
# print(layer_pa, '==============layer_pa')
# layer9->deconvolution
deconv1 = deconv_relu(x=layer5, kernal=(3, 3, 3, 128, 256), scope='deconv1')
# deconv1 = deconv_relu(x=layer_ca, kernal=(3, 3, 3, 128, 256), scope='deconv1')
# layer8->convolution
# layer6 = crop_and_concat(layer4, deconv1)
# layer6 = crop_and_concat(py_conv4(layer_pa4), deconv1)
layer6 = crop_and_concat(py_conv4(layer4), deconv1)
_, Z, H, W, _ = layer4.get_shape().as_list()
layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 256, 128), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer6_1')
layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer6_2')
layer6 = conv_bn_relu_drop(x=layer6, kernal=(3, 3, 3, 128, 128), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer6_3')
layer6 = resnet_Add(x1=deconv1, x2=layer6)
# layer9->deconvolution
deconv2 = deconv_relu(x=layer6, kernal=(3, 3, 3, 64, 128), scope='deconv2')
# layer8->convolution
# layer7 = crop_and_concat(py_conv3(layer3), deconv2)
layer7 = crop_and_concat(layer3, deconv2)
print(layer7, '=========================layer7')
_, Z, H, W, _ = layer3.get_shape().as_list()
layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 128, 64), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer7_1')
layer7 = conv_bn_relu_drop(x=layer7, kernal=(3, 3, 3, 64, 64), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer7_2')
layer7 = resnet_Add(x1=deconv2, x2=layer7)
# layer9->deconvolution
deconv3 = deconv_relu(x=layer7, kernal=(3, 3, 3, 32, 64), scope='deconv3')
# layer8->convolution
# layer8 = crop_and_concat(layer2, deconv3)
# layer8 = crop_and_concat(py_conv2(layer_pa2), deconv3)
layer8 = crop_and_concat(py_conv2(layer2), deconv3)
_, Z, H, W, _ = layer2.get_shape().as_list()
layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 64, 32), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer10_1')
layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer10_2')
layer8 = conv_bn_relu_drop(x=layer8, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer10_3')
layer8 = resnet_Add(x1=deconv3, x2=layer8)
# layer9->deconvolution
deconv4 = deconv_relu(x=layer8, kernal=(3, 3, 3, 16, 32), scope='deconv4')
# layer8->convolution
layer9 = crop_and_concat(py_conv1(layer1), deconv4)
# layer9 = crop_and_concat(layer_pa, deconv4)
_, Z, H, W, _ = layer1.get_shape().as_list()
layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer11_1')
layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer11_2')
layer9 = conv_bn_relu_drop(x=layer9, kernal=(3, 3, 3, 32, 32), image_z=Z, height=H, width=W, phase=phase,
drop=drop, scope='layer11_3')
layer9 = resnet_Add(x1=deconv4, x2=layer9)
print(layer9, Z, '=========================layer9')
# layer14->output
# output_map = conv_sigmod(x=layer9, kernal=(1, 1, 1, 32, n_class), scope='output')
output_map = conv_sigmod(x=layer9, kernal=(1, 1, 1, 32, 1), scope='output')
print(output_map, '=========================output')
#################################################################################################################
return output_map
# CAM
def CAM_module(inputs):
inputs_shape = inputs.get_shape().as_list()
batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
inputs_shape[4]
proj_query = tf.transpose(tf.reshape(inputs, [batchsize, width*height * depth, -1]), perm=[0, 2, 1])
proj_key = tf.reshape(inputs, [batchsize, width * height * depth, -1])
energy = tf.matmul(proj_query, proj_key)
energy_new = tf.maximum(energy, -1)-energy
attention = tf.nn.softmax(energy_new)
proj_value = tf.transpose(tf.reshape(inputs, [batchsize, width * height * depth, -1]), perm=[0, 2, 1])
out = tf.transpose(tf.matmul(attention, proj_value), perm=[0, 2, 1])
out = (tf.reshape(out, [batchsize, height, width, depth, C]))
out = out + inputs
return out
# SAM
def SAM_module(inputs):
inputs_shape = inputs.get_shape().as_list()
print(inputs_shape)
# batchsize, height, width, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[4]
batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
inputs_shape[4]
filter = tf.Variable(tf.truncated_normal([1, 1, 1, C, C//8], dtype=tf.float32, stddev=0.1), name='weights')
filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='weights1')
query_conv = tf.nn.conv3d(inputs, filter, strides=[1, 1, 1, 1, 1], padding='VALID')
key_conv = tf.nn.conv3d(inputs, filter, strides=[1, 1, 1, 1, 1], padding='VALID')
value_conv = tf.nn.conv3d(inputs, filter1, strides=[1, 1, 1, 1, 1], padding='VALID')
# proj_query = tf.reshape(query_conv, [batchsize, width*height, -1])
# proj_key = tf.transpose((tf.reshape(key_conv, [batchsize, width * height, -1])), perm=[0, 2, 1])
proj_query = tf.reshape(query_conv, [batchsize, width * height * depth, -1])
proj_key = tf.transpose((tf.reshape(key_conv, [batchsize, width * height * depth, -1])), perm=[0, 2, 1])
energy = tf.matmul(proj_query, proj_key)
attention = tf.nn.softmax(energy)
# proj_value = tf.reshape(value_conv, [batchsize, width * height, -1])
proj_value = tf.reshape(value_conv, [batchsize, width * height * depth, -1])
out = tf.matmul(attention, proj_value)
# out = tf.reshape(out, [batchsize, height, width, C])
out = tf.reshape(out, [batchsize, height, width, depth, C])
out = out + inputs
return out
def py_conv1(input):
inputs_shape = input.get_shape().as_list()
batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
inputs_shape[4]
filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_1')
filter2 = tf.Variable(tf.truncated_normal([15, 15, 7, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_1')
filter3 = tf.Variable(tf.truncated_normal([13, 13, 5, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_2')
filter4 = tf.Variable(tf.truncated_normal([7, 7, 3, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_3')
filter5 = tf.Variable(tf.truncated_normal([3, 3, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_2_4')
filter6 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py1_3')
'''保留1x1x1卷积层'''
py_conv1_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv1_2_1 = tf.nn.conv3d(py_conv1_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv1_2_2 = tf.nn.conv3d(py_conv1_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv1_2_3 = tf.nn.conv3d(py_conv1_1, filter4, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv1_2_4 = tf.nn.conv3d(py_conv1_1, filter5, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv1_3 = tf.nn.conv3d(py_conv1_1 + py_conv1_2_1 + py_conv1_2_2 + py_conv1_2_3 + py_conv1_2_4, filter6,
strides=[1, 1, 1, 1, 1], padding='SAME')
out = tf.reshape(py_conv1_3, [batchsize, height, width, depth, C])
return out
def py_conv2(input):
inputs_shape = input.get_shape().as_list()
batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
inputs_shape[4]
filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_1')
filter2 = tf.Variable(tf.truncated_normal([13, 13, 5, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_2_1')
filter3 = tf.Variable(tf.truncated_normal([7, 7, 3, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_2_2')
filter4 = tf.Variable(tf.truncated_normal([3, 3, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_2_3')
filter5 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py2_3')
'''保留1x1x1卷积层'''
py_conv2_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv2_2_1 = tf.nn.conv3d(py_conv2_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv2_2_2 = tf.nn.conv3d(py_conv2_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv2_2_3 = tf.nn.conv3d(py_conv2_1, filter4, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv2_3 = tf.nn.conv3d(py_conv2_1 + py_conv2_2_1 + py_conv2_2_2 + py_conv2_2_3, filter5,
strides=[1, 1, 1, 1, 1], padding='SAME')
out = tf.reshape(py_conv2_3, [batchsize, height, width, depth, C])
return out
def py_conv3(input):
inputs_shape = input.get_shape().as_list()
batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
inputs_shape[4]
filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_1')
filter2 = tf.Variable(tf.truncated_normal([15, 15, 7, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_2_1')
filter3 = tf.Variable(tf.truncated_normal([13, 13, 5, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_2_2')
filter4 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py3_3')
'''保留1x1x1卷积层'''
py_conv3_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv3_2_1 = tf.nn.conv3d(py_conv3_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv3_2_2 = tf.nn.conv3d(py_conv3_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv3_3 = tf.nn.conv3d(py_conv3_1 + py_conv3_2_1 + py_conv3_2_2, filter4,
strides=[1, 1, 1, 1, 1], padding='SAME')
out = tf.reshape(py_conv3_3, [batchsize, height, width, depth, C])
return out
def py_conv4(input):
inputs_shape = input.get_shape().as_list()
batchsize, height, width, depth, C = inputs_shape[0], inputs_shape[1], inputs_shape[2], inputs_shape[3], \
inputs_shape[4]
filter1 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py4_1')
filter2 = tf.Variable(tf.truncated_normal([3, 3, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py4_2_1')
filter3 = tf.Variable(tf.truncated_normal([1, 1, 1, C, C], dtype=tf.float32, stddev=0.1), name='kernel_py4_3')
'''保留1x1x1卷积层'''
py_conv4_1 = tf.nn.conv3d(input, filter1, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv4_2_1 = tf.nn.conv3d(py_conv4_1, filter2, strides=[1, 1, 1, 1, 1], padding='SAME')
py_conv4_3 = tf.nn.conv3d(py_conv4_1 + py_conv4_2_1, filter3, strides=[1, 1, 1, 1, 1], padding='SAME')
out = tf.reshape(py_conv4_3, [batchsize, height, width, depth, C])
return out
# Serve data by batches
def _next_batch(train_images, train_labels, batch_size, index_in_epoch):
start = index_in_epoch
index_in_epoch += batch_size
num_examples = train_images.shape[0]
# when all trainig data have been already used, it is reorder randomly
if index_in_epoch > num_examples:
# shuffle the data
perm = np.arange(num_examples)
np.random.shuffle(perm)
train_images = train_images[perm]
train_labels = train_labels[perm]
# start next epoch
start = 0
index_in_epoch = batch_size
# print(train_images.shape,'-----------------------------------------------------------------------------------------------------------------------------------')
assert batch_size <= num_examples
end = index_in_epoch
return train_images[start:end], train_labels[start:end], index_in_epoch
class Vnet3dModule_64(object):
"""
A unet2d implementation
:param image_height: number of height in the input image
:param image_width: number of width in the input image
:param image_depth: number of depth in the input image
:param channels: number of channels in the input image
:param costname: name of the cost function.Default is "dice coefficient"
"""
def __init__(self, image_height, image_width, image_depth, channels=1, costname="dice coefficient", inference=False,
model_path=None):
self.image_width = image_width
self.image_height = image_height
self.image_depth = image_depth
self.channels = channels
self.X = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width,
self.channels])
self.Y_gt = tf.placeholder("float", shape=[None, self.image_depth, self.image_height, self.image_width,
self.channels])
self.weight = 1 # 分割权重,根据GT的大小决定
self.lr = tf.placeholder('float')
self.phase = tf.placeholder(tf.bool)
self.drop = tf.placeholder('float')
self.is_training = tf.placeholder(tf.bool, name='is_training')
self.Y_pred = _create_conv_net(self.X, self.image_depth, self.image_width, self.image_height, self.channels,
self.phase, self.drop)
self.accuracy = self.__get_accuracy("dice coefficient")
# self.cost_d = self.__get_cost_d("cross entropy")
global_step_D = tf.Variable(tf.constant(0))
lr_d = 0.001
D_score_fake, D_sigmoid_fake = build_D(tf.multiply(self.X, self.Y_pred/255))
D_score_real, D_sigmoid_real = build_D(tf.multiply(self.X, self.Y_gt/255), reuse=True)
Loss_D_fake = sig_loss(D_score_fake, False)
Loss_D_real = sig_loss(D_score_real, True)
self.Loss_D = Loss_D_fake + Loss_D_real
tf.summary.scalar('loss_d', self.Loss_D)
self.cost = self.__get_cost(costname) + 0.01 * sig_loss(D_score_fake, True)
# self.cost = self.__get_cost(costname)
all_vars = tf.trainable_variables()
d_vars = [var for var in all_vars if 'FCDiscriminator' in var.name]
print("d_vars:", len(d_vars))
self.train_op_D = update_optim(self.Loss_D, lr_d, d_vars, global_step_D)
if inference:
init = tf.global_variables_initializer()
saver = tf.train.Saver()
self.sess = tf.InteractiveSession()
self.sess.run(init)
saver.restore(self.sess, model_path)
def __get_cost(self, cost_name): # DICE loss损失函数
Z, H, W, C = self.Y_gt.get_shape().as_list()[1:]
print('Z', Z, 'H', H, 'W', W, 'C', C, 'gt_shape----', self.Y_gt)
if cost_name == "dice coefficient":
smooth = 1e-5
# print(self.Y_pred.shape, 'y_perd---------------------')
pred_flat = tf.reshape(self.Y_pred, [-1, H * W * C * Z])
true_flat = tf.reshape(self.Y_gt, [-1, H * W * C * Z])
intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + smooth # (GT and prediction)*2
denominator = tf.reduce_sum(pred_flat, axis=1) + tf.reduce_sum(true_flat, axis=1) + smooth
# GT or prediction
# loss value
loss = 1-tf.reduce_mean(intersection / denominator)
return loss
def __get_accuracy(self, cost_name): # DICE loss损失函数
Z, H, W, C = self.Y_gt.get_shape().as_list()[1:]
print('Z', Z, 'H', H, 'W', W, 'C', C, 'gt_shape----', self.Y_gt)
if cost_name == "dice coefficient":
smooth = 1e-5
# print(self.Y_pred.shape, 'y_perd---------------------')
pred_flat = tf.reshape(self.Y_pred, [-1, H * W * C * Z])
true_flat = tf.reshape(self.Y_gt, [-1, H * W * C * Z])
intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + smooth # (GT and prediction)*2
denominator = tf.reduce_sum(pred_flat, axis=1) + tf.reduce_sum(true_flat, axis=1) + smooth
loss = tf.reduce_mean(intersection / denominator)
return loss
def train(self, train_images, train_labels, train_images_val, train_labels_val, model_path, logs_path,
learning_rate, dropout_conv, train_steps=50000, batch_size=1):
train_op = tf.train.AdamOptimizer(self.lr).minimize(self.cost)
len_train_data = 5000
init = tf.global_variables_initializer()
saver = tf.train.Saver(tf.all_variables(), max_to_keep=30)
tf.summary.scalar("loss", self.cost)
tf.summary.scalar("accuracy", self.accuracy)
# tf.summary.scalar("lossD", self.Loss_D)
merged_summary_op = tf.summary.merge_all()
sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
# summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph().finalize())
# 锁定图,确认训练过程中有没有向图中追加节点
sess.run(init)
sess1 = tf.Session()
# DISPLAY_STEP = 1
# DISPLAY_STEP = 100
index_in_epoch = 0
bst_acc = 0
# decay_num = 1
for i in range(train_steps):
# print('training step:' + str(i) + '--------------------------')
# get new batch
batch_xs_path, batch_ys_path, index_in_epoch = _next_batch(train_images, train_labels, batch_size,
index_in_epoch)
batch_xs = np.empty((len(batch_xs_path), self.image_depth, self.image_height, self.image_width,
self.channels)) # image
batch_ys = np.empty((len(batch_ys_path), self.image_depth, self.image_height, self.image_width,
self.channels)) # label
# batch_xs_path_val, batch_ys_path_val, index_in_epoch = _next_batch(train_images_val, train_labels_val,
# batch_size, index_in_epoch)
# batch_xs_val = np.empty((len(batch_xs_path_val), self.image_depth, self.image_height, self.image_width,
# self.channels)) # 验证图像
# batch_ys_val = np.empty((len(batch_ys_path_val), self.image_depth, self.image_height, self.image_width,
# self.channels)) # 验证标注
for num in range(len(batch_xs_path)): # batch内循环/batch size
index = 0 # 图像深度
for _ in os.listdir(batch_xs_path[num][0]): # 按depth遍历
# if _ != 'Thumbs.db':
if _ != 'Thumbs.db' and index < 16:
image = cv2.imread(batch_xs_path[num][0]+'/'+str(index)+'.png', cv2.IMREAD_GRAYSCALE) # 读取文件夹内RGB图像
label = cv2.imread(batch_ys_path[num][0]+'/'+str(index)+'.png', cv2.IMREAD_GRAYSCALE) # 读取文件夹内标注图像
batch_xs[num, index, :, :, :] = np.reshape(image, (self.image_height, self.image_width,
self.channels)) # 读取的图像拼合
batch_ys[num, index, :, :, :] = np.reshape(label, (self.image_height, self.image_width,
self.channels)) # 读取的标注拼合
index += 1
# Extracting images and labels from given data
batch_xs = batch_xs.astype(np.float) # 转为float
batch_ys = batch_ys.astype(np.float)
# batch_xs_val = batch_xs_val.astype(np.float) # 转为float
# batch_ys_val = batch_ys_val.astype(np.float)
# Normalize from [0:255] => [0.0:1.0]
batch_xs = np.multiply(batch_xs, 1.0 / 255.0) # 归一化
batch_ys = np.multiply(batch_ys, 1.0 / 255.0)
# batch_xs_val = np.multiply(batch_xs_val, 1.0 / 255.0) # 归一化
# batch_ys_val = np.multiply(batch_ys_val, 1.0 / 255.0)
if i > len_train_data:
learning_rate = learning_rate * 0.9 ** (i/len_train_data)
print('train step:', str(i), '------------------------------------------')
_, summary = sess.run([train_op, merged_summary_op],
feed_dict={self.X: batch_xs, # 训练generator
self.Y_gt: batch_ys,
self.lr: learning_rate,
self.phase: 1,
self.drop: dropout_conv})
_, train_loss_d = sess.run([self.train_op_D, self.Loss_D], # 训练discriminator----------
feed_dict={self.X: batch_xs, self.Y_gt: batch_ys, self.is_training: True,
self.lr: learning_rate, self.drop: dropout_conv})
if i == len_train_data or i == len_train_data * 2 or i == len_train_data * 4 \
or i == len_train_data * 6 or i == len_train_data * 8 or i == len_train_data * 10 \
or i == len_train_data * 12 or i == len_train_data * 14 or i == len_train_data * 16 \
or i == len_train_data * 18 or i == len_train_data * 20 or i == 0:
save_path = saver.save(sess, model_path, global_step=i)
print("Model saved in file:" + save_path)
# train_loss_val = sess.run([self.cost], feed_dict={self.X: batch_xs_val, # 验证,频率改为每个epoch一次
# self.Y_gt: batch_ys_val,
# self.phase: 1,
# self.drop: 1})
# print(train_loss_val[0], '--------------------train_loss_val')
#
# if 1 - train_loss_val[0] >= bst_acc: # 保存模型
# # save_path = saver.save(sess, model_path + 'bst', global_step=i)
# save_path = saver.save(sess, model_path + '-bst')
# print("Model saved in file:" + save_path + '-----------train_acc:' + str(1 - train_loss_d) +
# '------bst_acc:' + str(bst_acc))
# bst_acc = 1 - train_loss_val[0]
summary_writer.add_summary(summary, i)
summary_writer.close()
save_path = saver.save(sess, model_path)
print("Model saved in file:", save_path)
def prediction(self, test_images):
test_images = np.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))
# test_images = np.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 3))
# rgb
test_images = test_images.astype(np.float)
test_images = np.multiply(test_images, 1.0 / 255.0)
y_dummy = test_images
# print(y_dummy.shape)
pred = self.sess.run(self.Y_pred, feed_dict={self.X: [test_images],
self.Y_gt: [y_dummy],
self.phase: 1,
self.drop: 1})
result = pred.astype(np.float32) * 255.
result = np.clip(result, 0, 255).astype('uint8')
result = np.reshape(result, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))
return result
'''
D_Net
'''
def build_D(inputs, reuse=False):
'''
inputs: G's output or label
'''
ndf = 64
print(inputs, '------------inputs') # 输入维度:(1, 32, 64, 64, 3)
# end_points_collections='DNet_End_Points'
with tf.variable_scope('FCDiscriminator'):
conv1 = slim.conv3d(inputs, ndf, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv1')
print(conv1, '------------conv1') # (1, 16, 32, 32, 64)
conv2 = slim.conv3d(conv1, ndf * 2, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv2')
print(conv2, '------------conv2') # (1, 8, 16, 16, 128)
conv3 = slim.conv3d(conv2, ndf * 4, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv3')
print(conv3, '------------conv3') # (1, 4, 8, 8, 256)
conv4 = slim.conv3d(conv3, ndf * 8, 3, 2, activation_fn=leaky_relu, reuse=reuse, scope='conv4')
print(conv4, '------------conv4') # (1, 2, 4, 4, 512)
classifier = slim.conv3d(conv4, 1, 3, 1, activation_fn=None, reuse=reuse, scope='classifier')
print(classifier, '------------classifier') # classifier shape:(1, 2, 4, 4, 1)
score = tf.placeholder(tf.float32, [classifier.shape[0], classifier.shape[1], classifier.shape[2],
classifier.shape[3], classifier.shape[4]])
score = upsample(classifier[:, 0, :, :, :], tf.shape(inputs)[2:4]) # 16深度测试
sigmoid = tf.nn.sigmoid(score)
print(score, '------------score')
print(sigmoid, '------------sigmoid')
return score, sigmoid
# input:batch, height, width, channels
def upsample(bottom, size): # bottom: input image size: resized shape (height, width)
out = tf.image.resize_bilinear(bottom, size=size)
return out
def leaky_relu(x, alpha=0.2):
return tf.maximum(x, alpha*x)
def update_optim(loss, decay_lr, var_list, global_step=None):
print(loss, '-----------------update opt input loss')
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer = tf.train.AdamOptimizer(decay_lr)
# delete global_step,because global_step auto add 1,because we want to use G weight share and same lr
train_op = optimizer.minimize(loss, var_list=var_list, global_step=global_step)
return train_op
def update_lr(learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, global_step):
'''
globale_step automatic update +1 and update lr
'''
decay_lr = tf.train.exponential_decay(learning_rate, global_step, learning_rate_decay_steps,
learning_rate_decay_rate, staircase=True)
return decay_lr
def sig_loss(logits, true_label=False):
# binary cross-entropy
if true_label:
labels = tf.ones_like(logits, dtype=tf.float32) # 返回全1 tensor
print(labels, '-----------------labels in sig loss, true')
else:
labels = tf.zeros_like(logits, dtype=tf.float32) # 返回全0 tensor
print(labels, '-----------------labels in sig loss, false')
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))