Pytorch学习笔记

Pytorch学习笔记

目录

Pytorch学习笔记

1. nn.moduleList 和Sequential用法和实例

1.1、nn.Sequential():模型建立方式

2. Pytorch基本操作

expand()扩展维度的

contiguous()

torch.ge,torch.gt,torch.le逐元素比较

 3. Pytorch常用工具

Pytorch可视化工具

统计模型参数量与FLOPs


1. nn.moduleList 和Sequential用法和实例

    https://blog.csdn.net/e01528/article/details/84397174

建立nn.Sequential()对象,必须小心确保一个块的输出大小与下一个块的输入大小匹配。基本上,它的行为就像一个nn.Module。

1.1、nn.Sequential():模型建立方式

第一种写法:
nn.Sequential()对象.add_module(层名,层class的实例)

net1 = nn.Sequential()

net1.add_module('conv', nn.Conv2d(3, 3, 3))

net1.add_module('batchnorm', nn.BatchNorm2d(3))

net1.add_module('activation_layer', nn.ReLU())

第二种写法:
nn.Sequential(*多个层class的实例)

net2 = nn.Sequential(

        nn.Conv2d(3, 3, 3),

        nn.BatchNorm2d(3),

        nn.ReLU()

        )

第三种写法:
nn.Sequential(OrderedDict([*多个(层名,层class的实例)]))

from collections import OrderedDict

net3= nn.Sequential(OrderedDict([

          ('conv', nn.Conv2d(3, 3, 3)),

          ('batchnorm', nn.BatchNorm2d(3)),

          ('activation_layer', nn.ReLU())

        ]))

2. Pytorch基本操作


expand()扩展维度的

扩展某个相同元素维度,将如x.shape=(2,2,1)的第3维度扩展为x.shape=(2,2,3),一般用在Tensor维度不匹配时,可使用该方法进行扩展.

import torch
x=torch.randn(2,2,1)
print(x)
y=x.expand(2,2,3)
print(y)

输出:

tensor([[[ 0.0608],
         [ 2.2106]],
 
        [[-1.9287],
         [ 0.8748]]])
tensor([[[ 0.0608,  0.0608,  0.0608],
         [ 2.2106,  2.2106,  2.2106]],
 
        [[-1.9287, -1.9287, -1.9287],
         [ 0.8748,  0.8748,  0.8748]]])

contiguous()

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。可以通过is_contiguous查看张量内存是否连续。调用view之前最好先contiguous

x.contiguous().view() 

返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。可以通过is_contiguous查看张量内存是否连续。

import torch
a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
print(a.is_contiguous)
print(a.contiguous().view(4,3))

输出:


tensor([[  1,   2,   3],
        [  4,   8,  12],
        [  1,   2,   3],
        [  4,   8,  12]])

torch.ge,torch.gt,torch.le逐元素比较

https://blog.csdn.net/jacke121/article/details/86360959


 3. Pytorch常用工具

Pytorch可视化工具

https://www.jianshu.com/p/46eb3004beca

https://github.com/miaoshuyu/pytorch-tensorboardx-visualization

统计模型参数量与FLOPs

https://mp.weixin.qq.com/s/2HFulqqNdqp3b6XXa7qPgA

PyTorch-OpCounter GitHub 地址: https://github.com/Lyken17/pytorch-OpCounte 

对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成:

from torchvision.models import resnet50
from thop import profile

model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))

 

你可能感兴趣的:(学习笔记,Pytorch)