import torch
HW = 7
N = 2
x = torch.rand(N,3,HW,HW)
x
tensor([[[[ 0.6801, 0.5986, 0.8342, 0.2059, 0.6529, 0.4588, 0.6079],
[ 0.3499, 0.9997, 0.4779, 0.6123, 0.8895, 0.6702, 0.1118],
[ 0.9417, 0.5027, 0.0630, 0.1088, 0.1800, 0.6639, 0.9210],
[ 0.2522, 0.1933, 0.2214, 0.0902, 0.9833, 0.2243, 0.4692],
[ 0.8330, 0.2169, 0.3008, 0.9063, 0.3030, 0.4961, 0.2058],
[ 0.1108, 0.7139, 0.1230, 0.0768, 0.0268, 0.1893, 0.5520],
[ 0.3987, 0.1723, 0.0756, 0.9247, 0.2617, 0.6532, 0.1511]],
[[ 0.1923, 0.0138, 0.9362, 0.3879, 0.8578, 0.2559, 0.3271],
[ 0.1985, 0.3664, 0.3374, 0.6199, 0.7864, 0.3920, 0.7427],
[ 0.8046, 0.9312, 0.7240, 0.3423, 0.4711, 0.4097, 0.7654],
[ 0.8009, 0.3712, 0.6248, 0.7377, 0.0233, 0.9360, 0.6116],
[ 0.4695, 0.6464, 0.0208, 0.2115, 0.8007, 0.7577, 0.9820],
[ 0.9249, 0.9200, 0.5269, 0.3906, 0.5382, 0.8067, 0.2442],
[ 0.6772, 0.6780, 0.3255, 0.9823, 0.6394, 0.4344, 0.0880]],
[[ 0.3615, 0.4707, 0.1852, 0.0465, 0.8819, 0.9937, 0.7102],
[ 0.0930, 0.9879, 0.3972, 0.6458, 0.3975, 0.1440, 0.4829],
[ 0.9814, 0.4748, 0.7973, 0.7196, 0.6132, 0.2092, 0.0649],
[ 0.8326, 0.6559, 0.2625, 0.3210, 0.0434, 0.4638, 0.6590],
[ 0.5413, 0.9833, 0.1283, 0.1576, 0.2311, 0.6617, 0.3430],
[ 0.7199, 0.9552, 0.3986, 0.9472, 0.5030, 0.6494, 0.8596],
[ 0.3445, 0.2613, 0.8283, 0.1728, 0.3771, 0.5291, 0.4734]]],
[[[ 0.8500, 0.6728, 0.6622, 0.2421, 0.7653, 0.0709, 0.4887],
[ 0.5043, 0.7861, 0.6012, 0.6661, 0.9236, 0.6521, 0.3341],
[ 0.5777, 1.0000, 0.5524, 0.2666, 0.2591, 0.1563, 0.1013],
[ 0.7084, 0.6471, 0.3055, 0.0547, 0.4499, 0.5782, 0.6310],
[ 0.4918, 0.8953, 0.1984, 0.3935, 0.4994, 0.7429, 0.4769],
[ 0.7932, 0.4119, 0.0737, 0.1912, 0.9368, 0.1328, 0.9625],
[ 0.1384, 0.9517, 0.1403, 0.2175, 0.0351, 0.4578, 0.0993]],
[[ 0.0378, 0.3744, 0.7194, 0.3785, 0.8246, 0.3476, 0.0014],
[ 0.4368, 0.5562, 0.9908, 0.4234, 0.9918, 0.4406, 0.1613],
[ 0.8269, 0.7115, 0.2828, 0.8004, 0.3766, 0.5500, 0.9736],
[ 0.2267, 0.8925, 0.5534, 0.7284, 0.3275, 0.4464, 0.7773],
[ 0.2549, 0.2889, 0.5091, 0.9417, 0.6562, 0.8813, 0.7422],
[ 0.6295, 0.9268, 0.1839, 0.8589, 0.6796, 0.3920, 0.2366],
[ 0.8217, 0.6012, 0.3639, 0.3125, 0.8596, 0.0460, 0.5015]],
[[ 0.6835, 0.0204, 0.4621, 0.9034, 0.9936, 0.2392, 0.1581],
[ 0.0362, 0.2661, 0.0505, 0.7764, 0.4404, 0.1929, 0.2910],
[ 0.8297, 0.8204, 0.8631, 0.9912, 0.2494, 0.7778, 0.0271],
[ 0.7450, 0.2234, 0.3558, 0.8840, 0.0821, 0.1914, 0.6607],
[ 0.1506, 0.7821, 0.4108, 0.4858, 0.0947, 0.2576, 0.6863],
[ 0.9268, 0.9442, 0.8276, 0.7365, 0.8599, 0.9713, 0.7455],
[ 0.4314, 0.4986, 0.2590, 0.4149, 0.9218, 0.0604, 0.6914]]]])
对样本按通道求均值
torch.mean(x, dim=1)
tensor([[[ 0.4113, 0.3610, 0.6519, 0.2134, 0.7975, 0.5695, 0.5484],
[ 0.2138, 0.7846, 0.4042, 0.6260, 0.6911, 0.4021, 0.4458],
[ 0.9092, 0.6362, 0.5281, 0.3903, 0.4214, 0.4276, 0.5838],
[ 0.6285, 0.4068, 0.3696, 0.3830, 0.3500, 0.5414, 0.5799],
[ 0.6146, 0.6155, 0.1499, 0.4251, 0.4449, 0.6385, 0.5102],
[ 0.5852, 0.8630, 0.3495, 0.4715, 0.3560, 0.5485, 0.5519],
[ 0.4735, 0.3705, 0.4098, 0.6933, 0.4261, 0.5389, 0.2375]],
[[ 0.5238, 0.3559, 0.6146, 0.5080, 0.8611, 0.2192, 0.2161],
[ 0.3258, 0.5361, 0.5475, 0.6220, 0.7853, 0.4285, 0.2621],
[ 0.7447, 0.8439, 0.5661, 0.6861, 0.2951, 0.4947, 0.3673],
[ 0.5600, 0.5876, 0.4049, 0.5557, 0.2865, 0.4053, 0.6897],
[ 0.2991, 0.6554, 0.3728, 0.6070, 0.4168, 0.6273, 0.6352],
[ 0.7832, 0.7610, 0.3618, 0.5956, 0.8255, 0.4987, 0.6482],
[ 0.4638, 0.6838, 0.2544, 0.3150, 0.6055, 0.1881, 0.4307]]])
# 求解最大值位置:
points = torch.argmax(torch.mean(x, dim=1).view(N, HW*HW),dim=1)
points
# tensor([ 14, 4])
将最大值位置转成坐标:
x_p = points / HW
print(x_p)
y_p = torch.fmod(points,HW)
print(y_p)
# x坐标
tensor([ 2, 0]) # 注意坐标形式与位置的对应关系
# y坐标
tensor([ 0, 4])
# 联合坐标
z_p = torch.cat((y_p.view(2,1),x_p.view(2,1)),dim=1).float() # 注意在F.grid_sample中我们计算的y_p才是x轴
z_p
tensor([[ 0., 2.],
[ 4., 0.]])
# 对坐标缩至-1,1之间:
z_p = ((z_p+1)-(HW+1)/2)/((HW-1)/2)
grid = z_p.unsqueeze(1).unsqueeze(1)
grid
tensor([[[[-1.0000, -0.3333]]],
[[[ 0.3333, -1.0000]]]])
# 生成通用裁剪区域:此处生成大小3*3
step = 2/(HW-1)
BOX_LEFT = 1
BOX = 2*BOX_LEFT+1
# torch.Size([Box, Box, 1])
direct = torch.linspace(-(BOX_LEFT)*step,(BOX_LEFT)*step,BOX).unsqueeze(0).repeat(BOX,1).unsqueeze(-1)
direct_trans = direct.transpose(1,0)
full = torch.cat([direct,direct_trans],dim=2).unsqueeze(0).repeat(N,1,1,1)
full
tensor([[[[-0.3333, -0.3333],
[ 0.0000, -0.3333],
[ 0.3333, -0.3333]],
[[-0.3333, 0.0000],
[ 0.0000, 0.0000],
[ 0.3333, 0.0000]],
[[-0.3333, 0.3333],
[ 0.0000, 0.3333],
[ 0.3333, 0.3333]]],
[[[-0.3333, -0.3333],
[ 0.0000, -0.3333],
[ 0.3333, -0.3333]],
[[-0.3333, 0.0000],
[ 0.0000, 0.0000],
[ 0.3333, 0.0000]],
[[-0.3333, 0.3333],
[ 0.0000, 0.3333],
[ 0.3333, 0.3333]]]])
# 将通用区域和最大值坐标对应起来,注意grid_sample要求flow field在-1到1之间:
full[:,:,:,0] = torch.clamp(full[:,:,:,0] + grid[:,:,:,0],-1,1)
full[:,:,:,1] = torch.clamp(full[:,:,:,1] + grid[:,:,:,1],-1,1)
full
tensor([[[[-1.0000, -0.6667],
[-1.0000, -0.6667],
[-0.6667, -0.6667]],
[[-1.0000, -0.3333],
[-1.0000, -0.3333], # 最大值坐标点
[-0.6667, -0.3333]],
[[-1.0000, 0.0000],
[-1.0000, 0.0000],
[-0.6667, 0.0000]]],
[[[ 0.0000, -1.0000],
[ 0.3333, -1.0000],
[ 0.6667, -1.0000]],
[[ 0.0000, -1.0000],
[ 0.3333, -1.0000], # 最大值坐标
[ 0.6667, -1.0000]],
[[ 0.0000, -0.6667],
[ 0.3333, -0.6667],
[ 0.6667, -0.6667]]]])
# 裁剪feature map
torch.nn.functional.grid_sample(x,full)
tensor([[[[ 0.3499, 0.3499, 0.9997],
[ 0.9417, 0.9417, 0.5027],
[ 0.2522, 0.2522, 0.1933]],
[[ 0.1985, 0.1985, 0.3664],
[ 0.8046, 0.8046, 0.9312],
[ 0.8009, 0.8009, 0.3712]],
[[ 0.0930, 0.0930, 0.9879],
[ 0.9814, 0.9814, 0.4748],
[ 0.8326, 0.8326, 0.6559]]],
[[[ 0.2421, 0.7653, 0.0709],
[ 0.2421, 0.7653, 0.0709],
[ 0.6661, 0.9236, 0.6521]],
[[ 0.3785, 0.8246, 0.3476],
[ 0.3785, 0.8246, 0.3476],
[ 0.4234, 0.9918, 0.4406]],
[[ 0.9034, 0.9936, 0.2392],
[ 0.9034, 0.9936, 0.2392],
[ 0.7764, 0.4404, 0.1929]]]])