2017-10-10 67 views
0

我遇到了我的numpy阵列,它的大小(29000,200,1024)(7Go)有问题。它是我数据集图像的特征。Tensorflow:tf.gather不适用于大阵列

一旦加载,我的函数就会收到建立当前批量索引的索引。 不幸的是,使用:

tf.gather(array, indices) 

冻结。虽然打印例如数组[0]立即工作。 我试图改变我的numpy数组与convert_to_tensor,所以我可以直接使用array_tensor(indice)但是再次,convert_to_tensor导致内存限制错误。

任何解决方法?

非常感谢您

回答

1

直接传递numpy的阵列到TF运建设API将其转换为tf.constant运算包含在OP定义数据,所以你内嵌了整个事情变成GraphDef,受2GB GraphDef限制。

要避免这种情况,请创建var=tf.Variable(my_placeholder)并通过运行var.initializer, feed_dict={my_placeholder: np_array}来初始化此变量。这将numpy数组数据直接放入变量存储中。

+0

非常感谢您的全力支持。 不幸的是,我没有我的会话处理程序呢。我可以从图表中稍后从我的图表中获取我的变量,但是我没有占位符,不再需要Feed字典......任何想法? tyvm提前 – Mickey

+0

将占位符保存到全局变量中? –

+0

感谢这似乎工作!不幸的是,打开numpy数组并将其放入张量会导致OOM错误。我需要以另一种方式进行批量生产。 – Mickey