2016-12-27 350 views
1

我已经预训练了网络,并且我正在试图仅仅获取它的一部分(子图)tf图以及变量和保存对象。在tensorflow中提取子图

这就是我正在做它:

subgraph = tf.graph_util.extract_sub_graph(default_graph, list of nodes to preserve) 
tf.reset_default_graph() 
tf.import_graph_def(subgraph) 

然而,这将删除所有变量(当我打电话reset_default_graph)。即使如果我明确地将变量的操作节点(仅“变量”类型操作)添加到“要保留的节点列表”中。

如何在保留变量值的同时保留较大图的子图? 这是另外一些新的节点“保留列表”的问题吗?

图形节点和变量之间的关系仍然不清楚,教程仅仅提到创建变量会在图形中创建一些操作(节点)。

回答

0

我觉得你在做什么看起来不错。正如你所说的,一个变量只是一个简单的操作(图中的一个节点),用于输出特定值的张量。您应该能够将变量节点添加到列表中以保留它们,就像您已经在做的那样。你可以使用print(sess.graph_def)来确保你提供的名字是正确的吗?

+0

变量是一组连接的操作。通常,它以ops:“variable [variable]”,“assign [assign]”,“read [identity]”(第一部分是名称,方括号用于类型)以及整套初始化操作符为主。问题在于图导出会以一种不被认为是变量的方式来削减变量结构。选择所有必需的操作非常麻烦 - 而且不是最聪明的方法。 – Pietrko

+0

是的,的确如此。如果您查看extract_sub_graph的函数接口(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/graph_util_impl.py#L110),注意到它只是简单的函数,没有任何智能处理对于变量,“选择所有需要的操作”可能仍然是你最好的选择。好消息是我认为你可以编写一个简单的函数(使用graph_def作为输入)来自动执行这个繁琐的选择变量相关节点的过程。 –

+0

好吧,我希望我可以避免这种情况,也许存在一些干净而快速的方式来使用现有的API。谢谢。 – Pietrko