Pytorch 实现对每个样本的feature map进行裁剪 F.grid_sample

# 生成batch N=2 ,宽高=HW=7的示例样本
import torch
HW = 7
N = 2
x = torch.rand(N,3,HW,HW)
x
tensor([[[[ 0.6801,  0.5986,  0.8342,  0.2059,  0.6529,  0.4588,  0.6079],
          [ 0.3499,  0.9997,  0.4779,  0.6123,  0.8895,  0.6702,  0.1118],
          [ 0.9417,  0.5027,  0.0630,  0.1088,  0.1800,  0.6639,  0.9210],
          [ 0.2522,  0.1933,  0.2214,  0.0902,  0.9833,  0.2243,  0.4692],
          [ 0.8330,  0.2169,  0.3008,  0.9063,  0.3030,  0.4961,  0.2058],
          [ 0.1108,  0.7139,  0.1230,  0.0768,  0.0268,  0.1893,  0.5520],
          [ 0.3987,  0.1723,  0.0756,  0.9247,  0.2617,  0.6532,  0.1511]],

         [[ 0.1923,  0.0138,  0.9362,  0.3879,  0.8578,  0.2559,  0.3271],
          [ 0.1985,  0.3664,  0.3374,  0.6199,  0.7864,  0.3920,  0.7427],
          [ 0.8046,  0.9312,  0.7240,  0.3423,  0.4711,  0.4097,  0.7654],
          [ 0.8009,  0.3712,  0.6248,  0.7377,  0.0233,  0.9360,  0.6116],
          [ 0.4695,  0.6464,  0.0208,  0.2115,  0.8007,  0.7577,  0.9820],
          [ 0.9249,  0.9200,  0.5269,  0.3906,  0.5382,  0.8067,  0.2442],
          [ 0.6772,  0.6780,  0.3255,  0.9823,  0.6394,  0.4344,  0.0880]],

         [[ 0.3615,  0.4707,  0.1852,  0.0465,  0.8819,  0.9937,  0.7102],
          [ 0.0930,  0.9879,  0.3972,  0.6458,  0.3975,  0.1440,  0.4829],
          [ 0.9814,  0.4748,  0.7973,  0.7196,  0.6132,  0.2092,  0.0649],
          [ 0.8326,  0.6559,  0.2625,  0.3210,  0.0434,  0.4638,  0.6590],
          [ 0.5413,  0.9833,  0.1283,  0.1576,  0.2311,  0.6617,  0.3430],
          [ 0.7199,  0.9552,  0.3986,  0.9472,  0.5030,  0.6494,  0.8596],
          [ 0.3445,  0.2613,  0.8283,  0.1728,  0.3771,  0.5291,  0.4734]]],


        [[[ 0.8500,  0.6728,  0.6622,  0.2421,  0.7653,  0.0709,  0.4887],
          [ 0.5043,  0.7861,  0.6012,  0.6661,  0.9236,  0.6521,  0.3341],
          [ 0.5777,  1.0000,  0.5524,  0.2666,  0.2591,  0.1563,  0.1013],
          [ 0.7084,  0.6471,  0.3055,  0.0547,  0.4499,  0.5782,  0.6310],
          [ 0.4918,  0.8953,  0.1984,  0.3935,  0.4994,  0.7429,  0.4769],
          [ 0.7932,  0.4119,  0.0737,  0.1912,  0.9368,  0.1328,  0.9625],
          [ 0.1384,  0.9517,  0.1403,  0.2175,  0.0351,  0.4578,  0.0993]],

         [[ 0.0378,  0.3744,  0.7194,  0.3785,  0.8246,  0.3476,  0.0014],
          [ 0.4368,  0.5562,  0.9908,  0.4234,  0.9918,  0.4406,  0.1613],
          [ 0.8269,  0.7115,  0.2828,  0.8004,  0.3766,  0.5500,  0.9736],
          [ 0.2267,  0.8925,  0.5534,  0.7284,  0.3275,  0.4464,  0.7773],
          [ 0.2549,  0.2889,  0.5091,  0.9417,  0.6562,  0.8813,  0.7422],
          [ 0.6295,  0.9268,  0.1839,  0.8589,  0.6796,  0.3920,  0.2366],
          [ 0.8217,  0.6012,  0.3639,  0.3125,  0.8596,  0.0460,  0.5015]],

         [[ 0.6835,  0.0204,  0.4621,  0.9034,  0.9936,  0.2392,  0.1581],
          [ 0.0362,  0.2661,  0.0505,  0.7764,  0.4404,  0.1929,  0.2910],
          [ 0.8297,  0.8204,  0.8631,  0.9912,  0.2494,  0.7778,  0.0271],
          [ 0.7450,  0.2234,  0.3558,  0.8840,  0.0821,  0.1914,  0.6607],
          [ 0.1506,  0.7821,  0.4108,  0.4858,  0.0947,  0.2576,  0.6863],
          [ 0.9268,  0.9442,  0.8276,  0.7365,  0.8599,  0.9713,  0.7455],
          [ 0.4314,  0.4986,  0.2590,  0.4149,  0.9218,  0.0604,  0.6914]]]])
对样本按通道求均值
torch.mean(x, dim=1)

