目录
Pytorch学习笔记
1. nn.moduleList 和Sequential用法和实例
1.1、nn.Sequential():模型建立方式
2. Pytorch基本操作
expand()扩展维度的
contiguous()
torch.ge,torch.gt,torch.le逐元素比较
3. Pytorch常用工具
Pytorch可视化工具
统计模型参数量与FLOPs
https://blog.csdn.net/e01528/article/details/84397174
建立nn.Sequential()对象,必须小心确保一个块的输出大小与下一个块的输入大小匹配。基本上,它的行为就像一个nn.Module。
第一种写法:
nn.Sequential()对象.add_module(层名,层class的实例)
net1 = nn.Sequential()
net1.add_module('conv', nn.Conv2d(3, 3, 3))
net1.add_module('batchnorm', nn.BatchNorm2d(3))
net1.add_module('activation_layer', nn.ReLU())
第二种写法:
nn.Sequential(*多个层class的实例)
net2 = nn.Sequential(
nn.Conv2d(3, 3, 3),
nn.BatchNorm2d(3),
nn.ReLU()
)
第三种写法:
nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))
from collections import OrderedDict
net3= nn.Sequential(OrderedDict([
('conv', nn.Conv2d(3, 3, 3)),
('batchnorm', nn.BatchNorm2d(3)),
('activation_layer', nn.ReLU())
]))
扩展某个相同元素维度,将如x.shape=(2,2,1)的第3维度扩展为x.shape=(2,2,3),一般用在Tensor维度不匹配时,可使用该方法进行扩展.
import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)
输出:
tensor([[[ 0.0608],
[ 2.2106]],
[[-1.9287],
[ 0.8748]]])
tensor([[[ 0.0608, 0.0608, 0.0608],
[ 2.2106, 2.2106, 2.2106]],
[[-1.9287, -1.9287, -1.9287],
[ 0.8748, 0.8748, 0.8748]]])
返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。可以通过is_contiguous查看张量内存是否连续。调用view之前最好先contiguous
x.contiguous().view()
返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。可以通过is_contiguous查看张量内存是否连续。
import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous)
print(a.contiguous().view(4,3))
输出:
tensor([[ 1, 2, 3],
[ 4, 8, 12],
[ 1, 2, 3],
[ 4, 8, 12]])
https://blog.csdn.net/jacke121/article/details/86360959
https://www.jianshu.com/p/46eb3004beca
https://github.com/miaoshuyu/pytorch-tensorboardx-visualization
https://mp.weixin.qq.com/s/2HFulqqNdqp3b6XXa7qPgA
PyTorch-OpCounter GitHub 地址: https://github.com/Lyken17/pytorch-OpCounte
对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成:
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))