resnet18的模型使用

输入图片格式为:(b,c,224,224)



修改输出层的输出维度:
import torchvision

resnet_model = torchvision.models.resnet18(pretrained=True)

for param in resnet_model.parameters():
    param.requires_grad = False

resnet_model.fc


class Net(nn.Module):
    def __init__(self, model):
        super(Net, self).__init__()
        # 取掉model的后1层
        self.resnet_layer = nn.Sequential(*list(model.children())[:-1])
        self.Linear_layer = nn.Linear(512, 11) #加上一层参数修改好的全连接层,例如修改为11层
 
    def forward(self, x):
        x = self.resnet_layer(x)
        x = x.view(x.size(0), -1)
        x = self.Linear_layer(x)
        return x
    
resnet_model = Net(resnet_model)

你可能感兴趣的:(pytorch,python,pytorch,深度学习)