2017-10-04 58 views
2

我是一个有点困惑如何在keras使用fit_generator fit_generator。如何在keras

例如让说:

  • 我们有10000个数据点
  • 我们要为10个时代
  • 与512

批量运行使用fit我们只是

  • x, y = load_data() 
    model.fit(x=x, y=y, batch_size=512, epochs=10) 
    

    其中load_data加载所有数据。

    现在该怎么做同样的fit_generator

    它,我不清楚它是如何使用fit_generator时处理。如果我有以下发生器:

    def data_generator(): 
        for x, y in load_data_per_line(): 
         yield x, y 
    

    在每次它yields一个数据点上方的发电机。和:

    def data_generator_2(): 
        x_output = [] 
        y_output = [] 
        i = 0 
        for x, y in load_data_per_line(): 
         x_output[i] = x 
         y_output[i] = y 
         i = i + 1 
         if i == batch_size: 
          yield x_output, y_output 
          i = 0 
          x_output = [] 
          y_output = [] 
    

    在上述发电机每次它yields批量大小的数据点(512在这种情况下)。

    为了达到相同fit但使用fit_generator

    model.fit_generator(data_generator(), steps_per_epoch=10000/512, epochs=10) 
    

    model.fit_generator(data_generator_2(), steps_per_epoch=10000/512, epochs=10) 
    

    或者两者都是错误的(fit_generatordata_generator S)?如果其中任何一个是正确的,那么是否保证所有数据点都将被处理并且被顺序处理?

    任何了解是有用

  • 回答

    2

    发生器2几乎是确定的,但它应更好地返回numpy的数组:

    yield np.asarray(x_output),np.asarray(y_output) 
    

    此外,它应为无穷大:

    while True: 
    
        #the code inside to loop infinitely 
    

    第一个将不会返回批次,并会失败。

    你可能会在steps_per_epoch一个问题,因为10000是不是512的倍数,您需要整数步骤。您可以在发电机内检查if i == 10000:并通过一个较小的批次作为最后一批。

    那么你已经有了(10000 //512) + (10000 % 512)步骤或批次。

    所有批次将按顺序读取,但keras自动洗牌,这些批次的内容,请使用suffle=False。如果你使用多线程(不是这种情况),那么你需要创建线程安全的生成器或使用keras Sequence

    +0

    只是一个好奇心,所以在这种情况下,最后一批将不会有512的大小,这是好的吗? – titipata

    +1

    这很好,只要你不让你的发生器尝试读取超过允许的值。 –

    +0

    感谢您的详细解答。因为拟合生成器的目的是训练一个有大量数据的模型,为什么它假设我必须知道数据点的数量?如果因为任何原因我不知道数据点的确切数目会发生什么?如何设置这种情况下的步骤? –