PyTorch_view

代码

import torch
import torch.nn as nn

input = torch.randn(1, 512, 1, 1)
output = input.view(input.size(0), -1)
print(output.size())

torch.Size([1, 512])
import torch
import torch.nn as nn

input = torch.randn(1, 512, 1, 1)
output = input.view(-1, input.size(0))
print(output.size())

torch.Size([512, 1])

引用

<1>

torch=1.7.1+cu101
torchvision=0.8.2
torchaudio=0.7.2

你可能感兴趣的:(pytorch)