在用PyTorch搭建深度学习模型时,常常遇到一些不知道该如何使用的函数,在网上查到资料弄懂之后,过段时间又忘了,所以以后再遇到不懂的函数就放在这儿,方便后续查询,就把这当成自己的API函数手册吧。由于PyTorch常与Numpy相结合,所以也把Numpy函数记录在这儿。
1.查看PyTorch版本
import torch
print(torch.__version__)
2.创建张量
x=torch.rand(1,3,3,3) #随机初始化张量 4维张量 [batch channel H W]
3.打印模型参数量
安装torchstat:pip install torchstat
示例:
from torchstat import stat
import torchvision.models as models
model = models.resnet34()
stat(model, (3, 224, 224))
4.深度可分离卷积
使用nn.Conv2d的groups
参数实现分组卷积
利用1x1卷积改变通道数
示例:
conv1 = nn.Conv2d(in_channels=3, out_channels=3,
kernel_size=3, stride=1, padding=1, groups=3, bias=False)
注意in_channels=out_channels=groups
5.保存和加载整个模型时的注意事项:
如果是训练时用的GPU训练 则预测时的输入也要使用GPU:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_weight_path = "./save_weights/FullModel.pth" # 直接加载整个模型和参数 不需要重新定义模型
model = torch.load(model_weight_path)
model(img.to(device)) #注意这里的输入也要由GPU计算
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).cpu().numpy() #注意这里使用cpu转换成numpy
参考:pytorch模型的保存和加载
1.items()方法
items() 方法的遍历:items() 方法把字典中每对 key 和 value 组成一个元组,并把这些元组放在列表中返回。
d = {'one': 1, 'two': 2, 'three': 3}
print(d.items())
#输出:dict_items([('one', 1), ('two', 2), ('three', 3)])
for key,value in d.items():#当两个参数时
print(key + ':' + str(value))
#输出:one:1 two:2 three:3
for i in d.items():#当参数只有一个时
print(i)
#输出:('one', 1) ('two', 2) ('three', 3)
1.argparse的使用
argparse是命令行参数解析器
使用方法:
import argparse #导入命令行参数解析器包
parser = argparse.ArgumentParser(
description=__doc__) #创建命令行解析器
#添加命令行参数
parser.add_argument('--root-dir',default='F:/datasets/MPGCCLASS', help='根目录') #添加根目录
args = parser.parse_args() #解析命令
args = vars(args) #为了方便使用,转化为字典形式
print(args)
2.tqdm的使用 可以显示进度
import tqdm
示例:for i in tqdm(imgnames, desc='正在执行......'):
pass
3.将python的输出信息存到文件中,同时控制台照常显示
定义一个类:
import sys
class Logger(object):
def __init__(self, filename='default.log', stream=sys.stdout):
self.terminal = stream
self.log = open(filename, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
sys.stdout = Logger('./train.out', sys.stdout)
sys.stderr = Logger('./train.err', sys.stderr)
# 示例:
for i in range(100):
print(i)