2016-04-03 63 views
5

假设我有一个尺寸为BxWxHxD的张量。我想处理张量,使得我有一个新的BxWxHxD张量,其中只保留每个WxH片中的最大元素,其他所有值都为零。Tensorflow多维argmax

换句话说,我认为实现这一目标的最好方法是以某种方式在WxH切片上采用2D argmax,从而产生行和列的BxD索引张量,然后可以将其转换为单热BxWxHxD张量被用作面具。我如何完成这项工作?

回答

1

您可以使用以下函数作为起点。它计算每个批次和每个通道的最大元素索引。结果数组采用格式(批量大小,2,通道数量)。

def argmax_2d(tensor): 

    # input format: BxHxWxD 
    assert rank(tensor) == 4 

    # flatten the Tensor along the height and width axes 
    flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3])) 

    # argmax of the flat tensor 
    argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32) 

    # convert indexes into 2D coordinates 
    argmax_x = argmax // tf.shape(tensor)[2] 
    argmax_y = argmax % tf.shape(tensor)[2] 

    # stack and return 2D coordinates 
    return tf.stack((argmax_x, argmax_y), axis=1) 

def rank(tensor): 

    # return the rank of a Tensor 
    return len(tensor.get_shape())