import tensorflow as tf
from tensorflow.python.framework import dtypes
import numpy as np
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
def freeze_from_checkpoint(): # freeze graph
path = tf.train.latest_checkpoint(r"D:\Result_RE_Revenge101_checkpoint\\")
input_graph_path = r"D:\Result_RE_Revenge101_checkpoint\graph_node.pbtxt" # the pbtxt path
output_nodes = "generator1/decoder_1/Tanh"
restore_op = r"save\restore_all"
filename_tensor = r"save\Const:0"
output_name = r"D:\Result_RE_Revenge101_checkpoint\pruning101_step11999.pb" # where you want to export your freezed model
freeze_graph.freeze_graph(input_graph_path, "", False, path, output_nodes, restore_op, filename_tensor, output_name, True, "")
def optimize_frozen_file():
"""
- Removing training-only operations like checkpoint saving.
- Stripping out parts of the graph that are never reached.
- Removing debug operations like CheckNumerics.
- Folding batch normalization ops into the pre-calculated weights.
- Fusing common operations into unified versions.
"Note: important: Don't use placeholder as training switch, otherwise the folding batch normalization will occur error"
:return: a optimized function
"""
inputGraph = tf.GraphDef()
frozen_graph_filename = "D:/result201910072_gpu_checkpoint/frozen_model.pb" # the freezed model path
with tf.gfile.Open(frozen_graph_filename, "rb") as f:
data2read = f.read()
inputGraph.ParseFromString(data2read)
outputGraph = optimize_for_inference_lib.optimize_for_inference(
inputGraph,
["input_image"], # an array of the input node(s)
["generator1/decoder_1/Tanh"], # an array of output nodes
dtypes.float32.as_datatype_enum)
# Save the optimized graph'test.pb'
f = tf.gfile.FastGFile('D:/result201910111_gpu_checkpoint/OptimizedGraph.pb', "w")
f.write(outputGraph.SerializeToString())
def load_graph():
frozen_filename = "D:/result201910111_gpu_checkpoint/OptimizedGraph.pb"
with tf.gfile.GFile(frozen_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
# tf.import_graph_def(graph_def, name="prefix")
tf.import_graph_def(graph_def)
return graph
def childs(t, d=0):
print('-' * d, t.name)
for child in t.op.inputs:
childs(child, d + 1)
if __name__ == '__main__':
freeze_from_checkpoint()
optimize_frozen_file()
# graph = load_graph()
# x = graph.get_tensor_by_name("import/input_image:0")
#
# pred = graph.get_tensor_by_name("import/generator1/decoder_1/Tanh:0")
# with tf.Session(graph=graph) as sess:
# input_data = img
# y = sess.run(pred, feed_dict={x: input_data})
# print(y)