Pytorch学习笔记--Pytorch常用函数总结1

目录​​​​​​​

1-torch.randn()函数 

2-set()函数和sorted()函数

3-DataLoader()函数和Dataset类

4-.t()函数

5-最大池化(max_pool2d)和平均池化(avg_pool2d)函数


​​​​​​​

1-torch.randn()函数 

import torch

batch_size = 1
seq_len = 3
input_size = 4

inputs = torch.randn(seq_len, batch_size, input_size) 
torch.randn()函数的作用是生成一组均值为0,方差为1(即标准正态分布)的随机数。示例如下: 
import torch

print(torch.randn(3, 2, 3, 3))

​​​​​​​Pytorch学习笔记--Pytorch常用函数总结1_第1张图片

torch.randn(seq_len, batch_size, input_size):第一个参数seq_len表示序列长度,示例中序列长度为3;第二个参数batch_size表示批大小,示例中批大小为2;第三个参数input_size为输入向量的维度,示例中为(3, 3)。(在RNN中可理解成:示例中,共有3个序列,每个序列分为2批,每批的维度为3*3。)

#####################################

#####################################

2-set()函数和sorted()函数

self.country_list = list(sorted(set(self.countries))) # set()去重,删除重复的数据; sorted()排序

set()函数用于删除重复的数据元素;sorted()用于元素的排序,示例如下:

a = ['china', 'china', 'japan']
print(list(set(a)))
print(list(sorted(set(a))))

Pytorch学习笔记--Pytorch常用函数总结1_第2张图片

由于‘c’ < 'j',所以‘china’排在‘japan’前面。

#####################################

#####################################

3-DataLoader()函数和Dataset类

from torchvision import datasets
from torch.utils.data import DataLoader, Dataset

batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(), #将shape为(H, W, C)的img转为shape为(C, H, W)的tensor,将每一个数值归一化到[0,1]
    transforms.Normalize((0.1307, ), (0.3081, )) #按通道进行数据标准化
])

train_dataset = datasets.MNIST(root = '../dataset/mnist/', train = True, download = True, transform = transform)

train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size)

test_dataset = datasets.MNIST(root = '../dataset/mnist/', train = False, download = True, transform = transform)

test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size)

​​​​​​​DataLoader()函数导入的数据集为Dataset类型,shuffle表示是否打乱数据集。

#####################################

#####################################

4-.t()函数

.t()函数的作用是将Tensor转置,示例如下:

import torch

input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

print(input)
print(input.t())

Pytorch学习笔记--Pytorch常用函数总结1_第3张图片

#####################################

#####################################

5-最大池化(max_pool2d)和平均池化(avg_pool2d)函数

import torch
import torch.nn.functional as F

input = torch.tensor([[[1, 2, 3, 1], [4, 5, 6, 1], [7, 8, 9, 1]]]).unsqueeze(0).float() # unsqueeze(0)在第0个维度前增加一个维度
print(input.size())
output = F.max_pool2d(input, kernel_size = (1, 4))
print(output)

max_pool2d():最大池化操作。根据设定的核大小,选取最大的元素值。示例中,核大小是(1,4),可理解为挑选出每行最大的元素值。

需要说明的是:unsqueeze(0)的作用是在第0维度前扩展一个维度,所以input的size为(1, 1, 3, 4)。

###

import torch
import torch.nn.functional as F

input = torch.randn(1, 1, 4, 4)
print(input.size())
print(input)
output = F.avg_pool2d(input, kernel_size = (2, 2))
print(output)

Pytorch学习笔记--Pytorch常用函数总结1_第4张图片

 avg_pool2d():平均池化操作。根据设定的核大小,计算得到核内元素的平均值。

池化的作用:降维;​​​​​​​抑制噪声,降低信息冗余;提升模型的尺度不变性、旋转不变形;降低模型计算量;防止过拟合。

你可能感兴趣的:(pytorch,python,深度学习)