最近用Pytorch需要用到两个输入,发现有很多朋友也没有找到解决方法,今天来分享一下我的做法。
**
**
def forward(self, x1, x2):
feature = self.conv1(x1)
out = torch.cat([feature, x2], dim=1)
out1 = self.layer1(out)
out2 = self.layer2(out1)
out = torch.cat([out1, out2], dim=1)
out3 = self.layer3(out)
out = torch.cat([out1, out2, out3], dim=1)
out = self.layer4(out)
out = out + x1
return out
只需要在原来forward(self,x)再加一个参数即可
**
**
def split_input(input, band=12):
# 输入为四维数据,第一维是batch,第二维才是band
input1 = input[:, band:band + 1]
# 注意这里要使用band:band+1才不会改变input的维度
input2 = np.delete(input, band, axis=1)
return input1, input2
这里的input是dataset返回的Inputs
**
**
summary(model.cuda(), [[1, 30, 30], [24, 30, 30]], batch_size=16)
**
**
dummy_input1 = torch.randn(16, 1, 30, 30)
dummy_input2 = torch.randn(16, 24, 30, 30)
with SummaryWriter(comment='RDNet') as w:
w.add_graph(model,(dummy_input1,dummy_input2))
以上测试均能成功。
注意目前tensorboardx还不支持pytorch1.4.0要回退到1.3.0,graph才能正常生成。