2017-03-07 109 views
1

我在IRIS数据集上使用了MXNet,它具有4个特征,它将花分为'setosa','versicolor','virginica'。我的训练数据有89行。我的标签数据是89列的行向量。我将花名编码为数字-0,1,2,因为它似乎mx.io.NDArrayIter不接受带有字符串值的numpy ndarray。然后我试图使用mod.predict给出了比预期更多的列

re = mod.predict(test_iter)

我得到它具有形状的结果来预测14 * 10 为什么我收到10列时,我只有3个标签,我怎么这些结果到我的标签映射。预测的结果如下所示:

[[0.11760861 0.12082944 0.1207106 0.09154381 0.09155304 0.09155869 0.09154817 0.09155204 0.09154914 0.09154641] [0.1176083 0.12082954 0.12071151 0.09154379 0.09155323 0.09155825 0.0915481 0.09155164 0.09154923 0.09154641] [0.11760829 0.1208293 0.12071083 0.09154385 0.09155313 0.09155875 0.09154838 0.09155186 0.09154932 0.09154625] [0.11760861 0.12082901 0.12071037 0.09154388 0.09155303 0.09155875 0.09154829 0.09155209 0.09154959 0.09154641] [0.11760896 0.12082863 0.12070955 0.09154405 0.09155299 0.09155875 0.09154839 0.09155225 0.09154996 0.09154646] [0.1176089 0.1208287 0.1207095 0.09154407 0.09155297 0.09155 882 0.09154844 0.09155232 0.09154989 0.0915464] [0.11760896 0.12082864 0.12070941 0.09154408 0.09155297 0.09155882 0.09154844 0.09155234 0.09154993 0.09154642] [0.1176088 0.12082874 0.12070983 0.09154399 0.09155302 0.09155872 0.09154837 0.09155215 0.09154984 0.09154641] [0.11760852 0.12082904 0.12071032 0.09154394 0.09155304 0.09155876 0.09154835 0.09155209 0.09154959 0.09154631] [0.11760963 0.12082832 0.12070873 0.09154428 0.09155257 0.09155893 0.09154856 0.09155177 0.09155051 0.09154671] [0.11760966 0.12082829 0.12070868 0.09154429 0.09155258 0.09155892 0.09154858 0.0915518 0.09155052 0.09154672] [0.11760949 0.1208282 0.12070852 0.09154446 0.09155259 0.09155893 0.09154854 0.09155205 0.0915506 0.09154666] [0.11760952 0.12082817 0.12070853 0.0915444 0.091552 61 0.09155891 0.09154853 0.09155206 0.09155057 0.09154668] [0.1176096 0.1208283 0.12070892 0.09154423 0.09155267 0.09155882 0.09154859 0.09155172 0.09155044 0.09154676]]

+0

你有一个最小可重现的例子吗? –

回答

1

使用 “Y = mod.predict(val_iter,num_batch = 1)”,而不是“Y = MOD。预测(val_iter)“,那么你只能得到一个批量标签。例如,如果batch_size是10,那么你将只能得到10个标签。

+0

我的数据集只有3个标签 - 'setosa','versicolor','virginica'。为什么预测方法给我10个标签 –

+0

我猜你的模型使用了错误的num_hidden参数。对于虹膜数据集,我的模型将如下所示: net = mx.sym.Variable('data') net = mx.sym.FullyConnected(net,name ='fc2',num_hidden = 3) net = mx.sym .SoftmaxOutput(net,name ='softmax',multi_output = True,) #num_hidden = 3表示我们输出三个标签。如果你设置num_hidden = 10,那么你将得到10个标签输出。 –

+0

感谢您指出。我正在使用num_hidden的错误值 –