2017-10-15 75 views
0

我想看看我在我的网络中使用的图片都是OK的,所以我使用下面的代码保存的一群人:torchvision MNIST加载程序无法正常工作,或者我做错了什么?

train_set = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=download) 

for it, (img, target) in enumerate(train_loader): 
    X = Variable(img) 
    tar = Variable(target) 
    X = X.view(batch_size, -1) 
    cur_img_batch = X.data.numpy() 
    cur_tar_batch = tar.data.numpy() 
    for i in range(batch_size): 
     cur_img = cur_img_batch[i] 
     im = Image.fromarray(cur_img.reshape((28, 28)).astype('uint8') * 255) 
     if cur_tar_batch[i] == 8: 
      im.save(test_img_dir + 'iter_' + str(it) + '_sample_' + str(i) + '.png') 

这不是最干净的代码,但它只是节省了一堆所有标记为“8”的图像。打开它们后,我发现其中大部分看起来像this,尽管它们中的一小部分完全是fine

我做错了什么?

+0

此行'cur_img.reshape((28,28))。astype('uint8')* 255'您是否将数据转换为整数后再乘以255? –

+0

当然!这是它 - 非常感谢:) –

+0

正确的行应该是:im = Image.fromarray((cur_img.reshape((28,28))* 255).astype('uint8')) –

回答

0

从评论:

的问题是在此行cur_img.reshape((28, 28)).astype('uint8') * 255,由255相乘,从而导致图像与0或255

更新的代码之前的归一化图像转换为整数:

​​3210
相关问题