pytorch训练算法的框架

写了一个pytorch的框架,主要是为了方便训练用的,目标是验证模型对EEG小波图的作用,好用点个赞,给个反馈,顺便说说如果机械转码怎么做?我已经尽力写了一年多了,真的坑,一点成功可能都没有。

# coding=utf8
import torch
import torch.nn as nn
import os
import numpy as np
from utils import save_data_df, save_dict
# from Separate_convolution_test.data_loader.data_loader_with_all_men_2 import DataLoaderX, KZDDataset_pic
from Separate_convolution_test.data_loader.data_loader_with_all_men_3 import DataLoaderX, KZDDataset_pic
from loss_fun.loss_function import loss_with_flo, EarlyStopping, seed_torch
from einops import rearrange, reduce, repeat
import warnings
from torch.optim import lr_scheduler
from torchvision.utils import make_grid
import tensorboardX
import time
import gc
import glob
import shutil
import datetime
import pandas as pd
from sklearn.metrics import cohen_kappa_score


def train(train_list, Net, learn_rate, batch_size,
          val_batch_size, data_saving_path, patience=100,
          experment_lable="", epoch_num=300, randseed=44, **kwargs):
    seed_torch(randseed)

    now_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

    tmp_save_data_dict = save_dict

    tmp_save_data_dict["数据集名称"] = kwargs["kwargs"]["dataset_name"]
    tmp_save_data_dict["被试标签"] = kwargs["kwargs"]["trip_lable"]
    tmp_save_data_dict["实际训练epoch"] = epoch_num

    tmp_save_data_dict["交叉验证折数K"] = kwargs["kwargs"]["K"]
    tmp_save_data_dict["交叉验证批次Ki"] = kwargs["kwargs"]["Ki"]
    tmp_save_data_dict["实验时间"] = now_time
    K = kwargs["kwargs"]["K"]
    Ki = kwargs["kwargs"]["Ki"]

    if 'label_smoothing' in kwargs["kwargs"].keys():
        label_smoothing = kwargs["kwargs"]['label_smoothing']
    else:
        label_smoothing = 0.3

    if 'biza' in kwargs["kwargs"].keys():
        biza = kwargs["kwargs"]['biza']
    else:
        biza = 0

    if 'weight_decay' in kwargs["kwargs"].keys():
        weight_decay = kwargs["kwargs"]['weight_decay']
    else:
        weight_decay = 0.12

    if 'T_max' in kwargs["kwargs"].keys():
        T_max = kwargs["kwargs"]['T_max']
    else:
        T_max = 80

    if 'eta' in kwargs["kwargs"].keys():
        eta = kwargs["kwargs"]['eta']
    else:
        eta = 0.001
    if 'model_save' in kwargs["kwargs"].keys():
        model_save = kwargs["kwargs"]['model_save']
    else:
        model_save = 10

    if 'scheduler' in kwargs["kwargs"].keys():
        scheduler = kwargs["kwargs"]['scheduler']
    else:
        scheduler = None

    if 'optimizer' in kwargs["kwargs"].keys():
        optimizer = kwargs["kwargs"]['optimizer']
    else:
        optimizer = "Adam"

    if 'Loss_weight' in kwargs["kwargs"].keys():
        Loss_weight = kwargs["kwargs"]['Loss_weight']
    else:
        Loss_weight = [0.5, 0.5]

    if 'add_noise' in kwargs["kwargs"].keys():
        add_noise = kwargs["kwargs"]['add_noise']
    else:
        add_noise = False

    if 'val_add_noise' in kwargs["kwargs"].keys():
        val_add_noise = kwargs["kwargs"]['val_add_noise']
    else:
        val_add_noise = False

    if 'noise_std' in kwargs["kwargs"].keys():
        noise_std = kwargs["kwargs"]['noise_std']
    else:
        noise_std = 0.05

    # gc.enable()  # 自动内存清理
    train_dataset = KZDDataset_pic(im_list=train_list, ki=Ki, K=K, type='train')  # 训练姐

    val_dataset = KZDDataset_pic(im_list=train_list, ki=Ki, K=K, type='val')  # 测试机

    train_loader = DataLoaderX(dataset=train_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               pin_memory=True,
                               # num_workers=3,
                               drop_last=True
                               )

    val_loader = DataLoaderX(dataset=val_dataset,
                             batch_size=val_batch_size,
                             shuffle=False,
                             pin_memory=True,
                             # num_workers=2,
                             )

    lr = learn_rate
    time_pre = time.time()
    device = torch.device('cuda')  # 使用GPU
    net = Net.to(device)  # 网络模型

    # net.apply(weigth_init)

    # weight_decay = 2e-5
    loss_func = loss_with_flo(label_smoothing=label_smoothing, biza=biza, weight=Loss_weight)
    if optimizer == "Adam":
        optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    if optimizer == "AdamW":
        optimizer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=True)
    if optimizer == "SGD":
        optimizer = torch.optim.SGD(net.parameters(), momentum=0.9, lr=lr, weight_decay=weight_decay, nesterov=True)
    if optimizer == "RMSprop":
        optimizer = torch.optim.RMSprop(net.parameters(), momentum=0.9, lr=lr, weight_decay=weight_decay)

    vaild_loss = nn.CrossEntropyLoss()

    if scheduler == "CosineAnnealingLR":
        if 'T_max' in kwargs["kwargs"].keys():
            T_max = kwargs["kwargs"]['T_max']
        else:
            T_max = 80

        if 'eta' in kwargs["kwargs"].keys():
            eta = kwargs["kwargs"]['eta']
        else:
            eta = 0.001

        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=lr * eta)  # 余弦退火算法

    if scheduler == "StepLR":
        if 'step_size' in kwargs["kwargs"].keys():
            step_size = kwargs["kwargs"]['step_size']
        else:
            step_size = 30

        scheduler = lr_scheduler.StepLR(optimizer, step_size, gamma=0.5, last_epoch=-1)  # 等间隔调整学习率下降

    if scheduler == "MultiStepLR":

        if 'milestones' in kwargs["kwargs"].keys():
            milestones = kwargs["kwargs"]['milestones']
        else:
            milestones = [100, 300]
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.5, last_epoch=-1)  # 按需调整学习率

    early_stop = EarlyStopping(mode='max', patience=patience)

    best_correct = 0
    step_n = 0
    min_loss = 100
    wait_time = 0

    # 记录用变量

    model_path = data_saving_path + "\\models\\lr=" + str(lr) + experment_lable + now_time  # 模型存放路径
    log_path = data_saving_path + "\\log\\lr=" + str(lr) + experment_lable + now_time  # 数据存放路径

    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    writer = tensorboardX.SummaryWriter(log_path)  # 数据记录器
    # tmp = torch.randn(10, 3, 100, 300).to(device).type(torch.cuda.FloatTensor)
    # writer.add_graph(net, tmp)
    prev_time = time.time()

    # 训练开始
    ####。。。。
    all_num = 0
    all_num_val = 0
    print("pre_use_time = ", time.time() - time_pre)
    for epoch in range(epoch_num):
        y_true_train = np.empty(shape=(0, 1))
        y_true_test = np.empty(shape=(0, 1))
        y_pred_train = np.empty(shape=(0, 1))
        y_pred_test = np.empty(shape=(0, 1))
        time_start = time.time()
        train_acc_num = 0
        test_acc_num = 0
        train_loss = 0
        test_loss = 0
        num = 0
        input_num = 0

        for i, data in enumerate(train_loader):
            net.train()  # train BN dropout
            inputs, labels = data
            input_num = inputs.size()[0] + input_num
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            inputs = inputs.type(torch.cuda.FloatTensor)

            if add_noise:
                tensor_size = inputs.size()
                std = torch.full(tensor_size, noise_std)
                noise = torch.normal(0, std, out=None).to(device, non_blocking=True).type(torch.cuda.FloatTensor)
                outputs = net(inputs + noise)

                if epoch == 0:  # 这样子加噪声会比较慢
                    pic = rearrange(inputs, "b c h w -> b 1 (c h) w")
                    b, c, h, w = inputs.size()
                    labels_pic = repeat(labels, " b -> b c h w ", c=1, h=30, w=w)
                    pic = torch.cat((pic, labels_pic), 2)
                    pic_grid = make_grid(pic, nrow=4, padding=2)
                    writer.add_image("训练集输入", pic_grid, all_num)
            else:
                outputs = net(inputs)
                if epoch == 0:
                    pic = rearrange(inputs, "b c h w -> b 1 (c h) w")
                    b, c, h, w = inputs.size()
                    labels_pic = repeat(labels, " b -> b c h w ", c=1, h=30, w=w)
                    pic = torch.cat((pic, labels_pic), 2)
                    pic_grid = make_grid(pic, nrow=4, padding=2)
                    writer.add_image("训练集输入", pic_grid, all_num)

            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _, pred = torch.max(outputs.data, dim=1)

            pred_label = pred.cpu().numpy()
            y_pred_train = np.append(y_pred_train, pred_label)
            y_true_train = np.append(y_true_train, labels.cpu().numpy())

            correct = pred.eq(labels.data).cpu().sum().item()
            train_acc_num = correct + train_acc_num
            train_loss = loss.item() + train_loss
            all_num = all_num + 1
            num = num + 1

        train_acc = 100.0 * train_acc_num / input_num
        kappa_train = cohen_kappa_score(y_true_train, y_pred_train)
        step_n += 1

        writer.add_scalar("train loss", train_loss, global_step=step_n)
        writer.add_scalar("train correct",
                          train_acc, global_step=step_n)
        writer.add_scalar("train kappa value", kappa_train, global_step=step_n)

        num = 0
        input_num = 0
        for i, data in enumerate(val_loader):
            num = num + 1
            net.eval()
            inputs, labels = data
            input_num = inputs.size()[0] + input_num
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            inputs = inputs.to(torch.float32)

            if val_add_noise:
                tensor_size = inputs.size()
                std = torch.full(tensor_size, noise_std)
                noise = torch.normal(0, std, out=None).to(device, non_blocking=True).type(torch.cuda.FloatTensor)
                outputs = net(inputs + noise)  # 这样子加噪声会比较慢
                if epoch == 0:
                    pic = rearrange(inputs, "b c h w -> b 1 (c h) w")
                    b, c, h, w = inputs.size()
                    labels_pic = repeat(labels, " b -> b c h w ", c=1, h=30, w=w)
                    pic = torch.cat((pic, labels_pic), 2)
                    pic_grid = make_grid(pic, nrow=4, normalize=True, padding=2)
                    writer.add_image("验证集输入", pic_grid, all_num_val)
            else:
                outputs = net(inputs)
                if epoch == 0:
                    pic = rearrange(inputs, "b c h w -> b 1 (c h) w")
                    b, c, h, w = inputs.size()
                    labels_pic = repeat(labels, " b -> b c h w ", c=1, h=30, w=w)
                    pic = torch.cat((pic, labels_pic), 2)
                    pic_grid = make_grid(pic, nrow=4, padding=2)
                    writer.add_image("验证集输入", pic_grid, all_num_val)
            loss = vaild_loss(outputs, labels)
            _, pred = torch.max(outputs.data, dim=1)

            pred_label = pred.cpu().numpy()
            y_pred_test = np.append(y_pred_test, pred_label)
            y_true_test = np.append(y_true_test, labels.cpu().numpy())

            correct = pred.eq(labels.data).sum().item()
            test_acc_num = correct + test_acc_num
            test_loss = loss.item() + test_loss
            all_num_val = all_num_val + 1

        test_acc = 100.0 * test_acc_num / input_num
        kappa_test = cohen_kappa_score(y_true_test, y_pred_test)
        writer.add_text('y_pred_test', ','.join(str(tmp) for tmp in y_pred_test), global_step=epoch)
        writer.add_text('y_true_test', ','.join(str(tmp) for tmp in y_true_test), global_step=epoch)
        writer.add_scalar("test kappa value", kappa_test, global_step=step_n)
        writer.add_scalar("team" + str(0) + "val loss", test_loss, global_step=epoch)
        writer.add_scalar("team" + str(0) + "val correct", test_acc, global_step=epoch)

        ###进行数据结算

        if epoch <= 3:
            best_train_correct = train_acc
            min_train_loss = train_loss
            best_test_correct = test_acc
            min_test_loss = test_loss
            tmp_save_data_dict["训练集max_ACC"] = train_acc
            tmp_save_data_dict["训练集最佳epoch(acc)"] = epoch
            tmp_save_data_dict["训练集min_LOSS"] = train_loss

            tmp_save_data_dict["验证集max_ACC"] = test_acc
            tmp_save_data_dict["验证集min_LOSS"] = test_loss
            tmp_save_data_dict["验证集最佳epoch(acc)"] = epoch

        if best_train_correct < train_acc and epoch > 3:
            tmp_save_data_dict["训练集max_ACC"] = train_acc
            tmp_save_data_dict["训练集最佳epoch(acc)"] = epoch
            torch.save(net.state_dict(), "{}/best_train_correct.pth".format(model_path))
            best_train_correct = train_acc

        if min_train_loss > train_loss and epoch > 3:
            tmp_save_data_dict["训练集min_LOSS"] = train_loss
            torch.save(net.state_dict(), "{}/min_train_loss.pth".format(model_path))
            min_train_loss = train_loss

        if best_test_correct < test_acc and epoch > 3:
            tmp_save_data_dict["验证集max_ACC"] = test_acc
            tmp_save_data_dict["验证集最佳epoch(acc)"] = epoch
            # torch.save(net.state_dict(),
            #            "{}/term_{}epoch_{}loss_{}acc{}.pth".format(model_path, 0, epoch + 1, test_loss,
            #                                                        test_acc))
            torch.save(net.state_dict(), "{}/best_test_correct.pth".format(model_path))
            best_test_correct = test_acc

        if best_test_correct == test_acc and epoch > 3 and epoch % (model_save // 2) == 0:
            pass
            # torch.save(net.state_dict(),
            #            "{}/term_{}epoch_{}loss_{}acc_{}.pth".format(model_path, 0, epoch + 1, test_loss,
            #                                                         test_acc))

        if min_test_loss > test_loss and epoch > 3:
            tmp_save_data_dict["验证集min_LOSS"] = test_loss
            torch.save(net.state_dict(), "{}/min_test_loss.pth".format(model_path))
            min_test_loss = test_loss

        if scheduler != None:
            # scheduler.step(test_loss)
            scheduler.step()

        if epoch % model_save == 0 and epoch != 0:
            torch.save(net.state_dict(), "{}/last_turn.pth".format(model_path))
            # 10次保存一次

        if early_stop(test_acc):
            print("early_stop time = ", epoch + 1)
            tmp_save_data_dict["实际训练epoch"] = epoch + 1

            break

        time_end = time.time()
        epoch_left = epoch_num - epoch
        time_left = datetime.timedelta(seconds=epoch_left * (time.time() - prev_time))
        prev_time = time.time()

        print(
            "[Epoch %d/%d][lr:%.7f] [train_loss:%f] [train_acc:%3.3f%%][train_kappa:%3.3f] [test_loss: %f] [test_acc: %3.3f%%] [tess_kappa:%3.3f][used time:%.3f] [early_stop_cout:%d]ETA: %s"
            % (epoch, epoch_num, optimizer.state_dict()['param_groups'][0]['lr'], train_loss, train_acc, kappa_train,
               test_loss,
               test_acc,
               kappa_test,
               time_end - time_start, early_stop.get_time(),
               time_left
               )
        )
    writer.close()
    print(tmp_save_data_dict)
    return tmp_save_data_dict


if __name__ == '__main__':
    train_list = glob.glob(f"..\\..\\CWT_PIC\\cmor3-3_300^100\\*\\EEG_C3\\*.jpg")  # 实验数据
    data_saving_path = "F:\\train_function_Test"
    experment_lable = "function_test"  # 实验的标签
    tmp_save_data_df = save_data_df
    dataset_dis = {"dataset_name": "BCI competition IV dataset 2b",
                   "trip_lable": "B001",
                   "K": 10,
                   "Ki": 2,

                   }
    for i in range(dataset_dis["K"]):
        dataset_dis["Ki"] = i
        data = train(train_list, SEALNet(), 0.0004, 128, 128, data_saving_path, epoch_num=10, kwargs=dataset_dis,
                     experment_lable=experment_lable)
        tmp_save_data_df = tmp_save_data_df.append(data, ignore_index=True)

    print(tmp_save_data_df)
    excel_path = data_saving_path + "\\excel"
    if not os.path.exists(excel_path):
        os.makedirs(excel_path)
    tmp_save_data_df.to_excel(excel_path + "\\test.xlsx", index=True, sheet_name='Sheet1', header=True)

你可能感兴趣的:(pytorch,算法,深度学习)