[一]深度学习Pytorch-张量定义与张量创建
[二]深度学习Pytorch-张量的操作:拼接、切分、索引和变换
[三]深度学习Pytorch-张量数学运算
[四]深度学习Pytorch-线性回归
[五]深度学习Pytorch-计算图与动态图机制
[六]深度学习Pytorch-autograd与逻辑回归
[七]深度学习Pytorch-DataLoader与Dataset(含人民币二分类实战)
[八]深度学习Pytorch-图像预处理transforms
[九]深度学习Pytorch-transforms图像增强(剪裁、翻转、旋转)
[十]深度学习Pytorch-transforms图像操作及自定义方法
transforms.Pad(padding, fill=0, padding_mode='constant')
(1)功能:对图片边缘进行填充;
(2)参数:
padding:
设置填充大小:
I. 当padding
为a
时,左右上下均填充a
个像素;
II. 当padding
为(a,b)
时,左右填充a
个像素,上下填充b
个像素;
III. 当padding
为(a,b,c,d)
时,左、上、右、下分别填充a、b、c、d
;
padding_mode:
填充模式,有4
种模式:
I. constant:
像素值由fill设定;
II. edge:
像素值由图像边缘的像素值决定;
III. reflect:
镜像填充,最后一个像素不镜像,eg. [1,2,3,4] --> [3,2,1,2,3,4,3,2]
;
向左:由于1不会镜像,所以左边镜像2、3
;
向右:由于4不会镜像,所以右边镜像3、2
;
IV. symmetric:
镜像填充,最后一个像素镜像,eg. [1,2,3,4] --> [2,1,1,2,3,4,4,3]
;
向左:1、2镜像
;
向右:4、3镜像
;
fill:
当padding_mode='constant'
时,用于设置填充的像素值,(R,G,B) or (Gray)
;
(3)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 1 Pad
transforms.Pad(padding=32, fill=(255, 0, 0), padding_mode='constant'),
transforms.Pad(padding=(8, 64), fill=(255, 0, 0), padding_mode='constant'),
transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='constant'),
transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='symmetric'),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
(1)功能:调整图片的亮度、对比度、饱和度和色相;
(2)参数:
brightness:
亮度调整因子,brightness > 1
会更亮,brightness < 1
会更暗;
I. 当brightness
为a
时,从区间[max(0,1-a),1+a]
中随机选择;
II. 当brightness
为(a,b)
时,从区间[a,b]
中随机选择;
contrast:
对比度参数,同brightness
,对比度越低,图像越灰;
saturation:
饱和度参数,同brightness
,饱和度越低,图像越暗淡;
hue:
色相参数;
I. 当hue
为a
时,从[-a,a]
中随机选择参数,注意a
的区间是0 ≤ a ≤ 0.5
;
II. 当hue
为(a,b)
时,从[a,b]
区间中随机选择参数,注意-0.5 ≤ a ≤ b ≤ 0.5
;
(3)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 2 ColorJitter
transforms.ColorJitter(brightness=0.5),
transforms.ColorJitter(contrast=0.5),
transforms.ColorJitter(saturation=0.5),
transforms.ColorJitter(hue=0.3),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
transforms.Grayscale(num_output_channels)
(1)功能:将图片转换为灰度图;
(2)参数:
num_output_channels:
输出通道数,只能设置为1
或3
;
(3)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 3 Grayscale
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
transforms.RandomGrayscale(num_output_channels, p=0.1)
(1)功能:根据概率将图片转换为灰度图;
(2)参数:
num_output_channels:
输出通道数,只能设置为1
或3
;
p:
概率值,图像被转换为灰度图的概率;
transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0)
(1)功能: 对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转;
(2)参数:
degrees:
旋转角度;
degrees旋转是中心旋转,degrees参数必须设置,不想旋转的话设置degrees=0;
I. 当degrees
为a
时,在区间(-a,a)
之间随机选择旋转角度;
II. 当degrees
为(a,b)
时,在区间(a,b)
之间随机选择旋转角度;
translate:
平移区间设置,如果为(a,b)
,a
设置宽width
,b
设置高height
,图像在宽维度平移的区间为-img_width * a < dx < img_width * a
,在高维度平移的区间为-img_height * a < dy < img_height *a
;
scale:
缩放比例(以面积为单位),scale
区间范围是[0,1]
;
fill_color:
填充颜色设置
shear:
错切角度设置,有水平错切和垂直错切;
I. 若shear
为a
,则仅在x
轴错切,错切角度在区间(-a,a)
之间随机选择;
II. 若shear
为(a,b)
,则a
设置x
轴错切角度,即区间(-a,a)
之间随机选择,b
设置y
轴错切角度,即区间(-b,b)
之间随机选择;
resample:
重采样方式,有NEAREST、BILINEAR、BICUBIC
;
(3)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 4 Affine
transforms.RandomAffine(degrees=30),
transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), fillcolor=(255, 0, 0)), #degrees参数必须设置,不想旋转的话设置degrees=0
transforms.RandomAffine(degrees=0, scale=(0.7, 0.7)), #不足的地方面积填充为黑色
transforms.RandomAffine(degrees=0, shear=(0, 0, 0, 45)), #y轴上错切
transforms.RandomAffine(degrees=0, shear=90, fillcolor=(255, 0, 0)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)
(1)功能:对图片进行随机遮挡;
(2)参数:
p:
概率值,图像被遮挡的概率;
scale:
遮挡区域的比例(以面积为单位);
ratio:
遮挡区域的长宽比;
value:
设置遮挡区域的像素值,eg. (R,G,B) or (Gray)
;
(3)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 5 Erasing
# RandomErasing接受的是张量,所以需要先进行ToTensor()操作
transforms.ToTensor(),
transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=(254/255, 0, 0)), #value=(254/255, 0, 0),此时为张量,需要进行归一化,除以255变换到0-1范围
transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='1234'), #value='1234'value为任意字符串时,则填充的为随机彩色像素值
transforms.Normalize(norm_mean, norm_std),
])
RandomErasing接受的是张量,所以需要先进行ToTensor()操作;
value='1234’value为任意字符串时,则填充的为随机彩色像素值。
transforms.Lambda(lambd)
(1)功能:用户自定义lambda
方法;
(2)参数:
lambd:
lambda
匿名函数,用法如下:
lambda [arg1 [, arg2, ..., argn]]: expression
(3)代码示例:
transforms.FiveCrop(112), #单独使用错误,直接使用transforms.FiveCrop(112)会报错,需要跟下一行一起使用
#lamda的冒号之前是函数的输入(crops),冒号之后是函数的返回值
transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])), #这里进行了ToTensor(),后面不需要执行Totensor()和Normalize
Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops]))
进行了ToTensor()
,后面不需要执行Totensor()
和Normalize
。
transforms.RandomChoice([transforms1, transforms2, transforms3])
(1)功能: 从一系列transforms
方法中随机选择一个执行; 执行一个
(2)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 1 RandomChoice
transforms.RandomChoice([transforms.RandomVerticalFlip(p=1), transforms.RandomHorizontalFlip(p=1)]),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)
(1)功能: 根据概率执行该组transforms
; 执行一组
(2)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 2 RandomApply
transforms.RandomApply([transforms.RandomAffine(degrees=0, shear=45, fillcolor=(255, 0, 0)),
transforms.Grayscale(num_output_channels=3)], p=0.5),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
transforms.RandomOrder([transforms1, transforms2, transforms3])
(1)功能: 对一组transforms
打乱顺序并执行一组; 执行一组
(2)代码示例:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 3 RandomOrder
transforms.RandomOrder([transforms.RandomRotation(15),
transforms.Pad(padding=32),
transforms.RandomAffine(degrees=0, translate=(0.01, 0.1), scale=(0.9, 1.1))]),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
transforms_methods_2.py
# -*- coding: utf-8 -*-
"""
# @file name : transforms_methods_2.py
# @brief : transforms方法二
"""
import os
import numpy as np
import torch
import random
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed(1) # 设置随机种子
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}
# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 1 Pad
# transforms.Pad(padding=32, fill=(255, 0, 0), padding_mode='constant'),
# transforms.Pad(padding=(8, 64), fill=(255, 0, 0), padding_mode='constant'),
# transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='constant'),
# transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='symmetric'),
# 2 ColorJitter
# transforms.ColorJitter(brightness=0.5),
# transforms.ColorJitter(contrast=0.5),
# transforms.ColorJitter(saturation=0.5),
# transforms.ColorJitter(hue=0.3),
# 3 Grayscale
# transforms.Grayscale(num_output_channels=3),
# 4 Affine
# transforms.RandomAffine(degrees=30),
# transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), fillcolor=(255, 0, 0)), #degrees参数必须设置,不想旋转的话设置degrees=0
# transforms.RandomAffine(degrees=0, scale=(0.7, 0.7)), #不足的地方面积填充为黑色
# transforms.RandomAffine(degrees=0, shear=(0, 0, 0, 45)), #y轴上错切
# transforms.RandomAffine(degrees=0, shear=90, fillcolor=(255, 0, 0)),
# 5 Erasing
# RandomErasing接受的是张量,所以需要先进行ToTensor()操作,注释掉89行
# transforms.ToTensor(),
# transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=(254/255, 0, 0)), #value=(254/255, 0, 0),此时为张量,需要进行归一化,除以255变换到0-1范围
# transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='1234'), #value='1234'value为任意字符串时,则填充的为随机彩色像素值
# 1 RandomChoice
# transforms.RandomChoice([transforms.RandomVerticalFlip(p=1), transforms.RandomHorizontalFlip(p=1)]),
# 2 RandomApply
# transforms.RandomApply([transforms.RandomAffine(degrees=0, shear=45, fillcolor=(255, 0, 0)),
# transforms.Grayscale(num_output_channels=3)], p=0.5),
# 3 RandomOrder
# transforms.RandomOrder([transforms.RandomRotation(15),
# transforms.Pad(padding=32),
# transforms.RandomAffine(degrees=0, translate=(0.01, 0.1), scale=(0.9, 1.1))]),
transforms.ToTensor(), #若使用RandomErasing,则注释掉该行
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
for i, data in enumerate(train_loader):
inputs, labels = data # B C H W
img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform)
plt.imshow(img)
plt.show()
plt.pause(0.5)
plt.close()
(1)仅接受一个参数,返回一个参数;
(2)注意上下游的输出与输入,比如是PIL Image还是Tensor。
my_transforms.py
# -*- coding: utf-8 -*-
"""
# @file name : my_transforms.py
# @brief : 自定义一个transforms方法
"""
import os
import numpy as np
import torch
import random
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed(1) # 设置随机种子
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}
class AddPepperNoise(object):
"""增加椒盐噪声
Args:
snr (float): Signal Noise Rate 信噪比
p (float): 概率值,依概率执行该操作
"""
def __init__(self, snr, p=0.9):
assert isinstance(snr, float) or (isinstance(p, float))
self.snr = snr
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): PIL Image
Returns:
PIL Image: PIL image.
"""
if random.uniform(0, 1) < self.p:
img_ = np.array(img).copy() #PIL Image转成ndarray
h, w, c = img_.shape #高 宽 chanel数量
signal_pct = self.snr #信号的百分比
noise_pct = (1 - self.snr) #噪声的百分比
mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.]) #0表示原噪声,1表示盐噪声,2表示椒噪声,(0,1,2)是为了构造mask
mask = np.repeat(mask, c, axis=2)
img_[mask == 1] = 255 # 盐噪声,白色
img_[mask == 2] = 0 # 椒噪声,黑色
return Image.fromarray(img_.astype('uint8')).convert('RGB') #ndarray转成PIL Image
else:
return img
# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
AddPepperNoise(0.9, p=0.5),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
for i, data in enumerate(train_loader):
inputs, labels = data # B C H W
img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform)
plt.imshow(img)
plt.show()
plt.pause(0.5)
plt.close()
# -*- coding: utf-8 -*-
"""
# @file name : RMB_data_augmentation.py
# @brief : 人民币分类模型数据增强实验
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed() # 设置随机种子
rmb_label = {"1": 0, "100": 1}
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.RandomGrayscale(p=0.9),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
scheduler.step() # 更新学习率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()
loss_val += loss.item()
valid_curve.append(loss_val)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
# ============================ inference ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)
for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
rmb = 1 if predicted.numpy()[0] == 0 else 100
img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform)
plt.imshow(img)
plt.title("LeNet got {} Yuan".format(rmb))
plt.show()
plt.pause(0.5)
plt.close()
tools/my_dataset.py
# -*- coding: utf-8 -*-
"""
# @file name : dataset.py
# @brief : 各数据集的Dataset定义
"""
import os
import random
from PIL import Image
from torch.utils.data import Dataset
random.seed(1)
rmb_label = {"1": 0, "100": 1}
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info
tools/common_tools.py
# -*- coding: utf-8 -*-
"""
# @file name : common_tools.py
# @brief : 通用函数
"""
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
def transform_invert(img_, transform_train):
"""
将data 进行反transfrom操作
:param img_: tensor
:param transform_train: torchvision.transforms
:return: PIL image
"""
if 'Normalize' in str(transform_train):
norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
img_.mul_(std[:, None, None]).add_(mean[:, None, None])
img_ = img_.transpose(0, 2).transpose(0, 1) # C*H*W --> H*W*C
if 'ToTensor' in str(transform_train):
img_ = np.array(img_) * 255
if img_.shape[2] == 3:
img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
elif img_.shape[2] == 1:
img_ = Image.fromarray(img_.astype('uint8').squeeze())
else:
raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )
return img_