2016-12-01 65 views
2

我有一个关于通过tensorflow python api更新张量值的基本问题。更新tensorflow中的变量值

考虑代码片段:

x = tf.placeholder(shape=(None,10), ...) 
y = tf.placeholder(shape=(None,), ...) 
W = tf.Variable(randn(10,10), dtype=tf.float32) 
yhat = tf.matmul(x, W) 

现在让我们假设我想实现某种算法,用于迭代更新W的值(例如一些优化算法中)。这将涉及类似步:

for i in range(max_its): 
    resid = y_hat - y 
    W = f(W , resid) # some update 

这里的问题是,W在LHS是一个新的张量,而不是在yhat = tf.matmul(x, W)使用W!也就是说,创建了一个新变量,并且我的“模型”中使用的W的值不会更新。

现在围绕这一方法是

for i in range(max_its): 
    resid = y_hat - y 
    W = f(W , resid) # some update 
    yhat = tf.matmul(x, W) 

导致创建一个新的“模式”我的每个循环迭代的!

有没有一种更好的方式来实现这个(在python中),而不是为循环的每次迭代创建一大堆新模型 - 而是更新原始张量W“in-place”可以这么说?

回答

2

变量有一个赋值方法。尝试:W.assign(f(W,resid))

+0

这似乎工作!谢谢。 – firdaus

0

@ aarbelle的简洁答案是正确的,我会扩展一下以防有人需要更多信息。最后2行用于更新W.

x = tf.placeholder(shape=(None,10), ...) 
y = tf.placeholder(shape=(None,), ...) 
W = tf.Variable(randn(10,10), dtype=tf.float32) 
yhat = tf.matmul(x, W) 

... 

for i in range(max_its): 
    resid = y_hat - y 
    update = W.assign(f(W , resid)) # do not forget to initialize tf variables. 
    # "update" above is just a tf op, you need to run the op to update W. 
    sess.run(update)