2017-08-03 92 views
0

我试图在培训期间最大化GPU占用率。我有可变长度的序列,我想密集包装成固定长度的批次。从本质上讲,我希望短序列跟随另一个序列,并且我想要分割长序列以使它们在下一批中继续。例如:连续,固定长度的可变长度序列批次

// Say batch size is 2 and desired sequence length is 4 
s1 = [a, b, c, d, e, f] 
s2 = [x, y, z] 
s3 = [l, m, n, o] 

// Resulting batches: 
b1 = [[a, b, c, d] 
     [x, y, z, l]] 
b2 = [[e, f, _, _] 
     [m, n, o, _]] 

在Tensorflow中有一个简单的方法吗?我的序列来自tf.TextLineReader

file_queue = tf.train.string_input_producer('./example_text') 
reader = tf.TextLineReader() 
key, sentence = reader.read(file_queue) 
// convert string to int32 vector 
sequence_tensor = to_sequence(sentence) 

// what I wish I had: 
batch = tf.fixed_length_batch_from_variable_length_sequences(
    sequence_tensor, batch_size, fixed_length) 

非常感谢您的任何建议。

回答

0

好的,我有一个工作的例子,几乎是我所希望的。以下代码以我想要的方式生成批处理,但它需要使用占位符将数据传入和传出TF会话。我希望能够在TF图中完全建立这些批次。

希望我是愚蠢的,有一些明显的解决方案,有人可以指出。也请原谅camelCase。

import tensorflow as tf 

def buildBatch(seqLength, batchSize): 

    def lineToSequence(line): 
     line = tf.expand_dims(line, axis=0) 
     line = tf.sparse_tensor_to_dense(tf.string_split(line), '_') 
     line = tf.concat([line, [['<GO>']]], 1) 
     return line 

    data = tf.contrib.data.TextLineDataset(['./exampleFile.txt']) 
    data = data.map(lambda line: lineToSequence(line)) 
    iterator = data.make_initializable_iterator() 

    # Grab lines from the file until the the sequence length is met and shave off any extra 
    def getFixedLengthSequence(start): 
     c = lambda s: tf.shape(s)[1] < seqLength # while sequence is is too short 
     b = lambda s: tf.concat([s, iterator.get_next()], 1) # concatenate the next line 
     sentences = tf.while_loop(c, b, [start], back_prop=False, parallel_iterations=1, 
      shape_invariants=[tf.TensorShape([1, None])]) 

     clippedToLength = tf.expand_dims(sentences[0, :seqLength], axis=0) 
     leftover = tf.expand_dims(sentences[0, seqLength:], axis=0) 
     return clippedToLength, leftover 

    # Placeholders pass in the start of each sequence (which are saved from the last batch) 
    startOfThisBatch = [tf.placeholder(tf.string, shape=[1,None]) for i in range(batchSize)] 
    # Capture what is leftover from each sequence so it can be passed in to start the next batch 
    startOfNextBatch = [tf.TensorArray(tf.string, size=1) for i in range(batchSize)] 

    # Build the batch 
    thisBatch = [] 
    for i, seqStart in enumerate(startOfThisBatch): 
     seq, leftover = getFixedLengthSequence(seqStart) 
     thisBatch.append(seq) 
     startOfNextBatch[i] = startOfNextBatch[i].write(0, leftover) 
    thisBatch = tf.concat(thisBatch, axis=0) 
    startOfNextBatch = [b.read(0) for b in startOfNextBatch] 

    return thisBatch, startOfThisBatch, startOfNextBatch, iterator.initializer 


def printBatch(): 
    sequenceLength = 10 
    batchSize = 3 

    batch, startOfThisBatch, startOfNextBatch, iteratorInit = buildBatch(sequenceLength, batchSize) 
    # The very first batch starts with <GO> tokens 
    batchStarts = [[['<GO>']]]*batchSize 

    sv = tf.train.Supervisor() 
    with sv.managed_session() as sess: 
     sess.run(iteratorInit) 
     for b in range(4): 
      # Populate feed dict with the beginning of each sequence in the batch 
      feed = {} 
      for i in range(batchSize): 
       feed[startOfThisBatch[i]] = batchStarts[i] 

      # Call TF to get this batch and the starting sequences of the next batch 
      out, batchStarts = sess.run([batch, startOfNextBatch], feed_dict=feed) 

      print 'Batch', b, ':' 
      for seq in out: 
       print " ".join(seq) 
      print 

printBatch() 

结果:

Batch 0 : 
<GO> A spokesman said the company has been affected by 
<GO> Having a little flexibility on that issue would go 
<GO> Long before the advent of e-commerce , Wal-Mart 's 

Batch 1 : 
the credit crunch in the United States . <GO> Abu 
a long way to putting together a final package . 
founder Sam Walton set out his vision for a successful 

Batch 2 : 
Dhabi is going ahead to build solar city and no 
<GO> Her back was torn open , her liver was 
retail operation : " We let folks know we 're 

Batch 3 : 
pollution city . <GO> Now it has 175 staging centers 
ruptured , one of her lungs had collapsed and the 
interested in them and that they 're vital to us-- 

注意,每个句子继续在以下批处理。示例文本文件来自1-billion word benchmark dataset,每行包含一个句子。