Few-shot classification(小样本分类)是机器学习和人工智能的一个子领域,解决的问题是在训练数据非常有限的情况下,学习对新样本进行分类。在传统的监督学习中,模型需要在包含大量标记样本的数据集上进行训练,每个类别都有丰富的标记样本。然而,在实际应用中,获得如此大量的标记数据可能会非常困难或昂贵。
目前网上对于入门few shot的十分少,博主之前对于episode这些也十分不明白,在看了一些资料和代码后才逐渐理解小样本是怎样进行训练的。对此,博主首先对其中的数据集加载部分进行了总结,希望能够对各位读者有一些启发。
- data_name
--- images
----- folder_name1
------- img1.png
------- img2.png
----- folder_name2
--- meta
----- classes.txt
----- fsl_train.txt
----- fsl_test.txt
----- fsl_train_class.txt
----- fsl_test_class.txt
其中folder_name1,folder_name2是文件夹的名字,通常是分类名称,有些可能也是下标数字(1-100的数字)
classes.txt里面含有图像全部的类别,如果没有需要自己构建一个,标签文件的内容大致如下:
class_name1
class_name2
class_name3
生成的文件包括:fsl_train.txt,fsl_test.txt,fsl_train_class.txt,fsl_test_class.txt文件
该代码目前支持的情况有:
文件大致内容为:fsl_train.txt:
fsl_train_class.txt
代码为:
def make_file(img_root_path, names, path, is_num):
"""
:param img_root_path: 图像文件夹
:param names: 对应的图像文件名称
:param path: 要保存的路径
:param is_num: 图像文件名称是否是数字
"""
with open(path,"w") as f:
for name in names:
img_dir = os.path.join(img_root_path,str(name))
img_names = os.listdir(img_dir)
if is_num:
sort_img_names = sorted(img_names,key=lambda s: int(s.split('.')[0]))
else:
sort_img_names = sorted(img_names)
for img_name in sort_img_names:
img_path = os.path.join(img_dir,img_name).replace(img_root_path + "/","")
f.write(f"{img_path}\n")
def generate_split_dataset(data_root, train_num, is_imgs_id, is_img_name_num):
"""
:param data_root: 数据集目录
:param train_num: 用于训练的类别数目
:param is_imgs_id: 图像文件夹名称是否是下标
:param is_img_name_num: 图像名字是否是数字
:return: None
"""
class_path = os.path.join(data_root,"meta", "classes.txt")
class_list = list_from_file(class_path)
if is_imgs_id:
# 下标从1开始,可以根据自己的需要修改
id2class = {i + 1 : _class for i, _class in enumerate(class_list)}
else:
id2class = {i: _class for i, _class in enumerate(class_list)}
# class2id = {_class : i + 1 for i, _class in enumerate(class_list)}
# 选择train_num个类作为训练集的,其他作为测试的
train_class_ids = random.sample(range(1, len(class_list) + 1),train_num)
test_class_ids = []
for id in range(1, len(class_list) + 1):
if id not in train_class_ids:
test_class_ids.append(id)
# 获得images文件夹的名称
if is_imgs_id:
train_class_name = train_class_ids
test_class_name = test_class_ids
else:
train_class_name = [id2class[id] for id in train_class_ids]
test_class_name = [id2class[id] for id in test_class_ids]
# 顺序排序
train_class_name = sorted(train_class_name)
test_class_name = sorted(test_class_name)
train_class_save_path = os.path.join(data_root, "meta", "fsl_train_class.txt")
test_class_save_path = os.path.join(data_root, "meta" , "fsl_test_class.txt")
with open(train_class_save_path, "w") as f:
for cls_name in train_class_name:
f.write(f"{str(cls_name)}\n")
with open(test_class_save_path, "w") as f:
for cls_name in test_class_name:
f.write(f"{str(cls_name)}\n")
# 将这些数据保存在fsl_train.txt中,格式为:class_name/img_name
img_root_path = os.path.join(data_root,"images")
train_imgs_name_path = os.path.join(data_root, "meta", "fsl_train.txt")
test_imgs_name_path = os.path.join(data_root, "meta", "fsl_test.txt")
make_file(img_root_path, train_class_name, train_imgs_name_path,is_img_name_num)
make_file(img_root_path, test_class_name,test_imgs_name_path, is_img_name_num)
basedataset类是一个用于加载含有类别名称的文件,代码为:
import copy
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Mapping, Optional, Sequence, Union
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os.path as osp
from PIL import Image
import torch
from util import tools
from mmpretrain.evaluation import Accuracy
class BaseFewShotDataset(Dataset, metaclass=ABCMeta):
def __init__(self,
pipeline,
data_prefix: str,
classes: Optional[Union[str, List[str]]] = None,
ann_file: Optional[str] = None) -> None:
super().__init__()
self.ann_file = ann_file
self.data_prefix = data_prefix
self.pipeline = pipeline
self.CLASSES = self.get_classes(classes)
self.data_infos = self.load_annotations()
self.data_infos_class_dict = {i: [] for i in range(len(self.CLASSES))}
for idx, data_info in enumerate(self.data_infos):
self.data_infos_class_dict[data_info['gt_label'].item()].append(
idx)
def load_image_from_file(self,info_dict):
img_prefix = info_dict['img_prefix']
img_name = info_dict['img_info']['filename']
img_file = osp.join(img_prefix,f"{img_name}")
img_data = Image.open(img_file).convert('RGB')
return img_data
@abstractmethod
def load_annotations(self):
pass
@property
def class_to_idx(self) -> Mapping:
return {_class: i for i, _class in enumerate(self.CLASSES)}
def prepare_data(self, idx: int) -> Dict:
results = copy.deepcopy(self.data_infos[idx])
imgs_data = self.load_image_from_file(results)
data = {
"img" : self.pipeline(imgs_data),
"gt_label" : torch.tensor(self.data_infos[idx]['gt_label'])
}
return data
def sample_shots_by_class_id(self, class_id: int,
num_shots: int) -> List[int]:
all_shot_ids = self.data_infos_class_dict[class_id]
return np.random.choice(
all_shot_ids, num_shots, replace=False).tolist()
def __len__(self) -> int:
return len(self.data_infos)
def __getitem__(self, idx: int) -> Dict:
return self.prepare_data(idx)
@classmethod
def get_classes(cls,
classes: Union[Sequence[str],
str] = None) -> Sequence[str]:
if isinstance(classes, str):
class_names = tools.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
该文件的作用主要是将数据从标签文件中拿出来,加载数据。
代码如下:
from datasets.base import BaseFewShotDataset
from typing_extensions import Literal
from typing import List, Optional, Sequence, Union
from util import tools
import os
import os.path as osp
import numpy as np
import torchvision.transforms as transforms
class UniversalFewShotDataset(BaseFewShotDataset):
def __init__(self,
img_size,
subset: Literal['train', 'test', 'val'] = 'train',
*args,
**kwargs):
if isinstance(subset, str):
subset = [subset]
for subset_ in subset:
assert subset_ in ['train', 'test', 'val']
self.subset = subset
self.file_format = file_format
# 归一化参数
norm_params = {'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]}
# 对数据进行处理
if subset[0] == 'train':
pipeline = transforms.Compose([
transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),
transforms.ToTensor(),
transforms.Normalize(**norm_params)
])
elif subset[0] == 'test':
pipeline = transforms.Compose([
transforms.Resize(size=int(img_size * 1.15)),
transforms.CenterCrop(size=img_size),
transforms.ToTensor(),
transforms.Normalize(**norm_params)
])
super().__init__(pipeline=pipeline, *args, **kwargs)
def get_classes(
self,
classes: Optional[Union[Sequence[str], str]] = None) -> Sequence[str]:
class_names = tools.list_from_file(classes)
return class_names
# 加载标签文件
def load_annotations(self) -> List:
data_infos = []
ann_file = self.ann_file
with open(ann_file) as f:
for i, line in enumerate(f):
class_name, filename = line.strip().split('/')
gt_label = self.class_to_idx[class_name]
info = {
'img_prefix':
osp.join(self.data_prefix, 'images', class_name),
'img_info': {
'filename': filename
},
'gt_label':
np.array(gt_label, dtype=np.int64)
}
data_infos.append(info)
return data_infos
代码如下:
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset,DataLoader
from functools import partial
import os.path as osp
from typing import Mapping
from util import tools
import json
class EpisodicDataset:
def __init__(self,
dataset: Dataset,
num_episodes: int,
num_ways: int,
num_shots: int,
num_queries: int,
episodes_seed: int):
self.dataset = dataset
self.num_ways = num_ways
self.num_shots = num_shots
self.num_queries = num_queries
self.num_episodes = num_episodes
self._len = len(self.dataset)
self.CLASSES = dataset.CLASSES
self.episodes_seed = episodes_seed
self.episode_idxes, self.episode_class_ids = \
self.generate_episodic_idxes()
def generate_episodic_idxes(self):
"""Generate batch indices for each episodic."""
episode_idxes, episode_class_ids = [], []
class_ids = [i for i in range(len(self.CLASSES))]
# 这一句可以不用
with tools.local_numpy_seed(self.episodes_seed):
for _ in range(self.num_episodes):
np.random.shuffle(class_ids)
# sample classes
sampled_cls = class_ids[:self.num_ways]
episode_class_ids.append(sampled_cls)
episodic_support_idx = []
episodic_query_idx = []
# sample instances of each class
for i in range(self.num_ways):
shots = self.dataset.sample_shots_by_class_id(
sampled_cls[i], self.num_shots + self.num_queries)
episodic_support_idx += shots[:self.num_shots]
episodic_query_idx += shots[self.num_shots:]
episode_idxes.append({
'support': episodic_support_idx,
'query': episodic_query_idx
})
return episode_idxes, episode_class_ids
def __getitem__(self, idx: int):
support_data = [self.dataset[i] for i in self.episode_idxes[idx]['support']]
query_data = [self.dataset[i] for i in self.episode_idxes[idx]['query']]
return {
'support_data':support_data,
'query_data':query_data
}
def __len__(self):
return self.num_episodes
def evaluate(self, *args, **kwargs):
return self.dataset.evaluate(*args, **kwargs)
def get_episode_class_ids(self, idx: int):
return self.episode_class_ids[idx]
配置文件除了json,也可以是其他形式的,这里以json格式为例:
{
"train":{
"num_episodes":2000,
"num_ways":10,
"num_shots":5,
"num_queries":5,
"episodes_seed":1001,
"per_gpu_batch_size":1,
"per_gpu_workers": 8,
"epoches": 160,
"dataset":{
"name": "vireo_172",
"img_size": 224,
"data_prefix":"/home/gaoxingyu/dataset/vireo-172/",
"classes":"/home/gaoxingyu/dataset/vireo-172/meta/fsl_train_class.txt",
"ann": "/home/gaoxingyu/dataset/vireo-172/meta/fsl_train.txt"
}
}
}
代码如下:
with open("config.json", 'r', encoding='utf-8') as f:
f = f.read()
configs = json.loads(f)
logger.info(f"Experiment Setting:{configs}")
# 创建数据集
## train_dataset
train_food_dataset = UniversalFewShotDataset(data_prefix=configs['train']['dataset']['data_prefix'],
subset="train", classes=configs['train']['dataset']['classes'],
img_size=configs['train']['dataset']['img_size'],ann_file=configs['train']['dataset']['ann'])
train_dataset = EpisodicDataset(dataset=train_food_dataset,
num_episodes=configs['train']['num_episodes'],
num_ways=configs['train']['num_ways'],
num_shots=configs['train']['num_shots'],
num_queries=configs['train']['num_queries'],
episodes_seed=configs['train']['episodes_seed'])
## train dataloader
train_samper = torch.utils.data.distributed.DistributedSampler(train_dataset, rank = local_rank, shuffle=True)
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=configs['train']['per_gpu_batch_size'],
sampler=train_samper,
num_workers=configs['train']['per_gpu_workers'],
collate_fn=partial(collate, samples_per_gpu=1),
worker_init_fn=worker_init_fn,
drop_last=True
)
for data in train_data_loader:
print(data)