2017-03-02 90 views
0

我已经为图像分类构建了CNN。在训练期间,我保存了几个检查点。数据通过feed_dictionary馈入网络。Tensorflow抱怨在图形还原期间丢失了feed_dict

现在我想恢复失败的模型,我不知道为什么。有代码的重要线路如下:

with tf.Graph().as_default(): 

.... 

if checkpoint_dir is not None: 
    checkpoint_saver = tf.train.Saver() 
    session_hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir, 
                 save_secs=flags.save_interval_secs, 
                 saver=checkpoint_saver)) 
.... 

with tf.train.MonitoredTrainingSession(
     save_summaries_steps=flags.save_summaries_steps, 
     hooks=session_hooks, 
     config=tf.ConfigProto(
      log_device_placement=flags.log_device_placement)) as mon_sess: 

    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) 
    if checkpoint and checkpoint.model_checkpoint_path: 

     # restoring from the checkpoint file 
     checkpoint_saver.restore(mon_sess, checkpoint.model_checkpoint_path) 

     global_step_restore = checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1] 
     print("Model restored from checkpoint: global_step = %s" % global_step_restore) 

行 “checkpoint_saver.restore” 抛出一个错误:

回溯(最近通话最后一个): 文件“C:\ Program Files文件\ Anaconda3 \ (* args) 文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \”文件名为“envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ client \ session.py”,第1022行,在_do_call中 site_packages \ tensorflow \ python \ client \ session.py“,第1004行,在_run_fn status,run_metadata) 文件”C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ contextlib.py“,第6行6,在退出 next(self.gen) 在raise_exception_on_not_ok_status文件“C:\ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ site-packages \ tensorflow \ python \ framework \ errors_impl.py”行469 pywrap_tensorflow.TF_GetCode(status)) tensorflow.python.framework.errors_impl.InvalidArgumentError:您必须为dtype float提供占位符张量'input_images'的值 [[Node:input_images = Placeholderdtype = DT_FLOAT,shape = [],_device =“/ job:localhost/replica:0/task:0/cpu:0”]]

任何知道如何解决这个问题?为什么我只需要填充的feed_dictionary来恢复图形?

在此先感谢!

更新:

这是保护对象的恢复方法的代码:

def restore(self, sess, save_path): 
    """Restores previously saved variables. 

    This method runs the ops added by the constructor for restoring variables. 
    It requires a session in which the graph was launched. The variables to 
    restore do not have to have been initialized, as restoring is itself a way 
    to initialize variables. 

    The `save_path` argument is typically a value previously returned from a 
    `save()` call, or a call to `latest_checkpoint()`. 

    Args: 
     sess: A `Session` to use to restore the parameters. 
     save_path: Path where parameters were previously saved. 
    """ 
    if self._is_empty: 
     return 
    sess.run(self.saver_def.restore_op_name, 
      {self.saver_def.filename_tensor_name: save_path}) 

什么我不明白:为什么图表立即执行?我使用错误的方法吗?我只想恢复所有可训练的变数。

+0

命名所有变量和占位符。这有帮助吗? http://stackoverflow.com/questions/34793978/tensorflow-complaining-about-placeholder-after-model-restore – hars

+0

所有变量都被命名。我的图像张量输入饲料丢失。我认为问题是由MonitoredTrainingSession和feed_dict的组合使用引起的。 MonitoredTrainingSession旨在用于更大的设置,可能与Feed Dictionarys不兼容?!?。我正在尝试为我的自定义“培训框架”构建测试用例。因此,我想保持示例模型的轻重(使用feed_dict而不是导入队列) – monchi

回答

1

问题是用于进程日志原因由SessionRunHook:

原始钩:

class _LoggerHook(tf.train.SessionRunHook): 
    """Logs loss and runtime.""" 

    def begin(self): 
    self._step = -1 

    def before_run(self, run_context): 
    self._step += 1 
    self._start_time = time.time() 
    return tf.train.SessionRunArgs(loss) # Asks for loss value. 

    def after_run(self, run_context, run_values): 
    duration = time.time() - self._start_time 
    loss_value = run_values.results 
    if self._step % 5 == 0: 
     num_examples_per_step = FLAGS.batch_size 
     examples_per_sec = num_examples_per_step/duration 
     sec_per_batch = float(duration) 

     format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
        'sec/batch)') 
     print (format_str % (datetime.now(), self._step, loss_value, 
          examples_per_sec, sec_per_batch)) 

改性钩:

class _LoggerHook(tf.train.SessionRunHook): 
    """Logs loss and runtime.""" 

    def __init__(self, flags, loss_op): 
     self._flags = flags 
     self._loss_op = loss_op 
     self._start_time = time.time() 

    def begin(self): 
     self._step = 0 

    def before_run(self, run_context): 
     if self._step == 0: 
      run_args = None 
     else: 
      run_args = tf.train.SessionRunArgs(self._loss_op) 

     return run_args 

    def after_run(self, run_context, run_values): 

     if self._step > 0: 
      duration_n_steps = time.time() - self._start_time 
      loss_value = run_values.results 
      if self._step % self._flags.log_every_n_steps == 0: 
       num_examples_per_step = self._flags.batch_size 

       duration = duration_n_steps/self._flags.log_every_n_steps 
       examples_per_sec = num_examples_per_step/duration 
       sec_per_batch = float(duration) 

       format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
           'sec/batch)') 
       print(format_str % (datetime.now(), self._step, loss_value, 
            examples_per_sec, sec_per_batch)) 

       self._start_time = time.time() 
     self._step += 1 

说明:

测井现在skiped第一次迭代。因此,由Saver.restore(..)执行的session.run不再需要填充的饲料字典。