2017-05-31 350 views
2

因此,我在创建的多标签数据集(约20000个样本)上训练了一个深度神经网络。我切换SOFTMAX乙状结肠和尝试(使用亚当优化器),以尽量减少:用于多标签分类的不平衡数据集

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred) 

我结束了与这位国王预测(漂亮“恒”):

Prediction for Im1 : [ 0.59275776 0.08751075 0.37567005 0.1636796 0.42361438 0.08701646 0.38991812 0.54468459 0.34593087 0.82790571] 

Prediction for Im2 : [ 0.52609032 0.07885984 0.45780018 0.04995904 0.32828355 0.07349177 0.35400775 0.36479294 0.30002621 0.84438241] 

Prediction for Im3 : [ 0.58714485 0.03258472 0.3349618 0.03199361 0.54665488 0.02271551 0.43719986 0.54638696 0.20344526 0.88144571] 

起初,我还以为我只需要为每个班级找到一个门槛值。

但我注意到,例如,在我的20000个样本中,第一类出现约10800,所以是0.54的比率,它是我每次预测的值。所以我认为我需要找到解决tuis“不平衡数据集”问题的方法。

我想减少我的数据集(Undersampling)每个班级的发生次数相同,但只有26个样本对应我的一个班级...这会让我失去很多样本...

我读过关于过度取样或关于更多的惩罚更少的类是罕见的,但没有真正理解它是如何工作的。

有人可以分享一些关于这些方法的解释吗?

在实践中,在Tensorflow上,是否有函数可以帮助实现这个功能?

其他建议?

谢谢:)

PS:Neural Network for Imbalanced Multi-Class Multi-Label Classification这篇文章提出了同样的问题,但没有答案!

+0

为什么不使用您拥有的所有样本,并使用该不平衡数据来使用异常检测算法? – Gabriel

+0

如果我理解的很好,你的建议是在我的(9)班(在我的数据集中“很好”代表)上训练我的网络,然后在我的“代表性很差”的班级上训练另一个网络(就像在这个二进制分类上做的那样类)? –

+1

不,我建议使用算法来检测非常小的数字,这对于绝大多数数据来说是不同的。他们通常被称为异常检测算法,因为通常当您尝试检测异常时,您有很多“好”样本但很少“异常”样本。然而,这些算法通常用于在两个类别之间进行分类。所以也许这对你不好,但可能是更复杂的分类过程的一部分 – Gabriel

回答

1

那么,在一个班级中有10000个样本,而在一个难得的班级中只有26个样本,这确实是一个问题。

但是,对我而言,您所看到的更像是“输出甚至不会看到输入”,因此网络只会学习您的输出分布。

要调试这个,我会创建一个缩减集(仅用于此调试目的),比如说每个类有26个样本,然后尝试严重过度配合。如果你得到正确的预测,我的想法是错误的。但是,如果网络甚至无法检测到那些欠采样的过载样本,那么确实这是一个架构/实现问题,而不是由于图表分布(然后您需要修复,但它不会像当前结果那么糟糕)。

+0

我认为这可能是来自我的网络的一个问题,但对于单标签分类(如MNIST)(当我使用Softmax时)它可以正常工作。但是我仍然会尝试为每班的26个样本进行过度训练!谢谢您的回答 ! –

+1

好吧,非常明确,你是对的..不幸的是我!但是,正如我之前所说的,完全相同的架构用于在MNIST数据集和我创建的数据集(多类单标签)上学习和执行得非常好! 唯一改变的是我用Sigmoid取代了Softmax。 –

1

你的问题不是班级不平衡,而只是缺少数据。对于几乎任何真正的机器学习任务,26个样本被认为是非常小的数据集。通过确保每个小班将至少有一个来自每个班级的样本(这导致了一些样本将比另一个更频繁地使用,但是谁在乎)的情况,可以容易地处理班级不平衡。

但是,如果仅存在26个样本,此方法(以及其他方法)将很快导致过拟合。这个问题可以通过某种形式的数据增强得到部分解决,但是仍然有太少的样本来构建合理的东西。

所以,我的建议是收集更多的数据。

+0

26不是我的数据集的大小,而是整个数据集中的一个类的出现次数(即20000个样本)。感谢你们“确保每个小班将至少有一个班级的样本”。这跟Oersampling是一样的吗? :) –