pytorch使用argmax argsoftmax

argmax的目的在于求取最大值所在的索引,这个函数是不可导
采用soft argmax替代,可以求导,其公式如下所示:
在这里插入图片描述
假设输入的Tensor如下图所示:

tensor([0.1000, 0.3000, 0.6000, 2.1000, 0.5500], dtype=torch.float64)

我们将其经过上述公式可得:

np.sum(np.exp(data)/np.sum(np.exp(data)) * np.array([0,1,2,3,4]))    # E = p*index
'''output:
2.5694236670240085
'''

而最大之所在的位置应该是3。
从上面看到位置计算不够准确,一个原因就是最大值的概率不够大,或者说增大相对最大值而减弱其他值的影响就可以得到更加准确的位置坐标。
在这里插入图片描述
可以看到,上式与softmax的期望只有一个差别,即给向量的每个元素乘以beta。
pytorch使用argmax argsoftmax_第1张图片
输出的坐标为2.99,即为3,且这种寻找极值所在位置(坐标)的方法是可微的。常用于图像特征点位置的提取。

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np


class SpatialSoftmax(torch.nn.Module):
    def __init__(self, height, width, channel, device,temperature=None, data_format='NCHW'):
        super(SpatialSoftmax, self).__init__()
        self.data_format = data_format
        self.height = height
        self.width = width
        self.channel = channel
        self.device=device
        if temperature:
            self.temperature = Parameter(torch.ones(1) * temperature)
        else:
            self.temperature = 1.

        pos_x, pos_y = np.meshgrid(
            np.linspace(-1., 1., self.height),
            np.linspace(-1., 1., self.width)
        )
        pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
        pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
        self.register_buffer('pos_x', pos_x)
        self.register_buffer('pos_y', pos_y)

    def forward(self, feature):
        # Output:
        #   (N, C*2) x_0 y_0 ...
        if self.data_format == 'NHWC':
            feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height * self.width)
        else:
            feature = feature.view(-1, self.height * self.width)

        softmax_attention = F.softmax(feature, dim=-1)
        self.pos_x=self.pos_x.to(self.device)
        self.pos_y= self.pos_y.to(self.device)
        softmax_attention=softmax_attention.to(self.device)
        expected_x = torch.sum(self.pos_x* softmax_attention, dim=1, keepdim=True)
        expected_y = torch.sum(self.pos_y * softmax_attention, dim=1, keepdim=True)
        expected_xy = torch.cat([expected_x, expected_y], 1)
        feature_keypoints = expected_xy.view(-1, self.channel * 2)

        return feature_keypoints


if __name__ == '__main__':
    data = torch.zeros([1, 3, 3, 3])
    data[0, 0, 0, 1] = 10
    data[0, 1, 1, 1] = 10
    data[0, 2, 1, 2] = 10
    layer = SpatialSoftmax(3, 3, 3, temperature=1)
    print(layer(data))

你可能感兴趣的:(argmax)