import os
import tensorflow as tf
from PIL import Image
import numpy as np
import pandas as pd

# 原始图片的存储位置
orig_picture = os.getcwd() + '\\image\\test'
# 生成图片的存储位置
gen_picture = os.getcwd() + '\\image'
# 需要的识别类型
classes = {'0', '1'}
# 样本总数
num_samples = 40


# 制作TFRecords数据
def create_record():
    writer = tf.python_io.TFRecordWriter("test.tfrecords")
    for index, name in enumerate(classes):
        class_path = orig_picture + "/" + name + "/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((32, 32))  # 设置需要转换的图片大小
            ###图片灰度化######################################################################
            # img=img.convert("L")
            ##############################################################################################
            img_raw = img.tobytes()  # 将图片转化为原生bytes
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
            writer.write(example.SerializeToString())
    writer.close()


# =======================================================================================
def read_and_decode(filename, is_batch):
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        })
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [32, 32, 3])
    # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)

    if is_batch:
        batch_size = 3
        min_after_dequeue = 10
        capacity = min_after_dequeue + 3 * batch_size
        img, label = tf.train.shuffle_batch([img, label],
                                            batch_size=batch_size,
                                            num_threads=3,
                                            capacity=capacity,
                                            min_after_dequeue=min_after_dequeue)
    return img, label