什么是 TFRecord
PS:这段内容摘自 http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
一种保存记录的方法可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocolbuffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriterclass写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。
从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个parse_single_example操作可以将Example协议内存块(protocolbuffer)解析为张量。 MNIST的例子就使用了convert_to_records 所构建的数据。 请参看tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py,
代码
adjust_pic.py
单纯的转换图片大小
-
-
- import tensorflow as tf
-
- def resize(img_data, width, high, method=0):
- return tf.image.resize_images(img_data,[width, high], method)
pic2tfrecords.py
将图片保存成TFRecord
-
-
- import os.path
- import matplotlib.image as mpimg
- import tensorflow as tf
- import adjust_pic as ap
- from PIL import Image
-
-
- SAVE_PATH = 'data/dataset.tfrecords'
-
-
- def _int64_feature(value):
- return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
-
- def _bytes_feature(value):
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-
- def load_data(datafile, width, high, method=0, save=False):
- train_list = open(datafile,'r')
-
- writer = tf.python_io.TFRecordWriter(SAVE_PATH)
-
- with tf.Session() as sess:
- for line in train_list:
-
- tmp = line.strip().split(' ')
- img_path = tmp[0]
- label = int(tmp[1])
-
-
- image = tf.gfile.FastGFile(img_path, 'r').read()
-
- image = tf.image.decode_jpeg(image)
-
-
- image = tf.image.convert_image_dtype(image, dtype=tf.float32)
-
- image = ap.resize(image, width, high)
-
- image = sess.run(image)
-
-
- image_raw = image.tostring()
-
- example = tf.train.Example(features=tf.train.Features(feature={
- 'image_raw': _bytes_feature(image_raw),
- 'label': _int64_feature(label),
- }))
-
-
- writer.write(example.SerializeToString())
-
- writer.close()
-
-
- load_data('train_list.txt_bak', 224, 224)
tfrecords2data.py
从TFRecord中读取并保存成图片
-
-
- import tensorflow as tf
- import numpy as np
-
-
- SAVE_PATH = 'data/dataset.tfrecords'
-
-
- def load_data(width, high):
- reader = tf.TFRecordReader()
- filename_queue = tf.train.string_input_producer([SAVE_PATH])
-
-
- _, serialized_example = reader.read(filename_queue)
-
- features = tf.parse_single_example(
- serialized_example,
- features={
- 'image_raw': tf.FixedLenFeature([], tf.string),
- 'label': tf.FixedLenFeature([], tf.int64),
- })
-
-
- images = tf.decode_raw(features['image_raw'], tf.uint8)
- labels = tf.cast(features['label'], tf.int64)
-
- with tf.Session() as sess:
-
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
-
-
- for i in range(2):
-
- label, image = sess.run([labels, images])
-
-
-
-
-
-
-
-
-
-
-
-
- image = np.fromstring(image, dtype=np.float32)
-
- image = tf.reshape(image, [224, 224, 3])
-
- image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
-
- image = tf.image.encode_jpeg(image)
-
- with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f:
- f.write(sess.run(image))
-
-
- load_data(224, 224)
train_list.txt_bak 中的内容如下: image_1093.jpg 13
image_0805.jpg 10
转自:http://blog.csdn.net/xueyingxue001/article/details/68943650