pytorch深度学习网络调试办法

可以用如下代码段进行eval:
以densenet为例

if __name__ == "__main__":
    dense121 = densenet121(pretrained=False,progress=True,num_classes=40).cuda()
    #print(dense121)
    model = dense121
    model.eval()
    print(model)
    input = torch.randn(1,3,224,224)
    device = torch.device("cuda")
    input = input.to(device)
    y = model(input)
    print(y.size())

以SKNet为例

if __name__=='__main__':
    from PIL import Image
    from torchvision import transforms
    from torch.autograd import Variable
    import torch
    img=Image.open('img.jpg').convert('RGB')
    img=transforms.ToTensor()(img)
    img=Variable(img).cuda()
    img=torch.stack([img,img])
    #img=img.unsqueeze(0)
    temp=SKNet50().cuda()
    pred=temp(img)
    print(pred)

你可能感兴趣的:(深度学习)