2017-06-06 121 views
-1

我有三个数组,X,YZ。如果Z的对应元素为真,我想要放入resX的元素;否则,我会放入一个来自Y的元素。其中()需要1到2个位置参数,但有3个被给出

我实现这样的:

X = tf.constant([[1, 2], [3, 4]]) 
Y = tf.constant([[5, 6], [7, 8]]) 
Z = tf.constant([[True, False], [False, True]], tf.bool) 
res = tf.where(Z, X, Y) 
print(res.eval()) 

不过,我得到这个错误:

TypeError: where() takes from 1 to 2 positional arguments but 3 were given 

我看着tf.where的definiton从here和我的使用似乎罚款。

任何想法可能是什么问题?

+0

你可以试试'tf.where(Z,X = X,Y = Y)' – pramod

+0

您的代码工作正常TensorFlow 1.0.1,所以我很好奇:这你使用TF版本? – npf

回答

相关问题