我正在实施WGAN,需要剪辑权重变量。如何以clip_by_weight的张量形式访问Keras图层中的权重变量?
我目前使用Tensorflow与Keras作为高层次的API。因此用Keras构建图层以避免手动创建和初始化变量。
问题是WGAN需要剪下重量变量,这可以用tf.clip_by_value(x, v0, v1)
来完成,一旦我得到了这些重量变量张量,但我不知道如何安全地得到它们。
一个可能的解决方案可能使用tf.get_collection()
来获取所有可训练变量。但我不知道如何获得体重变量没有偏差变量。
另一种解决方案是layer.get_weights()
,但它得到numpy
阵列,虽然我可以用numpy
的API剪辑他们用layer.set_weights()
设置它们,但是这可能需要CPU-GPU的公司,并且可能不是一个很好的选择,因为夹操作需要在每个训练阶段执行。
我所知道的唯一的方法是用确切变量名,我可以从TF较低级别的API或TensorBoard获得访问它们,但是这可能不是因为命名Keras的规则安全不能保证稳定。
是否有任何干净的方法只能在Tensorflow和Keras的W
上执行clip_by_value
?
谢谢期待你的答复。我已经在y = ax + b的玩具问题上尝试过了,它完美无缺! – soar0x48
不客气。 – indraforyou
使用Tensorflow实施时存在问题。使用model.fit时,它可以很好地工作,但是当使用TF命令使用sess.run(tran_step,...)时,似乎约束不起作用。有关如何解决它的任何想法? – soar0x48