2017-05-31 72 views
0

我是新手到tensorflow,我试图获得张量中最大值的索引。下面是代码:沿多个维度的Tensorflow argmax

def select(input_layer): 

    shape = input_layer.get_shape().as_list() 

    rel = tf.nn.relu(input_layer) 
    print (rel) 
    redu = tf.reduce_sum(rel,3) 
    print (redu) 

    location2 = tf.argmax(redu, 1) 
    print (location2) 

sess = tf.InteractiveSession() 
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32) 
matI, matO = sess.run([I, select(I, 3)]) 
print(matI, matO) 

这里是输出:

Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32) 
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32) 
Tensor("ArgMax:0", shape=(32, 3), dtype=int64) 
... 

由于尺寸= 1在argmax功能的Tensor("ArgMax:0") = (32,3)形状。有没有办法在应用argmax之前得到argmax输出张量大小= (32,)而不是做reshape

+0

这有什么错了'tf.reshape(热度,[32,-1])'? ['tf.argmax'](https://www.tensorflow.org/api_docs/python/tf/argmax)只会沿着一个轴减少 – martianwars

回答

1

您可能不希望输出大小为(32,),因为当沿着几个方向argmax时,通常希望所有缩小的维度都具有最大值的坐标。在你的情况下,你会想要一个大小为(32,2)的输出。

你可以做一个二维argmax这样的:

import numpy as np 
import tensorflow as tf 

x = np.zeros((10,9,8)) 
# pick a random position for each batch image that we set to 1 
pos = np.stack([np.random.randint(9,size=10), np.random.randint(8,size=10)]) 

posext = np.concatenate([np.expand_dims([i for i in range(10)], axis=0), pos]) 
x[tuple(posext)] = 1 

a = tf.argmax(tf.reshape(x, [10, -1]), axis=1) 
pos2 = tf.stack([a // 8, tf.mod(a, 8)]) # recovered positions, one per batch image 

sess = tf.InteractiveSession() 
# check that the recovered positions are as expected 
assert (pos == pos2.eval()).all(), "it did not work" 
+0

感谢您的快速响应,但问题是我正在使用tensorflow版本0.9并且tf.stack在该版本中不存在。你有替代品吗? – Maystro

+0

对不起,我没有版本0.9,tensorflow.org只有0.10以上的文档。如果存在连接,可以将其与'expand_dims'结合起来以模拟'stack'。 – user1735003