pytorch之网络结构可视化和参数统计

网络结构可视化

import torch
import torchvision.models as models
from tensorboardX import SummaryWriter

model = models.resnet50().cuda()
dummy_input = torch.randn(1, 3, 224, 224).to("cuda:0")
with SummaryWriter(comment='Resnet50') as w:
    w.add_graph(model, ((dummy_input,)))

会在当前路径生成runs文件夹,使用tensorboard命令打开

网络参数量统计(附inference time)

import torch
import torchvision.models as models
import torchsummary
import time


model = models.resnet50().cuda()
torchsummary.summary(model, input_size=(3, 224, 244))
# inference time
inputs = torch.randn(1, 3, 224, 224).cuda()
end = time.time()
y = model(inputs)
print("inference time:{}".format(time.time() - end))

你可能感兴趣的:(pytorch)