python 训练神经网络时的小工具

1.删除当前文件夹下的所有jpg/png

import os
files= os.listdir()

for pic in files:  # 遍历文件夹

    if pic.endswith(".png"):
        os.remove(pic)
    elif pic.endswith(".jpg"):
        os.remove(pic)

2.mac下安装XGBoost

macOS终端输入:

 conda install py-xgboost

python文件直接import:

import xgboost as xgb

3. 实验数据写入csv文件保存

可以写一个函数:

import csv
import codecs
def wcsv(fname,vgs,vds,a1s,b1s,a2s,b2s):
    # fname为“xx.csv” 即csv文件的名字
    # vgs,vds,a1s,b1s,a2s,b2s为要保存的变量
    header = ['vg', "vd", "a1", "b1", "a2","b2"]
    # header为csv文件的表头
    f=codecs.open(fname,'w+','utf-8')
    writer=csv.writer(f)
    writer.writerow(header)
    # 逐行写入变量数据
    for vg,vd,a1,b1,a2,b2 in zip(vgs,vds,a1s,b1s,a2s,b2s):
        writer.writerow((vg,vd,a1,b1,a2,b2))

4.循环在一张图上将多条曲线

import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
for i in tvg.keys(): 
    # 取列表中某条曲线的点
    x = tvg[i]
    y = gids[i]
    ax.plot(x, y) 
plt.savefig(f"origin_{index}.jpg")
#保存图片

5.为网络参数设置不同的学习率

在optimizer里直接设置(以sgd为例子):

  optimizer = torch.optim.SGD([
    # 直接改每一项的lr即可
        {'params': wnn.net1.parameters(), 'lr':1e-3},
        {'params': wnn.net2.parameters(), 'lr': 1e-3},
        {'params': wnn.a1.parameters(), 'lr': 1e-3},
        {'params': wnn.a2.parameters(), 'lr': 1e-3},
        {'params': wnn.b1.parameters(), 'lr': 1e-5},
        {'params': wnn.b2.parameters(), 'lr': 1e-3}, ],
        weight_decay=0.0005,
    )

 

6.为网络参数设置自适应学习率

torch.optim.lr_scheduler 提供了几种方法来根据 epoch 的数量调整学习率。 torch.optim.lr_scheduler.ReduceLROnPlateau 允许基于一些验证测量来降低动态学习率。

以LambdaLR为例:

# 学习率调整策略
lambda1= lambda epoch: 0.95 ** epoch
# lr_lambda=[]用于填写optimizer对应优化参数的优化策略
# 此例为八个参数都选择lambda1的递减策略
scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=[lambda1,lambda1,lambda1,lambda1,lambda1,lambda1,lambda1,lambda1])
 

记得在优化循环中加上scheduler.step() !! 

7.模型的保存与重用

保存训练好的模型:

torch.save(model, './model.pt')

重新调用该模型:

model=torch.load('./model.pt')

8.一个优化器同时优化多个网络

可以参考5中的写法,也可以使用 itertools.chain:(以Adam为例)

optimizer=torch.optim.Adam(itertools.chain(
    wnn.net1.parameters(),
    wnn.net2.parameters(),
    wnn.a1.parameters(),
    wnn.a2.parameters(),
    wnn.b1.parameters(),
    wnn.b2.parameters(),
    wnn.bias.parameters()
    ), 
    lr=wnn.plr)

你可能感兴趣的:(学习笔记,python,python,开发语言)