SwinTransformer中SW-MSA中attn_mask生成逻辑纪录

input_resolution = (12, 12)
window_size = 6
shift_size = 3

生成部分的源码如下:

		if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            # 对 [H, W] 大小进行分区
            # 分区的目的在于,shift之后,进行window划分时,
            # 一个window内包含多个区域,可能彼此不相临,需要进行标号区分
            # 数字相同表示在shift之前区域相邻
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            
            """
            # 也可以按如下分区
            # 数字相同表示在shift之前区域相邻
            h_slices = (slice(0, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.shift_size),
                        slice(-self.shift_size, None))
            """
            
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

			# windows 划分
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            # attn_mask 计算
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

window实际划分时不是按照中线四等份划分的
更新:SwinTransformer中下采样非四等份中线划分,window划分应该是按照中线四等份划分的。

你可能感兴趣的:(计算机幻觉,深度瞎搞,图像分类,SwinTransformer,pytorch,python,transformer)