我有一个关于通过tensorflow python api更新张量值的基本问题。更新tensorflow中的变量值
考虑代码片段:
x = tf.placeholder(shape=(None,10), ...)
y = tf.placeholder(shape=(None,), ...)
W = tf.Variable(randn(10,10), dtype=tf.float32)
yhat = tf.matmul(x, W)
现在让我们假设我想实现某种算法,用于迭代更新W的值(例如一些优化算法中)。这将涉及类似步:
for i in range(max_its):
resid = y_hat - y
W = f(W , resid) # some update
这里的问题是,W
在LHS是一个新的张量,而不是在yhat = tf.matmul(x, W)
使用W
!也就是说,创建了一个新变量,并且我的“模型”中使用的W
的值不会更新。
现在围绕这一方法是
for i in range(max_its):
resid = y_hat - y
W = f(W , resid) # some update
yhat = tf.matmul(x, W)
导致创建一个新的“模式”我的每个循环迭代的!
有没有一种更好的方式来实现这个(在python中),而不是为循环的每次迭代创建一大堆新模型 - 而是更新原始张量W
“in-place”可以这么说?
这似乎工作!谢谢。 – firdaus