使用torchvision.datasets 里面有很多数据集供选择
import torch
import torchvision
from torchvision import transforms, models
batch_size = 32
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5),std=(0.5)),
])
train_data = torchvision.datasets.MNIST('./mn',train=True,download=True,transform=transform)
data_loader_train = torch.utils.data.DataLoader(dataset=train_data,batch_size= batch_size,shuffle=True)
test_data = torchvision.datasets.MNIST('./mn',train=False,download=True,transform=transform)
data_loader_test = torch.utils.data.DataLoader(dataset=test_data,batch_size= batch_size,shuffle=True)
next(iter(data_loader_train)) # 用于查看数据
device=('cuda' if torch.cuda.is_available() else 'cpu')
def load_img(image_path,transform=None,max_size=None,shape=None):
image=Image.open(image_path)
if max_size:
scale=max_size/max(image.size)
size=np.array(image.size)*scale
image=image.resize(size.astype(int),Image.ANTIALIAS)
if shape:
image=image.resize(shape)
if transform:
image=transform(image).unsqueeze(0)
return image.to(device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]),
])
content = load_img('image/content.jpg',transform,max_size=400)
style = load_img('image/style.jpg',transform,max_size=400)
这个时候需要把不同label 的数据放到不同的文件夹,ImageFolder 会自动加上标签,
from torchvison import datasets
data_dir = './data'
all_imgs=datasets.ImageFolder(os.path.join(data_dir,"train"),transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]))
loader = torch.utils.data.DataLoader(all_imgs,batch_size=batch_size,shuffle=True)
img=next(iter(loader))[0]
unloader=transforms.ToPILImage()
def imshow(tensor,title=None):
image=tensor.cpu().clone()
image=image.squeeze(0)
image=unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001)
plt.figure()
imshow(img[31],title='image')