2017-02-10 191 views
3

假设我有一个形状不同的几个张量A_i [N_i,N_i]。在张量流中是否有可能用对角线上的这些矩阵创建块对角矩阵?我现在能想到的唯一方法是通过堆叠和添加tf.zeros完全构建它自己。Tensorflow中的块对角矩阵

回答

5

我同意这样做会很高兴有一个C++操作系统。在此期间,这里是我做什么(获取静态形状信息的权利是有点繁琐):

import tensorflow as tf 

def block_diagonal(matrices, dtype=tf.float32): 
    r"""Constructs block-diagonal matrices from a list of batched 2D tensors. 

    Args: 
    matrices: A list of Tensors with shape [..., N_i, M_i] (i.e. a list of 
     matrices with the same batch dimension). 
    dtype: Data type to use. The Tensors in `matrices` must match this dtype. 
    Returns: 
    A matrix with the input matrices stacked along its main diagonal, having 
    shape [..., \sum_i N_i, \sum_i M_i]. 

    """ 
    matrices = [tf.convert_to_tensor(matrix, dtype=dtype) for matrix in matrices] 
    blocked_rows = tf.Dimension(0) 
    blocked_cols = tf.Dimension(0) 
    batch_shape = tf.TensorShape(None) 
    for matrix in matrices: 
    full_matrix_shape = matrix.get_shape().with_rank_at_least(2) 
    batch_shape = batch_shape.merge_with(full_matrix_shape[:-2]) 
    blocked_rows += full_matrix_shape[-2] 
    blocked_cols += full_matrix_shape[-1] 
    ret_columns_list = [] 
    for matrix in matrices: 
    matrix_shape = tf.shape(matrix) 
    ret_columns_list.append(matrix_shape[-1]) 
    ret_columns = tf.add_n(ret_columns_list) 
    row_blocks = [] 
    current_column = 0 
    for matrix in matrices: 
    matrix_shape = tf.shape(matrix) 
    row_before_length = current_column 
    current_column += matrix_shape[-1] 
    row_after_length = ret_columns - current_column 
    row_blocks.append(tf.pad(
     tensor=matrix, 
     paddings=tf.concat(
      [tf.zeros([tf.rank(matrix) - 1, 2], dtype=tf.int32), 
      [(row_before_length, row_after_length)]], 
      axis=0))) 
    blocked = tf.concat(row_blocks, -2) 
    blocked.set_shape(batch_shape.concatenate((blocked_rows, blocked_cols))) 
    return blocked 

举个例子:

blocked_tensor = block_diagonal(
    [tf.constant([[1.]]), 
    tf.constant([[1., 2.], [3., 4.]])]) 

with tf.Session(): 
    print(blocked_tensor.eval()) 

打印:

[[ 1. 0. 0.] 
[ 0. 1. 2.] 
[ 0. 3. 4.]] 
+0

谢谢艾伦,它就像一个魅力! – Fork2