2017-07-14 166 views
2

我正在使用tf.estimator API来预测标点符号。我使用TFRecords和tf.train.shuffle_batch用预处理数据训练了它。现在我想做出预测。我可以将静态NumPy数据填入tf.constant,并从input_fn返回。如何使用TensorFlow Estimator API运行异步预测?

但是我正在处理序列数据,我需要一次提供一个示例,下一个输入取决于先前的输出。我也希望能够通过HTTP请求处理数据输入。

每次调用estimator.predict时,都会重新加载检查点并重新创建整个图形。这是缓慢和昂贵的。所以我需要能够动态地向input_fn提供数据。

我当前的尝试大致是这样的:

feature_input = tf.placeholder(tf.int32, shape=[1, MAX_SUBSEQUENCE_LEN]) 
q = tf.FIFOQueue(1, tf.int32, shapes=[[1, MAX_SUBSEQUENCE_LEN]]) 
enqueue_op = q.enqueue(feature_input) 

def input_fn(): 
    return q.dequeue() 

estimator = tf.estimator.Estimator(model_fn, model_dir=model_file) 
predictor = estimator.predict(input_fn=input_fn) 
sess = tf.Session() 
output = None 

while True: 
    x = get_numpy_data(x, output) 
    if x is None: 
     break 
    sess.run(enqueue_op, {feature_input: x}) 
    output = predictor.next() 
    save_to_file(output) 

sess.close() 

但是我收到以下错误:通过input_fn ValueError: Input graph and Layer graph are not the same: Tensor("EmbedSequence/embedding_lookup:0", shape=(1, 200, 128), dtype=float32) is not from the passed-in graph.

我如何可以异步插入数据到我现有的图形得到的预测一个在一次?

回答

3

事实证明,主要问题是所有张量都需要在input_fn内部创建,否则它们不会被添加到同一图表中。我需要运行入队操作,但不可能访问从输入函数返回的任何内容。

我结束了继承Estimator类和创建自定义的预测功能,让我来动态数据添加到预测队列并返回结果:

# async_estimator.py 

import six 
import tensorflow as tf 
from tensorflow.python.estimator.estimator import Estimator 
from tensorflow.python.estimator.estimator import _check_hooks_type 
from tensorflow.python.estimator import model_fn as model_fn_lib 
from tensorflow.python.framework import ops 
from tensorflow.python.framework import random_seed 
from tensorflow.python.training import saver 
from tensorflow.python.training import training 


class AsyncEstimator(Estimator): 

    def async_predictor(self, 
       dtype, 
       shape=None, 
       predict_keys=None, 
       hooks=None, 
       checkpoint_path=None): 
     """Returns a tuple of functions: first runs predicitons on the model, second cleans up 
     Args: 
      dtype: the dtype of the input 
      shape: the shape of the input placeholder (optional) 
      predict_keys: list of `str`, name of the keys to predict. It is used if 
      the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used 
      then rest of the predictions will be filtered from the dictionary. If 
      `None`, returns all. 
      hooks: List of `SessionRunHook` subclass instances. Used for callbacks 
      inside the prediction call. 
      checkpoint_path: Path of a specific checkpoint to predict. If `None`, the 
      latest checkpoint in `model_dir` is used. 
     Returns: 
      (predict, finish): tuple of functions 

      predict: runs a single prediction and returns the results 
       Args: 
        x: NumPy array of input 
       Returns: 
        Evaluated value of the prediction 

      finish: closes the session, allowing the program to exit 

     Raises: 
      ValueError: Could not find a trained model in model_dir. 
      ValueError: if batch length of predictions are not same. 
      ValueError: If there is a conflict between `predict_keys` and 
      `predictions`. For example if `predict_keys` is not `None` but 
      `EstimatorSpec.predictions` is not a `dict`. 
     """ 
     hooks = _check_hooks_type(hooks) 
     # Check that model has been trained. 
     if not checkpoint_path: 
      checkpoint_path = saver.latest_checkpoint(self._model_dir) 
     if not checkpoint_path: 
      raise ValueError('Could not find trained model in model_dir: {}.'.format(
       self._model_dir)) 

     with ops.Graph().as_default() as g: 
      random_seed.set_random_seed(self._config.tf_random_seed) 
      training.create_global_step(g) 
      input_placeholder = tf.placeholder(dtype=dtype, shape=shape) 
      queue = tf.FIFOQueue(1, dtype, shapes=shape) 
      enqueue_op = queue.enqueue(input_placeholder) 
      features = queue.dequeue() 
      estimator_spec = self._call_model_fn(features, None, 
               model_fn_lib.ModeKeys.PREDICT) 
      predictions = self._extract_keys(estimator_spec.predictions, predict_keys) 
      mon_sess = training.MonitoredSession(
        session_creator=training.ChiefSessionCreator(
         checkpoint_filename_with_path=checkpoint_path, 
         scaffold=estimator_spec.scaffold, 
         config=self._session_config), 
        hooks=hooks) 

      def predict(x): 
       if mon_sess.should_stop(): 
        raise StopIteration 
       mon_sess.run(enqueue_op, {input_placeholder: x}) 
       preds_evaluated = mon_sess.run(predictions) 
       if not isinstance(predictions, dict): 
        return preds_evaluated 
       else: 
        preds = [] 
        for i in range(self._extract_batch_length(preds_evaluated)): 
         preds.append({ 
          key: value[i] 
          for key, value in six.iteritems(preds_evaluated) 
         }) 
        return preds 

      def finish(): 
       mon_sess.close() 

      return predict, finish 

这里是粗糙的代码使用它:

import tensorflow as tf 
from async_estimator import AsyncEstimator 


def doPrediction(model_fn, model_dir, max_seq_length): 
    estimator = AsyncEstimator(model_fn, model_dir=model_dir) 
    predict, finish = estimator.async_predictor(dtype=tf.int32, shape=(1, max_seq_length)) 
    output = None 

    while True: 
     # my input is dependent on the previous output 
     x = get_numpy_data(output) 
     if x is None: 
      break 
     output = predict(x) 
     save_to_disk(output) 

    finish() 

注:这是一个简单的解决方案,适用于我的需要,它可能需要修改的其他案件。它正在研究TensorFlow 1.2.1。

希望TF会正式采用这样的方式,使Estimator更容易进行动态预测。