SpatialSoftmax implenmentation

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
import matplotlib.pyplot as plt


class SpatialSoftmax(torch.nn.Module):
    def __init__(self, height, width, channel, temperature=None, data_format='NCHW', debug=False):
        super(SpatialSoftmax, self).__init__()
        self.height = height
        self.width = width
        self.channel = channel   
        self.data_format = data_format
        self.debug = debug                                             

        if temperature is None: 
            self.temperature = Parameter(torch.ones(1))
        else:
            self.temperature = temperature

        pos_x, pos_y = np.meshgrid(
                np.linspace(-1., 1., self.height),
                np.linspace(-1., 1., self.width)
                )
        # print("pos_x:\n{}\npos_y:\n{}".format(pos_x, pos_y))
        pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float()
        pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float()
        self.register_buffer('pos_x', pos_x)
        self.register_buffer('pos_y', pos_y)
        # print("self.pos_x:\n{}\nself.pos_y:\n{}".format(self.pos_x, self.pos_y))

        
    def forward(self, feature):
        if self.debug:
            print("input:\n{}".format(feature))
        # Output:
        #   (N, C*2) x_0 y_0 ...
        if self.data_format == 'NHWC':  # trnsform to 'NCHW' then flatten to N*C imgs of H*W
            feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height*self.width)
        else:  # flatten to N*C imgs of H*W
            feature = feature.view(-1, self.height*self.width)
        softmax_attention = F.softmax(feature/self.temperature, dim=-1)
        expected_x = torch.sum(self.pos_x*softmax_attention, dim=1, keepdim=True)
        expected_y = torch.sum(self.pos_y*softmax_attention, dim=1, keepdim=True)
        expected_xy = torch.cat([expected_x, expected_y], 1)
        feature_keypoints = expected_xy.view(-1, self.channel*2)
        if self.debug:
            print("softmax_attention:\n{}".format(softmax_attention))
            print("self.pos_x:\n{}\nself.pos_y:\n{}".format(self.pos_x, self.pos_y))
            print("expected_x:\n{}\nexpected_y:\n{}".format(expected_x, expected_y))
            print("expected_xy:\n{}".format(expected_xy))
            print("feature_keypoints:\n{}".format(feature_keypoints))
        return feature_keypoints
    
    
if __name__ == '__main__':
#   data = torch.zeros([3,3,3,3])
#   data[0,0,0,1] = 10
#   data[0,1,1,1] = 10
#   data[0,2,1,2] = 10
#   layer = SpatialSoftmax(3, 3, 3, temperature=3, debug=True)
#   layer(data)
    feature_from_conv = torch.zeros(6,3,28,28)
    feature_from_conv[0,:,10,0:4] = 1
    feature_from_conv[1,:,10,4:8] = 1
    feature_from_conv[2,:,10,8:12] = 1
    feature_from_conv[3,:,10,12:16] = 1
    feature_from_conv[4,:,10,16:20] = 1
    feature_from_conv[5,:,10,20:24] = 1
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(feature_from_conv[i][0], interpolation='none')
        plt.title("Feature of: {}".format(i))
        plt.xticks([])
        plt.yticks([])
    plt.show()
    layer = SpatialSoftmax(28, 28, 3, debug=False)
    feature_points = layer(feature_from_conv).detach().numpy()
    plt.imshow(feature_points)
    plt.title("Feature Points")
    plt.xticks([])
    plt.yticks([])
    plt.show()

你可能感兴趣的:(Deep,Learning,python,深度学习,人工智能)