2017-08-07 97 views
0

我有一个迭代器函数,它产生一批特征和标签作为numpy数组的元组。如何将numpy数组的迭代器馈送到tensorflow估计器/可评估

高清batch_iter(): 为...: 产量(np_features,np_labels)

,然后我试着喂张量估计像

# the cnn_model_fn will print out shapes of various tensor when 
# constructing the model 
classifier = learn.Estimator(
    model_fn=cnn_model_fn, model_dir="/tmp/convnet_model") 
for train_data, train_labels in batch_iter(): 
    classifier.fit(
     input_fn=lambda: (tf.constant(train_data), tf.constant(train_labels)), 
     steps=1, 
     monitors=[logging_hook]) 

的(注释)日志看起来像

conv1 shape (100, 16, 20, 32) 
pool1 shape (100, 8, 10, 32) 
conv2 shape (100, 8, 10, 64) 
pool2 shape (100, 4, 5, 64) 
onehot label shape (100, 5) 
INFO:tensorflow:Create CheckpointSaverHook. 
INFO:tensorflow:Saving checkpoints for 1 into /tmp/convnet_model/model.ckpt. # checkpoint is saved in every iteration 
INFO:tensorflow:step = 1, loss = 1618.76 
INFO:tensorflow:Loss for final step: 1618.76. 
conv1 shape (100, 16, 20, 32) # the model_fn is called in every iteration 
pool1 shape (100, 8, 10, 32) 
conv2 shape (100, 8, 10, 64) 
pool2 shape (100, 4, 5, 64) 
onehot label shape (100, 5) 
INFO:tensorflow:Create CheckpointSaverHook. 
INFO:tensorflow:Restoring parameters from /tmp/convnet_model/model.ckpt-1 # checkpoint is restored in every iteration 
INFO:tensorflow:Saving checkpoints for 2 into /tmp/convnet_model/model.ckpt. 
INFO:tensorflow:step = 2, loss = 69370.6 
INFO:tensorflow:Loss for final step: 69370.6. 
conv1 shape (100, 16, 20, 32) 
pool1 shape (100, 8, 10, 32) 
conv2 shape (100, 8, 10, 64) 
pool2 shape (100, 4, 5, 64) 
onehot label shape (100, 5) 
INFO:tensorflow:Create CheckpointSaverHook. 
INFO:tensorflow:Restoring parameters from /tmp/convnet_model/model.ckpt-2 
INFO:tensorflow:Saving checkpoints for 3 into /tmp/convnet_model/model.ckpt. 
INFO:tensorflow:step = 3, loss = 289303.0 
INFO:tensorflow:Loss for final step: 289303.0. 
... 

批处理被读取并且循环迭代时损失会下降。但是,似乎每次迭代都会保存并恢复检查点,并在每次迭代中调用model_fn。所以我觉得这是不对的。

将迭代器提供给Estimator/Evaluable的正确方法是什么?

回答

1

in your input_fn you can use tf.contrib.training.python_input