2017-03-03 105 views
0

我的代码是如下:我不能保存,并与tensorflow占位恢复一个变量,由于

import tensorflow as tf 
import numpy as np 
def add_layer(input): 
    v2 = tf.Variable(tf.random_normal([2, 2], dtype=tf.float32, name='v2')) 
    tf.add_to_collection('h0_v2',v2) 
    output=tf.matmul(input,v2) 
    return output 
x1=tf.placeholder(tf.float32) 
outputs=add_layer(x) 
tf.add_to_collection('outputs', outputs) 
saver = tf.train.Saver() 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    x1=np.random.random([2, 2]) 
    print(sess.run(outputs,feed_dict={x:x1})) 
    save_path = saver.save(sess, './model.ckpt') 
    print("model saved in file:", save_path) 

,然后另一码是拼命地跑:

import tensorflow as tf 
import numpy as np 
sess = tf.Session() 
saver = tf.train.import_meta_graph('./model.ckpt.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
x2=np.random.random([2, 2]) 
print(sess.run(tf.get_collection('outputs',feed_dict={x:x2}))) 
print('model is loaded') 
sess.close() 

然后电脑告诉我'x'没有定义,我不知道什么是错的。

+0

但第一代码无故障运行? – CrisH

回答

0

我会说这样做:

import tensorflow as tf 
import numpy as np 
sess = tf.Session() 
saver = tf.train.Saver() 
saver.restore(sess, './model.ckpt') 
x2=np.random.random([2, 2]) 
print(sess.run(tf.get_collection('outputs',feed_dict={x:x2}))) 
print('model is loaded') 
sess.close() 

我发现这对tensorflow的网站。希望能帮助到你。

+0

非常感谢! – quan

+0

它现在在工作吗? :) – CrisH

+0

是的,但是当我运行它两次,错误将被抛出,说:ValueError:至少有两个变量具有相同的名称:变量/ Adadelta_1。 – quan

0

我找到解决问题的方法:

import tensorflow as tf 
import numpy as np 
def add_layer(input): 
    #v1 = tf.Variable(np.random.random([2, 2]), dtype=tf.float32, name='v1') 
    v2 = tf.Variable(tf.random_normal([2, 2], dtype=tf.float32, name='v2')) 
    tf.add_to_collection('h0_v2',v2) 
    output=tf.matmul(input,v2) 
    return output 
x=tf.placeholder(tf.float32) 
outputs=add_layer(x) 
saver = tf.train.Saver() 
sess = tf.Session() 
saver = tf.train.import_meta_graph('./model.ckpt.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
x2=np.random.random([2, 2]) 
print(sess.run(outputs,feed_dict={x:x2})) 
print('model is loaded') 
sess.close() 
相关问题