2017-04-05 96 views
2

我正在尝试在keras中使用theano后端实现梯度范数的正则化术语improved WGAN training。基本上我想基于它是有多远从1由于自定义丢失函数,Keras抛出DisconnectedInputError

我实现这样一个自定义的损失,惩罚梯度的L2范数:

def get_gradient_norm(model, y_pred): 
    weights = model.trainable_weights 
    gradients = model.optimizer.get_gradients(K.mean(y_pred), weights) 
    acc = None 
    for g in gradients: 
     s = K.sum(K.square(g)) 
     if acc == None: 
      acc = s 
     else: 
      acc = s + acc 
    return K.sqrt(acc) 

def make_w_reg_loss(model): 
    lvar = K.variable(lamb, name="Lambda") 

    def foo(y_true, y_pred): 
     gnorm = get_gradient_norm(model, y_pred) 
     return lvar * K.square(gnorm - 1) 

return foo 

[...] 

critic.compile(loss=make_w_reg_loss(critic), optimizer=RMSprop(learn_rate)) 

它抛出一个DisconnectedInputError一次训练过程中尝试尝试获取我自定义丢失函数的渐变。

为什么?

用一些标准损失工作替换损失。这个错误是关于我定义的损失函数的。

看到这个要点我尝试a minimal not-working example

编辑:

所以我想我知道如何使它现在的工作。 首先,我只是随机添加这个词来我的损失,直接从富返回之前(y_true,y_pred):

K.mean(y_pred) - K.mean(y_pred) 

显然是一个常量零,如果我只能用这个词作为我的损失怎么办得到零。 但是,如果我将这个“常量零”添加到我的正则化损失中,它突然正常工作。我从正规化中获得了非零的损失,并且许多train_on_batch的优化确实也减少了损失。

所以这是一个奇怪的问题,theano有点过分投掷异常?我的问题仍然存在:为什么它会抛出原始代码。由于添加一个固定的零项修复它,它对我来说看起来像一个错误?

回答

0

我真的很想在keras中实现这个改进的wgan,我很惊讶地看到你是如何解决你的“问题”的。您是否验证过您的wgan-gp损失是否按预期运行的实验性实验? 它应该很容易检查,它是一个非常稳定的训练,使您可以使用非常深的鉴别器;) 我想做你做的同样的工作,但是与tensorflow后端,我会尝试看看你的代码和代码在这里:keras improved wgan

我会很高兴听到您的更新,我会再次写在这里,只要我有一个wgan-gp工作代码在keras/tensorflow! P.S.上面的链接正在执行张量流代码中的所有过程,迫使使用tf训练功能。我真的很喜欢你的方法,在那里我们可以简单地定义一个keras损失,使用我们所有通常的keras高级API进行训练;)

编辑:从你的代码看,你完全可以用K后端,所以你的代码应该轻松运行tensorflow后端。您是否尝试更改后端以检查问题/错误是否与Theano真正相关?

第二编辑:您正在计算梯度w.r.t的权重,但在wgan-gp纸张中,梯度损失是从梯度w.r.t开始计算生成样本和实际样本之间的平均样本。这会带来非常不同的结果。 在下面的链接,你可以找到一个非常好的改进wgan损失的实施,对theano是可能工作过: https://github.com/farizrahman4u/keras-contrib/

+0

我张贴的代码是一个残酷的削减版本,肯定不能正确实现什么,这是只是为了展示这个问题。 我的真实代码通过在真实样本和假样本之间传递插值数据点来实现采样。目前为止我只测试了玩具的例子,但他们看起来很有希望。但是,更多的“真实”工作让我望而却步,所以我无法测试出更复杂的数据集。 –

+0

我没有测试tensorflow,没有在这里安装它,因为最终损失函数包含更多的术语,异常问题不是真正的问题。它只是让我困惑。 我猜你发布的wgan实现可能是由具有更多keras经验的人编写的,并且有更好的文档记录。当我回过头来看时,我可能会用到那个,因为它似乎是在GPU上实现插值部分,我是用CPU来完成的。凉! –

+0

我最终在调试中失去了12个小时,试图修改我链接的代码,以便作为单独的梯度损失损失工作(而不是集成到鉴别器中),并且我很快卡在墙上“tensorflow获得无损”类型的错误。我突然记起你的修复,而且,我也在修复。没有你的修复,如果我通过model.summary()可视化模型,没有输入层。用你简单的修复,突然输入层显示为输入(并且梯度损失损失不起作用) –

相关问题