0

我想修改以下keras均方误差损失(MSE),以便只计算稀疏损失。如何在Keras中实现稀疏均方误差损失

def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1)

我的输出y是一个3通道图像,其中,所述第三信道是非零在只有那些损失要计算的像素。任何想法如何修改上述计算稀疏损失?

回答

2

这不是你正在寻找确切的损失,但我希望它会给你一个提示,写你的函数:

def masked_mse(mask_value): 
    def f(y_true, y_pred): 
     mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx()) 
     masked_squared_error = K.square(mask_true * (y_true - y_pred)) 
     masked_mse = (K.sum(masked_squared_error, axis=-1)/
         K.sum(mask_true, axis=-1)) 
     return masked_mse 
    f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value) 
    return f 

的函数计算在预测输出的所有值的MSE损失,除了那些在真实输出中的相应值等于掩蔽值(例如-1)的元素。

有两点需要注意: - 计算平均值的分母必须是非屏蔽值的数量,而不是阵列的 尺寸时,这就是为什么我不使用K.mean(masked_squared_error, axis=1),我 而不是手动平均。 - 掩码值必须是有效的数字(即np.nannp.inf不会执行此作业),这意味着您必须调整数据以使其不包含mask_value

在此示例中,目标输出始终为[1, 1, 1, 1],但某些预测值会逐渐被屏蔽。

y_pred = K.constant([[ 1, 1, 1, 1], 
        [ 1, 1, 1, 3], 
        [ 1, 1, 1, 3], 
        [ 1, 1, 1, 3], 
        [ 1, 1, 1, 3], 
        [ 1, 1, 1, 3]]) 
y_true = K.constant([[ 1, 1, 1, 1], 
        [ 1, 1, 1, 1], 
        [-1, 1, 1, 1], 
        [-1,-1, 1, 1], 
        [-1,-1,-1, 1], 
        [-1,-1,-1,-1]]) 

true = K.eval(y_true) 
pred = K.eval(y_pred) 
loss = K.eval(masked_mse(-1)(y_true, y_pred)) 

for i in range(true.shape[0]): 
    print(true[i], pred[i], loss[i], sep='\t') 

预期输出是:

[ 1. 1. 1. 1.] [ 1. 1. 1. 1.] 0.0 
[ 1. 1. 1. 1.] [ 1. 1. 1. 3.] 1.0 
[-1. 1. 1. 1.] [ 1. 1. 1. 3.] 1.33333 
[-1. -1. 1. 1.] [ 1. 1. 1. 3.] 2.0 
[-1. -1. -1. 1.] [ 1. 1. 1. 3.] 4.0 
[-1. -1. -1. -1.] [ 1. 1. 1. 3.] nan