首先是导入包因为使用的是pytorch框架
所以倒入torch
相关包,summary
是可以获得自己搭建模型的参数、各层特征图大小、以及各层的参数所占内存的包作用效果如p2
;
安装方法:pip install torchsummary
'''
导入包
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
定义类Class
以及def super
,这些是类的继承最基础的知识啦如果不懂原理就按模版记下即可;接着开始搭建层,这里采用nn.Sequential
,相当于一个大容器
可以放入任意量的网络层
在p1
中放入一个卷积层;接着进入线性层
依然使用nn.Sequential
;
class Net(nn.Module):
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.fetures = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64,
kernel_size=3, stride=1, padding=1))
self.classify = nn.Sequential(nn.Linear(32 * 32 * 64, 20),
nn.Linear(20, num_classes))
定义好网络层
就可以定义层之间的计算过程啦首先进入卷积层接着需要将卷积层的形状从四维变成二维,在这里使用了view函数
,接着传入线性层得到return
def forward(self, x):
x = self.fetures(x)
x = x.view(x.size(0), -1)
x = self.classify(x)
return x
实例化网络
;假设输入大小为(10, 3, 32,32)
,将输入传入网络就得到输出结果的尺寸啦!其中10代表每一次输入的图像张数;3是通道数3, 32, 32
为输入图片的宽高。调用summary
检查网络结构,此时只需输入(3, 32, 32)
即可因为summary
中只需输入通道数以及宽高即可。
Modle = Net()
input = torch.ones([10, 3, 32, 32])
result = Modle(input)
print(result.shape)
summary(Modle.to("cuda"), (3, 32, 32))
完整代码如下,如果运行后出现错误可以在评论区里写下你的看法和建议:
'''
Aouther:LiuZhenming
Time:2022-09-25
'''
# 导入包
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
# 定义类和函数
class Net(nn.Module):
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.fetures = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64,
kernel_size=3, stride=1, padding=1))
self.classify = nn.Sequential(nn.Linear(32 * 32 * 64, 20),
nn.Linear(20, num_classes))
# 定义网络层
def forward(self, x):
x = self.fetures(x)
x = x.view(x.size(0), -1)
x = self.classify(x)
return x
# 实例化网络
Modle = Net()
input = torch.ones([10, 3, 32, 32])
result = Modle(input)
print(result.shape)
summary(Modle.to("cuda"), (3, 32, 32))
torch .Size([10,10])
–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
Layer (type) Output Shape Param #==============================================================
Conv2d-1 [-1,64,32,32] 1,792
Linear-3 [-1,10] 210==============================================================
Total params: 1, 312, 742
Trainable params: 1, 312, 742
Non-trainable params: 0–––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––––
Input size (MB) : 0.01
Forward/ backward pass size (MB) : 0.50