2017-04-05 79 views
0

正如标题所述,我试图从tensorflow中的矩阵中提取每行最高的n个元素,并将结果存储在稀疏Tensor中。将tf.nn.top_n的输出转换为稀疏矩阵

我已经能够使用tf.nn.top_n提取索引和值,但索引不遵循tf.SparseTensor所要求的约定。

具体来说,tf.nn.top_n返回与所得值矩阵(行xn)具有相同形状的col索引矩阵,而tf.SparseTensor想要一个(#非零x 2)矩阵,每行一行非零元素和包含行和列索引的列。

这些值可能是一个类似的问题,从而需要一个非零元素列表而不是值矩阵。

如何在这些索引记法方案之间快速转换?

回答

2

这是有点模块化算术。这是一个可以在矩阵上工作的例子,虽然可以循环更多的轴。

import tensorflow as tf 

def slices_to_dims(slice_indices): 
    """ 
    Args: 
    slice_indices: An [N, k] Tensor mapping to column indices. 
    Returns: 
    An index Tensor with shape [N * k, 2], corresponding to indices suitable for 
    passing to SparseTensor. 
    """ 
    slice_indices = tf.cast(slice_indices, tf.int64) 
    num_rows = tf.shape(slice_indices, out_type=tf.int64)[0] 
    row_range = tf.range(num_rows) 
    item_numbers = slice_indices * num_rows + tf.expand_dims(row_range, axis=1) 
    item_numbers_flat = tf.reshape(item_numbers, [-1]) 
    return tf.stack([item_numbers_flat % num_rows, 
        item_numbers_flat // num_rows], axis=1) 

用法示例:

dense_shape = [5, 7] 
dense_matrix = tf.random_normal(shape=dense_shape) 
top_values, top_indices = tf.nn.top_k(dense_matrix, k=2) 
sparse_indices = slices_to_dims(top_indices) 
sparse_tensor = tf.sparse_reorder(tf.SparseTensor(
    indices=sparse_indices, 
    values=tf.reshape(top_values, [-1]), 
    dense_shape=dense_shape)) 
densified_top = tf.sparse_tensor_to_dense(sparse_tensor) 
with tf.Session() as session: 
    sparse_top, dense_original, dense_selected = session.run(
     [sparse_tensor, dense_matrix, densified_top]) 
    print(dense_original) 
    print(dense_selected) 
    print(sparse_top) 

打印:

[[ 1.44056129 -1.01790774 -0.2795608 2.34854746 -2.27528405 -0.62035948 
    3.36598897] 
[ 0.7114948 -0.42564821 -0.93446779 -0.25373486 -0.51730365 0.72331643 
    -0.75625718] 
[-0.6501748 -0.92748415 -0.95409006 -0.07157528 0.80637723 -0.32177576 
    -1.4516511 ] 
[-1.081038 -0.67226124 -1.19455576 0.44537872 -0.69019234 -0.61539739 
    0.15328468] 
[ 0.43032476 -0.11295394 0.83491379 -0.67906654 0.20325914 -0.0155068 
    0.52107805]] 
[[ 0.   0.   0.   2.34854746 0.   0. 
    3.36598897] 
[ 0.7114948 0.   0.   0.   0.   0.72331643 
    0.  ] 
[ 0.   0.   0.   -0.07157528 0.80637723 0.   0.  ] 
[ 0.   0.   0.   0.44537872 0.   0. 
    0.15328468] 
[ 0.   0.   0.83491379 0.   0.   0. 
    0.52107805]] 
SparseTensorValue(indices=array([[0, 3], 
     [0, 6], 
     [1, 0], 
     [1, 5], 
     [2, 3], 
     [2, 4], 
     [3, 3], 
     [3, 6], 
     [4, 2], 
     [4, 6]]), values=array([ 2.34854746, 3.36598897, 0.7114948 , 0.72331643, -0.07157528, 
     0.80637723, 0.44537872, 0.15328468, 0.83491379, 0.52107805], dtype=float32), dense_shape=array([5, 7])) 
+0

就像一个魅力!我担心这样的方法会导致太多的开销,但看起来相当活泼。谢谢! – zergylord