本文旨在简单介绍wandb在卷积神经网络训练过程中的一些基础设置,可以快速入门并使用wandb记录自己的训练日志,方便后续的实验复现。如果有什么问题欢迎在评论区讨论。
登录注册账号:https://wandb.ai/
登陆注册账号后创建一个project,会得到一个账号相关联的key
可以认为该key是一个身份码,与自己注册的账号绑定
(学习教程:https://docs.wandb.ai/quickstart)
#在命令行对应环境中执行
pip install wandb
#运行下面代码后会要求输入key,即输入步骤一的key
wandb login
#如果想要换账号:
#wandb login --relogin
运行wandb login --relogin
代码会出现:
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:
此时随意输入40位数字即可覆盖之前的账号信息,以保证在其他电脑上使用时实验数据的安全。或者换一个账号的key进行登录
import os
os.environ['WANDB_MODE'] = 'offline'
调用上面语句可以在调试的时候不传数据去云端,但是运行日志会保留在本地?
此功能一般在调试bug的时候不想在云端保存训练数据时使用。
import wandb
wandb.init(project="test", entity="shuttle",name="test1")
其中的test对应wandb云端的project,entity对应的云端账号,name=tets1 对应的云端的运行run。
记录超参数一般在网络训练之前,大多数人的用法是直接记录argparse中的所有参数,见方式③
① 方式1
wandb.config = {
"learning_rate": 0.001,
"epochs": 100,
"batch_size": 128
}
wandb.config.update()
② 方式2
#记录训练的超参数
config = wandb.config
config = {
"learning_rate": opt.lr,
"epochs": opt.epoch,
"batch_size": opt.batch_size
}
#name为运行的空间名称,可以有同名
wandb.init(project="study",entity="shuttle",name="train_face",config=config)
后续添加超参数可以使用如下语句:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.config.update({"device" : device , "model" : "Vgg16"})
③ 方式3
# CASIA_options()为项目执行参数表
options = CASIA_options()
opt = options.parse()
wandb.config.update(opt)
使用这种方法,要在前面将config放进init函数里面,即:
wandb.init(project="study",entity="shuttle",name="train_face",config=config)
④ 更多方法详情参考网址:https://docs.wandb.ai/guides/track/config
在正常打印日志的地方添加下列语句:
if(step%100 == 0):
print("step:{} all_step:{} loss:{}".format(step,int(train_size/4),loss))
wandb.log({"loss": loss,
"epoch":epoch})
wandb.log
输入的是字典类型
这里如果需要保存图片可以使用下列方法:
Img = wandb.Image(image, caption="epoch:{}".format(epoch) + string)
wandb.log({"epoch "+str(epoch): Img})
这里的image为需要输入到wandb的图片,epoch为训练的epoch数,string为自定义图片的名称。
下面字典的键值对中键表示的是wandb显示图片上面的标题,值就是图片。
在所有epoch运行完之后,添加下列语句即可保存模型并关闭wandb:
wandb.save('model.h5')
wandb.finish()
在获得模型后面加就可以了
Vgg_model = model.vgg16(pretrained=True,progress=True)
#获得训练过程中权值梯度的直方图
wandb.watch(Vgg_model)
随便在一个有wandb并login的主机新建py文件,运行,其中,project、entity、id需与项目保持一致。
import wandb
wandb.init(project="Net", entity="shuttle", id="fgor***z")
wandb.finish()
import wandb
api = wandb.Api()
run = api.run("shuttle/Net/fgor***z")
if run.state == "finished":
for i, row in run.history().iterrows()
print(row["_timestamp"], row["accuracy"])
使用wandb可以导出对应的run信息
wandb sync "wandb/run-20220428_220148-1yeps3ra/"
注意:这里得"wandb/run-20220428_220148-1yeps3ra/"地址需要修改为你电脑里的相对应的地址。"1yeps3ra"表示的是项目id地址。
这里主要是因为使用wandb训练时经常是出现网络中断而导致本地日志没有全部上传云端,所以在云端只有部分数据,可以使用这条命令进行同步,将云端数据补全。
wandb关于本地的一些日志设置方法可以参考下述链接:
https://zoeyuchao.github.io/2020/10/15/Weights-&-Biases%E4%BD%BF%E7%94%A8%E6%8C%87%E5%8D%97.html#3run
wandb对于视觉的处理:https://docs.wandb.ai/examples
参考链接:https://blog.csdn.net/qq_40507857/article/details/112791111