分类网络预测结果代码解析

分类网络到底预测了什么

以前没搞懂分类网络中,一张图片经过了神经网络后怎么就变成图片类别的,现在研究出了一点自己的体会,分享给大家,纯属原创。

分类网络基本套路代码

outputs = net(inputs)
        loss = criterion(outputs, targets)
        # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct+=predicted.eq(targets.data).cpu().sum()
        

这里的net是分类网络,有很多种,当我们把一张图片归一化后输入到网络中,这张图片就开始被各钟各样的卷积核给卷积,卷积的操作我就不在这解释了。
torch.Size([100, 3, 32, 32])
torch.Size([100, 128])
torch.Size([100, 10])
tensor(-2.2718, device=‘cuda:0’, grad_fn=)
原本输入了100张图片的数据,每张图片都是彩色(3表示彩色)每张图片都是32×32的大小,输入到网络中经过卷积层卷积,变成了一个个的矩阵,矩阵维度是四维。
再经过连接层后变成了二维矩阵,也就是上面的最后一个torch size,100是100张图片,10是连接层中最后一层的输出通道,这个输出通道自己根据要判断的图片类别而设置,每个通道能经过加权求和得到一个数字,这里也就是得到10个数字。
到此,原本的一张彩色图片,经过网络后变成了10个数字(设置图片类别数是10),这10个数字代表对应类别的概率,比如一张图片它是0类别(用0到9的数字分别表示10类)那么在10个通道中的第一个通道就是判断为0类别的概率。
我们只需用torch.max这个代码就能得到一张图片输出后的10个数字中,哪个数字最大,那个最大的数字代表的通道就是预测的类别,比如第一个通道的数字最大,那么就对应0类别。

到此就实现了网络预测图片是哪一类别了,当然准不准要另说。

学生党一枚,如有不对的地方请给我留言,谢谢!

你可能感兴趣的:(研究一下)