对比自己的模型结构和预训练加载的模型结构是否一致

import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from torchvision.models import vgg16
from model_fcn8s import VGG

# 期望的FCN模型实例
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
expected_model = VGG(num_classes=21, struct=struct)

# 加载预训练的VGG16模型
pretrained_vgg16 = vgg16(pretrained=True)
pretrained_vgg16 = pretrained_vgg16.features

# 对比期望的模型结构和加载的模型结构
for expected_param, loaded_param in zip(expected_model.parameters(), pretrained_vgg16.parameters()):
    assert expected_param.shape == loaded_param.shape, "Parameter shape mismatch!"

print("Model structure matches the expected structure.")

你可能感兴趣的:(Experiment,pytorch)