在之前的文章中已经说明了卷积前后的坐标关系:
已知卷积后的坐标,求卷积前的感受野范围
在语义分割模型中,希望提取SoftMax后的平均概率由大到小排序的前k个patch(尺寸如32×32)。
import torch
import torch.nn as nn
import torch.nn.functional as F
class genpatch(nn.Module):
def __init__(self,orisize=256,device=None):
super(genpatch, self).__init__()
#设置forward用的参数,避免重复定义;orisize为classification feature的Size。
self.zero = torch.tensor(0).to(device)
self.max = torch.tensor(orisize - 1).to(device)
#设置网格
grid_y, grid_x = torch.meshgrid(
[torch.arange(orisize).cuda(), torch.arange(orisize).cuda()])
grid_xy = torch.stack([grid_x, grid_y], dim=-1).permute((2, 0, 1)).unsqueeze(
1).float() # torch.Size([113, 113, 2])
self.grid_xy = torch.flatten(grid_xy, start_dim=2, end_dim=3)
#设置均值池化
self.stride = 2
self.padding = 0
self.kernel=32
self.avgpool=nn.AvgPool2d(kernel_size=(self.kernel,self.kernel),stride=self.stride,padding=self.padding)
def forward(self,infeat,device=None):
output = F.softmax(infeat, dim=1)
output_featurePatch = self.avgpool(output)
#本demo中classifier有两类
PatchDict= {'unchg':[],'chg':[]}
for c in range(2):
output_featurePatchum = torch.flatten(output_featurePatch[:, c, :, :].unsqueeze(1), start_dim=2,
end_dim=3) # torch.Size([2, 1, 12769])
values, indices = torch.topk(output_featurePatchum, k=2, dim=2, largest=True,
sorted=True) # torch.Size([2, 1, 2]) torch.Size([2, 1, 2])
for i in range(output.shape[0]):
# print('indices',indices[i])
pcxy=[]
for j in range(2):
pxy=self.grid_xy[:,:,indices[i,:,j]].permute((1,2,0))# torch.Size([1, 1, 2])
px=pxy[0,0,0]
py=pxy[0,0,1]
#计算卷积前后的坐标对应关系
ox=torch.tensor([torch.maximum(self.zero,px*self.stride-self.padding),
torch.minimum(self.max,px*self.stride+32-1-self.padding)],
dtype=torch.int32).to(device)
oy=torch.tensor([torch.maximum(self.zero,py*self.stride-self.padding),
torch.minimum(self.max,py*self.stride+32-1-self.padding)],
dtype=torch.int32).to(device)
oxy=torch.cat([ox.unsqueeze(0),oy.unsqueeze(0)],dim=0)#torch.Size([2, 2])
#提取classifier Feature的感受野patch
patchOri=infeat[i,:,ox[0]:ox[1]+1,oy[0]:oy[1]+1].unsqueeze(0)
pcxy.append(oxy)
elementDict = {'patch':patchOri,'provalue':values[i,:,j],'pointXY':oxy,'c':i,'i':j}
if patchOri.shape[1]==self.kernel and patchOri.shape[2]==self.kernel:
if c==0:
PatchDict['unchg'].append(elementDict)
else:
PatchDict['chg'].append(elementDict)
return PatchDict