2017-04-02 238 views
-1

有定义损失函数如下程序:关于ValueError异常:如果`inputs`并不都具有相同的形状和D型或形状

def loss(hypes, decoded_logits, labels): 
"""Calculate the loss from the logits and the labels. 

Args: 
    logits: Logits tensor, float - [batch_size, NUM_CLASSES]. 
    labels: Labels tensor, int32 - [batch_size]. 

Returns: 
    loss: Loss tensor of type float. 
""" 
logits = decoded_logits['logits'] 
with tf.name_scope('loss'): 
    logits = tf.reshape(logits, (-1, 2)) 
    shape = [logits.get_shape()[0], 2] 
    epsilon = tf.constant(value=hypes['solver']['epsilon']) 
    # logits = logits + epsilon 
    labels = tf.to_float(tf.reshape(labels, (-1, 2))) 

    softmax = tf.nn.softmax(logits) + epsilon 

    if hypes['loss'] == 'xentropy': 
     cross_entropy_mean = _compute_cross_entropy_mean(hypes, labels, 
                 softmax) 
    elif hypes['loss'] == 'softF1': 
     cross_entropy_mean = _compute_f1(hypes, labels, softmax, epsilon) 

    elif hypes['loss'] == 'softIU': 
     cross_entropy_mean = _compute_soft_ui(hypes, labels, softmax, 
               epsilon) 



    reg_loss_col = tf.GraphKeys.REGULARIZATION_LOSSES 

    print('******'*10) 
    print('loss type ',hypes['loss']) 
    print('type ', type(tf.get_collection(reg_loss_col))) 
    print("Regression loss collection: {}".format(tf.get_collection(reg_loss_col))) 
    print('******'*10) 


    weight_loss = tf.add_n(tf.get_collection(reg_loss_col)) 

    total_loss = cross_entropy_mean + weight_loss 

    losses = {} 
    losses['total_loss'] = total_loss 
    losses['xentropy'] = cross_entropy_mean 
    losses['weight_loss'] = weight_loss 

return losses 

运行程序引发了以下错误信息

File "/home/ decoder/kitti_multiloss.py", line 86, in loss 
    name='reg_loss') 
    File "/devl /tensorflow/tf_0.12/lib/python3.4/site-packages/tensorflow/python/ops/math_ops.py", line 1827, in add_n 
    raise ValueError("inputs must be a list of at least one Tensor with the " 
ValueError: inputs must be a list of at least one Tensor with the same dtype and shape 

我检查了tf.add_n的功能,其实现如下。我的问题是,如何检查tf.add_n中的第一个参数tf.get_collection(reg_loss_col),并打印其信息以找出错误消息生成的原因?

def add_n(inputs, name=None): 
    """Adds all input tensors element-wise. 
    Args: 
    inputs: A list of `Tensor` objects, each with same shape and type. 
    name: A name for the operation (optional). 
    Returns: 
    A `Tensor` of same shape and type as the elements of `inputs`. 
    Raises: 
    ValueError: If `inputs` don't all have same shape and dtype or the shape 
    cannot be inferred. 
    """ 
    if not inputs or not isinstance(inputs, (list, tuple)): 
    raise ValueError("inputs must be a list of at least one Tensor with the " 
        "same dtype and shape") 
    inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) 
    if not all(isinstance(x, ops.Tensor) for x in inputs): 
    raise ValueError("inputs must be a list of at least one Tensor with the " 
        "same dtype and shape") 

回答

0

为什么你甚至需要进入add_n看到什么tf.get_collection(reg_loss_col)是什么?你可以有tmp = tf.get_collection(reg_loss_col)然后看它的类型。顺便说一句,它看起来像你没有任何规则化的损失在你的图中,在这种情况下,tf.get_collection(reg_loss_col)将返回一个空的列表。

在Python中查看对象的类型您可以使用内置函数type。例如看到tmp的类型:print type(tmp)

+0

嗨阿里,谢谢你的回复。哪个函数可以使我看到tmp = tf.get_collection(reg_loss_col)的类型?此外,在原始程序中,它有reg_loss_col = tf.GraphKeys.REGULARIZATION_LOSSES它是否代表正则化损失? – user785099

+0

已更新答案以显示对象的锄头检查类型。 'tf.GraphKeys.REGULARIZATION_LOSSES'是一个字符串,一个名称,并且通过调用'tf.get_collection()'您正在请求一个具有该名称的图节点。您需要在图表中定义损失。 – Ali

+0

http://stackoverflow.com/questions/37107223/how-to-add-regularizations-in-tensorflow可以帮助你了解什么'tf.GraphKeys.REGULARIZATION_LOSSES'是。 – Ali

相关问题