from torchvision.transforms import ToTensor#用于把图片转化为张量
import numpy as np#用于将张量转化为数组,进行除法
from torchvision.datasets import ImageFolder#用于导入图片数据集
means = [0,0,0]
std = [0,0,0]#初始化均值和方差
transform=ToTensor()#可将图片类型转化为张量,并把0~255的像素值缩小到0~1之间
dataset=ImageFolder("./data/train",transform=transform)#导入数据集的图片,并且转化为张量
num_imgs=len(dataset)#获取数据集的图片数量
for img,a in dataset:#遍历数据集的张量和标签
for i in range(3):#遍历图片的RGB三通道
# 计算每一个通道的均值和标准差
means[i] += img[i, :, :].mean()
std[i] += img[i, :, :].std()
mean=np.array(means)/num_imgs
std=np.array(std)/num_imgs#要使数据集归一化,均值和方差需除以总图片数量
print(mean,std)#打印出结果
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
参数详解
root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
loader:表示数据集加载方式,通常默认加载方式即可。
is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件
返回的dataset都有以下三种属性:
self.classes:用一个 list 保存类别名称self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
self.imgs:保存(img-path, class) tuple的 list
from torchvision.transforms import ToTensor#用于把图片转化为张量
from torchvision.datasets import ImageFolder#用于导入图片数据集
transform=ToTensor()#可将图片类型转化为张量,并把0~255的像素值缩小到0~1之间
dataset=ImageFolder("./data/train",transform=transform)#导入数据集的图片,并且转化为张量
print('类别',dataset.classes)
print('类别索引',dataset.class_to_idx)
print('图片列表',dataset.imgs)
print('数据集大小',len(dataset))
功能:逐channel的对图像进行标准化(均值变为0,标准差变为1),可以加快模型的收敛
output = (input - mean) / std
mean:各通道的均值
std:各通道的标准差
from torchvision.transforms import ToTensor#用于把图片转化为张量
import numpy as np#用于将张量转化为数组,进行除法
from torchvision.datasets import ImageFolder#用于导入图片数据集
from torchvision import transforms
means = [0,0,0]
std = [0,0,0]#初始化均值和方差
transform=ToTensor()#可将图片类型转化为张量,并把0~255的像素值缩小到0~1之间
dataset=ImageFolder("./data/train",transform=transform)#导入数据集的图片,并且转化为张量
num_imgs=len(dataset)#获取数据集的图片数量
for img,a in dataset:#遍历数据集的张量和标签
for i in range(3):#遍历图片的RGB三通道
# 计算每一个通道的均值和标准差
means[i] += img[i, :, :].mean()
std[i] += img[i, :, :].std()
mean=np.array(means)/num_imgs
std=np.array(std)/num_imgs#要使数据集归一化,均值和方差需除以总图片数量
print('未归一化的结果',mean,std)#打印出结果
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean ,std=std)
# transforms.Normalize(mean=[0.53124684 ,0.44909474 ,0.3981565] ,std=[0.23371747 ,0.23109557 ,0.2231768 ])
])
means = [0,0,0]
std = [0,0,0]#初始化均值和方差
dataset=ImageFolder("./data/train",transform=transform)#导入数据集的图片,并且转化为张量
num_imgs=len(dataset)#获取数据集的图片数量
for img,a in dataset:#遍历数据集的张量和标签
for i in range(3):#遍历图片的RGB三通道
# 计算每一个通道的均值和标准差
means[i] += img[i, :, :].mean()
std[i] += img[i, :, :].std()
mean=np.array(means)/num_imgs
std=np.array(std)/num_imgs#要使数据集归一化,均值和方差需除以总图片数量
print('归一化后的结果',mean,std)#打印出结果