从vgg16-397923af.pth里读取的数值应该和加载预训练模型后model.load_state_dict参数一致

import numpy as np
import torch
import torchvision
from torchvision import models
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from src import fcn_resnet50, resnet50
from model_fcn8s import VGG, fcn_vgg16

pretrain_backbone = True
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
backbone = VGG(num_classes=21, struct=struct)
# vgg16 = models.vgg16(pretrained=False)
if pretrain_backbone is True:
    weights_dict = torch.load("vgg16-397923af.pth", map_location='cpu')
    for name, param in weights_dict.items():
        print(f"{name}: {param}")

    backbone.load_state_dict(weights_dict, strict=False)
    print("Loaded pretrained model parameters:")
    for name, param in backbone.state_dict().items():
        print(f"{name}: {param}")

你可能感兴趣的:(Experiment,pytorch)