batch_size = 16 #批次的大小
lr = 1e-4 #优化器的学习率
max_epochs = 100 #训练轮次
# 方案一:使用os.environ,这种情况如果使用GPU不需要设置
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
__init__
: 用于向类传入外部参数,同时定义样本集__getitem__
: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__
: 用于返回数据集的样本数import torch
from torchvision import datasets
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)
data_transform
可以对图像进行一定的变换,如翻转、裁剪等操作,可自己定义class MyDatast(Dataset):
det __init__(self, data_dir, info_csv, image_list, transform=None):
"""
Args:
data_dir: path to image directory.
info_csv: path to the csv file containing image indexes with corresponding labels
image_list: path to the txt file contains image names to training/validation set
transform: optional transform to be applied on a sample.
"""
label_info = pd.read_csv(info_csv)
image_file = open(image_list).readlines()
self.data_dir = data_dir
self.image_file = image_file
self.label_info = label_info
self.transform = transform
def __ggetitem__(self, index):
"""
Args:
index:the index of item
Returns:
image and its labels
"""
image_name = self.image_file[index].strip('\n')
raw_label = self.label_info.loc[self.label_info['Image_index'] == image_name]
label = raw_label.iloc[:,0]
image_name = os.path.join(self.data_dir, image_name)
image = Image.open(image_name).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.image_file)
from torch.utils.data import DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
plt.imshow(images[0].transpose(1,2,0))
plt.show()
处理的目的:
增强模型鲁棒性
扩充数据容量
new_im = transforms.RandomHorizontalFlip(p=1)(im) #p表示概率
new_im.save(os.path.join(outfile, '1_1.jpg'))
new_im = transforms.RandomVerticalFlip(p=1)(im)
new_im.save(os.path.join(outfile, '1_2.jpg'))
旋转 new_im = transforms.RandomRotation(45)(im) #随即旋转45度
缩放 new_im = transforms.Resize((100, 200))(im)
裁剪
new_im = transforms.RandomCrop(100)(im) #裁剪出100×100的区域
new_im.save(os.path.join(outfile, '4_1.jpg'))
new_im = transforms.CencerCrop(100)(im) #中心裁剪
new_im.save(os.path.join(outfile, '4_2.jpg'))
new_im = transforms.ColorJitter(brightness=1)(im)
new_im = transforms.ColorJitter(contrast=1)(im)
new_im = transforms.ColorJitter(saturation=0.5)(im)
资料参考来源:1. Datawhale社区《深入浅出PyTorch教程》
2. 有三AI《PyTorch入门及实战》
3. 其他零散网络资源