使用数据集:宝可梦【皮卡丘:234,超梦:239;杰尼龟:223;小火龙:238;妙蛙种子:234】
思想:数据较少,可以利用迁移学习来获得好的效果。
4 steps
class NumbersDataset(Dataset)
def __init__(self,training=True) #分为训练和测试数据集
if training:
self.samples=list(range(1,1001))
else
self.samples=list(range(1001,1501))
def __len__(self):
return len(self.samples)
def__getitem__(self,idx):
return self.sample[idx]
宝可梦数据预处理
Image Resize:224x224 for ResNet18
Data Argumentation: Rotate,Crop #增加数据集
Normalize:Mean,std
宝可梦数据集代码的整体结构
class Pokemon(Dataset):
def __init__(self, root, resize, mode):{...}
def load_csv(self, filename):{...
return images, labels}
def __len__(self):{
return len(self.images)}
def denormalize(self, x_hat):{...
return x}
def __getitem__(self, idx):{...
return img, label}
初始化__init__
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
'''首先创建文件夹对应的label映射表'''
self.name2label = {} # "sq...":0
for name in sorted(os.listdir(os.path.join(root))): #导入
if not os.path.isdir(os.path.join(root, name)): #过滤文件
continue
self.name2label[name] = len(self.name2label.keys()) #用文件名做label
# print(self.name2label)
'''导入csv[image,label]表'''
self.images, self.labels = self.load_csv('images.csv')
'''数据集分为6:2:2进行训练,验证,测试'''
if mode=='train': # 60% = 【0%->60%】
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif mode=='val': # 20% = 【60%->80%】
self.images =
self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels =
self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else: # 20% = 【80%->100%】
self.images = self.images[int(0.8*len(self.images)):]
self.labels = self.labels[int(0.8*len(self.labels)):]
加载csv如果不存在csv[images,labels]就创建
'''1将路径存成一维的表格'''
self.images, self.labels = self.load_csv('images.csv')#csv存在就导入
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):#csv不存在就创建
images = [] #存入images数组
for name in self.name2label.keys():
# 类似pokemon\\mewtwo\\00001.png保存在images里
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#print(len(images), images) # 1167, 'pokemon\\bulbasaur\\00000000.png'
'''1.1将images表打乱并和label写入二维表格csv'''
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images: # 'pokemon\\bulbasaur\\00000000.png'
name = img.split(os.sep)[-2]
label = self.name2label[name]
# 存储格式:【'pokemon\\bulbasaur\\00000000.png', 0】
writer.writerow([img, label])
#print('writen into csv file:', filename)
'''2加载csv file'''
images, labels = [], [] #分别存入images和labels数组
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
# 'pokemon\\bulbasaur\\00000000.png', 0
img, label = row #每一行
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels) #让数据和label数组长度一样
获取数组长度__len__ &特定数组元素__len__
#from torchvision import transforms
#from PIL import Image
def __getitem__(self, idx):
# idx~[0~len(images)]
# self.images, self.labels
# [img,label]=[0,'pokemon\\bulbasaur\\00000000.png']
img, label = self.images[idx], self.labels[idx]
'''tf功能:把path变成指定的数据类型'''
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'), #图片路径 => 图片数据
'''重置图像大小'''
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
'''对图片进行随机旋转'''
transforms.RandomRotation(15),
'''对图片进行中心裁剪'''
transforms.CenterCrop(self.resize),
'''图片转为tensor类型'''
transforms.ToTensor(),
'''进行标准化服从N[μ,σ²],归一化至[0-1],即先减均值,再除以标准差'''
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
'''把[path,label]变成tensor的int类型'''
img = tf(img)
label = torch.tensor(label)
return img, label
为了满足可视化,将标准化的x还原
def denormalize(self, x_hat):
# x_hat = (x-mean)/std
# x = x_hat*std = mean
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x: [c, h, w]
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
Inherit from base class
可视化图片:
'''加载一张'''
import visdom
viz = visdom.Visdom()
db = Pokemon('pokemon', 224, 'train')
x,y = next(iter(db))
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
viz.images(db.denormalize(x), win='sample_x', opts=dict(title='batch'))
'''加载多张'''
from torch.utils.data import Dataset, DataLoader
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
'''totensor的方法二'''
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
])
db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
loader = DataLoader(db, batch_size=32, shuffle=True)