2017-07-29 62 views
0

我在lua中有以下代码writtein。从torch删除项目。传感器

我想从scores及其相应的分数得到N个最高分数的索引。

它看起来像我将不得不迭代从scores删除当前的最大值,并再次检索最大值,但找不到一个合适的方式来做到这一点。

nqs=dataset['question']:size(1); 
scores=torch.Tensor(nqs,noutput); 
qids=torch.LongTensor(nqs); 
for i=1,nqs,batch_size do 
    xlua.progress(i, nqs) 
    r=math.min(i+batch_size-1,nqs); 
    scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r); 
-- print(scores) 
end 

tmp,pred=torch.max(scores,2); 

回答

1

我希望我没有误会,因为你展示的代码(尤其是福尔循环)并没有真正似乎相关要你想做的事情。无论如何,这是我该怎么做。

sr=scores:view(-1,scores:size(1)*scores:size(2)) 
val,id=sr:sort() 
--val is a row vector with the values stored in increasing order 
--id will be the corresponding index in sr 
--now you can slice val and id from the end to find the N values you want, then you can recover the original index in the scores matrix simply with 
col=(index-1)%scores:size(2)+1 
row=math.ceil(index/scores:size(2)) 

希望这有助于。

+0

您能否详细说明“从最后找到N个值的切片val和id”部分? – ytrewq

+0

我喜欢val {{{1},{val:size(2)-N + 1,val:size(2)}}]我喜欢''并且与'id'相同,因为'N'最大的元素是在排序张量的末尾。 – Ash

+0

请注意,这不会解决重复的问题(我的意思是如果'scores'包含,* eg *,例如,100的最大值的两倍),但我认为这不是问题,因为它不是问题在你的问题中提到。 – Ash