2016-12-08 57 views
0

我想创建一个热图,描绘具有不同粒度的卷积神经网络的分类准确性。在我对CNN的特殊实现中,我只使用奇数大小的过滤器,但绘制的热图将绘图元素放置在奇数和偶数位置。什么是忽略偶数位置的正确方法,并将这些元素仅绘制在奇数位置?matplotlib热图 - 绘制奇数指数

search_height = range(1, 5, 2) # [1,3] 
search_width = range(1, 5, 2) # [1,3] 

预分配一个数组来保存精度值:

我通过定义内核的宽度和高度,我很感兴趣的开始。我认为这可能是问题的一部分,因为它在偶数索引中存储值?

grid_accuracy = np.empty((len(search_height), len(search_width))) # 2x2 array 

然后我得到的网络不同的内核尺寸的精度,并将其存储在阵列中:

for i, h in enumerate(search_height): 
    for j, w in enumerate(search_width): 
     cur_test_acc = main(batch_size=200, num_epochs=100, k_height=h, k_width=w) 
     grid_accuracy[i, j] = cur_test_acc 

最后我绘制热图与存储的精度值:

plt.imshow(grid_accuracy, cmap = plt.cm.hot, interpolation='none') 
plt.grid(True) 
plt.xlabel('Kernel Width') 
plt.ylabel('Kernel Height') 
plt.xticks(search_width) 
plt.yticks(search_height) 

问题是,我最终得到一个如下所示的情节:

enter image description here

目前它似乎使用grid_accuracy的水平和垂直索引作为元素的位置。

我真正想要的只是一个2x2网格,其中每个单元格的值是相应内核宽度/高度的精度。轴蜱应该是我手动定义(最好与水平轴蜱情节以上)的高度和宽度:

enter image description here

回答

0

您需要设置刻度标签而不仅仅是自己的立场:

import numpy as np 
from matplotlib import pyplot as plt 

plt.imshow(np.random.randn(2, 2), interpolation='none') 
plt.xticks([0, 1], [1, 3]) 
plt.yticks([0, 1], [1, 3]) 

您可以使用plt.tick_params得到你要寻找的格式:

plt.tick_params(length=0, labelbottom=False, labeltop=True) 
+0

谢谢。实际上,我有更多的过滤器大小,而不仅仅是[1,3]。有没有办法自动设置标签,所以我不需要手动输入全部? – Simon

+0

另外,我如何将x标题移动到图上方? – Simon

+1

您可以将'range(1,5,2)'作为标签参数传递给'plt.ticks'(并且tick位置只是'range(grid.shape [1])'和'range(grid.shape [ 0])'分别为x和y轴)。我已经展示了如何使用'tick_params'将x刻度标签移动到顶部。 –