2016-12-14 63 views

回答

0

假设您正在使用一个softmax分类器来选择N个类作为网络的最后一层。伪代码可能看起来像这样,其中最后一层的批量大小为其第一维:

# computation graph 
predictions = argmax(softmax(final_layer)) 
matches = predictions == argmax(labels) # if one-hot encoded 

# later 
batch_matches = sess.run(matches, feed_dict={...}) 

for image, does_match in zip(batch_images, batch_matches): 
    if not does_match: 
    cv2.imwrite('mismatched.png', image)