2016-07-24 59 views
1

我在tensorflow中有一个类,它有权重和文档嵌入。我将用它来进行训练和验证。我的查询是,它可能在tensorflow会话中用于验证集,以仅重用来自我的训练而不是嵌入的权重,并让它为有效集学习新的文档嵌入。代码片段。如何仅重用张量流中的一些变量?

Class NewModel(Object): 
    def __init__(self, is_training, vocabuary_size, embedding_size): 
    self.X = tf.placeholder("float", [None, 300]) 
    self.doc_int = tf.placeholder(tf.int32, shape=[None]) 

    self.embeddings=tf.get_variable("embedding", [vocabulary_size ,embedding_size],initializer=tf.random_uniform_initializer(-0.1, 0.1)) 
    self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int) 
    self.weights = tf.get_variable("weights",weight_shapeinitializer=tf.random_normal_initializer()) 
    biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0)) 
    # Some neural network with optimiser and loss that will train weight and embeddings.. 

with tf.Graph().as_default(), tf.Session() as sess: 

    initializer = tf.random_uniform_initializer() 
    with tf.variable_scope("foo", reuse=None, initializer=initializer): 
    train = NewModel(is_training=True, vocabulary_size=4000,\ 
    embedding_size =50) 
    with tf.variable_scope("foo", reuse=True, initializer=initializer): 
     valid = NewModel(is_training=False, vocabulary_size= 1000, embedding_size = 50) 
# Here is where I am confused. I want to use trained variable of weight but not embeddings and 
want new embeddings to be trained for valid set. 
    tf.initialize_all_variables().run() 
# will call some function to run epochs and stuff 

也许使用不同的作用域名称可能会有所帮助,但仍需要一些关于它的建议。或者是否有可能在某处提到要重用的变量。

回答

0

我也许会重新组织NewModel类。

Class NewModel(Object): 
    def __init__(self, vocabuary_size, embedding_size, initializer): 
     self.X = tf.placeholder("float", [None, 300]) 
     self.doc_int = tf.placeholder(tf.int32, shape=[None]) 
     self.vocabuary_size = vocabuary_size 
     self.embedding_size = embedding_size 
     self.initializer = initializer 

    def initialize_embeddings(self): 
     with tf.variable_scope("embed",initializer=initializer) as scope: 
      self.embeddings=tf.get_variable("embedding", [self.vocabulary_size ,self.embedding_size],initializer=self.initializer) 
      self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int) 
      scope.reuse_variable() 

    def initialize_weights(self, weight_shape, biase_shape, initializer=initializer): 
     with tf.variable_scope("weight", initializer=initializer) as scope: 
      self.weights = tf.get_variable("weights",weight_shapeinitializer=self.initializer) 
      biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0)) 
      scope.reuse_variable() 

    def train_network(self): 
     # Some neural network with optimiser and loss that will train weight and embeddings.. 

    def validate_network(self): 
     # A function for the validation process 

这样您就可以将嵌入初始化与权重和偏置初始化分开。这种新类的使用会像...

with tf.Graph().as_default(), tf.Session() as sess: 

    initializer = tf.random_uniform_initializer() 
    model = NewModel(vocabulary_size=4000, embedding_size =50, initializer=initializer) # construct a model instance 
    model.initialize_weights(weight_shape, biase_shape) # initialize the weights and biases 
    model.initialize_embeddings() # initialize embeddings 
    model.train_network() # train the network 
    # Before start validation process, re-initialize embeddings 
    model.initialize_embeddings() 
    model.validate_network()