2016-08-15 215 views
0

我有一个Theano tensor3(即,3维阵列)x移基于偏移矢量tensor3元素的位置

[[[ 0 1 2 3] 
    [ 4 5 6 7] 
    [ 8 9 10 11]] 

[[12 13 14 15] 
    [16 17 18 19] 
    [20 21 22 23]]] 

以及一个Theano向量(即,一维阵列)y,我们将参考作为“偏移”向量,因为它指定了期望的偏移:

[2, 1] 

欲转移基于矢量yx元素的位置,从而使输出如下(在第二维进行变速):

[[[ a b c d] 
    [ e f g h] 
    [ 0 1 2 3]] 

[[ i j k l] 
    [12 13 14 15] 
    [16 17 18 19]]] 

其中ab,...,l可以是任何数目。

例如,一个有效的输出可能是:

[[[ 0 0 0 0] 
    [ 0 0 0 0] 
    [ 0 1 2 3]] 

[[ 0 0 0 0] 
    [12 13 14 15] 
    [16 17 18 19]]] 

另一种有效的输出可能是:

[[[ 4 5 6 7] 
    [ 8 9 10 11] 
    [ 0 1 2 3]] 

[[20 21 22 23] 
    [12 13 14 15] 
    [16 17 18 19]]] 

我知道的功能theano.tensor.roll(x, shift, axis=None)的,但是shift只能拿一个标量作为输入,即它移动所有具有相同偏移量的元素。

例如,代码:

import theano.tensor 
from theano import shared 
import numpy as np 

x = shared(np.arange(24).reshape((2,3,4))) 
print('theano.tensor.roll(x, 2, axis=1).eval(): \n{0}'. 
     format(theano.tensor.roll(x, 2, axis=1).eval())) 

输出:

theano.tensor.roll(x, 2, axis=1).eval(): 
[[[ 4 5 6 7] 
    [ 8 9 10 11] 
    [ 0 1 2 3]] 

[[16 17 18 19] 
    [20 21 22 23] 
    [12 13 14 15]]] 

这不是我想要的。

如何根据偏移矢量移动tensor3元素的位置? (注意在这个例子中提供的代码中,为了方便起见,tensor3是一个共享变量,但在我的实际代码中它将是一个符号变量)

回答

0

我找不到任何专用的功能,所以我简单地结束了使用theano.scan

import theano 
import theano.tensor 

from theano import shared 
import numpy as np 

y = shared(np.array([2,1])) 
x = shared(np.arange(24).reshape((2,3,4))) 
print('x.eval():\n{0}\n'.format(x.eval())) 

def shift_and_reverse_row(matrix, y):  
    ''' 
    Shift and reverse the matrix in the direction of the first dimension (i.e., rows) 
    matrix: matrix 
    y: scalar 
    ''' 
    new_matrix = theano.tensor.zeros_like(matrix) 
    new_matrix = theano.tensor.set_subtensor(new_matrix[:y,:], matrix[y-1::-1,:]) 
    return new_matrix 

new_x, updates = theano.scan(shift_and_reverse_row, outputs_info=None, 
          sequences=[x, y[::-1]]) 
new_x = new_x[:, ::-1, :] 
print('new_x.eval(): \n{0}'.format(new_x.eval())) 

输出:

x.eval(): 
[[[ 0 1 2 3] 
    [ 4 5 6 7] 
    [ 8 9 10 11]] 

[[12 13 14 15] 
    [16 17 18 19] 
    [20 21 22 23]]] 

new_x.eval(): 
[[[ 0 0 0 0] 
    [ 0 0 0 0] 
    [ 0 1 2 3]] 

[[ 0 0 0 0] 
    [12 13 14 15] 
    [16 17 18 19]]]