我有一个张量,它是[100 X 16 X 16]
。我想获得张量的对角线元素以获得形状[100 X 16]
的张量。我尝试了以下内容:Tensorflow:如何获得多个矩阵的对角线(批处理模式)
#sum_cov
是[100 X 16 X 16]
和diagonal_elements
预计为[100 X 16]
。
diagonal_elements = tf.diag_part(sum_cov)
不过,我得到以下错误:
Input must have even rank <= 6, input rank is 3 for 'DiagPart'
可有人请告诉我如何实现这一目标?
看起来像'tf.matrix_diag_part'做你想要的。 –