2017-08-27 189 views
1

我有一些基本的功能,取入的图像的URL,并通过VGG-16 CNN其转换:tensorflow多对图像特征提取

def convert_url(_id, url): 
    im = get_image(url) 
    return _id, np.squeeze(sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im})) 

我有一个大组URL(〜60,000)我想在其上执行此功能。每次迭代需要一秒以上的时间,这太慢了。我想通过并行使用多个进程来加速它。没有共享状态可以担心,所以多线程的通常陷阱不是问题。

但是,我不确定如何实际使tensorflow与多处理程序包一起工作。我知道你不能将tensorflow session传递给Pool变量。所以不是,我想初始化的session多个实例:

def init(): 
    global sess; 
    sess = tf.Session() 

但是当我真正启动过程中,它只是挂起无限期:

with Pool(processes=3,initializer=init) as pool: 
    results = pool.starmap(convert_url, list(id_img_dict.items())[0:5]) 

注意,tensorflow图被全局定义。我认为这是正确的做法,但我不确定:

input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image') 
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor) 
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5) 
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0) 

arg_scope = vgg_arg_scope() 
with slim.arg_scope(arg_scope): 
    _, end_points = vgg_16(scaled_input_tensor, is_training=False) 
saver = tf.train.Saver() 
saver.restore(sess, checkpoint_file) 

任何人都可以帮助我实现这个工作吗?多谢。

+0

如果你在使用在线文件,你应该看看使用异步,这应该会产生很大的加速。 –

回答

1

忘掉python的正常多线程工具并使用tensorflow.contrib.data.Dataset。尝试如下所示。

urls = ['img1.jpg', 'img2.jpg', ...] 
batch_size = 16 
n_batches = len(urls) // batch_size # do something more elegant for remainder 


def load_img(url): 
    image = tf.read_file(url, name='image_data') 
    image = tf.image.decode_jpeg(image, channels=3, name='image') 
    return image 


def preprocess(img_tensor): 
    img_tensor = (tf.cast(img_tensor, tf.float32)/255 - 0.5)*2 
    img_tensor.set_shape((256, 256, 3)) # whatever shape 
    return img_tensor 


dataset = tf.contrib.data.Dataset.from_tensor_slices(urls) 
dataset = dataset.map(load_img).map(preprocess) 

preprocessed_images = dataset.batch(
    batch_size).make_one_shot_iterator().get_next() 


arg_scope = vgg_arg_scope() 
with slim.arg_scope(arg_scope): 
    _, end_points = vgg_16(preprocessed_images, is_training=False) 
    output = end_points['vgg_16/fc7'] 


results = [] 

with tf.Session() as sess: 
    tf.train.Saver().restore(sess, checkpoint_file) 
    for i in range(n_batches): 
     batch_results = sess.run(output) 
     results.extend(batch_results) 
     print('Done batch %d/%d' % (i+1, n_batches)) 
+0

欣赏响应!如果我在本地保存所有文件,这似乎是一个好方法。但是,我只在线存储URL到jpeg或png文件。显然,实际使用'urllib'或'requests'来获取图像文件本身是微不足道的,但是这会削弱使用'Datasets'来并行化图形的价值吗? – anon

+0

我明白了。在这种情况下,你可以很大程度上忽略这个答案。我没有做过很多正常的python多线程,但我想象如果你并行化数据提取(单独从tensorflow),你应该能够使用上面的代码,批量大于1来显着提高速度,例如,获取16张图片,当所有图片加载时,都会提供给tf.sess。 – DomJack

+0

我宁愿让每个线程都负责数据获取和tensorflow转换,因为这是一个更清晰的解决方案。你的解决方案是可行的,但我想知道通过多处理可以一次提取16张图像多快。也就是说,即使我连续抓住这16个图像,我也可以通过并行运行VGG,这样就可以工作。尽管如此,我仍然坚持有人知道如何做我最初想要的东西! – anon