我有一些基本的功能,取入的图像的URL,并通过VGG-16 CNN其转换:tensorflow多对图像特征提取
def convert_url(_id, url):
im = get_image(url)
return _id, np.squeeze(sess.run(end_points['vgg_16/fc7'], feed_dict={input_tensor: im}))
我有一个大组URL(〜60,000)我想在其上执行此功能。每次迭代需要一秒以上的时间,这太慢了。我想通过并行使用多个进程来加速它。没有共享状态可以担心,所以多线程的通常陷阱不是问题。
但是,我不确定如何实际使tensorflow与多处理程序包一起工作。我知道你不能将tensorflow session
传递给Pool变量。所以不是,我想初始化的session
多个实例:
def init():
global sess;
sess = tf.Session()
但是当我真正启动过程中,它只是挂起无限期:
with Pool(processes=3,initializer=init) as pool:
results = pool.starmap(convert_url, list(id_img_dict.items())[0:5])
注意,tensorflow图被全局定义。我认为这是正确的做法,但我不确定:
input_tensor = tf.placeholder(tf.float32, shape=(None,224,224,3), name='input_image')
scaled_input_tensor = tf.scalar_mul((1.0/255), input_tensor)
scaled_input_tensor = tf.subtract(scaled_input_tensor, 0.5)
scaled_input_tensor = tf.multiply(scaled_input_tensor, 2.0)
arg_scope = vgg_arg_scope()
with slim.arg_scope(arg_scope):
_, end_points = vgg_16(scaled_input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
任何人都可以帮助我实现这个工作吗?多谢。
如果你在使用在线文件,你应该看看使用异步,这应该会产生很大的加速。 –