【torchsummary报错】RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.F

源代码:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet18(pretrained=None)
model.fc = nn.Linear(512, 10)
    
summary(model, input_size=[(3, 224, 224)], batch_size=256, device="cuda")

报错为:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

解决方法

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.resnet18(pretrained=None)
model.fc = nn.Linear(512, 10)

model = model.to(device)  # 加一行这个就Ok了

summary(model, input_size=[(3, 224, 224)], batch_size=256, device="cuda")

你可能感兴趣的:(pytorch)