Pytorch -- sensitivity 敏感度计算
1. sensitivity是一种局部性的指标,表达 正确识别正类个数 / 正类总个数
- Sensitivity/TPR = TP / (TP + FN)
2. specificity同理,不同之处为,正确识别负类个数 / 负类总个数
- Specificity/TNR = TN / (TN + FP)
def sensitivity(output, target, sensi):
'''
这里类别数为3
传入参数:
sensi = np.array([-1] * 3) (首次,后面变为sensitivity的值)
output --> tensor(80,3) 从outputs, _ = net(inputs)中获取
target --> tensor(80)
返回值:
sensitivity --> np.array
'''
_, pred = output.max(1)
pre_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.)
tar_mask = torch.zeros(output.size()).scatter_(1, target.data.cpu().view(-1, 1), 1.)
acc_mask = pre_mask * tar_mask
sensitivity = acc_mask.sum(0) / tar_mask.sum(0)
sensitivity = sensitivity.numpy()
if sensi[0] != -1 :
sensitivity = (sensitivity + sensi) / 2
return sensitivity
Batch_size = 80
print(output)
tensor([[-0.0082, -0.1216, 0.0823],
[ 0.0433, -0.1183, -0.0050],
..................................... ,
[ 0.0682, -0.1924, 0.0039]],device='cuda:0')
:softmax计算得到值,3分类,故有3个值
print(target)
tensor([1, 2, ... ,1], device='cuda:0')
:目标标签值
print(output.max(1))
torch.return_types.max(
values=tensor([ 0.0823, 0.0433, ... 0.0682],
device='cuda:0'),
indices=tensor([2, 0, ... 0], device='cuda:0'))
:max(1) --> values 对应ouput每一行中最大值,indices 下标
_, pred = output.max(1)
print(pred)
tensor([2, 0, ... 0], device='cuda:0')
:取得预测的下标值
print(ouput.size())
torch.Size([80, 3])
print(target.size())
torch.Size([80])
:类似numpy的shape
print(pred.eq(target))
tensor([0, 0, ... 0], device='cuda:0')
:值相同为1,不同为0
print(pred.eq(target).sum())
tensor(21, device='cuda:0')
:将所有的值相加
print(pred.eq(target).sum().item())
21
:取出tensor里面的值
print(pred.cpu())
tensor([2, 0, ... 0])
:少了" device='cuda:0' " 应该是转移到了cpu中
print(pred.cpu().view(-1,1))
tensor([[2],
[0],
...
[0]])
:由1行80列,变成80行1列,view(-1,1)表示张量维度,-1表缺省,但可推断值
print(pred_mask)
tensor([[0., 0., 1.],
[1., 0., 0.],
.......
[1., 0., 0.]])
:转换成one-hot编码形式
print(pred_mask.sum(0))
tensor([32., 11., 37.])
:sum(0),0表示以行为基本单位,列项相加
print(output)
print(target)
_, pred = output.max(1)
print(output.max(1))
print(pred)
print(output.size())
print(target.size())
print(pred.eq(target))
print(pred.eq(target).sum())
print(pred.eq(target).sum().item())
print(pred.cpu())
print(pred.cpu().view(-1, 1))
print(torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.))
pred_mask = torch.zeros(output.size()).scatter_(1, pred.cpu().view(-1, 1), 1.)
print(pred_mask.sum(0))
tar_mask = torch.zeros(output.size()).scatter_(1, target.data.cpu().view(-1, 1), 1.)
print(tar_mask)
acc_mask = pred_mask * tar_mask
print(acc_mask)
- 4、scatter_()函数具体解析
https://www.cnblogs.com/daremosiranaihana/p/12538512.html
注:scatter() 与 scatter_() 的区别在于 后者直接修改源数据