2017-09-30 46 views
0

我在一个tf.train.SequenceExample中放入了一组固定长度和可变长度的特征。批量大小大于1时tensorflow数据集API不能稳定工作

context_features 
    length,   scalar,     tf.int64 
    site_code_raw,  scalar,     tf.string 
    Date_Local_raw, scalar,     tf.string 
    Time_Local_raw, scalar,     tf.string 
Sequence_features 
    Orig_RefPts,  [#batch, #RefPoints, 4] tf.float32 
    tgt_location,  [#batch, 3]    tf.float32 
    tgt_val   [#batch, 1]    tf.float32 

#RefPoints的值对于不同的序列实例是可变的。我将其值存储在context_features中的length功能中。其余功能具有固定大小。

这里是我用来阅读&解析数据代码:

def read_batch_DatasetAPI(
    filenames, 
    batch_size = 20, 
    num_epochs = None, 
    buffer_size = 5000): 

    dataset = tf.contrib.data.TFRecordDataset(filenames) 
    dataset = dataset.map(_parse_SeqExample1) 
    if (buffer_size is not None): 
     dataset = dataset.shuffle(buffer_size=buffer_size) 
    dataset = dataset.repeat(num_epochs) 
    dataset = dataset.batch(batch_size) 
    iterator = dataset.make_initializable_iterator() 
    next_element = iterator.get_next() 

    # next_element contains a tuple of following tensors 
    # length,   scalar,     tf.int64 
    # site_code_raw,  scalar,     tf.string 
    # Date_Local_raw, scalar,     tf.string 
    # Time_Local_raw, scalar,     tf.string 
    # Orig_RefPts,  [#batch, #RefPoints, 4] tf.float32 
    # tgt_location,  [#batch, 3]    tf.float32 
    # tgt_val   [#batch, 1]    tf.float32 

    return iterator, next_element 

def _parse_SeqExample1(in_SeqEx_proto): 

    # Define how to parse the example 
    context_features = { 
     'length': tf.FixedLenFeature([], dtype=tf.int64), 
     'site_code': tf.FixedLenFeature([], dtype=tf.string), 
     'Date_Local': tf.FixedLenFeature([], dtype=tf.string), 
     'Time_Local': tf.FixedLenFeature([], dtype=tf.string) #, 
    } 

    sequence_features = { 
     "input_features": tf.VarLenFeature(dtype=tf.float32), 
     'tgt_location_features': tf.FixedLenSequenceFeature([3], dtype=tf.float32), 
     'tgt_val_feature': tf.FixedLenSequenceFeature([1], dtype=tf.float32) 
    }               

    context, sequence = tf.parse_single_sequence_example(
     in_SeqEx_proto, 
     context_features=context_features, 
     sequence_features=sequence_features) 

    # distribute the fetched context and sequence features into tensors 
    length = context['length'] 
    site_code_raw = context['site_code'] 
    Date_Local_raw = context['Date_Local'] 
    Time_Local_raw = context['Time_Local'] 

    # reshape the tensors according to the dimension definition above 
    Orig_RefPts = sequence['input_features'].values 
    Orig_RefPts = tf.reshape(Orig_RefPts, [-1, 4]) 
    tgt_location = sequence['tgt_location_features'] 
    tgt_location = tf.reshape(tgt_location, [-1]) 
    tgt_val = sequence['tgt_val_feature'] 
    tgt_val = tf.reshape(tgt_val, [-1]) 

    return length, site_code_raw, Date_Local_raw, Time_Local_raw, \ 
     Orig_RefPts, tgt_location, tgt_val 

当我打电话read_batch_DatasetAPIbatch_size = 1(见下面的代码),它可以处理所有(20万左右)序列的例子单没有任何问题。但是,如果我将batch_size更改为大于1的任何数字,则在取出320到700个序列示例之后,它会停止而没有任何错误消息。我不知道如何解决这个问题。任何帮助表示赞赏!

# the iterator to get the next_element for one sample (in sequence) 
iterator, next_element = read_batch_DatasetAPI(
    in_tf_FWN, # the file name of the tfrecords containing ~200,000 Sequence Examples 
    batch_size = 1, # works when it is 1, doesn't work if > 1 
    num_epochs = 1, 
    buffer_size = None) 

# tf session initialization 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 

## reset the iterator to the beginning 
sess.run(iterator.initializer) 

try: 
    step = 0 

    while (True): 

     # get the next batch data 
     length, site_code_raw, Date_Local_raw, Time_Local_raw, \ 
     Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element) 

     step = step + 1 

except tf.errors.OutOfRangeError: 
    # Task Done (all SeqExs have been visited) 
    print("closing ", in_tf_FWN) 

except ValueError as err: 
    print("Error: {}".format(err.args)) 

except Exception as err: 
    print("Error: {}".format(err.args)) 

回答

0

我看到一些帖子(Example 1Example 2)提新dataset功能from_generatorhttps://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/Dataset#from_generator)。我不确定如何使用它来解决我的问题。任何人都知道如何去做,请将其作为新的答案发布。谢谢!

这是我的当前诊断和解决我的问题:

的序列长度(#RefPoints)的变化所引起的问题。 dataset.map(_parse_SeqExample1)只适用于#RefPoints碰巧在批次中相同。这就是为什么如果batch_size是1,它总是有效,但如果它大于1,它在某个时候失败了。

我发现datasetpadded_batch函数,它可以将可变长度填充到批处理中的最大长度。作了一些变更,暂时解决我的问题(我猜from_generator将是真正解决了我的情况):

  1. _parse_SeqExample1功能,return语句改为

    return tf.tuple([length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_val])

  2. read_batch_DatasetAPI功能,声明

    dataset = dataset.batch(batch_size)

    改为

    dataset = dataset.padded_batch(batch_size, padded_shapes=( tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([None, 4]), tf.TensorShape([3]), tf.TensorShape([1]) ) )

  3. 最后,从改变fetch语句

    length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)

注意: 我不知道为什么,这只适用于当前的tf-nightly-gpu版本而不是tensorflow-gpu v1.3。