上一篇分析了SummaryWriter主要函数,本篇借助TensorboardX官方demo,解释add_graph用法
本文选用demo中代表性的条目,完整demo参见:https://github.com/lanpa/tensorboardX/blob/master/examples/demo_graph.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable
from tensorboardX import SummaryWriter
dummy_input = torch.ones(1,3)
class LinearInLinear(nn.Module):
def __init__(self):
super(LinearInLinear,self).__init__()
self.l = nn.Linear(3,5)
def forward(self,x):
return self.l(x)
with SummaryWriter(comment='LinearInLinear2') as w:
w.add_graph(LinearInLinear(),dummy_input,False)
class MutipleInput(nn.Module):
def __init__(self):
super(MutipleInput,self).__init__()
self.Linear_1 = nn.Linear(3,5)
def forward(self, x,y):
return self.Linear_1(x+y)
model_m = MutipleInput()
with SummaryWriter(comment='MutipleInput') as w:
w.add_graph(model_m,(torch.ones(1,3),torch.zeros(1,3)),True)
def conv3x3(in_channels,out_channels,stride=1):
return nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,inplanes,planes,stride=1,downsample=None):
super(BasicBlock,self).__init__()
self.conv1 = conv3x3(inplanes,planes,stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes,planes)
self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride
def forward(self,x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = F.relu(out)
return out
dummy_input = torch.rand(1, 3, 224, 224)
with SummaryWriter(comment='basicblock') as w:
model = BasicBlock(3, 3)
w.add_graph(model, (dummy_input, ), verbose=True)
官方demo中如下的代码片段:
dummy_input = torch.Tensor(1,3,224,224)
with SummaryWriter(comment='vgg19') as w:
model = torchvision.models.vgg19()
w.add_graph(model, (dummy_input,))
with SummaryWriter(comment='resnet18') as w:
model = torchvision.models.resnet18()
w.add_graph(model, (dummy_input,))
无法运行,报错信息如下:
assert output_size == [1, 1], "Only output_size=[1, 1] is supported"
AssertionError: Only output_size=[1, 1] is supported
仅仅是增加一个graph,和output有什么关系?
也许是版本不一致。
仍将持续更新,感谢关注!