2017-07-17 76 views
1

我有一个keras模型建立如下(TF 1.2.1):Keras训练环路如何过滤损失值?

import tensorflow.contrib.keras as keras 

model = keras.models.Sequential() 

... 

model.compile(loss=keras.losses.mean_squared_error, 
       optimizer=keras.optimizers.Adam(lr=1e-4)) 

model.summary() 


Layer (type)     Output Shape    Param # 
================================================================= 
conv2d_1 (Conv2D)   (None, 29, 29, 64)  6336  
_________________________________________________________________ 
conv2d_2 (Conv2D)   (None, 13, 13, 128)  204928  
_________________________________________________________________ 
conv2d_3 (Conv2D)   (None, 11, 11, 256)  295168  
_________________________________________________________________ 
conv2d_4 (Conv2D)   (None, 5, 5, 256)   590080  
_________________________________________________________________ 
flatten_1 (Flatten)   (None, 6400)    0   
_________________________________________________________________ 
dense_1 (Dense)    (None, 2)     12802  
================================================================= 
Total params: 1,109,314 
Trainable params: 1,109,314 
Non-trainable params: 0 

的输出是一个简单的浮动向量和它收敛如期望。损失是均方误差。示例输出:

18/100 [====>.........................] - ETA: 30s - loss: 31.5118 
19/100 [====>.........................] - ETA: 29s - loss: 30.7577 
20/100 [=====>........................] - ETA: 29s - loss: 29.7815 
21/100 [=====>........................] - ETA: 28s - loss: 29.0535 
22/100 [=====>........................] - ETA: 28s - loss: 28.1963 
23/100 [=====>........................] - ETA: 28s - loss: 27.3314 
24/100 [======>.......................] - ETA: 28s - loss: 26.7219 
25/100 [======>.......................] - ETA: 28s - loss: 25.9702 
26/100 [======>.......................] - ETA: 27s - loss: 25.4181 
27/100 [=======>......................] - ETA: 27s - loss: 25.0638 
28/100 [=======>......................] - ETA: 27s - loss: 24.6081 
29/100 [=======>......................] - ETA: 26s - loss: 24.0928 

的损失似乎在逐渐下降。然而,当我看实际损失([email protected]_batch_end)它不是那么顺利:

25.473383 
28.051779 
20.519075 
13.204493 
20.74946 
21.246254 
25.611149 
13.194682 
13.268744 
15.408422 
17.183851 
11.232637 
14.493115 
10.196851 

我试图在Keras源代码挖,但不明白什么是引擎盖下发生。凯拉如何过滤实际损失?源代码中的哪些地方可以找到它?

谢谢!

回答

0

因此,progbar中实际显示的是打印时在给定历元中执行的所有批次的损失的平均值。 (平均2批后的前2个,3个时代后的前3个平均值等)。因此 - 您可以在n-th时代之后获取打印后的数值,并在第一个n损失值上取平均值。你可以在Progbar的定义中阅读here

+0

明白了。谢谢! – Wilco

相关问题