Pytorch nn.Linear() 使用示例

由于我之前一直用Keras,感觉Keras中的Dense()层非常好用,可以非常方便地进行全连接操作。这次需要用到Pytorch中的全连接层,一开始还不太会用,但是仔细研究后发现,其实两者殊途同归,方法很相似。

当我们需要将形如[batch_size, 18, 1, 1]的张量全连接操作后分别得到[batch_size, 18, 16][batch_size,1]的张量,该如何操作呢?

[batch_size, 18, 16]

class my_network(nn.Module):
    def __init__(self):
        super(my_network

你可能感兴趣的:(笔记,sys,python,深度学习,python,pytorch,机器学习,神经网络)