基于度量学习的ReID代码实现(1)

Sampler.py 设置每个batch是N=p*k的采样

from __future__ import absolute_import
from collections import defaultdict  # 用于append字典内容
import numpy as np
from .data_manager import Market1501
import torch
from torch.utils.data.sampler import Sampler


class RandomIdentitySampler(Sampler):
    """
    随机采样n个id,对于每一个id
    随机采样 k个实例,所以 batch size的大小为n*k
    Args:
        data_source (Dataset): dataset to sample from.
        num_instances (int): number of instances per identity.
    """

    # 自定义的采样类,需继承Sampler类,并且实现以下三个方法
    def __init__(self, data_source, num_instances):
        self.data_source = data_source
        self.num_instances = num_instances
        # 写一个dic,dic的key是id,value是各id对应的图片序号
        self.index_dic = defaultdict(list)
        for index, (_, pid, _) in enumerate(data_source):
            self.index_dic[pid].append(index)

            self.pids = list(self.index_dic.key())
            self.num_identities = len(self.pids)
            from IPython import embed
            embed()

    # iter返回的是一个epoch的数据,是一个list,,
    def __iter__(self):
        # 因为后来dataloader需要shuffle,所以在写采样的
        # 前面应该先对数据做一次打乱
        indices = torch.randperm(self.num_identities)  # 这个indeces指的是打乱的id列表

        # 共有p个id,从每一id的所有图片中随机选出k个,组成n = p*k 的batchsize
        # 最终的一个所有n=p*k的图片列表放进ret
        ret = []

        for i in indices:
            pid = self.pids[i]
            t = self.index_dic[pid]  # t就是当前属于这个pid的所有图片序号,t是个列表

            if len(t) < self.num_instances:
                replace = True
            else:
                replace = False

            # 从t(该pid对应的所有图片)里随机挑选k(num_instances)张图片
            t = np.random.choice(t, size=self.num_instances, replace=replace)  # replace表示如果只有两张图片,而硬要挑4张的类似情况
            ret.extend(t)
        return ret

    def __len__(self):
        return self.num_instances * self.num_instances


if __name__ == '__main__':
    dataset = Market1501(root='../../data')
    sampler = RandomIdentitySampler(dataset.train,num_instances=4)
    a = sampler.__init__()

你可能感兴趣的:(ReID行人重识别)