2017-10-20 273 views
0

我有以下代码:Tensorflow Dataset.from_tensor_slices时间太长

data = np.load("data.npy") 
print(data) # Makes sure the array gets loaded in memory 
dataset = tf.contrib.data.Dataset.from_tensor_slices((data)) 

文件"data.npy"为3.3 GB。用numpy读取文件需要几秒钟,但是接下来创建tensorflow数据集对象的那一行需要很长时间才能执行。这是为什么?它在底下做了什么?

回答

2

引用此answer

一个npznp.load只返回一个文件加载器,而不是实际的数据。这是一个'懒惰的加载程序',只有在访问时加载特定的数组。

这就是为什么它很快。

编辑1:扩大多一点这样的回答,从tensorflow's documentation另一句名言:

如果所有输入数据存放在内存中,最简单的方法来创建他们Dataset是转换他们到tf.Tensor对象并使用Dataset.from_tensor_slices()

这适用于小数据集,但浪费内存---因为数组内容将被复制多次---并可能运行到tf.GraphDef协议缓冲区的2GB限制。

该链接还显示如何有效地做到这一点。

+0

如果我尝试打印'data',以便确保它实际上被加载,它仍然需要几秒钟,而'Dataset'需要几分钟。 – niko

+0

它不一定打印所有数据。打印与否并不能确保它“实际上被加载”。我不是tensorflow方面的专家,只是看着'from_tensor_slices'循环遍历整个数据集的代码(并且速度相当慢),这肯定会*加载所有数据。海事组织这可能可以加快,但公平我没有尝试过。在某些情况下,如果您的计算机在内存中占用3.3GB的空间,您可能只需要投入更多硬件。 – Iguananaut

+0

我更新了我的答案,给你更多的细节。 –