pytorch 提供了torchvision.models接口, 可以轻松初始化一些常见模型, 还可以设置pretrained参数为True, 加载pytorch官方提供的预训练模型。
例如初始化一个resnet18模型:
model = torchvision.models.resnet18()
model = torchvision.models.resnet18(pretrained=True)
当我们想用一个变量来初始化模型, 比如‘resnet18’ 或者 ‘resnet50', 可以通过访问torchvision.models.__dict__接口来实现。
model = torchvision.models.__dict__[name](pretrained=True)