2017-08-30 387 views
8

对于Tensorflow训练的LSTM模式,我已经结构化我的数据为tf.train.SequenceExample格式,并将其存储到TFRecord文件。我现在想使用新的DataSet API来生成生成填充批次用于培训。在the documentation有一个使用padded_batch的例子,但对于我的数据我无法弄清楚padded_shapes应该是什么值。如何使用DataSet API在Tensorflow中为tf.train.SequenceExample数据创建填充批次?

对于读TFrecord文件到我写了下面的Python代码批次:

import math 
import tensorflow as tf 
import numpy as np 
import struct 
import sys 
import array 

if(len(sys.argv) != 2): 
    print "Usage: createbatches.py [RFRecord file]" 
    sys.exit(0) 


vectorSize = 40 
inFile = sys.argv[1] 

def parse_function_dataset(example_proto): 
    sequence_features = { 
     'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize], 
              dtype=tf.float32), 
     'labels': tf.FixedLenSequenceFeature(shape=[], 
              dtype=tf.int64)} 

    _, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features) 

    length = tf.shape(sequence['inputs'])[0] 
    return sequence['inputs'], sequence['labels'] 

sess = tf.InteractiveSession() 

filenames = tf.placeholder(tf.string, shape=[None]) 
dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parse_function_dataset) 
# dataset = dataset.batch(1) 
dataset = dataset.padded_batch(4, padded_shapes=[None]) 
iterator = dataset.make_initializable_iterator() 

batch = iterator.get_next() 

# Initialize `iterator` with training data. 
training_filenames = [inFile] 
sess.run(iterator.initializer, feed_dict={filenames: training_filenames}) 

print(sess.run(batch)) 

代码工作得很好,如果我使用dataset = dataset.batch(1)(在这种情况下,不需要填充),但是当我使用padded_batch变种,我得到以下错误:

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: .

你能帮我搞清楚什么我应该通过对padded_shapes参数?

(我知道有很多的例子代码中使用线程和队列对于这一点,但我宁愿使用新的DataSet API为这个项目)

+0

谢谢Marijn!你的问题帮了我很多! –

回答

6

您需要通过形状的元组。 在你的情况,你应该通过

dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None])) 

,或者尝试

dataset = dataset.padded_batch(4, padded_shapes=([None],[None])) 

检查这个code了解更多详情。我不得不调试这个方法来找出为什么它不适合我。

+0

谢谢!这很有道理。以下是我的例子:'padded_shapes =([None,vectorSize],[None])'。第一个张量是维度为vectorSize的矢量列表,第二个是具有整数标签的列表。 –

+0

就像补充一样,'padded_shapes'对嵌套结构的类型很敏感(如果数据集返回一个元组,padded_shapes也应该是一个元组,并且不是列表) – Conchylicultor

0

如果您当前的Dataset对象包含元组,则还可以指定每个填充元素的形状。

例如,我有一个(same_sized_images, Labels)数据集,每个标签具有不同的长度但排名相同。

def process_label(resized_img, label): 
    # Perfrom some tensor transformations 
    # ...... 

    return resized_img, label 

dataset = dataset.map(process_label) 
dataset = dataset.padded_batch(batch_size, 
           padded_shapes=([None, None, 3], 
               [None, None])) # my label has rank 2 
相关问题