向量拉直

pool5_flat = pool5.view(pool5.size(0), -1)

【1518,512,7,7】→【1518,25088】

--coding:utf-8-- 维度对应问题

import torch
import torch.nn as nn

a=torch.rand(4,25088)

class Net(nn.Module):
def init(self):
super().init()

    self.layer = nn.Sequential(
        nn.Linear(25088, 4096),
        nn.ReLU(True),
        nn.Linear(4096, 10),
        nn.ReLU(True),
         )

def forward(self, x):
    x = self.layer(x)
    return x

net=Net()
b=net(a)
print(b.shape)

你可能感兴趣的:(向量拉直)