Pytorch入门十三 || 编写卷积神经网络时自动求出卷积后flatten向量的维度

Pytorch入门十三 || 编写卷积神经网络时自动求出flatten向量的维度

# 在模型定义的类中加入函数 _get_conv_out()
# 此处shape是传入的图片大小,如(3,84,84) 大小为84*84的彩色图像
def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1 ,*shape))   # 在前面多加一个通道,表示1个batch,conv是卷积层
        return int(np.prod(o.size())),

完整代码如下

import torch
import torch.nn as nn
import numpy as np

class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()

        self.conv = nn.Sequential(		# 卷积层,得到64个通道的特征图,大小不清楚
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)	# 调用函数得到卷积后64个特征图flatten后的大小
        
        self.fc = nn.Sequential(	# 线性层
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    # 自动计算特征图的大小
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))   # 在前面多加一个通道,表示1个batch
        return int(np.prod(o.size()))	# np.prod()函数返回(batch,c,w,h)这四个数相乘的结果

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

# 使用
net = DQN((3,84,84),4)	# 第一个参数为元组,表示图像的shape信息

你可能感兴趣的:(深度学习,pytorch,神经网络)