torch.tensor(data,
dtype=None,
device=None,
requires_grad=False,
pin_memory=False)
功能:从data创建tensor
torch.from_numpy(ndarray)
功能:从numpy创建tensor,创建后的tensor与numpy共享内存,即一个修改后另一个也会随之改变
torch.zeros(size,
out=None,
dtype=None,
layout=torch.strided,
device=None,
requires_grad=False)
torch.zeros_like(input,
dtype=None,
layout=None,
device=None,
requires_grad=False)
torch.ones()与torch.zeros类似
torch.ones_like()与torch.zeros_like类似
torch.normal(mean,
std,
out=None)
功能:生成正态分布
torch.cat(tensors,
dim=0,
out=None)
功能:将张量按维度dim进行拼接
torch.stack(tensors,
dim=0,
out=None)
功能:在新创建的维度dim上进行拼接
torch.chunk(input,
chunks,
dim=0 )
功能:将张量按维度dim进行平均切分
返回值:张量列表
torch.split(tensor,
split_size_or_sections,
dim=0)
功能:将张量按维度dim进行切分
返回值:张量列表
torch.reshape(input,
shape)
功能:变换张量形状
torch.transpose(input,
dim0,
dim1)
功能:交换张量的两个维度
torch.t(input)
功能:2维张量转置,对矩阵而言,等价于torch.transpose(input,0,1)
torch.squeeze(input,
dim,
out=None)
功能:压缩长度为1的轴
torch.unsqueeze(input,
dim,
out=None)
功能:依据dim扩展维度
函数汇总(具体的使用在用到时去查阅官方文档即可)
torch.add()
torch.addcdiv()
torch.addcmul()
torch.sub()
torch.div()
torch.mul()
torch.log()
torch.log10()
torch.log2()
torch.exp()
torch.pow()
torch.abs()
torch.acos()
torch.cosh()
torch.cos()
torch.asin()
torch.atan()
torch.atan2()
详细介绍几个常用的运算函数
torch.add(input,
alpha=1,
other,
out=None)
功能:逐元素计算 input + alpha x other
torch.mm(tensor1, tensor2)
功能:(m x n) x (n x p) -> (m x p).
torch.bmm(tensor1, tensor2)
功能: (b x m x n) x (b x n x p) -> (b x m x p)
torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sample=None,
batch_sample=None,
num_workers=0,
collate_fn=None,
pin_memory=Flase,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
torch.utils.data.Dataset()
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写。
1.1、transforms.CenterCrop()
功能:从图像中心裁剪图片
1.2、transforms.RandomCrop()
功能:从图中随机裁剪出尺寸为size的图片
2.1、transforms.RandomHorizontalFlip()
功能:依概率水平翻转
2.2、transforms.RandomVerticalFlip()
功能:依概率垂直翻转
2.3、transforms.RandomRotation()
功能:随机旋转图片
3.1、transforms.Pad()
功能:对图片边缘进行填充
3.2、transforms.ColorJitter()
功能:调整亮度、对比度、饱和度和色相
3.3、transforms.RandomGrayscale()
功能:依概率将图片转换为灰度图
import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
如果有多张显卡,则
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
清除显存
torch.cuda.empty_cache()
torch.cuda.device_count() # 计算当前可见可用gpu数
torch.cuda.get_device_name() # 获取gpu名称
torch.cuda.manual_seed() # 为当前gpu设置随机种子
torch.cuda.manual_seed_all() # 为所有可见gpu设置随机种子