2017-06-21 174 views
1

我正在训练tensorflow中的卷积模型。在训练了大约70个时期的模型之后,花了近1.5个小时,我无法保存模型。它给了我ValueError: GraphDef cannot be larger than 2GB。我发现随着训练的进行,图形中的节点数量不断增加。Tensorflow:随着训练的进行,图中的节点数量不断增加

在时代0,3,6,9处,图中节点的数量分别是7214,7238,7262,7286。当我使用with tf.Session() as sess:时,不是将会话作为sess = tf.Session()传递,而是分别在时期0,3,6,9处的节点数为3982,4006,4030,4054。

this答案,据说随着节点被添加到图中,它可以超过其最大尺寸。我需要帮助了解节点数量如何在我的图表中继续上升。

def runModel(data): 
    ''' 
    Defines cost, optimizer functions, and runs the graph 
    ''' 
    X, y,keep_prob = modelInputs((755, 567, 1),4) 
    logits = cnnModel(X,keep_prob) 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y), name="cost") 
    optimizer = tf.train.AdamOptimizer(.0001).minimize(cost) 
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1), name="correct_pred") 
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') 

    sess = tf.Session() 
    sess.run(tf.global_variables_initializer()) 
    saver = tf.train.Saver() 
    for e in range(12): 
     batch_x, batch_y = data.next_batch(30) 
     x = tf.reshape(batch_x, [30, 755, 567, 1]).eval(session=sess) 
     batch_y = tf.one_hot(batch_y,4).eval(session=sess) 
     sess.run(optimizer, feed_dict={X: x, y: batch_y,keep_prob:0.5}) 
     if e%3==0: 
      n = len([n.name for n in tf.get_default_graph().as_graph_def().node]) 
      print("No.of nodes: ",n,"\n") 
      current_cost = sess.run(cost, feed_dict={X: x, y: batch_y,keep_prob:1.0}) 
      acc = sess.run(accuracy, feed_dict={X: x, y: batch_y,keep_prob:1.0}) 
      print("At epoch {epoch:>3d}, cost is {a:>10.4f}, accuracy is {b:>8.5f}".format(epoch=e, a=current_cost, b=acc)) 

什么原因导致节点的数量增加:

我用下面的代码训练我的模型?

+0

也许你可以在每一步获得新节点的名称,并查看它们是哪个节点?也许这只是每次被复制的输入节点,我不知道......你使用的是什么版本的tf? – gdelab

+0

@gdelab我正在使用'1.0.1',每个时代的节点数似乎都增加了8! – dpk

+0

是的,但是你可以在每一步获得八个新的节点名称吗?也许他们可以帮助理解新节点的创建地点...... – gdelab

回答

2

您正在训练循环中创建新节点。特别是,您打电话tf.reshapetf.one_hot,其中每个创建一个(或多个)节点。您可以:

  • 使用占位符作为输入在图的外部创建这些节点,然后仅在循环中对它们进行评估。
  • 对这些操作不使用TensorFlow,而是使用NumPy或等效操作。

我会推荐第二个,因为在使用TensorFlow进行数据准备时似乎没有任何好处。你可以有这样的事情:

import numpy as np 
# ... 
    x = np.reshape(batch_x, [30, 755, 567, 1]) 
    # ... 
    # One way of doing one-hot encoding with NumPy 
    classes_arr = np.arange(4).reshape([1] * batch_y.ndims + [-1]) 
    batch_y = (np.expand_dims(batch_y, -1) == classes_arr).astype(batch_y.dtype) 
    # ... 

PD:我也建议在withcontext manager使用tf.Session(),以确保其close()方法在最后被调用(除非您想以后使用同一个会话保持)。

相关问题