我是新手到tensorflow
,我试图获得张量中最大值的索引。下面是代码:沿多个维度的Tensorflow argmax
def select(input_layer):
shape = input_layer.get_shape().as_list()
rel = tf.nn.relu(input_layer)
print (rel)
redu = tf.reduce_sum(rel,3)
print (redu)
location2 = tf.argmax(redu, 1)
print (location2)
sess = tf.InteractiveSession()
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32)
matI, matO = sess.run([I, select(I, 3)])
print(matI, matO)
这里是输出:
Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32)
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32)
Tensor("ArgMax:0", shape=(32, 3), dtype=int64)
...
由于尺寸= 1在argmax
功能的Tensor("ArgMax:0") = (32,3)
形状。有没有办法在应用argmax
之前得到argmax
输出张量大小= (32,)
而不是做reshape
?
这有什么错了'tf.reshape(热度,[32,-1])'? ['tf.argmax'](https://www.tensorflow.org/api_docs/python/tf/argmax)只会沿着一个轴减少 – martianwars