2017-08-07 121 views
0

我已经在Tensorflow中构建了一个模型,我已经训练了它。现在我想处理输出,所以我想将检查点,Meta和所有其他文件加载到tensorlow中。将经过训练的模型加载回张量流

我用下面的代码来训练模型:

# Logging 
merged = tf.summary.merge_all() 
train_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/train') 
test_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/test') 
validate_writer = tf.summary.FileWriter(FLAGS.summary_dir + '/validate') 
writer = tf.summary.FileWriter(FLAGS.summary_dir, sess.graph) 
saver = tf.train.Saver() # for storing the best network 

# Initialize variables 
init = tf.global_variables_initializer() 
sess.run(init) 

# Best validation accuracy seen so far 
bestValidation = -0.1 

# Training loop 
coord = tf.train.Coordinator() # coordinator for threads 
threads = tf.train.start_queue_runners(coord = coord, sess=sess) # start queue thread 

# Training loop 
for i in range(FLAGS.maxIter): 
    xTrain, yTrain = sess.run(data_batch) 
    sess.run(train_step, feed_dict={x_data: xTrain, y_target: np.transpose([yTrain])}) 
    summary = sess.run(merged, feed_dict={x_data: xTrain, y_target: np.transpose([yTrain])}) 
    train_writer.add_summary(summary, i) 
    if ((i + 1) % 10 == 0): 
     print("Iteration:", i + 1, "/", FLAGS.maxIter) 
     summary = sess.run(merged, feed_dict={x_data: dataTest.data, y_target: np.transpose([dataTest.target])}) 
     test_writer.add_summary(summary, i) 
     currentValidation, summary = sess.run([accuracy, merged], feed_dict={x_data: dataTest.data, 
                      y_target: np.transpose(
                       [dataTest.target])}) 
    validate_writer.add_summary(summary, i) 
    if (currentValidation > bestValidation and currentValidation <= 0.9): 
     bestValidation = currentValidation 
     saver.save(sess=sess, save_path=FLAGS.summary_dir + '/bestNetwork') 
     print("\tbetter network stored,", currentValidation, ">", bestValidation) 

coord.request_stop() # ask threads to stop 
coord.join(threads) # wait for threads to stop 

现在我想加载模型回Tensorflow。我希望能够做一些事情:

  • 使用我已经为训练和测试数据集创建的输出。
  • 将新数据加载到模型中,然后可以使用相同的权重生成新的输出。

我使用下面的代码回加载模型到tensorflow试过,但它不工作:

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph(FLAGS.summary_dir + '/bestNetwork.meta') 
    saver.restore(sess,tf.train.latest_checkpoint(FLAGS.summary_dir + '/checkpoint')) 

运行的代码时,我收到以下错误:

TypeError:期望的字节,找不到的类型

正如我已经说明的那样,我使用tf.train.import_meta_graph()函数加载了上一节中的元图,然后使用检查点部分加载了权重。那么,为什么这不起作用?

回答

0

您将模型保存为bestNetwork。试试这个:

saver.restore(sess,tf.train.latest_checkpoint(FLAGS.summary_dir + '/**bestNetwork**'))