2017-10-09 132 views
0

我想用graph_cnnDefferrard et al. 2016)来输入节点数量的变化。作者提供了示例代码(请参阅graph_cnn)。下面是我认为的代码的关键部分张量流中graph_cnn的批处理

def chebyshev5(self, x, L, Fout, K): 
    N, M, Fin = x.get_shape() 
    N, M, Fin = int(N), int(M), int(Fin) 
    # Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L. 
    L = scipy.sparse.csr_matrix(L) 
    L = graph.rescale_L(L, lmax=2) 
    L = L.tocoo() 
    indices = np.column_stack((L.row, L.col)) 
    L = tf.SparseTensor(indices, L.data, L.shape) 
    L = tf.sparse_reorder(L) 
    # Transform to Chebyshev basis 
    x0 = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 
    x0 = tf.reshape(x0, [M, Fin*N]) # M x Fin*N 
    x = tf.expand_dims(x0, 0) # 1 x M x Fin*N 
    def concat(x, x_): 
     x_ = tf.expand_dims(x_, 0) # 1 x M x Fin*N 
     return tf.concat([x, x_], axis=0) # K x M x Fin*N 
    if K > 1: 
     x1 = tf.sparse_tensor_dense_matmul(L, x0) 
     x = concat(x, x1) 
    for k in range(2, K): 
     x2 = 2 * tf.sparse_tensor_dense_matmul(L, x1) - x0 # M x Fin*N 
     x = concat(x, x2) 
     x0, x1 = x1, x2 
    x = tf.reshape(x, [K, M, Fin, N]) # K x M x Fin x N 
    x = tf.transpose(x, perm=[3,1,2,0]) # N x M x Fin x K 
    x = tf.reshape(x, [N*M, Fin*K]) # N*M x Fin*K 
    # Filter: Fin*Fout filters of order K, i.e. one filterbank per feature pair. 
    W = self._weight_variable([Fin*K, Fout], regularization=False) 
    x = tf.matmul(x, W) # N*M x Fout 
    return tf.reshape(x, [N, M, Fout]) # N x M x Fout 

从本质上讲,我觉得这样做有什么方法可以简化为像

return = concat{(L*x)^k for (k=0 to K-1)} * W

xN x M x Fin(大小可变的输入在任何批次中):

L是一组运算符x,每个运算符的大小为M x M,与匹配的对应值样品(大小在任何批次中可变)。

W是要优化的神经网络参数,它的大小是Fin x K x Fout

N:样本的批处理(大小固定为任何批次)号码;

M:图中节点的数量(任意批次中的大小变量);

Fin:输入特征的数量(大小固定为任何批次)]。

Fout是输出特征的数量(对于任何批次固定的大小)。

K是一个常数,表示在图表

对于单的示例步骤(跳)的数量,在上述代码工作。但是,由于xL对于批次中的每个样品都具有可变长度,所以我不知道如何使其适用于一批样品。

回答

0

tf.matmul当前(v1.4)仅支持密集张量的最低2个dims上的批量矩阵乘法。如果输入张量中的任何一个都很稀疏,则会提示尺寸不匹配错误。 tf.sparse_tensor_dense_matmul也不能应用于批量输入。因此,我当前的解决方案是在调用函数之前移动所有L准备步骤,将L作为稠密张量(形状:[N,M,M])传递,并使用tf.matmul执行批量矩阵乘法。

这是我修改后的代码:

''' 
chebyshev5_batch 
Purpose: 
    perform the graph filtering on the given layer 
Args: 
    x: the batch of inputs for the given layer, 
     dense tensor, size: [N, M, Fin], 
    L: the batch of sorted Laplacian of the given layer (tf.Tensor) 
     if in dense format, size of [N, M, M] 
    Fout: the number of output features on the given layer 
    K: the filter size or number of hopes on the given layer. 
    lyr_num: the idx of the original Laplacian lyr (start form 0) 
Output: 
    y: the filtered output from the given layer 

''' 
def chebyshev5_batch(x, L, Fout, K, lyr_num): 
    N, M, Fin = x.get_shape() 
    #N, M, Fin = int(N), int(M), int(Fin) 
# # Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L. 
# L = scipy.sparse.csr_matrix(L) 
# L = graph.rescale_L(L, lmax=2) 
# L = L.tocoo() 
# indices = np.column_stack((L.row, L.col)) 
# L = tf.SparseTensor(indices, L.data, L.shape) 
# L = tf.sparse_reorder(L) 
# # Transform to Chebyshev basis 
# x0 = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 
# x0 = tf.reshape(x0, [M, Fin*N]) # M x Fin*N 

    def expand_concat(orig, new): 
     new = tf.expand_dims(new, 0) # 1 x N x M x Fin 
     return tf.concat([orig, new], axis=0) # (shape(x)[0] + 1) x N x M x Fin 

    # L: # N x M x M 
    # x0: # N x M x Fin 
    # L*x0: # N x M x Fin 

    x0 = x # N x M x Fin 
    stk_x = tf.expand_dims(x0, axis=0) # 1 x N x M x Fin (eventually K x N x M x Fin, if K>1) 

    if K > 1: 
     x1 = tf.matmul(L, x0) # N x M x Fin 
     stk_x = expand_concat(stk_x, x1) 
    for kk in range(2, K): 
     x2 = tf.matmul(L, x1) - x0 # N x M x Fin 
     stk_x = expand_concat(stk_x, x2) 
     x0 = x1 
     x1 = x2 

    # now stk_x has the shape of K x N x M x Fin 
    # transpose to the shape of N x M x Fin x K 
    ## source positions   1 2 3  0 
    stk_x_transp = tf.transpose(stk_x, perm=[1,2,3,0]) 
    stk_x_forMul = tf.reshape(stk_x_transp, [N*M, Fin*K]) 


    #W = self._weight_variable([Fin*K, Fout], regularization=False) 
    W_initial = tf.truncated_normal_initializer(0, 0.1) 
    W = tf.get_variable('weights_L_'+str(lyr_num), [Fin*K, Fout], tf.float32, initializer=W_initial) 
    tf.summary.histogram(W.op.name, W) 

    y = tf.matmul(stk_x_forMul, W) 
    y = tf.reshape(y, [N, M, Fout]) 
    return y