Pytorch -- sensitivity 计算

Pytorch -- sensitivity 敏感度计算

1. sensitivity是一种局部性的指标,表达  正确识别正类个数 / 正类总个数
              - Sensitivity/TPR = TP / (TP + FN)  
2. specificity同理,不同之处为,正确识别负类个数 / 负类总个数
              - Specificity/TNR = TN / (TN + FP) 
  • 1、代码如下:
def sensitivity(output, target, sensi):
    '''
        这里类别数为3
        
        传入参数:
        sensi = np.array([-1] * 3) (首次,后面变为sensitivity的值)
        output --> tensor(80,3) 从outputs, _ = net(inputs)中获取
        target --> tensor(80)
        
        返回值:
        sensitivity --> np.array
    '''
    # 取得到分类分数最大的值,返回第一维度是value,第二维度是index
    _, pred = output.max(1) 
    # 将 pred 展开成 one-hot编码形式
    pre_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.)
    # 将 target 也展开成 one-hot编码形式
    tar_mask = torch.zeros(output.size()).scatter_(1, target.data.cpu().view(-1, 1), 1.)
    # 计算 acc 的one-hot编码形式
    acc_mask = pre_mask * tar_mask
    # 计算 sensitivity
    sensitivity = acc_mask.sum(0) / tar_mask.sum(0)
    # 转换成numpy()
    sensitivity = sensitivity.numpy()

    if sensi[0] != -1 : #不是第一次计算sensivity, 计算求平均值
        sensitivity  = (sensitivity + sensi) / 2

    return sensitivity
  • 2、详细的具体解析
Batch_size = 80

print(output)
tensor([[-0.0082, -0.1216,  0.0823],
        [ 0.0433, -0.1183, -0.0050],
         ..................................... ,
        [ 0.0682, -0.1924,  0.0039]],device='cuda:0')
:softmax计算得到值,3分类,故有3个值

print(target)
tensor([1, 2, ... ,1], device='cuda:0')
:目标标签值

print(output.max(1))
torch.return_types.max(
        values=tensor([ 0.0823,  0.0433, ...  0.0682],
            device='cuda:0'),
        indices=tensor([2, 0, ... 0], device='cuda:0'))max(1) --> values 对应ouput每一行中最大值,indices 下标 

_, pred  = output.max(1)
print(pred)
tensor([2, 0, ... 0], device='cuda:0')
:取得预测的下标值

print(ouput.size())
torch.Size([80, 3])
print(target.size())
torch.Size([80])
:类似numpy的shape

print(pred.eq(target))
tensor([0, 0, ... 0], device='cuda:0')
:值相同为1,不同为0

print(pred.eq(target).sum())
tensor(21, device='cuda:0')
:将所有的值相加

print(pred.eq(target).sum().item())
21
:取出tensor里面的值

print(pred.cpu())
tensor([2, 0, ... 0])
:少了" device='cuda:0' " 应该是转移到了cpu中

print(pred.cpu().view(-1,1))
tensor([[2],
        [0],
        ...
        [0]])
:由180列,变成801列,view(-1,1)表示张量维度,-1表缺省,但可推断值

print(pred_mask)
tensor([[0., 0., 1.],
        [1., 0., 0.],
        .......
        [1., 0., 0.]])
:转换成one-hot编码形式

print(pred_mask.sum(0))
tensor([32., 11., 37.])sum(0)0表示以行为基本单位,列项相加
  • 3、解析代码
    print(output)
    print(target)
    _, pred  = output.max(1)
    print(output.max(1))
    print(pred) 
    print(output.size())
    print(target.size())
    print(pred.eq(target))
    print(pred.eq(target).sum())
    print(pred.eq(target).sum().item())
    print(pred.cpu())
    print(pred.cpu().view(-1, 1))
    print(torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.))
    pred_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.)
    print(pred_mask.sum(0))
    tar_mask = torch.zeros(output.size()).scatter_(1, target.data.cpu().view(-1, 1), 1.)
    print(tar_mask)
    acc_mask = pred_mask * tar_mask
    print(acc_mask)
  • 4、scatter_()函数具体解析
    https://www.cnblogs.com/daremosiranaihana/p/12538512.html
    注:scatter() 与 scatter_() 的区别在于 后者直接修改源数据

你可能感兴趣的:(Pytorch学习,Pytorch,Sensitivity,敏感度计算,代码实现)