1
我已经预训练了网络,并且我正在试图仅仅获取它的一部分(子图)tf图以及变量和保存对象。在tensorflow中提取子图
这就是我正在做它:
subgraph = tf.graph_util.extract_sub_graph(default_graph, list of nodes to preserve)
tf.reset_default_graph()
tf.import_graph_def(subgraph)
然而,这将删除所有变量(当我打电话reset_default_graph)。即使如果我明确地将变量的操作节点(仅“变量”类型操作)添加到“要保留的节点列表”中。
如何在保留变量值的同时保留较大图的子图? 这是另外一些新的节点“保留列表”的问题吗?
图形节点和变量之间的关系仍然不清楚,教程仅仅提到创建变量会在图形中创建一些操作(节点)。
变量是一组连接的操作。通常,它以ops:“variable [variable]”,“assign [assign]”,“read [identity]”(第一部分是名称,方括号用于类型)以及整套初始化操作符为主。问题在于图导出会以一种不被认为是变量的方式来削减变量结构。选择所有必需的操作非常麻烦 - 而且不是最聪明的方法。 – Pietrko
是的,的确如此。如果您查看extract_sub_graph的函数接口(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/graph_util_impl.py#L110),注意到它只是简单的函数,没有任何智能处理对于变量,“选择所有需要的操作”可能仍然是你最好的选择。好消息是我认为你可以编写一个简单的函数(使用graph_def作为输入)来自动执行这个繁琐的选择变量相关节点的过程。 –
好吧,我希望我可以避免这种情况,也许存在一些干净而快速的方式来使用现有的API。谢谢。 – Pietrko