实验基于论文: Class-Balanced Loss Based on Effective Number of Samples
论文解读:https://blog.csdn.net/weixin_41735859/article/details/105637597
Class-balanced-loss代码地址:https://github.com/vandit15/Class-balanced-loss-pytorch
resnet18代码参考链接:https://blog.csdn.net/sunqiande88/article/details/80100891
论文中通过公式 n = n i u i n = n_iu^i n=niui, i i i为类索引.制作长尾cifar10数据集.以下代码以不均匀比例100为例.论文作者制作好的数据集,我们也可以通过科学上网点击该谷歌云链接下载.
loadcifar.py
import torch
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# 从源文件读取数据
# 返回 train_data[12406,3072]和labels[12406]
# test_data[10000,3072]和labels[10000]
def get_data(train=False):
data = None
labels = None
new_data = None
new_labels = []
if train == True:
for i in range(1, 6):
batch = unpickle('data/cifar-10-batches-py/data_batch_' + str(i))
if i == 1:
data = batch[b'data']
labels = batch[b'labels']
else:
data = np.concatenate([data, batch[b'data']])
labels = np.concatenate([labels, batch[b'labels']])
count = np.zeros((10),dtype=np.int)
for i in range(len(labels)):
labels[i] = labels[i].reshape(1,1)
data[i] = data[i].reshape((1,3072))
# 设置 n = n_iu^i
if count[labels[i]] < int(np.floor(5000 * ((1 / 100) ** (1 / 9)) ** (labels[i]))):
count[labels[i]] += 1
if i == 0:
new_data = data[i]
else:
new_data = np.concatenate([new_data,data[i]])
new_labels.append(labels[i])
else:
continue
new_labels = np.array(new_labels)
new_data = new_data.reshape(-1,3072)
else:
batch = unpickle('data/cifar-10-batches-py/test_batch')
new_data = batch[b'data']
new_labels = batch[b'labels']
return new_data, new_labels
# 图像预处理函数,Compose会将多个transform操作包在一起
# 对于彩色图像,色彩通道不存在平稳特性
transform = transforms.Compose([
# ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C)
# 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
transforms.ToTensor(),
# Normalize函数将图像数据归一化到[-1,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 将标签转换为torch.LongTensor
def target_transform(label):
label = np.array(label)
target = torch.from_numpy(label).long()
return target
'''
自定义数据集读取框架来载入cifar10数据集
需要继承data.Dataset
'''
class Cifar10_Dataset(Data.Dataset):
def __init__(self, train=True, transform=None, target_transform=None):
# 初始化文件路径
self.transform = transform
self.target_transform = target_transform
self.train = train
# 载入训练数据集
if self.train:
self.train_data, self.train_labels = get_data(train)
num = self.train_data.shape[0]
self.train_data = self.train_data.reshape((num, 3, 32, 32))
# 将图像数据格式转换为[height,width,channels]方便预处理
self.train_data = self.train_data.transpose((0, 2, 3, 1))
# 载入测试数据集
else:
self.test_data, self.test_labels = get_data()
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1))
pass
def __getitem__(self, index):
# 从数据集中读取一个数据并对数据进行
# 预处理返回一个数据对,如(data,label)
if self.train:
img, label = self.train_data[index], self.train_labels[index]
else:
img, label = self.test_data[index], self.test_labels[index]
img = Image.fromarray(img)
# 图像预处理
if self.transform is not None:
img = self.transform(img)
# 标签预处理
if self.target_transform is not None:
target = self.target_transform(label)
return img, target
def __len__(self):
# 返回数据集的size
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
if __name__ == '__main__':
# 读取训练集和测试集
train_data = Cifar10_Dataset(True, transform, target_transform)
print('size of train_data:{}'.format(train_data.__len__()))
test_data = Cifar10_Dataset(False, transform, target_transform)
print('size of test_data:{}'.format(test_data.__len__()))
第二步:定义损失函数
第三步:训练