2017-08-14 116 views
0

我正在使用Keras和Tensorflow。 因为我想创建LSTM-CRF model,我定义使用tf.contrib.crf.crf_log_likelihood我自己的损失函数:如何将损失函数中的变量存储到实例变量中

def loss(self, y_true, y_pred): 
    sequence_lengths = ... # calc from y_true 
    log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(y_pred, y_true, sequence_lengths) 
    loss = tf.reduce_mean(-log_likelihood) 
    self.transition_params = transition_params 

    return loss 

如你所知,CRF需要在预测相变PARAMS。所以我将transition_params存储为实例变量,self.transition_params

问题是,self.transition_params在minibatch期间从未更新过。根据我的观察,编译模型时似乎只存储一次。

有没有什么方法可以将损失函数中的变量存储到Keras中的实例变量中?

回答

2

问题是错误的函数签名tf.contrib.crf.crf_log_likelihood,您需要通过transition_params与您当前的转换参数。以下更改将解决相同的问题。

log_likelihood, transition_params = 
    tf.contrib.crf.crf_log_likelihood(y_pred, y_true, sequence_lengths, 
    transition_params=self.transition_params)