0
我想看看我在我的网络中使用的图片都是OK的,所以我使用下面的代码保存的一群人:torchvision MNIST加载程序无法正常工作,或者我做错了什么?
train_set = dset.MNIST(root=root, train=True, transform=transforms.ToTensor(), download=download)
for it, (img, target) in enumerate(train_loader):
X = Variable(img)
tar = Variable(target)
X = X.view(batch_size, -1)
cur_img_batch = X.data.numpy()
cur_tar_batch = tar.data.numpy()
for i in range(batch_size):
cur_img = cur_img_batch[i]
im = Image.fromarray(cur_img.reshape((28, 28)).astype('uint8') * 255)
if cur_tar_batch[i] == 8:
im.save(test_img_dir + 'iter_' + str(it) + '_sample_' + str(i) + '.png')
这不是最干净的代码,但它只是节省了一堆所有标记为“8”的图像。打开它们后,我发现其中大部分看起来像this,尽管它们中的一小部分完全是fine。
我做错了什么?
此行'cur_img.reshape((28,28))。astype('uint8')* 255'您是否将数据转换为整数后再乘以255? –
当然!这是它 - 非常感谢:) –
正确的行应该是:im = Image.fromarray((cur_img.reshape((28,28))* 255).astype('uint8')) –