2016-08-13 103 views
21

我想部署一个简单的TensorFlow模型,并在像Flask这样的REST服务中运行它。 在github上或这里找不到目前为止的好例子。TensorFlow REST前端,但不是TensorFlow服务

我不准备使用TF担任其他职位的建议,这是谷歌完美的解决方案,但它矫枉过正我与GRPC,巴泽尔,C++编码,protobuf的任务......

+1

缩小问题只是想在保存器加载模型后返回结果的Flask示例 – chro

回答

5

有不同的方法来做到这一点。纯粹地,使用张量流并不是非常灵活,然而相对简单。这种方法的缺点是您必须重新生成图形并在恢复模型的代码中初始化变量。有一种方法显示在tensorflow skflow/contrib learn这是更优雅,但是这似乎并没有功能的时刻和文档已过时。

我在github here上放了一个简短的例子,它展示了如何将GET或POST参数命名为瓶式REST部署的tensorflow模型。

主要代码然后在需要立足岗位字典的功能/ GET数据:

@app.route('/model', methods=['GET', 'POST']) 
@parse_postget 
def apply_model(d): 
    tf.reset_default_graph() 
    with tf.Session() as session: 
     n = 1 
     x = tf.placeholder(tf.float32, [n], name='x') 
     y = tf.placeholder(tf.float32, [n], name='y') 
     m = tf.Variable([1.0], name='m') 
     b = tf.Variable([1.0], name='b') 
     y = tf.add(tf.mul(m, x), b) # fit y_i = m * x_i + b 
     y_act = tf.placeholder(tf.float32, [n], name='y_') 
     error = tf.sqrt((y - y_act) * (y - y_act)) 
     train_step = tf.train.AdamOptimizer(0.05).minimize(error) 

     feed_dict = {x: np.array([float(d['x_in'])]), y_act: np.array([float(d['y_star'])])} 
     saver = tf.train.Saver() 
     saver.restore(session, 'linear.chk') 
     y_i, _, _ = session.run([y, m, b], feed_dict) 
    return jsonify(output=float(y_i)) 
3

github project显示恢复模型检查点并使用Flask的工作示例。

@app.route('/api/mnist', methods=['POST']) 
def mnist(): 
    input = ((255 - np.array(request.json, dtype=np.uint8))/255.0).reshape(1, 784) 
    output1 = simple(input) 
    output2 = convolutional(input) 
    return jsonify(results=[output1, output2]) 

在线demo看起来很快。

3

我不喜欢把数据/模型处理太多的代码在烧瓶宁静的文件。我通常有tf模型课等。 它可能是这样的:

# model init, loading data 
cifar10_recognizer = Cifar10_Recognizer() 
cifar10_recognizer.load('data/c10_model.ckpt') 

@app.route('/tf/api/v1/SomePath', methods=['GET', 'POST']) 
def upload(): 
    X = [] 
    if request.method == 'POST': 
     if 'photo' in request.files: 
      # place for uploading process workaround, obtaining input for tf 
      X = generate_X_c10(f) 

     if len(X) != 0: 
      # designing desired result here 
      answer = np.squeeze(cifar10_recognizer.predict(X)) 
      top3 = (-answer).argsort()[:3] 
      res = ([cifar10_labels[i] for i in top3], [answer[i] for i in top3]) 

      # you can simply print this to console 
      # return 'Prediction answer: {}'.format(res) 

      # or generate some html with result 
      return fk.render_template('demos/c10_show_result.html', 
             name=file, 
             result=res) 

    if request.method == 'GET': 
     # in html I have simple form to upload img file 
     return fk.render_template('demos/c10_classifier.html') 

cifar10_recognizer.predict(X)是简单FUNC,在TF会话中运行预测操作:

def predict(self, image): 
     logits = self.sess.run(self.model, feed_dict={self.input: image}) 
     return logits 

P.S.从文件中保存/恢复模型是一个非常漫长的过程,尽量避免这种情况,同时服务发布/获取请求