2017-03-08 66 views
4

我有3张
X形状(1, c, h, w),假设(1, 20, 40, 50)
Fx形状(num, w, N),假设(1000, 50, 10)
Fy形状(num, N, h),假设(1000, 10, 40)MATMUL不同等级

我想要做的就是Fy * (X * Fx)*意味着matmul
X * Fx形状(num, c, h, N),假设(1000, 20, 40, 10)
Fy * (X * Fx)形状(num, c, N, N),假设(1000, 20, 10, 10)

我使用tf.tiletf.expand_dims
但我认为它使用了大量的内存(tile复制数据吧?),并缓慢
试图找到更好的办法那速度更快,占用内存小,以实现

# X: (1, c, h, w) 
# Fx: (num, w, N) 
# Fy: (num, N, h) 

X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1]) # (num, c, h, w) 
Fx_ex = tf.expand_dims(Fx, axis=1) # (num, 1, w, N) 
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1]) # (num, c, w, N) 
tmp = tf.matmul(X, Fxt_ex) # (num, c, h, N) 

Fy_ex = tf.expand_dims(Fy, axis=1) # (num, 1, N, h) 
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1]) # (num, c, N, h) 
res = tf.matmul(Fy_ex, tmp) # (num, c, N, N) 

回答

2

的情况为mythical einsum,我想:

>>> import numpy as np 
>>> X = np.random.rand(1, 20, 40, 50) 
>>> Fx = np.random.rand(100, 50, 10) 
>>> Fy = np.random.rand(100, 10, 40) 
>>> np.einsum('nMh,uchw,nwN->ncMN', Fy, X, Fx).shape 
(100, 20, 10, 10) 

它应该在tfnumpy几乎相同(在一些tf版本中,我看到使用大写索引是不允许的)。尽管如此,如果你以前从未见过符号,这肯定会超过一个不可读的正则表达式。

+0

是的,我从来没有见过这个,有点难以理解,想了解这个想法 – xxi

+0

awwwwwesome,它巨大的速度提高,非常感谢 – xxi

0

对于otherone可能感兴趣
我觉得@phg的答案也许工作
但在我的情况下numhw是动态的,即None
所以tf.einsum在tensorflow R1.0会引发错误,因为有更多的比一个张量

幸好一个None形状,有一个issuepull request
似乎可以处理的情况,有一个以上的None SHA PE
需要从源(主分支)建立
后,我重新构建tensorflow

BTW我将报告结果,在tf.einsum只接受小写

报告
是,最新版本张量流(主枝)接受动态形状为tf.einsum
而且使用后速度大幅提升tf.einsum,真是太棒了