tensor([[[ 0.4113,  0.3610,  0.6519,  0.2134,  0.7975,  0.5695,  0.5484],
         [ 0.2138,  0.7846,  0.4042,  0.6260,  0.6911,  0.4021,  0.4458],
         [ 0.9092,  0.6362,  0.5281,  0.3903,  0.4214,  0.4276,  0.5838],
         [ 0.6285,  0.4068,  0.3696,  0.3830,  0.3500,  0.5414,  0.5799],
         [ 0.6146,  0.6155,  0.1499,  0.4251,  0.4449,  0.6385,  0.5102],
         [ 0.5852,  0.8630,  0.3495,  0.4715,  0.3560,  0.5485,  0.5519],
         [ 0.4735,  0.3705,  0.4098,  0.6933,  0.4261,  0.5389,  0.2375]],

        [[ 0.5238,  0.3559,  0.6146,  0.5080,  0.8611,  0.2192,  0.2161],
         [ 0.3258,  0.5361,  0.5475,  0.6220,  0.7853,  0.4285,  0.2621],
         [ 0.7447,  0.8439,  0.5661,  0.6861,  0.2951,  0.4947,  0.3673],
         [ 0.5600,  0.5876,  0.4049,  0.5557,  0.2865,  0.4053,  0.6897],
         [ 0.2991,  0.6554,  0.3728,  0.6070,  0.4168,  0.6273,  0.6352],
         [ 0.7832,  0.7610,  0.3618,  0.5956,  0.8255,  0.4987,  0.6482],
         [ 0.4638,  0.6838,  0.2544,  0.3150,  0.6055,  0.1881,  0.4307]]])
# 求解最大值位置:
points = torch.argmax(torch.mean(x, dim=1).view(N, HW*HW),dim=1)
points

# tensor([ 14,   4])
将最大值位置转成坐标:
x_p = points / HW
print(x_p)
y_p = torch.fmod(points,HW)
print(y_p)

# x坐标
tensor([ 2,  0]) # 注意坐标形式与位置的对应关系
# y坐标
tensor([ 0,  4])

# 联合坐标
z_p = torch.cat((y_p.view(2,1),x_p.view(2,1)),dim=1).float() # 注意在F.grid_sample中我们计算的y_p才是x轴
z_p

tensor([[ 0.,  2.],
        [ 4.,  0.]])
# 对坐标缩至-11之间:
z_p = ((z_p+1)-(HW+1)/2)/((HW-1)/2)
grid = z_p.unsqueeze(1).unsqueeze(1)
grid

tensor([[[[-1.0000, -0.3333]]],
        [[[ 0.3333, -1.0000]]]])
# 生成通用裁剪区域:此处生成大小3*3
step = 2/(HW-1)
BOX_LEFT = 1
BOX = 2*BOX_LEFT+1
# torch.Size([Box, Box, 1])
direct = torch.linspace(-(BOX_LEFT)*step,(BOX_LEFT)*step,BOX).unsqueeze(0).repeat(BOX,1).unsqueeze(-1)
direct_trans = direct.transpose(1,0)
full = torch.cat([direct,direct_trans],dim=2).unsqueeze(0).repeat(N,1,1,1)
full


tensor([[[[-0.3333, -0.3333],
          [ 0.0000, -0.3333],
          [ 0.3333, -0.3333]],

         [[-0.3333,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.3333,  0.0000]],

         [[-0.3333,  0.3333],
          [ 0.0000,  0.3333],
          [ 0.3333,  0.3333]]],


        [[[-0.3333, -0.3333],
          [ 0.0000, -0.3333],
          [ 0.3333, -0.3333]],

         [[-0.3333,  0.0000],
          [ 0.0000,  0.0000],
          [ 0.3333,  0.0000]],

         [[-0.3333,  0.3333],
          [ 0.0000,  0.3333],
          [ 0.3333,  0.3333]]]])
# 将通用区域和最大值坐标对应起来,注意grid_sample要求flow field在-11之间:
full[:,:,:,0] = torch.clamp(full[:,:,:,0] + grid[:,:,:,0],-1,1)
full[:,:,:,1] = torch.clamp(full[:,:,:,1] + grid[:,:,:,1],-1,1)
full

tensor([[[[-1.0000, -0.6667],
          [-1.0000, -0.6667],
          [-0.6667, -0.6667]],

         [[-1.0000, -0.3333],
          [-1.0000, -0.3333],  # 最大值坐标点
          [-0.6667, -0.3333]],

         [[-1.0000,  0.0000],
          [-1.0000,  0.0000],
          [-0.6667,  0.0000]]],


        [[[ 0.0000, -1.0000],
          [ 0.3333, -1.0000],
          [ 0.6667, -1.0000]],

         [[ 0.0000, -1.0000],
          [ 0.3333, -1.0000], # 最大值坐标
          [ 0.6667, -1.0000]],

         [[ 0.0000, -0.6667],
          [ 0.3333, -0.6667],
          [ 0.6667, -0.6667]]]])
# 裁剪feature map
torch.nn.functional.grid_sample(x,full)

tensor([[[[ 0.3499,  0.3499,  0.9997],
          [ 0.9417,  0.9417,  0.5027],
          [ 0.2522,  0.2522,  0.1933]],

         [[ 0.1985,  0.1985,  0.3664],
          [ 0.8046,  0.8046,  0.9312],
          [ 0.8009,  0.8009,  0.3712]],

         [[ 0.0930,  0.0930,  0.9879],
          [ 0.9814,  0.9814,  0.4748],
          [ 0.8326,  0.8326,  0.6559]]],


        [[[ 0.2421,  0.7653,  0.0709],
          [ 0.2421,  0.7653,  0.0709],
          [ 0.6661,  0.9236,  0.6521]],

         [[ 0.3785,  0.8246,  0.3476],
          [ 0.3785,  0.8246,  0.3476],
          [ 0.4234,  0.9918,  0.4406]],

         [[ 0.9034,  0.9936,  0.2392],
          [ 0.9034,  0.9936,  0.2392],
          [ 0.7764,  0.4404,  0.1929]]]])

你可能感兴趣的:(DL,tools)