pytorch修改预训练的网络的输出

pytorch修改预训练的网络的输出

pytorch下载预训练的网络

torchvision模块里有提供models模块可以下载欲训练的网络

from torchvision import models
net = models.resnet18(pretrained = True)

修改网络的最后一层

num_fc_ftr = net.fc.in_features

这句是get到网络的最后一层数,原来的resnet18是1000输出,这里改成我需要的3个输出

net.fc = torch.nn.Linear(num_fc_ftr, 3)

你可能感兴趣的:(深度学习)