pytorch模型输入两个参数

最近用Pytorch需要用到两个输入,发现有很多朋友也没有找到解决方法,今天来分享一下我的做法。
**

forward:

**

    def forward(self, x1, x2):
        feature = self.conv1(x1)
        out = torch.cat([feature, x2], dim=1)

        out1 = self.layer1(out)
        out2 = self.layer2(out1)

        out = torch.cat([out1, out2], dim=1)
        out3 = self.layer3(out)

        out = torch.cat([out1, out2, out3], dim=1)
        out = self.layer4(out)

        out = out + x1

        return out

只需要在原来forward(self,x)再加一个参数即可
**

get_data中可以用下列方法分成input1和input2

**

def split_input(input, band=12):
    # 输入为四维数据,第一维是batch,第二维才是band
    input1 = input[:, band:band + 1]
    # 注意这里要使用band:band+1才不会改变input的维度
    input2 = np.delete(input, band, axis=1)
    return input1, input2

这里的input是dataset返回的Inputs
**

torchsummary的测试

**

summary(model.cuda(), [[1, 30, 30], [24, 30, 30]], batch_size=16)

**

tensorboard的测试

**

    dummy_input1 = torch.randn(16, 1, 30, 30)
    dummy_input2 = torch.randn(16, 24, 30, 30)

    with SummaryWriter(comment='RDNet') as w:
        w.add_graph(model,(dummy_input1,dummy_input2))

以上测试均能成功。
注意目前tensorboardx还不支持pytorch1.4.0要回退到1.3.0,graph才能正常生成。

你可能感兴趣的:(深度学习,机器学习,pytorch)