DW-GAN训练代码

        最近在看去雾方面的论文时,学习了2021年CVPR去雾赛道冠军DWT,这篇论文引入了一种使用二维离散小波变换的新型去雾网络DW-GAN,使用双分支网络来解决雾度分布复杂和过拟合问题,在 DWT 分支中利用小波变换,在knowledge adaptation分支中使用Res2Net,最后使用基于补丁的判别器来减少恢复图像的伪影。

        作者提供的源码中只有测试代码,,以下给出我自己理解的训练代码train.py

import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from test_dataset import dehaze_test_dataset
from model import fusion_net,Discriminator
from torchvision.utils import save_image as imwrite
import os
import time
import re
from train_dataset import dehaze_train_dataset
from torchvision.models import vgg16
from utils_test import to_psnr, to_ssim_skimage
import torch.nn.functional as F
from perceptual import LossNetwork
from pytorch_msssim import msssim

# --- Parse hyper-parameters train --- #
parser = argparse.ArgumentParser(description='Dehaze Training')
parser.add_argument('--train_dir', type=str, default='./train/')
parser.add_argument('--output_dir', type=str, default='./trained_result/')
parser.add_argument('-train_batch_size', help='Set the training batch size', default=2, type=int)
parser.add_argument('-learning_rate', type=float, default=1e-4)
parser.add_argument('-train_epoch', help='Set the training epoch', default=500, type=int)
parser.add_argument('--model_save_dir', type=str, default='./output_result')

# --- Parse hyper-parameters test --- #
parser.add_argument('--test_dir', type=str, default='./test_image/')
parser.add_argument('-test_batch_size', help='Set the testing batch size', default=1, type=int)
parser.add_argument('--vgg_model', default='', type=str, help='load trained model or not')
args = parser.parse_args()

# --- train --- #
learning_rate = args.learning_rate
train_batch_size = args.train_batch_size
train_epoch = args.train_epoch
train_dir = args.train_dir
output_dir = args.output_dir
train_dataset = dehaze_train_dataset(train_dir)

# --- test --- #
test_dir = args.test_dir
test_dataset = dehaze_test_dataset(test_dir)
test_batch_size = args.test_batch_size

# --- output picture and check point --- #
if not os.path.exists(args.model_save_dir):
    os.makedirs(args.model_save_dir)

# --- Gpu device --- #
#将模型和数据迁移到gpu上
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
net = fusion_net()
DNet = Discriminator()
print('# Discriminator parameters:', sum(param.numel() for param in DNet.parameters()))

# --- Multi-GPU --- #
net = net.to(device)
net = nn.DataParallel(net)
DNet = DNet.to(device)
DNet= nn.DataParallel(DNet)

#net.load_state_dict(torch.load('./weights/dehaze.pkl'))

# --- Build optimizer --- #
G_optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(G_optimizer, milestones=[5000, 7000, 8000], gamma=0.5)
D_optim = torch.optim.Adam(DNet.parameters(), lr=0.0001)
scheduler_D = torch.optim.lr_scheduler.MultiStepLR(D_optim, milestones=[5000,7000,8000], gamma=0.5)

# --- Load training data --- #
train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True)

# --- Load testing data --- #
test_loader = DataLoader(dataset=train_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)


# --- Define the perceptual loss network --- #
#预训练的VGG16作为损失网络来测量感知相似度
vgg_model = vgg16(pretrained=True)
vgg_model = vgg_model.features[:16].to(device)
for param in vgg_model.parameters():
    param.requires_grad = False
loss_network = LossNetwork(vgg_model)
loss_network.eval()

msssim_loss = msssim

# --- Load the network weight --- #
try:
    net.load_state_dict(torch.load(os.path.join(args.teacher_model, 'epoch100000.pkl')))
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')

# --- Strat training --- #
iteration = 0
for epoch in range(train_epoch):
    start_time = time.time()
    scheduler_G.step()
    scheduler_D.step()
    net.train()
    DNet.train()
    print(epoch)
    for batch_idx, (name, hazy, clean) in enumerate(train_loader):
        iteration += 1
        hazy = hazy.to(device)
        frame_out = net(hazy)
        clean = clean.to(device)

        DNet.zero_grad()
        real_out = DNet(clean).mean()
        fake_out = DNet(frame_out).mean()
        D_loss = 1 - real_out + fake_out

        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()

        D_loss.backward(retain_graph=True)
        net.zero_grad()
        adversarial_loss = torch.mean(1 - fake_out)
        smooth_loss_l1 = F.smooth_l1_loss(frame_out, clean)
        perceptual_loss = loss_network(frame_out, clean)
        msssim_loss_ = -msssim_loss(frame_out, clean, normalize=True)
        total_loss = smooth_loss_l1 + 0.01 * perceptual_loss + 0.0005 * adversarial_loss + 0.5 * msssim_loss_

        total_loss.backward()
        D_optim.step()
        G_optimizer.step()

    if epoch % 5 == 0:
        print('we are testing on epoch: ' + str(epoch))
        with torch.no_grad():
            psnr_list = []
            ssim_list = []
            recon_psnr_list = []
            recon_ssim_list = []
            net.eval()
            for batch_idx, (name, hazy, clean) in enumerate(test_loader):
                clean = clean.to(device)
                hazy = hazy.to(device)
                frame_out = net(hazy)
                if not os.path.exists(output_dir + '/'):
                    os.makedirs(output_dir + '/')
                name = re.findall("\d+", str(name))
                #imwrite(frame_out, output_dir + '/' + str(name[0]) + '.png', range=(0, 1))  # 保存图像
                psnr_list.extend(to_psnr(frame_out, clean))
                ssim_list.extend(to_ssim_skimage(frame_out, clean))

            avr_psnr = sum(psnr_list) / len(psnr_list)
            avr_ssim = sum(ssim_list) / len(ssim_list)
            print(epoch, 'dehazed', avr_psnr, avr_ssim)
            frame_debug = torch.cat((frame_out, clean), dim=0)
            torch.save(net.state_dict(), os.path.join(args.model_save_dir, 'epoch' + str(epoch) + '.pkl'))

你可能感兴趣的:(计算机视觉学习笔记,人工智能)