2017-07-28 90 views
1

我正在试图在张量流中训练一个神经网络。我使用tf.train.batch_join()函数加载数据及其标签。我这样做:张量流中的tf.train.batch_join()函数如何工作?

image_batch, label_batch, image_batch_f = tf.train.batch_join(
     images_and_labels, batch_size=batch_size_placeholder, 
     #shapes=[(args.image_size, args.image_size, 3),()], enqueue_many=True, 
     shapes=[(args.image_height, args.image_width, 3),(), (args.image_height, args.image_width, 3)], enqueue_many=True, 
     capacity=4 * nrof_preprocess_threads * args.batch_size, 
     allow_smaller_final_batch=True) 
    image_batch = tf.identity(image_batch, 'image_batch') 
    image_batch = tf.identity(image_batch, 'input') 
    label_batch = tf.identity(label_batch, 'label_batch') 
    image_batch_f = tf.identity(image_batch_f, 'flipped_images_batch') 

在这里,我得到了三批数据。一批图像,一批标签和一批与图像批次中相同图像的翻转图像。我想提取一批图像和翻转图像的功能。以下各行通过网络传递批量数据。

# Build the inference graph 
    prelogits, _ = network.inference(image_batch, args.keep_probability, 
     phase_train=phase_train_placeholder, feature_dimension=args.embedding_size, 
     weight_decay=args.weight_decay) 


    features = tf.nn.l2_normalize(prelogits, 1, 1e-10, name='embeddings') 

    #getting the flipped embeddings 
    prelogits_f, _ = network.inference(image_batch_f,args.keep_probability, 
        phase_train=phase_train_placeholder,feature_dimension=args.embedding_size, 
        weight_decay=args.weight_decay,reuse=True) 
    features_flipped_images = tf.nn.l2_normalize(prelogits_f,1,1e-10,name='embeddings_f') 

为了获取这两个功能,我在features和features_flipped_images ops上运行了一个session.run()。这样的事情:

feed_dict = {phase_train_placeholder:False, batch_size_placeholder:batch_size} 
emb, emb_f = sess.run([features, features_flipped_images],feed_dict=feed_dict) 

我的问题是以下。我猜测,当我在功能上运行会话时,即batch_join函数将派发一批batch_size大小的图像。但是当我在features_flipped_images上执行session.run()时,该函数还会从batch_join函数中获取一批翻转的图像。在执行features_flipped_images时,batch_join函数是否派发一批新的翻转图像?或者它是在执行特征时生成的同一批翻转图像?如果没有,那我该怎么做?我想提取一批图像和一批翻转图像的特征。

回答

0

我的猜测是每次运行[features,features_flipped_images]只会得到同一批数据。让我们举个例子:

imgs_batch,labels_batch = tf.train.batch([img, label]...) 

那么,如果你想看到什么是批处理:

imgs_data, labels_data = sess.run([imgs_batch, labels_batch]) 

你看,当你运行sess.run([特点,features_flipped_images]很相似,。 )。我不认为你会得到两批,否则,imgs_data和labels_data不相互对应。

+0

我不确定它是否是同一批次,因为如果我将图像和翻转图像的功能连接起来并使用连接功能进行匹配,那么当它实际上应该改善匹配性能时,我的系统性能会明显下降。我还不清楚批量加载器。 –