图像分类如何得到每一类的预测概率?(结合python代码)

要得到每一类的预测概率,首先通过torch.eq判断每个图片预测的准不准确,循环每个预测结果,得到没个结果对应的标签,如果准确,在该标签类的正确数量加一,在该类的总的数量加一。最后输出该类正确的数量除以该类总的数量就得到了该类的预测概率了。

# 查看单类准确率
        classes = ('0', '1', '2', '3','4')
        N_CLASSES = 5
        class_correct = list(0. for i in range(N_CLASSES))
        class_total = list(0. for i in range(N_CLASSES))
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                # print(val_labels.shape)
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                c = torch.eq(predict_y, val_labels.to(device)).squeeze()
                size = int(val_labels.shape[0])
                for i in range(size):
                    label = val_labels[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        for i in range(N_CLASSES):
            print('Accuracy of %5s : %2d %%' % (
                classes[i], 100 * class_correct[i] / class_total[i]))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

若该分类任务存在类间分类,每一类差距很小,想要使预测结果处于相邻类就算分类正确时,则需要先将val_loader的batch_size设置为1,再通过一系列if语句实现该效果。

# 查看单类准确率
        classes = ('0', '1', '2', '3','4')
        N_CLASSES = 5
        class_correct = list(0. for i in range(N_CLASSES))
        class_total = list(0. for i in range(N_CLASSES))
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                
                labels = val_labels.numpy()
                predict = predict_y.cpu().numpy()
                
                if labels == 0:
                    if predict==0 or predict==1:
                        c = True
                    else:
                        c = False
                elif labels == 4:
                    if predict==3 or predict==4:
                        c = True
                    else:
                        c = False
                else:
                    if predict==labels-1 or predict==labels or predict==labels+1:
                        c = True
                    else:
                        c = False
                
                size = int(val_labels.shape[0])
                for i in range(size):
                    label = val_labels[i]
                    
                    class_correct[label] += c
                    class_total[label] += 1

                acc += c
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        for i in range(N_CLASSES):
            print('Accuracy of %5s : %2d %%' % (
                classes[i], 100 * class_correct[i] / class_total[i]))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

你可能感兴趣的:(python,分类,人工智能)