img should be PIL Image. Got lt;class #39;torch.Tensor#39;gt;(img 应该是 PIL Image.得到了 lt;class torch.Tensorgt;)
问题描述
我正在尝试遍历加载器以检查它是否正常工作,但是给出了以下错误:
I'm trying to iterate through a loader to check if it's working, however the below error is given:
TypeError: img 应该是 PIL Image.得到了
我已经尝试添加 transforms.ToTensor() 和 transforms.ToPILImage() 并且它给了我一个错误要求相反.即,使用 ToPILImage(),它将要求张量,反之亦然.
I've tried adding both transforms.ToTensor() and transforms.ToPILImage() and it gives me an error asking for the opposite. i.e, with ToPILImage(), it will ask for tensor, and vice versa.
# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np
data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'
#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)
#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32,
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))
它应该允许我在运行 plt.imshow(images[0]) 后简单地看到图像,如果它工作正常.
It should allow me to simply see the image once I run plt.imshow(images[0]), if its working correctly.
推荐答案
transforms.RandomHorizontalFlip() 适用于 PIL.Images,而不是 torch.Tensor代码>.在上面的代码中,您在 transforms.RandomHorizontalFlip() 之前应用 transforms.ToTensor(),这会产生张量.
transforms.RandomHorizontalFlip() works on PIL.Images, not torch.Tensor. In your code above, you are applying transforms.ToTensor() prior to transforms.RandomHorizontalFlip(), which results in tensor.
但是,根据官方 pytorch 文档这里、
But, as per the official pytorch documentation here,
transforms.RandomHorizontalFlip() 水平翻转给定的 PIL以给定的概率随机图像.
transforms.RandomHorizontalFlip() horizontally flip the given PIL Image randomly with a given probability.
因此,只需更改上面代码中的转换顺序,如下所示:
So, just change the order of your transformation in above code, like below:
train_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
这篇关于img 应该是 PIL Image.得到了 <class 'torch.Tensor'>的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:img 应该是 PIL Image.得到了 <class 'torch.Te
基础教程推荐
- 由Python将MP3转换为MIDI(类型错误:无法加载插件:mtg-Melodia:Melodia) 2022-01-01
- 将 x 轴刻度更改为自定义字符串 2022-01-01
- 尝试制作WhatsApp机器人 2022-01-01
- 用 Python 编写 Fortran 无格式文件 2022-01-01
- 在 Celery 工作人员中捕获 Heroku SIGTERM 以优雅地关 2022-01-01
- pyserial - 可以从线程 a 写入串行端口,是否阻塞从线程 b 读取? 2022-01-01
- 使用生成器和迭代器时 Python 多循环失败 2022-01-01
- Discord.py 缺少必需的参数 2022-01-01
- numpy float:比算术运算中内置的慢 10 倍? 2022-01-01
- 与常规 dict 相比,Python manager.dict() 非常慢 2022-01-01
