pytorch计算模型的top_k分类准确率

记录一下计算top_k分类准确率的自定义函数accuracy(),使用时直接复制调用即可。

参数:

output:模型的输出,即模型对不同类别的评分。shape为[batch_size, num_classes]

target:真实的类别标签。shape为[batch_size, ]

topk:需要计算top_k准确率中的k值,元组类型。默认为(1, 5),即函数返回top1和top5的分类准确率

import torch


def accuracy(output, target, topk=(1, 5)):
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k / batch_size)
    return res


if __name__ == '__main__':
    output = torch.randint(low=0, high=6, size=[8, 10])
    target = torch.ones(8, dtype=torch.long)
    print(accuracy(output, target))

 

你可能感兴趣的:(pytorch)