import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
import matplotlib.pyplot as plt
class SpatialSoftmax(torch.nn.Module):
def __init__(self, height, width, channel, temperature=None, data_format='NCHW', debug=False):
super(SpatialSoftmax, self).__init__()
self.height = height
self.width = width
self.channel = channel
self.data_format = data_format
self.debug = debug
if temperature is None:
self.temperature = Parameter(torch.ones(1))
else:
self.temperature = temperature
pos_x, pos_y = np.meshgrid(
np.linspace(-1., 1., self.height),
np.linspace(-1., 1., self.width)
)
# print("pos_x:\n{}\npos_y:\n{}".format(pos_x, pos_y))
pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float()
pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float()
self.register_buffer('pos_x', pos_x)
self.register_buffer('pos_y', pos_y)
# print("self.pos_x:\n{}\nself.pos_y:\n{}".format(self.pos_x, self.pos_y))
def forward(self, feature):
if self.debug:
print("input:\n{}".format(feature))
# Output:
# (N, C*2) x_0 y_0 ...
if self.data_format == 'NHWC': # trnsform to 'NCHW' then flatten to N*C imgs of H*W
feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height*self.width)
else: # flatten to N*C imgs of H*W
feature = feature.view(-1, self.height*self.width)
softmax_attention = F.softmax(feature/self.temperature, dim=-1)
expected_x = torch.sum(self.pos_x*softmax_attention, dim=1, keepdim=True)
expected_y = torch.sum(self.pos_y*softmax_attention, dim=1, keepdim=True)
expected_xy = torch.cat([expected_x, expected_y], 1)
feature_keypoints = expected_xy.view(-1, self.channel*2)
if self.debug:
print("softmax_attention:\n{}".format(softmax_attention))
print("self.pos_x:\n{}\nself.pos_y:\n{}".format(self.pos_x, self.pos_y))
print("expected_x:\n{}\nexpected_y:\n{}".format(expected_x, expected_y))
print("expected_xy:\n{}".format(expected_xy))
print("feature_keypoints:\n{}".format(feature_keypoints))
return feature_keypoints
if __name__ == '__main__':
# data = torch.zeros([3,3,3,3])
# data[0,0,0,1] = 10
# data[0,1,1,1] = 10
# data[0,2,1,2] = 10
# layer = SpatialSoftmax(3, 3, 3, temperature=3, debug=True)
# layer(data)
feature_from_conv = torch.zeros(6,3,28,28)
feature_from_conv[0,:,10,0:4] = 1
feature_from_conv[1,:,10,4:8] = 1
feature_from_conv[2,:,10,8:12] = 1
feature_from_conv[3,:,10,12:16] = 1
feature_from_conv[4,:,10,16:20] = 1
feature_from_conv[5,:,10,20:24] = 1
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(feature_from_conv[i][0], interpolation='none')
plt.title("Feature of: {}".format(i))
plt.xticks([])
plt.yticks([])
plt.show()
layer = SpatialSoftmax(28, 28, 3, debug=False)
feature_points = layer(feature_from_conv).detach().numpy()
plt.imshow(feature_points)
plt.title("Feature Points")
plt.xticks([])
plt.yticks([])
plt.show()