pytorch的nn.Linear缺省输入节点个数

Pytorch中的Linear不能像Keras中的Dense一样自动计算数据维度,在利用**x = x.view(x.size(0), -1)**延展后,还需要手动输入Linear的输入维度。可以利用x.shape查看输入到分类器前的数据维度,以实现nn.Linear缺省输入节点个数。例如,输入数据维度是[batch1,4000](一维数据+一维CNN),经过卷积conv1的处理后,将其延展为batch一维,利用x.shape[1]获取延展后的数据维度,即可作为Linear的输入维度。

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 =  nn.Sequential(
            nn.Conv1d(1,16,kernel_size=2,padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(16, 32, kernel_size=2, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
            )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(x.size(0), -1)  # (batch_size,size(1)*size(2))
        out = nn.Linear(x.shape[1],2)(x)# (batch_size,2)
        return out
        
if __name__ == '__main__':
    net = CNN()
    x = Variable(torch.FloatTensor(217, 1, 4000))
    y = net(x)
    print(y.data.shape)

你可能感兴趣的:(pytorch,深度学习,cnn)