记录一下,方便后边查阅。
torchvision是pytorch的一个图形库,它服务于深度学习Pytorch框架,主要用来构建计算机视觉模型。
下面是torchvision的构成[1]:
1.torchvision.datasets:一些加载数据的函数及常用的数据集接口;
2.torchvision.models:包含常用的模型结构,例如AlexNet,VGG,ResNet等;
3.torchvision.transforms:常用的一些图片变换,例如图片裁剪、选择等;
4.torchvison.utils:其他一些有用的方法
代码示例:
def get_train_dataset():
return dataset.FashionMNIST(
root='./data',
train=True,
download=True,
transform=getTransforms()
其中,
root:表示数据集下载保存位置
train:表示下载的数据集是不是训练集,True表示训练集,False表示测试集
download:表示数据集是否需要下载
transform:表示图片变换的一系列操作
root路径详解:
root='/':表示根目录下,如果你的代码保存在D盘,就是下载到D盘根目录下
root='./':表示当前文件夹下
root='':效果等同于'./'
root='./data':表示在当前文件夹下的data(如果没有,则会新建一个)文件夹下保存数据集
root='data':效果等同于'./data'
代码出自文献[2]:
import torch.utils.data
from torchvision import datasets as dataset
from torchvision import transforms
def getTransforms():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3018,))]
)
return transform
def get_train_dataset():
return dataset.FashionMNIST(
root='',
train=True,
download=True,
transform=getTransforms()
)
def get_test_dataset():
return dataset.FashionMNIST(
root='',
train=False,
download=True,
transform=getTransforms()
)
def get_train_loader(batch_size, shuffle=True):
return torch.utils.data.DataLoader(
dataset=get_train_dataset(),
batch_size=batch_size,
shuffle=shuffle
)
def get_test_loader(batch_size, shuffle=True):
return torch.utils.data.DataLoader(
dataset=get_test_dataset(),
batch_size=batch_size,
shuffle=shuffle
)
import torch.utils.data
from torchvision import datasets as dataset
from torchvision import transforms
def getTransforms():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3018,))]
)
return transform
def get_train_dataset():
return dataset.MNIST(
root='LeNetTest',
train=True,
download=True,
transform=getTransforms()
)
def get_test_dataset():
return dataset.MNIST(
root='LeNetTest',
train=False,
download=True,
transform=getTransforms()
)
def get_train_loader(batch_size, shuffle=True):
return torch.utils.data.DataLoader(
dataset=get_train_dataset(),
batch_size=batch_size,
shuffle=shuffle
)
def get_test_loader(batch_size, shuffle=True):
return torch.utils.data.DataLoader(
dataset=get_test_dataset(),
batch_size=batch_size,
shuffle=shuffle
)
第一次写,可能写得不是很好,希望大家多多包涵!
[1]:https://wenku.baidu.com/view/21bbc06bf4ec4afe04a1b0717fd5360cba1a8df6.html
[2]:https://blog.csdn.net/weixin_38878828/article/details/125614377