2017-07-04 97 views
0

我有一个在Keras中训练过的模型,我使用它的中间层输出,使用K.function()。有没有办法将这个对象保存为张量流图?我想在张量流服务中使用这个对象,但我没有看到冻结Keras的方法K.function()对象冷冻Keras K.function()tensorflow图

+0

不知道我是否理解正确,但如果你想要的是冻结图层,你可以使用每个图层的可训练属性,如[这里]所述(https://github.com/fchollet/keras/issues/4471)和[这里](https://keras.io/getting-started/faq/) – Mathias

+0

这不是我要找的。我正在将它冻结成张量流图,以便我可以在张量流服务中使用它。不是keras模型或图层 –

+0

看看这个:https://stackoverflow.com/questions/43434292/benchmark-keras-model-using-tensforflow-benchmark – marcopah

回答

0

我最终发现它,认为我应该在这里发布它,以防将来任何人需要这样做。我必须从K.function()对象中获取节点名称。

encoder = K.function([...]) # define your Function object 
encoder_tensor_names = [t.name for t in encoder.outputs] 
encoder_node_names = [tn.replace(':0', '') for tn in encoder_tensor_names] # node names are tensor names without :0 
graph_def = tf.graph_util.convert_variables_to_constants(
    K.get_session(), 
    K.get_session().graph.as_graph_def(), 
    encoder_node_names 
) 

然后按照代码https://stackoverflow.com/a/44044405/5453184写出来的二进制文件(或文本)文件。它的优点在于tensorflow足够聪明,只保留前馈操作所需的图形部分,在我的情况下,它也冻结了字嵌入,所以我不再需要单独的单词矢量文件。