2017-08-31 219 views
0

这是一个如何保存和恢复训练模型的例子。 希望这会对初学者有所帮助。tensorflow:如何保存/恢复训练有素的模型

生成1个带relu激活函数的隐层神经网络。 (听说relu已被证明比sigmoid好多了,特别是对于隐层数量众多的神经网络。)

训练数据显然是异或。

火车和保存 “tf_train_save.py”

import tensorflow as tf 
import numpy as np 

x = np.matrix([[0, 0], [0, 1], [1, 0], [1, 1]]) 
y = np.matrix([[0], [1], [1], [0]]) 

n_batch = x.shape[0] 
n_input = x.shape[1] 
n_hidden = 5 
n_classes = y.shape[1] 

X = tf.placeholder(tf.float32, [None, n_input], name="X") 
Y = tf.placeholder(tf.float32, [None, n_classes], name="Y") 

w_h = tf.Variable(tf.random_normal([n_input, n_hidden], stddev=0.01), tf.float32, name="w_h") 
w_o = tf.Variable(tf.random_normal([n_hidden, n_classes], stddev=0.01), tf.float32, name="w_o") 

l_h = tf.nn.relu(tf.matmul(X, w_h)) 
hypo = tf.nn.relu(tf.matmul(l_h, w_o), name="output") 

cost = tf.reduce_mean(tf.square(Y-hypo)) 
train = tf.train.GradientDescentOptimizer(0.1).minimize(cost) 

init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    sess.run(init) 

    for epoch in range(1000): 
     for i in range(4): 
      sess.run(train, feed_dict = {X:x[i,:], Y:y[i,:]}) 

    result = sess.run([hypo, tf.floor(hypo+0.5)], feed_dict={X:x}) 

    print(*result[0]) 
    print(*result[1]) 

    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"]) 
    tf.train.write_graph(output_graph_def, "./logs/mp_logs", "test.pb", False) 

负荷 “tf_load.py”

import tensorflow as tf 
from tensorflow.python.platform import gfile 
import numpy as np 

x = np.matrix([[0, 0], [0, 1], [1, 0], [1, 1]]) 
y = np.matrix([[0], [1], [1], [0]]) 

with gfile.FastGFile("./logs/mp_logs/test.pb",'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    tf.import_graph_def(graph_def, name='') 

with tf.Session() as sess: 
    X = sess.graph.get_tensor_by_name("X:0") 
    print(X) 
    output = sess.graph.get_tensor_by_name("output:0") 
    print(output) 

    tf.global_variables_initializer().run() 

    result = sess.run([output, tf.floor(output+0.5)], feed_dict={X:x}) 

    print(*result[0]) 
    print(*result[1]) 

会有更简单的方法?

+0

您的问题标题似乎不符合您的要求。假设题目问题,你的代码是否符合你的期望?我想知道加载脚本中的初始化。 –

+0

你的力量保存你的权重变量,因为你加载它们,所以你的代码是不正确的。看看这个https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model?rq=1 –

+0

@EricPlaton它的工作原理。我只是好奇,如果有更简单的方法。像...保存张量名称一样。 –

回答

0

您使用的是convert_variables_to_constants,所以您在训练方面确实很棒。对于路人来说,该API出现在v1.0中(如果我在跟踪API后没有弄错)。

在负载方面,我认为最小代码是一个命令更短。鉴于您已将所有变量转换为常量,因此在恢复时没有变量可以初始化。所以行:

tf.global_variables_initializer().run() 

什么都不做。从v1.3的docs开始:

但是,如果var_list为空,该函数仍会返回可以运行的Op。该操作只是没有效果。

加载脚本没有全局变量,并且因为tf.global_variables_initializer()等于tf.variables_initializer(tf.global_variables()),所以该操作是空操作。

+1

我期待恢复时不处理张量名称,如'输入'和'输出'。找不到例子。 我认为读取VGGish源代码是可能的。但我误解了它。他们只是做了一个定义图的函数,并在生成和恢复函数中使用它们。 猜猜我必须做同样的事情,一起处理图形文件和py文件 –

相关问题