【无标题】python迭代器构造

最近在学习transreid 关于viT-pytorch中的这段代码并不是很理解,因此写这个博客进行总结

# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, container_abcs.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


# 迭代器构造
# isinstance(变量名,变量的类型)
# 用于判断一个变量是不是属于输入的变量类型
# repeat(element,n)将一个元素重复n遍,并返回一个迭代器
#  应该是迭代两次,用来下面的生成的图片的X,y,以及patch的x,y

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
to_2tuple = _ntuple(2)

在应用的时候,主要用在输入的尺寸上:

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding   图片切块分为patch
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

输入的尺寸为224,patch为16,要生成X,Y两个方向,所以n输入为2,即to_2tuple = _ntuple(2),输入的尺寸为int,而我们想要的是两个数据,即元组,元组第一个即img_size[0]为X或者Y, if isinstance(x, container_abcs.Iterable):用于判断一个变量是不是属于输入的变量类型

以及在以下代码中实现

class PatchEmbed_overlap(nn.Module):
    """ Image to Patch Embedding with overlapping patches
    """

    def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        stride_size_tuple = to_2tuple(stride_size)
        self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1  # python中“//”是一个算术运算符,表示整数除法,
        # 它可以返回商的整数部分(向下取整)   (224-16)//20+1=10+1=11
        self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1
        print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x))
        num_patches = self.num_x * self.num_y  # 总的patch数
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

你可能感兴趣的:(机器学习,python,pytorch,开发语言)