2017-03-01 580 views
4

我有以下代码。与BatchNormalization图层相关的参数数量是多少?

x = keras.layers.Input(batch_shape = (None, 4096)) 
hidden = keras.layers.Dense(512, activation = 'relu')(x) 
hidden = keras.layers.BatchNormalization()(hidden) 
hidden = keras.layers.Dropout(0.5)(hidden) 
predictions = keras.layers.Dense(80, activation = 'sigmoid')(hidden) 
mlp_model = keras.models.Model(input = [x], output = [predictions]) 
mlp_model.summary() 

而这是模型总结:

____________________________________________________________________________________________________ 
Layer (type)      Output Shape   Param #  Connected to      
==================================================================================================== 
input_3 (InputLayer)    (None, 4096)   0            
____________________________________________________________________________________________________ 
dense_1 (Dense)     (None, 512)   2097664  input_3[0][0]      
____________________________________________________________________________________________________ 
batchnormalization_1 (BatchNorma (None, 512)   2048  dense_1[0][0]      
____________________________________________________________________________________________________ 
dropout_1 (Dropout)    (None, 512)   0   batchnormalization_1[0][0]  
____________________________________________________________________________________________________ 
dense_2 (Dense)     (None, 80)   41040  dropout_1[0][0]     
==================================================================================================== 
Total params: 2,140,752 
Trainable params: 2,139,728 
Non-trainable params: 1,024 
____________________________________________________________________________________________________ 

其输入为BatchNormalization(BN)层的尺寸为512根据Keras documentation,输出为BN层的形状是相同输入是512.

那么与BN层相关的参数数量是多少?

+0

有什么理由不满意这个问题? –

回答

5

Keras中的批处理标准化实现this paper

正如你可以在那里阅读的那样,为了使批处理标准化在训练过程中工作,他们需要跟踪每个标准化尺寸的分布。要这样做,因为默认情况下你在mode=0之内,所以它们在上一层计算每个特征的4个参数。这些参数确保您正确传播和反向传播信息。

所以4*512 = 2048,这应该回答你的问题。

1

这些2048参数实际上是[gamma weights, beta weights, moving_mean(non-trainable), moving_variance(non-trainable)],每个参数都有512个元素(输入层的大小)。

相关问题