2016-10-02 69 views
0

所以我有以下numpy数组。For loop评估精度不执行

  • X验证集,X_val:(47151,32,32,1)
  • Ý验证集(标签),y_val_dummy:(47151,5,10)
  • ý验证预测套组,y_pred: (47151,5,10)

当我运行代码时,它似乎需要永远。有人可以建议为什么?我相信这是一个代码效率问题。我似乎无法完成这个过程。

y_pred_list = model.predict(X_val) 
correct_preds = 0 
# Iterate over sample dimension 
for i in range(X_val.shape[0]):   
    pred_list_i = [y_pred_array[i] for y_pred in y_pred_array] 
    val_list_i = [y_val_dummy[i] for y_val in y_val_dummy] 
    matching_preds = [pred.argmax(-1) == val.argmax(-1) for pred, val in zip(pred_list_i, val_list_i)] 
    correct_preds = int(np.all(matching_preds)) 

total_acc = correct_preds/float(x_val.shape[0]) 
+0

不应该是'[y_pred [i] for y_pred in y_pred_array]'而不是类似的下一步? – Divakar

+0

@Divakar谢谢是的。哈哈。 – Ritchie

回答

0

你主要的问题是,你产生非常大的列表数量庞大的没有真正的理由

for i in range(X_val.shape[0]): 
    # this line generates a 47151 x 5 x 10 array every time   
    pred_list_i = [y_pred_array[i] for y_pred in y_pred_array] 

发生了什么事是迭代的第二numpy的数组迭代速度最慢(即最左边的),所以每个列表理解运行在47K条目上。

稍好将

for i in range(X_val.shape[0]):   
    pred_list_i = [y_pred for y_pred in y_pred_array[i]] 
    val_list_i = [y_val for y_val in y_val_dummy[i]] 
    matching_preds = [pred.argmax(-1) == val.argmax(-1) for pred, val in zip(pred_list_i, val_list_i)] 
    correct_preds = int(np.all(matching_preds)) 

但你仍然复制了很多阵列没有真正的目的。下面的代码应该这样做,没有无用的复制。

correct_preds = 0.0 
for pred, val in zip(y_pred_array, y_val_dummy): 
    correct_preds += all(p.argmax(-1) == v.argmax(-1) 
         for p, v in zip(pred, val)) 
total_accuracy = correct_preds/x_val.shape[0] 

这假设您的准确预测准确性是准确的。 您可以完全避免显式循环,只需拨打np.argmax即可,但您必须自行解决。