import tensorflow as tf
from tensorflow.python.framework import dtypes
from PIL import Image
import numpy as np
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib

# # load one image
# test_image_dir = "D:/test13/pb_test/20180626030939.jpg"  # test image path for testing whether the model optimization works
# img = Image.open(test_image_dir)
# img_ndarray = np.array(img, dtype='uint8')
#
# print(img_ndarray.shape)
# img = img_ndarray.reshape((1, 256, 256, 3))
# print(img)


def freeze_from_checkpoint():  # freeze graph
    path = tf.train.latest_checkpoint("D:/resultAREinProcess10_gpu_checkpoint/")      # the path used for only use variable saved at the last time.
    input_graph_path = "D:/resultAREinProcess10_gpu_checkpoint/graph_node.pbtxt"      # the pbtxt path
    output_nodes = "generator1/decoder_1/Tanh"
    restore_op = "save/restore_all"
    filename_tensor = "save/Const:0"
    output_name = "D:/resultAREinProcess10_gpu_checkpoint/AREinProcess2_step8100.pb"   # where you want to export your freezed model
    freeze_graph.freeze_graph(input_graph_path, "", False, path, output_nodes, restore_op, filename_tensor, output_name, True, "")


def optimize_frozen_file():
    """
    - Removing training-only operations like checkpoint saving.
    - Stripping out parts of the graph that are never reached.
    - Removing debug operations like CheckNumerics.
    - Folding batch normalization ops into the pre-calculated weights.
    - Fusing common operations into unified versions.

    "Note: important: Don't use placeholder as training switch, otherwise the folding batch normalization will occur error"
    :return:  a optimized function
    """
    inputGraph = tf.GraphDef()
    frozen_graph_filename = "D:/result201910072_gpu_checkpoint/frozen_model.pb"     # the freezed model path
    with tf.gfile.Open(frozen_graph_filename, "rb") as f:

        data2read = f.read()
        inputGraph.ParseFromString(data2read)

        outputGraph = optimize_for_inference_lib.optimize_for_inference(
            inputGraph,
            ["input_image"],  # an array of the input node(s)
            ["generator1/decoder_1/Tanh"],  # an array of output nodes
            dtypes.float32.as_datatype_enum)

        # Save the optimized graph'test.pb'

        f = tf.gfile.FastGFile('D:/result201910111_gpu_checkpoint/OptimizedGraph.pb', "w")

        f.write(outputGraph.SerializeToString())


def load_graph():
    frozen_filename = "D:/result201910111_gpu_checkpoint/OptimizedGraph.pb"
    with tf.gfile.GFile(frozen_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        # tf.import_graph_def(graph_def, name="prefix")
        tf.import_graph_def(graph_def)
    return graph


def childs(t, d=0):
    print('-' * d, t.name)
    for child in t.op.inputs:
        childs(child, d + 1)


if __name__ == '__main__':

    freeze_from_checkpoint()
    # optimize_frozen_file()
    #
    # graph = load_graph()
    # x = graph.get_tensor_by_name("import/input_image:0")
    #
    # pred = graph.get_tensor_by_name("import/generator1/decoder_1/Tanh:0")

    # with tf.Session(graph=graph) as sess:
    #     input_data = img
    #     y = sess.run(pred, feed_dict={x: input_data})
    #     print(y)
