2016-10-25 310 views
0

以下代码使用tensorflow库,与numpy库相比运行速度非常慢。我知道我正在调用一个函数,它使用python for循环中的tensorflow库(我将稍后与python多处理进行并行化),但代码的运行速度非常慢。tensorflow在python for循环内运行速度极慢

有人可以请帮助我如何让这段代码运行得更快吗?谢谢。


from math import * 
import numpy as np 
import sys 
from multiprocessing import Pool 
import tensorflow as tf 

def Trajectory_Fun(tspan, a, b, session=None, server=None): 
    # Open tensorflow session 
    if session==None: 
     if server==None: 
      sess = tf.Session() 
     else: 
      sess = tf.Session(server.target)  
    else: 
     sess = session 
    B = np.zeros(np.size(tspan), dtype=np.float64) 
    B[0] = b 
    for i, t in enumerate(tspan): 
     r = np.random.rand(1) 
     if r>a: 
      c = sess.run(tf.trace(tf.random_normal((4, 4), r, 1.0))) 
     else: 
      c = 0.0 # sess.run(tf.trace(tf.random_normal((4, 4), 0.0, 1.0))) 
     B[i] = c 
    # Close tensorflow session 
    if session==None: 
     sess.close() 
    return B 

def main(argv): 
    # Parameters 
    tspan = np.arange(0.0, 1000.0) 
    a = 0.1 
    b = 0.0 
    # Run test program 
    B = Trajectory_Fun(tspan, a, b, None, None) 
    print 'Done!' 

if __name__ == "__main__": 
    main(sys.argv[1:]) 
+0

您正在缓慢地调整session.run调用之间的Graph对象。你可以在第一个'sess.run'前添加所有的操作并调用'tf.get_default_graph()。finalize()' –

+0

@YaroslavBulatov感谢您的快速响应。正如你可能已经注意到的那样,我需要每个时间步长的变量c的值。请您再澄清一下,我可以如何将您的建议纳入我的上述代码中?我会很感激。谢谢。 – QED

+0

在循环开始之前做'a = tf.random_normal((4,4),0.0,1.0)',然后执行'sess.run(a)' –

回答

2

正如你的问题说,因为它创造了每运行几个新TensorFlow图节点这一计划将给予表现不佳。 TensorFlow中的基本假设是(大约)您将构建一次图形,然后多次调用sess.run()(的各个部分)。您第一次运行图形相对昂贵,因为TensorFlow必须构建各种数据结构并优化跨多个设备的图形执行。然而,TensorFlow缓存了这项工作,所以后续使用便宜得多。

通过构建一次图并使用(例如)tf.placeholder() op来提供每次迭代中更改的值,可以使该程序快得多。例如,下面应该做的伎俩:

B = np.zeros(np.size(tspan), dtype=np.float64) 
B[0] = b 

# Define the TensorFlow graph once and reuse it in each iteration of the for loop. 
r_placeholder = tf.placeholder(tf.float32, shape=[]) 
out_t = tf.trace(tf.random_normal((4, 4), r_placeholder, 1.0)) 

with tf.Session() as sess: 
    for i, t in enumerate(tspan): 
    r = np.random.rand(1) 
    if r > a: 
     c = sess.run(out_t, feed_dict={r_placeholder: r}) 
    else: 
     c = 0.0 
    B[i] = c 
    return B 

你可能使这更有效的利用TensorFlow循环和sess.run()使得更少的调用,但总的原则是一样的:重复使用相同的图形多次获得TensorFlow的好处。