2017-04-06 70 views
1
model_dir = "no_regulation" 
print(model_dir) 
m = tf.contrib.learn.LinearClassifier(
    feature_columns=feature_columns, 
    optimizer=tf.train.FtrlOptimizer(
     learning_rate=3, 
     l1_regularization_strength=0, 
     l2_regularization_strength=0), 
    n_classes = n_classes, 
    model_dir=model_dir) 

def train_input_fn(): 
    print("Here!") 
    return input_fn(train.sample(50000), label_column = "course_index", categorical_columns = CATEGORICAL_COLUMNS) 

,如果我这样做时,它批处理50000个样品每10步,如何在tf.learn中使用input_fn进行批量训练?

for i in range(40): 
    for j in range(20): 
     m.fit(input_fn=train_input_fn, steps = 10) 
    m.evaluate(input_fn=eval_input_fn1, steps = 1, name="test1") 
    m.evaluate(input_fn=eval_input_fn2, steps = 1, name="test2") 

,这是否合理?如果我做m.fit(input_fn = train_input_fn,步= 1),每一个适合的通话将创建检查点,并会减慢培养了不少。我应该禁用检查点?如果是这样,怎么样?

回答

1

我发现的一种方法是使用m.partial_fit而不是fitpartial_fit不会触发CheckpointSaverHook

看起来虽然这evaluate呢。

相关问题