class AgeDB(data.Dataset):
def __init__(self, df, data_dir, img_size, split='train', reweight='none',
lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
self.df = df
self.data_dir = data_dir
self.img_size = img_size
self.split = split
self.weights = self._prepare_weights(reweight=reweight, lds=lds, lds_kernel=lds_kernel, lds_ks=lds_ks, lds_sigma=lds_sigma)
def __len__(self):
return len(self.df)
def __getitem__(self, index):
index = index % len(self.df)
row = self.df.iloc[index]
img = Image.open(os.path.join(self.data_dir, row['path'])).convert('RGB')
transform = self.get_transform()
img = transform(img)
label = np.asarray([row['age']]).astype('float32')
weight = np.asarray([self.weights[index]]).astype('float32') if self.weights is not None else np.asarray([np.float32(1.)])
return img, label, weight
def get_transform(self):
if self.split == 'train':
transform = transforms.Compose([
transforms.Resize((self.img_size, self.img_size)),
transforms.RandomCrop(self.img_size, padding=16),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
])
else:
transform = transforms.Compose([
transforms.Resize((self.img_size, self.img_size)),
transforms.ToTensor(),
transforms.Normalize([.5, .5, .5], [.5, .5, .5]),
])
return transform
def _prepare_weights(self, reweight, max_target=121, lds=False, lds_kernel='gaussian', lds_ks=5, lds_sigma=2):
assert reweight in {'none', 'inverse', 'sqrt_inv'}
assert reweight != 'none' if lds else True, \
"Set reweight to \'sqrt_inv\' (default) or \'inverse\' when using LDS"
value_dict = {x: 0 for x in range(max_target)}
labels = self.df['age'].values
for label in labels:
value_dict[min(max_target - 1, int(label))] += 1
if reweight == 'sqrt_inv':
value_dict = {k: np.sqrt(v) for k, v in value_dict.items()}
elif reweight == 'inverse':
value_dict = {k: np.clip(v, 5, 1000) for k, v in value_dict.items()} # clip weights for inverse re-weight
num_per_label = [value_dict[min(max_target - 1, int(label))] for label in labels]
if not len(num_per_label) or reweight == 'none':
return None
print(f"Using re-weighting: [{reweight.upper()}]")
if lds:
lds_kernel_window = get_lds_kernel_window(lds_kernel, lds_ks, lds_sigma)
print(f'Using LDS: [{lds_kernel.upper()}] ({lds_ks}/{lds_sigma})')
smoothed_value = convolve1d(
np.asarray([v for _, v in value_dict.items()]), weights=lds_kernel_window, mode='constant')
num_per_label = [smoothed_value[min(max_target - 1, int(label))] for label in labels]
weights = [np.float32(1 / x) for x in num_per_label]
scaling = len(weights) / np.sum(weights)
weights = [scaling * x for x in weights]
return weights
这段代码定义了一个名为AgeDB
的数据集类,该类继承自data.Dataset
。这个类旨在为AgeDB数据集提供一个结构,从而方便地加载、处理和提供图像数据及其相关标签。
以下是代码的详细解释:
__init__
:当实例化一个AgeDB
对象时,此函数被调用。
_prepare_weights
函数来为每个图像样本准备权重。__len__
函数:返回数据集中的样本数量。__getitem__
函数:按索引获取数据集中的一个样本。
get_transform
函数:根据数据集的分割(训练或其他)返回图像的变换序列。_prepare_weights
函数:为数据集中的每个样本计算权重。
这个AgeDB
数据集类是为年龄估计任务设计的。它加载图像并为每个图像提供与其相关的年龄标签。此外,该类还支持为每个样本提供权重,这在训练不均衡数据集时非常有用。权重的目的是在训练过程中为不同的样本提供不同的重要性,从而改善模型的性能。