2011-03-08 76 views
2

我有一个在numpy中乘以矩阵的具体问题。 下面是一个例子:numpy tensordot相关问题

P=np.arange(30).reshape((-1,3)) 
array([[ 0, 1, 2], 
    [ 3, 4, 5], 
    [ 6, 7, 8], 
    [ 9, 10, 11], 
    [12, 13, 14], 
    [15, 16, 17], 
    [18, 19, 20], 
    [21, 22, 23], 
    [24, 25, 26], 
    [27, 28, 29]]) 

欲通过它的转置,以便获得3×3矩阵的每一行, 例如用于第一行相乘的每一行:

P[0]*P[0][:,np.newaxis] 
array([[0, 0, 0], 
    [0, 1, 2], 
    [0, 2, 4]]) 

和存储结果在3 d矩阵M:

M=np.zeros((10,3,3)) 
for i in range(10): 
    M[i] = P[i]*P[i][:,np.newaxis] 

我觉得可能是一个办法做到这一点不循环,可能与张量点,但无法找到它。

有人有想法吗?

回答

3

这只是简单的像这样:

In []: P= arange(30).reshape(-1, 3) 
In []: P[:, :, None]* P[:, None, :] 
Out[]: 
array([[[ 0, 0, 0], 
     [ 0, 1, 2], 
     [ 0, 2, 4]], 
     [[ 9, 12, 15], 
     [ 12, 16, 20], 
     [ 15, 20, 25]], 
     [[ 36, 42, 48], 
     [ 42, 49, 56], 
     [ 48, 56, 64]], 
     #... 
     [[729, 756, 783], 
     [756, 784, 812], 
     [783, 812, 841]]])  
In []: P[1]* P[1][:, None] 
Out[]: 
array([[ 9, 12, 15], 
     [12, 16, 20], 
     [15, 20, 25]]) 
+1

太棒了,正是我在找的东西,有没有一种简单的方法来理解我应该用None来扩展哪个索引? – 2011-03-08 16:58:29

+0

@Andrea Z:对于这样的方法论,我会先试着弄清楚'numpy'广播是如何运作的。然后,只是意识到你正在寻找一个'形状=(10,3,3)',或多或少'自然'地发现'(P [:,,无] * P [:,无, :])。shape ==(10,3,3)'。也许不是一个最好的描述,但我主要的观点是;熟悉广播如何工作。现在随意问'关于广播的内部'的其他问题!谢谢 – eat 2011-03-08 18:24:31

1

因为我喜欢stride_tricks,所以我会用它。我相信还有其他的方法。

更改数组的步幅和形状,以便将其展开为3D。你可以很容易地用P的“转置”版本做同样的事情,但是在这里我只是重新塑造它,让广播规则将它拉伸到另一个维度。

P=np.arange(30).reshape((-1,3)) 
astd = numpy.lib.stride_tricks.as_strided 
its = P.itemsize 
M = astd(P,(10,3,3),(its*3,its,0))*P.reshape((10,1,3)) 

我要添加引用this post因为它是stride_tricks.as_strided一个很好的详细说明。

+0

+1 stride_tricks真的很棒,但是太复杂了! – 2011-03-08 17:00:03

0

这部分解决了使用tensordot()问题,

from numpy import arange,tensordot 

P = arange(30).reshape((-1,3)) 

i = 3 

T = tensordot(P,P,0)[:,:,i,:] 

print T[i] 
print tensordot(P[i],P[i],0) 

T包含了所有你想要的(及以上)的产品,它只是一个解压的问题他们。

+0

尽管OP推测使用'tensordot(。)'的​​可能性,但他的代码中没有任何内容需要它。您的解决方案似乎在执行时间和内存消耗方面都非常昂贵。你有没有制定任何基准来比较你的解决方案与其他的?谢谢 – eat 2011-03-08 14:45:13

+0

@eat:真的,这是昂贵的记忆方式。在这个例子中,对于一个10乘3的数组,执行时间大约是使用'tensordot()'的两倍。随着阵列尺寸的增加,它会变得越来越糟糕。 – lafras 2011-03-08 20:21:38

+1

我没有明确指出这一点,但由于完全不同的原因,tensordot可能并不是这样做的方式。如果你把'P'视为2级张量,那么只有3个选项存在'P'与其自身的乘积,1)所有的指标都取消给你一个0级张量(一个标量),2)1一组索引取消,你留下一个2级张量(矩阵),或3)他们中没有一个取消,你剩下一个4级张量。 OP想要一个三级张量作为答案,除非你分级四级产品,否则'tensordot()'不会给你。 – lafras 2011-03-08 20:21:55