2017-07-24 83 views
0

我已经使用张量流建立了多个DNN和conVNN,我现在可以达到很好的准确性。现在我的问题是如何在实际的例子中使用这个训练好的网络。 对于计算机视觉而言,我如何使用模型来分类新图片?我可以生成像convNN.exe那样的图像作为输入参数,通过分类结果出来吗?如何使用Tensorflow

回答

1

一旦你训练的模型,您应该保存它通过添加类似的代码

builder = saved_model_builder.SavedModelBuilder(export_path) 
builder.add_meta_graph_and_variables(
     sess, [tag_constants.SERVING], 
     signature_def_map={ 
      'predict_images': 
       prediction_signature, 
      signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 
       classification_signature, 
     }, 
     legacy_init_op=legacy_init_op) 
builder.save() 

然后某处,你可以使用Tensorflow serving通过运行

使用高性能的C++服务器,以满足您的模型
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server \ 
    --port=9000 --model_name=mnist \ 
    --model_base_path=/tmp/mnist_model/ 

修改您的模型的代码,当然。你需要实现一个客户端;有一个MNIST here的例子。客户端的胆量会是这样的:

def do_inference(hostport, work_dir, concurrency, num_tests): 
    """Tests PredictionService with concurrent requests. 
    Args: 
    hostport: Host:port address of the PredictionService. 
    work_dir: The full path of working directory for test data set. 
    concurrency: Maximum number of concurrent requests. 
    num_tests: Number of test images to use. 
    Returns: 
    The classification error rate. 
    Raises: 
    IOError: An error occurred processing test data set. 
    """ 
    test_data_set = mnist_input_data.read_data_sets(work_dir).test 
    host, port = hostport.split(':') 
    channel = implementations.insecure_channel(host, int(port)) 
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) 
    result_counter = _ResultCounter(num_tests, concurrency) 
    for _ in range(num_tests): 
    request = predict_pb2.PredictRequest() 
    request.model_spec.name = 'mnist' 
    request.model_spec.signature_name = 'predict_images' 
    image, label = test_data_set.next_batch(1) 
    request.inputs['images'].CopyFrom(
     tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size])) 
    result_counter.throttle() 
    result_future = stub.Predict.future(request, 5.0) # 5 seconds 
    result_future.add_done_callback(
     _create_rpc_callback(label[0], result_counter)) 
    return result_counter.get_error_rate() 


def main(_): 
    if FLAGS.num_tests > 10000: 
    print('num_tests should not be greater than 10k') 
    return 
    if not FLAGS.server: 
    print('please specify server host:port') 
    return 
    error_rate = do_inference(FLAGS.server, FLAGS.work_dir, 
          FLAGS.concurrency, FLAGS.num_tests) 
    print('\nInference error rate: %s%%' % (error_rate * 100)) 

if __name__ == '__main__': 
    tf.app.run() 

这是在Python,当然,但我们没有理由,如果你想创建一个二进制可执行文件,你不能用另一种语言(如进入或C++) 。