-1
有定义损失函数如下程序:关于ValueError异常:如果`inputs`并不都具有相同的形状和D型或形状
def loss(hypes, decoded_logits, labels):
"""Calculate the loss from the logits and the labels.
Args:
logits: Logits tensor, float - [batch_size, NUM_CLASSES].
labels: Labels tensor, int32 - [batch_size].
Returns:
loss: Loss tensor of type float.
"""
logits = decoded_logits['logits']
with tf.name_scope('loss'):
logits = tf.reshape(logits, (-1, 2))
shape = [logits.get_shape()[0], 2]
epsilon = tf.constant(value=hypes['solver']['epsilon'])
# logits = logits + epsilon
labels = tf.to_float(tf.reshape(labels, (-1, 2)))
softmax = tf.nn.softmax(logits) + epsilon
if hypes['loss'] == 'xentropy':
cross_entropy_mean = _compute_cross_entropy_mean(hypes, labels,
softmax)
elif hypes['loss'] == 'softF1':
cross_entropy_mean = _compute_f1(hypes, labels, softmax, epsilon)
elif hypes['loss'] == 'softIU':
cross_entropy_mean = _compute_soft_ui(hypes, labels, softmax,
epsilon)
reg_loss_col = tf.GraphKeys.REGULARIZATION_LOSSES
print('******'*10)
print('loss type ',hypes['loss'])
print('type ', type(tf.get_collection(reg_loss_col)))
print("Regression loss collection: {}".format(tf.get_collection(reg_loss_col)))
print('******'*10)
weight_loss = tf.add_n(tf.get_collection(reg_loss_col))
total_loss = cross_entropy_mean + weight_loss
losses = {}
losses['total_loss'] = total_loss
losses['xentropy'] = cross_entropy_mean
losses['weight_loss'] = weight_loss
return losses
运行程序引发了以下错误信息
File "/home/ decoder/kitti_multiloss.py", line 86, in loss
name='reg_loss')
File "/devl /tensorflow/tf_0.12/lib/python3.4/site-packages/tensorflow/python/ops/math_ops.py", line 1827, in add_n
raise ValueError("inputs must be a list of at least one Tensor with the "
ValueError: inputs must be a list of at least one Tensor with the same dtype and shape
我检查了tf.add_n
的功能,其实现如下。我的问题是,如何检查tf.add_n
中的第一个参数tf.get_collection(reg_loss_col)
,并打印其信息以找出错误消息生成的原因?
def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
Args:
inputs: A list of `Tensor` objects, each with same shape and type.
name: A name for the operation (optional).
Returns:
A `Tensor` of same shape and type as the elements of `inputs`.
Raises:
ValueError: If `inputs` don't all have same shape and dtype or the shape
cannot be inferred.
"""
if not inputs or not isinstance(inputs, (list, tuple)):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
if not all(isinstance(x, ops.Tensor) for x in inputs):
raise ValueError("inputs must be a list of at least one Tensor with the "
"same dtype and shape")
嗨阿里,谢谢你的回复。哪个函数可以使我看到tmp = tf.get_collection(reg_loss_col)的类型?此外,在原始程序中,它有reg_loss_col = tf.GraphKeys.REGULARIZATION_LOSSES它是否代表正则化损失? – user785099
已更新答案以显示对象的锄头检查类型。 'tf.GraphKeys.REGULARIZATION_LOSSES'是一个字符串,一个名称,并且通过调用'tf.get_collection()'您正在请求一个具有该名称的图节点。您需要在图表中定义损失。 – Ali
http://stackoverflow.com/questions/37107223/how-to-add-regularizations-in-tensorflow可以帮助你了解什么'tf.GraphKeys.REGULARIZATION_LOSSES'是。 – Ali