2016-11-22 99 views
1

我使用exportertensorflow.contrib.session_bundle救了我的模型:出口Tensorflow型号不保留占位符形状

x = tf.placeholder(tf.float32, (None,) + (100, 200) + (1,)) 
.... 
saver = tf_saver.Saver(sharded=True) 
model_exporter = exporter.Exporter(saver) 
model_exporter.init(
    sess.graph.as_graph_def(), 
    named_graph_signatures={ 
     'inputs': exporter.generic_signature({'images': x}), 
     'outputs': exporter.generic_signature({'classes': y})}) 

,然后我加载模型回(session_bundletensorflow.contrib.session_bundle):

sess, meta_graph_def = session_bundle.load_session_bundle_from_path(input) 

但是,当我检查对应于输入x的占位符张量时,我看不到任何形状信息:

> sess.graph.get_tensor_by_name(input_name) 
<tf.Tensor 'Placeholder:0' shape=<unknown> dtype=float32> 

这是设计还是有造成形状丢失的缺陷?

回答

0

下面是来自同事的答案:

“的exporter.generic_signature呼叫(构建named_graph_signatures时)填充generic_signature的地图这里定义:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/manifest.proto#L69

在地图的值是一个TensorBinding,这本身就是张量名称,见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/manifest.proto#L20

因此,预计形状不会保留,名称应该足以确定张量。

+0

对,但我正在从张贴名称,我[读出protobuf](https://gist.github.com/cancan101/31df34ca5dd971338cad8ca85bc1d8e2),然后调用'get_tensor_by_name'这确实解析名称为实际张量对象。 –

+0

嗨,亚历克斯,它似乎只要张力信息可以检索使用的名称,这应该工作。你能否详细说明具体问题可能是什么? – Neal

+0

我仍然不遵循为什么输入张量的形状保持不变的输入张量的形状。 –