2017-09-01 72 views
5

我想用model.fit()在一个python应用程序中并行训练一些不同的模型。使用的模型没有必要的共同点,它们在不同的时间在一个应用程序中启动。张量多线程/ keras

首先,我在一个单独的线程中,然后在主线程中启动一个model.fit(),没有问题。如果我现在要开始第二model.fit(),我得到了以下错误消息:

start_learn(self:) 
    tf_session = K.get_session() # this creates a new session since one doesn't exist already. 
    tf_graph = tf.get_default_graph() 

    keras_learn_thread.Learn(learning_data, model, self.env_cont, tf_session, tf_graph) 
    learning_results.start() 

Exception in thread Thread-1: 
tensorflow.python.framework.errors_impl.InvalidArgumentError: Node 'hidden_1/BiasAdd': Unknown input node 'hidden_1/MatMul' 

他们都充分利用的方法通过相同的代码行启动钍叫做类/方法是这样的:

def run(self): 
    tf_session = self.tf_session # take that from __init__() 
    tf_graph = self.tf_graph # take that from __init__() 

    with tf_session.as_default(): 
     with tf_graph.as_default(): 
      self.learn(self.learning_data, self.model, self.env_cont) 
      # now my learn method where model.fit() is located is being started 

我想我无论如何都必须指定一个新tf_session并为每个单个线程新tf_graph。但我不太确定。我会很高兴每一个简短的想法,因为我现在坐在这上太久了。

感谢

回答

0

我不知道,如果你固定您的问题,但是这看起来像另一个问题I recently answered

  • 您需要在 开始之前完成主线程中的图形创建。
  • 在keras的情况下,图形在第一次调用拟合或预测函数时被初始化。你可以通过调用一些模型的内部函数强制图表生成:

    model._make_predict_function() 
    model._make_test_function() 
    model._make_train_function() 
    

    如果还是不行,请尝试调用虚拟数据热身模型。

  • 完成图形创建后,请在您的主图上调用finalize(),以便它可以安全地与不同线程共享(这将使其成为只读)。

  • 完成图形还可以帮助您找到无意修改图形的其他位置。

希望能帮到你。