2017-08-29 103 views
0

我该如何改变MNIST教程使用TFRecords而不是教程从网上下载的奇怪格式?Tensorflow MNIST TFRecord

我以前build_image_data.py从成立之初模型来创建一个包含200×200的RGB图像我TFRecords并打算培养这个在1080Ti,但我不能找到如何加载TFRecords任何好的例子,并将它们送入卷积神经网络。

+0

查看[本指南](https://www.tensorflow.org/programmers_guide/datasets)它有示例显示如何从数据中加载TFRecord文件和gt张量的数据。那么这只是一个将数据作为输入传递到网络的问题,而不是网络现在获得的任何输入 – GPhilo

+0

@GPhilo我有我的数据集可用作“图像:图像。大小为4D的张量[batch_size,FLAGS.image_size, image_size,3]。 标签:[FLAGS.batch_size]。的一维整数张量。“,但我没有看到tf.estimator.inputs有一个函数来接受我加载的内容。 – Eejin

+0

tf.estimator.inputs具有便利的功能,可以将尚未处于张量格式的数据转换为网络可以使用的数据。你需要重写'input_fn'。我不熟悉这个高级API,但是来自[Estimator文档](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)我想你需要定义一个'input_fn'返回一个字典'{'images':your_image_tensor,'labels':your_label_tensor}'。 – GPhilo

回答

0

我做了一件类似的事情,你打算做。我也采用了相同的脚本来构建图像数据。我的代码读取数据和训练它是

import tensorflow as tf 

height = 28 
width = 28 

tfrecords_train_filename = 'train-00000-of-00001' 
tfrecords_test_filename = 'test-00000-of-00001' 


def read_and_decode(filename_queue): 
    reader = tf.TFRecordReader() 

    _, serialized_example = reader.read(filename_queue) 

    features = tf.parse_single_example(
     serialized_example, 
     features={ 
      'image/height': tf.FixedLenFeature([], tf.int64), 
      'image/width': tf.FixedLenFeature([], tf.int64), 
      'image/colorspace': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 
      'image/channels': tf.FixedLenFeature([], tf.int64), 
      'image/class/label': tf.FixedLenFeature([], tf.int64), 
      'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 
      'image/format': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 
      'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 
      'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='') 
     }) 

    image_buffer = features['image/encoded'] 
    image_label = tf.cast(features['image/class/label'], tf.int32) 

    # Decode the jpeg 
    with tf.name_scope('decode_jpeg', [image_buffer], None): 
     # decode 
     image = tf.image.decode_jpeg(image_buffer, channels=3) 

     # and convert to single precision data type 
     image = tf.image.convert_image_dtype(image, dtype=tf.float32) 
     image = tf.image.rgb_to_grayscale(image) 

    image_shape = tf.stack([height, width, 1]) 
    image = tf.reshape(image, image_shape) 

    return image, image_label 


def inputs(filename, batch_size, num_epochs): 
    if not num_epochs: num_epochs = None 

    with tf.name_scope('input'): 
     filename_queue = tf.train.string_input_producer([filename], num_epochs=None) 
     image, label = read_and_decode(filename_queue) 

     # Shuffle the examples and collect them into batch_size batches. 
     images, sparse_labels = tf.train.shuffle_batch(
      [image, label], batch_size=batch_size, num_threads=2, 
      capacity=1000 + 3 * batch_size, 
      min_after_dequeue=1000) 

     return images, sparse_labels 

image, label = inputs(filename=tfrecords_train_filename, batch_size=200, num_epochs=None) 
image = tf.reshape(image, [-1, 784]) 
label = tf.one_hot(label - 1, 10) 

# Create the model 
x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.matmul(x, W) + b 
y_ = tf.placeholder(tf.float32, [None, 10]) 

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) 

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 

    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

    for i in range(1000): 
     img, lbl = sess.run([image, label]) 
     sess.run(train_step, feed_dict={x: img, y_: lbl}) 

    img, lbl = sess.run([image, label]) 
    print(sess.run(accuracy, feed_dict={x: img, y_: lbl})) 

    coord.request_stop() 
    coord.join(threads) 

这是一个超级简单的分类mnist模型。不过,我认为这也是如何使用TFRecord文件进行训练的一个可扩展的答案。它尚未考虑到评估数据,因为这需要更多的协调工作。