我使用的是张量流1.3.0后端的keras 2.0.8。Keras Tensorflow - 从多线程预测的异常
我在类init中加载模型,然后用它来预测多线程。
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
class CNN:
def __init__(self, model_path):
self.cnn_model = load_model(model_path)
self.session = K.get_session()
self.graph = tf.get_default_graph()
def query_cnn(self, data):
X = self.preproccesing(data)
with self.session.as_default():
with self.graph.as_default():
return self.cnn_model.predict(X)
我初始化CNN一次和query_cnn方法从多个线程会发生。
的例外,我在我的日志得到的是:
File "/home/*/Similarity/CNN.py", line 43, in query_cnn
return self.cnn_model.predict(X)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 913, in predict
return self.model.predict(x, batch_size=batch_size, verbose=verbose)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1713, in predict
verbose=verbose, steps=steps)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1269, in _predict_loop
batch_outs = f(ins_batch)
File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py", line 2273, in __call__
**self.session_kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps
的代码工作正常,大多数时候,它可能是一些问题与多线程。
我该如何解决?
我应该保存图形变量后最终确定吗? –
问题是图形创建[不是线程安全的](https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/core_graph_data_structures),所以你需要完成图形创建一个线程,然后启动其他线程。 'finalize()'会使图形成为只读 –
它对你有用吗? –