SRGAN代码逐行精读并注释——CSDN第一人

SRGAN代码解读

SRGAN代码逐行解读并注释——CSDN第一人!

本人开创了CSDN的先河,实属CSDN的第一人,之前在CSDN查找过有关SRGAN代码的解读文章,奈何没有一篇文章做到了对SRGAN的完全解读,所以本人打算打破这尴尬的处境,对SRGAN的data_utils.py,model.py,loss.py,train.py四个文件的每行代码逐一精细解读,并做好完美的注释,简直细致的不能在细致,前无古人后无来者!
下面不多说,直接上代码解读。由于本人是第一次作代码解读,难免有错误之处,请大家指正,本人必虚心接收并修改。

model.py

import math
import torch
from torch import nn

# 深度残差网络:两部分【1】:深度残差模型 【2】:子像素卷积模型 【3】:除了深度残差模块和子像素卷积模块以外,在整个模型输入和输出部分均添加了一个卷积模块用于数据调整和增强。
class Generator(nn.Module):# Generator Network
    def __init__(self, scale_factor):
        # 上采样块数,8倍就有3个
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        # 前面的
        # 连接卷积层和激活函数层
        ########## 【1】:深度残差模块  ########## 作用:进行高效的特征提取,可以在一定程度上削弱图像噪点。
        self.block1 = nn.Sequential(
            # 3个通道,64个卷积核,卷积核大小为9,需要扩充
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()# Parametric ReLU
        )
        # 中间的
        # 残差块
        self.block2 = ResidualBlock(64)# 解释k9n64s1是什么意思:每个卷积层对应的核大小(k)、特征映射数(n)和步长(s)
        self.block3 = ResidualBlock(64)# 两个具有小3*3核和64个特征映射的卷积层,然后使用批归一化层和参数化作为激活函数
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            # BN层
            nn.BatchNorm2d(64)
        )
        # 最后的
        # 上采样层
        ########## 【2】:子像素卷积模型  ########## 作用:用来放大图像尺寸
        #说明:这里有两个子像素卷积模块,每个子像素卷积模块使得输入图像放大2倍,因此这个模型最终可以将图像放大4倍
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2

# 判别器,较为简单,VGG在计算损失时使用,这里没有。
class Discriminator(nn.Module):# Discriminator Network
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))

# 残差块的定义
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual

# 用pixelshuffle进行上采样,详情参考ESPCN
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

train.py

import argparse
import os
from math import log10

import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm #进度条

import pytorch_ssim
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator

'''
argparse是一个Python模块:命令行选项、参数和子命令解析器。
主要有三个步骤:
【1】:创建 ArgumentParser() 对象
【2】:调用 add_argument() 方法添加参数
【3】:使用 parse_args() 解析添加的参数
'''

# 【1】:创建解析器  作用:ArgumentParser 对象包含将命令行解析成 Python 数据类型所需的全部信息。
parser = argparse.ArgumentParser(description='Train Super Resolution Models')#训练超分辨模型

# 【2】:添加参数  ''里的是参数名,default参数对应的默认值,type是值的类型,help是参数说明
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')# 设置参数:【训练图像裁剪大小】

parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],help='super resolution upscale factor')
#设置参数:【放大倍数】,通过修改default的值来设置放大倍数,注意由于choices=[2, 4, 8],这里放大倍数只能设置为2或4或8

parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')# 设置参数:【训练轮数】


if __name__ == '__main__':
    # 【3】:解析参数
    opt = parser.parse_args()#获得训练时的参数,训练时有三个参数,三个参数详情如上面所示,可按住Ctrl键然后鼠标移动到parser位置点击进入

    # 提取opt(选项器)中设置的参数,设定为常量
    CROP_SIZE = opt.crop_size# 将设置好的【裁剪大小】参数赋值给CROP_SIZE,方便后续调用
    UPSCALE_FACTOR = opt.upscale_factor# 将设置好的【放大倍数】参数赋值给UPSCALE_FACTOR,方便后续调用
    NUM_EPOCHS = opt.num_epochs# 将设置好的【训练轮数】参数赋值给NUM_EPOCHS,方便后续调用

    # 创建数据集,指定数据集所在路径,裁剪大小和放大因子
    # 训练数据集
    train_set = TrainDatasetFromFolder('./data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    # 验证数据集
    val_set = ValDatasetFromFolder('./data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)

    # 加载数据集,使用loader,从训练集中,一次性处理一个batch的文件(批量加载器)
    '''
        num_workers:int类型,表示用多少个子进程加载数据,进程越多加载速度越快。默认值为0,0表示数据将在主进程中加载
        batch_size:int类型,表示每一次batch记载多个样本,默认值为1
        shuffle:bool类型,设置为True时会在每个epoch重新打乱数据(默认: False).
    '''
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=16, shuffle=True)#原本:batch_size=64,根据GPU性能调整batch_size大小
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

    # 创建生成器实例 netG
    netG = Generator(UPSCALE_FACTOR)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))# 显示netG网络的所有参数和
    '''
    sum(param.numel() for param in netG.parameters())的意思是获得这个网络的所有参数和
    
    parameters()函数:
    作用:返回一个生成器(迭代器),生成器每次生成的是Tensor类型的数据。
    
    numel()函数:
    作用:查看一个张量有多少元素,张量就是Tensor类型的数据
    
    为什么要求网络的所有参数和?是为了写论文?
    '''

    # 创建辨别器实例 netG
    netD = Discriminator()
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))# 显示netD网络的所有参数和

    # 实例化【生成损失函数模型】
    generator_criterion = GeneratorLoss()# criterion是标准的意识

    if torch.cuda.is_available():# 判断cuda是否可用,如果可用,就利用GPU进行训练,注意,只有1、网络模型,2、数据(输入、标注),3、损失函数这三个能利用GPU训练
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    # 构建一个优化器optimizer,传入模型需要优化的参数和学习率(参数必须要写,学习率没有可不写),使用一个参数优化算法进行优化,调用step()可进行一次模型参数优化
    optimizerG = optim.Adam(netG.parameters())
    '''
    以上面代码为例,构建的优化器名字为optimizerG,这里需要优化的参数是模型netG的参数,使用的参数优化算法为Adam
    Adam的特点有:
    1、结合了Adagrad善于处理稀疏梯度和RMSprop善于处理非平稳目标的优点;
    2、对内存需求较小;
    3、为不同的参数计算不同的自适应学习率;
    4、也适用于大多非凸优化-适用于大数据集和高维空间。
    '''
    optimizerD = optim.Adam(netD.parameters())

    # 结果集:loss score psnr(峰值信噪比) ssim(结构相似性)
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}# ??? loss和score是什么

############################前面是把想要的列出来后面才是开始执行,用训练集##############################

    # 开始训练,一次epoch跑一趟训练集
    for epoch in range(1, NUM_EPOCHS + 1):
        '''
        恢复训练相关事项:
        如果想恢复训练,直接修改上面的代码即可,for epoch in range(?,NUM_EPOCHS + 1),将?号处改为想要恢复训练的开始轮数即可
        '''

        # 加载进度条
        train_bar = tqdm(train_loader)
        # tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。

        # 初始化参数,为后面训练方便计算参数
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        # 假如要恢复训练,要加载之前的权重
        # netG.Load_state_dict(torch.load('netG_epoch_4_75.pth'))
        # netD.Load_state_dict(torch.load('netD_epoch_4_75.pth'))

        # 进入训练模式
        netG.train()
        netD.train()
        '''
        train()函数:
        如果模型中有BN层(Batch Normalization)和 Dropout,需要在训练时添加model.train()。
        model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。
        
        什么是Dropout:
        这是一种算法,用于防止过拟合,通过阻止特征检测器的共同作用来提高神经网络的性能。
        
        Dropout可以作为训练深度神经网络的一种trick供选择。在每个训练批次中,通过忽略一半的特征检测器(让一半的隐层节点值为0),可以明显地减少过拟合现象。
        这种方式可以减少特征检测器(隐层节点)间的相互作用,检测器相互作用是指某些检测器依赖其他检测器才能发挥作用。

        Dropout说的简单一点就是:我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作,这样可以使模型泛化性更强,因为它不会太依赖某些局部的特征

        过拟合与欠拟合:        
        '''

        # 训练细节
        for data, target in train_bar:# train_bar:进度条
            g_update_first = True
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size# 计算running_results中的batch_sizes参数
    
            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            # 最大化判别器判别原图(HR)概率,最小化生成图(SR)判别概率
            ###########################

            # HR
            real_img = Variable(target)# target是HR图片,real_img是真实高分辨率图像
            if torch.cuda.is_available():
                real_img = real_img.cuda()
            '''
            Variable()函数:
            autograd.Variable 是包的核心类。它包装了张量,并且支持几乎所有的操作。
            一旦你完成了你的计算, 就可以调用 .backward() 方法 来自动计算所有的梯度。
            你还可以通过 .data 属性来访问原始的张量,而关于该 variable(变量)的梯度会被累计到 .grad上去。
            
            Pytorch都是由Tensor(Tensor 是一个多维矩阵)计算的,而Tensor里面的参数都是Variable的形式。 
            如果用Variable计算的话,那返回的也是一个同类型的Variable。这正好就符合了反向传播,参数更新的属性。

            备注:Tensor不能反向传播,Variable可以反向传播。
            '''

            # LR
            z = Variable(data)# data是LR图片???,z是低分辨率图像
            if torch.cuda.is_available():
                z = z.cuda()

            # SR
            fake_img = netG(z)# 低分辨率图像(z)通过生成网络(netG)生成的虚假高分辨率图像(fake_img)

            '''
            GAN原理
            它由两部分组成
            (1):Generator生成器,它是一个深度神经网络,输入一个低维vector,输出高维vector(图片或文本或语音)
            (2):Discriminator判别器,它也是一个深度神经网络,输入一个高维vector(图片或文本或语音),输出一个标量。标量越大,代表输入图片
                 (或文本语音)越真实。
                 
            关于real_img(真实高分辨率图像)、z(低分辨率图像)、fake_img(虚假高分辨图像)的说明:
            生成器(netG)生成新图片,从而可以骗过判别器(netD)。判别器也在不断迭代进化,努力识别越来越接近真实的假图片(fake_img)。
            通过二者对抗学习,最终生成器生成的假图片(fake_img)越来越像真实图片(real_img),而判别器越来越能区分和真实图片(fake_img)很接近
            的假图片(real_img)。 二者能力在迭代过程中,都可以得到大幅提升。
            '''


            '''
            在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward()和optimizer.step()三个函数
            这三个函数的作用是
            (1):先将梯度清零:optimizer.zero_grad():
            (2):然后反向传播计算得到每个参数的梯度值:loss.backward():
            (3):最后通过梯度下降执行一步参数更新:optimizer.step():

            '''
            # 梯度清零
            netD.zero_grad()
            '''
            为什么Pytorch每一轮batch都需要设置optimizer.zero_grad?
            根据pytorch中backward()函数的计算,当网络参量进行反馈时,梯度是累积计算而不是被替换,
            但在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算,因此需要对每个batch调用一遍zero_grad()将参数梯度置0.。

            损失函数表示预测值与实际值之间的误差。
            
            梯度下降:梯度(导数)下降就是用来求损失函数(误差函数)最小值时对应的自变量取值
            '''

            # 反向传播过程,对辨别器输出的标量取平均值
            real_out = netD(real_img).mean()# mean()函数:求平均值
            fake_out = netD(fake_img).mean()

            # 计算损失
            d_loss = 1 - real_out + fake_out

            # 反向传播计算梯度
            d_loss.backward(retain_graph=True)

            # 进行参数优化
            optimizerD.step()
            '''
            optimizer.step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,
            所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。那么为什么optimizer.step()需要放在每一个batch训练中,
            而不是epoch训练中,这是因为现在的mini-batch训练模式是假定每一个训练集就只有mini-batch这样大,
            因此实际上可以将每一次mini-batch看做是一次训练,一次训练更新一次参数空间,因而optimizer.step()放在这里。
            
            注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。
            '''
    
            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            # 最小化生成网络中SR被认出概率、感知损失(VGG计算)、图像损失(MSE)、平滑损失
            ###########################

            # 梯度损失
            netG.zero_grad()

            ## The two lines below are added to prevent runetime error in Google Colab ##
            ## 翻译:添加下面两行代码是为了防止在谷歌Colab中运行时出错 ##
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            #计算损失及反向传播
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            
            fake_img = netG(z)# 低分辨率图像(z)通过生成网络(netG)生成的虚假高分辨率图像(fake_img)
            fake_out = netD(fake_img).mean()# 对辨别器输出的标量取平均值

            optimizerG.step()

            # loss for current batch before optimization
            # 优化之前的当前批次的损失
            running_results['g_loss'] += g_loss.item() * batch_size# g_loss = generator_criterion(fake_out, fake_img, real_img)
            running_results['d_loss'] += d_loss.item() * batch_size# d_loss = 1 - real_out + fake_out real_img/fake_img通过判别器的差距
            running_results['d_score'] += real_out.item() * batch_size# real_img通过判别器的值
            running_results['g_score'] += fake_out.item() * batch_size# fake_img通过判别器的值
            '''
                item()函数:
                作用:将一个Tensor变量转换为python标量(int float等)常用于用于深度学习训练时,将loss值转换为标量并加,
                以及进行分类任务,计算准确值值时需要
            '''

            # 描述进度和损失函数,得分函数的平均值,%.4f代表输出小数位为4的浮点数
            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))

            # 以训练时弹出的训练窗口里的具体页面显示更能看懂上面代码:例如
            # [1/100] Loss_D: 0.7932 Loss_G: 0.0296 D(x): 0.4958 D(G(z)): 0.2523: 100%|██████████| 50/50 [11:03<00:00, 13.27s/it]


########################################进入eval模式(测试模式参数固定,只有前向传播)用测试集###################################

        # 测试模式,无需更新网络
        netG.eval()
        '''
        如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。
        model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。
        对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

        训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。
        这是model中含有BN层和Dropout所带来的的性质。

        '''
        # 模型保存
        out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
        '''
            str()函数:
            将对象转化为适于人阅读的形式。是一种Python内置函数
            返回值:一个对象的string格式
            str(UPSCALE_FACTOR)的意思是将int类型数据UPSCALE_FACTOR转换为string数据类型,例如UPSCALE_FACTOR为4,
            则str(UPSCALE_FACTOR)变为'4','training_results/SRF_' + str(UPSCALE_FACTOR) + '/'变为
            'training_results/SRF_4/',这是一个路径。
        '''
        # 如果路径不存在则创建一个路径
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        '''
        os.path.exists()函数:
        作用:判断括号里的文件是否存在的意思,括号内的可以是文件路径。
        os.makedirs()函数:
        作用:创建递归的目录树,可以是相对路径或者绝对路径.
        '''

        # 参数计算
        with torch.no_grad():
            # 加载进度条
            val_bar = tqdm(val_loader)# 用验证集进行计算MSE,输出图像

            # 初始化参数
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = [] # 声明一个空列表

            for val_lr, val_hr_restore, val_hr in val_bar:
                batch_size = val_lr.size(0)
                '''
                size()函数:
                作用:是用来统计矩阵元素个数,或矩阵某一维上的元素个数的函数。

                例:.size(a, axis=None)
                   a:输入的矩阵
                   axis:int型的可选参数,指定返回哪一维的元素个数。当没有指定时,返回整个矩阵的元素个数
                   
                   如果传入的参数只有一个,则返回矩阵的元素个数;
                   如果传入的第二个参数是0,则返回矩阵的行数;
                   如果传入的第二个参数是1,则返回矩阵的列数。
                '''

                # 已经测试过的数目
                valing_results['batch_sizes'] += batch_size

                lr = val_lr
                hr = val_hr

                if torch.cuda.is_available():
                    lr = lr.cuda()
                    hr = hr.cuda()

                sr = netG(lr)# 低分辨率图像(lr)通过生成网络(netG)修复生成的高分辨率图像(sr)

                # 计算MSE(均方误差),计算PSNR时需要用到MSE
                batch_mse = ((sr - hr) ** 2).data.mean()
                valing_results['mse'] += batch_mse * batch_size

                # 计算SSIM(结构相似性)
                batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                valing_results['ssims'] += batch_ssim * batch_size

                # 计算PSNR(峰值信噪比)
                valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']

                val_bar.set_description(
                    desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        valing_results['psnr'], valing_results['ssim']))

                # 以训练时弹出的训练窗口里的具体页面显示更能看懂上面代码:例如
                # [converting LR images to SR images] PSNR: 19.0167 dB SSIM: 0.4825: 100%|██████████| 100/100 [1:06:36<00:00, 39.97s/it]

                # 通过extend把三张图连在一起 如果想提高训练速度 下面到 index += 1 可以注释
                val_images.extend(
                    [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                     display_transform()(sr.data.cpu().squeeze(0))])
                '''
                extend()函数:
                extend()向列表尾部追加一个列表,对象必须是一个可以迭代的序列,将列表中的每个元素都追加进来,会在已存在的列表中添加新的列表内容。
                
                .cpu()
                将数据的处理设备从其他设备转到cpu上(如.cuda()拿到cpu上),不会改变变量类型,转换后仍然是Tensor变量。
                
                hr.data
                hr.data返回和hr相同的数据 tensor,这个新的tensor和原来的tensor(即hr)是共用数据的,一者改变,另一者也会跟着改变
                
                squeeze()函数:
                作用是对tensor变量进行维度压缩,去除维数为1的维度。
                squeeze(0)的意思是压缩第0维。
                '''

            # 按行拼接,按列拼接
            val_images = torch.stack(val_images)
            val_images = torch.chunk(val_images, val_images.size(0) // 15)
            '''
            stack()函数:
            沿一个新维度对输入张量序列进行连接,序列中所有张量应为相同形状;
            stack 函数返回的结果会新增一个维度,而stack()函数指定的dim参数,就是新增维度的(下标)位置。
            
            chunk()函数:
            作用:张量分块,返回一个张量列表。
            例:torch.chunk(tensor, chunks, dim=0)
              (1):tensor (Tensor) – the tensor to split
              (2):chunks (int) – number of chunks to return(分割的块数)
              (3):dim (int) – dimension along which to split the tensor(沿着哪个轴分块)
            '''

            val_save_bar = tqdm(val_images, desc='[saving training results]')# 传入str类型,作为进度条标题(类似于说明)
            index = 1
            # 以训练时弹出的训练窗口里的具体页面显示更能看懂上面代码:例如
            # [saving training results]: 100%|██████████| 20/20 [00:21<00:00,  1.09s/it]

            for image in val_save_bar:
                #每一行显示三个图像
                image = utils.make_grid(image, nrow=3, padding=5)
                utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                '''
                保存训练生成过程中生成的图片
                out_path是存放路径,定义在该文件的第292行,这里路径设置为out_path = 'training_results/SRF_4/'
                epoch_%d_index_%d.png是图片名称 ,例如 epoch_1_index_1.png
                '''
                index += 1

################################################ 保存工作 ############################################

        # save model parameters
        # 保存训练生成好的模型,这个模型是用来测试的,
        torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        '''
        state_dict()函数:
        pytorch 中的 state_dict 是一个简单的python的字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),
        将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等),从而为PyTorch模型和优化器增加了大量的模块化

        模型保存在epochs文件夹中,保存的模型名字示例为netD_epoch_4_400.pth或netG_epoch_4_400.pth
        '''

        # save loss\scores\psnr\ssim
        # 保存各项指标
        # append()函数:向列表末尾添加元素
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr'])
        results['ssim'].append(valing_results['ssim'])

        # 10轮保存一次
        if epoch % 10 == 0 and epoch != 0:
            # 各项指标保存路径
            out_path = 'statistics/'

            data_frame = pd.DataFrame(
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
                index=range(1, epoch + 1))
            '''
            DataFrame是Python中Pandas库中的一种数据结构,它类似excel,是一种二维表。
            DataFrame的单元格可以存放数值、字符串等,这和excel表很像,同时DataFrame可以设置列名columns与行名index
            '''

            # .to_csv()是DataFrame类的方法,用来文件内容的导出和添加
            data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')


'''
[1/100] Loss_D: 0.7932 Loss_G: 0.0296 D(x): 0.4958 D(G(z)): 0.2523: 100%|██████████| 50/50 [11:03<00:00, 13.27s/it]
[converting LR images to SR images] PSNR: 19.0167 dB SSIM: 0.4825: 100%|██████████| 100/100 [1:06:36<00:00, 39.97s/it]
[saving training results]: 100%|██████████| 20/20 [00:21<00:00,  1.09s/it]
[2/100] Loss_D: 0.9053 Loss_G: 0.0180 D(x): 0.7315 D(G(z)): 0.6270: 100%|██████████| 50/50 [10:09<00:00, 12.20s/it]
[converting LR images to SR images] PSNR: 19.1935 dB SSIM: 0.5281:  41%|████      | 41/100 [28:27<37:56, 38.59s/it]
'''

loss.py

import torch
from torch import nn
# 可以改成VGG19
from torchvision.models.vgg import vgg16,vgg19


class GeneratorLoss(nn.Module):# 生成损失函数
    def __init__(self):
        super(GeneratorLoss, self).__init__()

        # 下面都说再说VGG,若改成VGG19这里也要改
        vgg = vgg19(pretrained=True)
        '''
        pretrained=True的意思是加载模型中预先训练好的参数
        原因:卷积神经网络的训练是耗时的,很多场合不可能每次都从随机初始化参数开始训练网络。
             pytorch中自带几种常用的深度学习网络预训练模型,如VGG、ResNet等。
             往往为了加快学习的进度,在训练的初期我们直接加载pre-train模型中预先训练好的参数
        '''
        # 用VGG前31层(相当于全部)计算,跟论文有关,具体可以自己数
        # 具体参考 blog.csdn.net/zml194849/article/details/112790683
        # 其中的卷积层要x2,因为包括激活函数层。之后数出来32层(到第一个全连接层)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() #这个对应VGG16,注意[0:31]是0到30层
        # loss_network = nn.Sequential(*list(vgg.features)[:9]).eval() #这个对应VGG19版本的SRGAN-VGG22,[0:9]是0到8层
        # loss_network = nn.Sequential(*list(vgg.features)[:36]).eval() #这个对应VGG19版本的SRGAN-VGG54,[0:36]是0到35层
        '''
        [:31]的意思是取0到30层的特征(注意是从第0层开始,不是从第1层开始),共31层

        对VGG网络进行特征提取,主要是用于论文中的SRGAN-VGG22以及SRGAN-VGG54(注意这里是针对VGG19),如果要使用这个,要将VGG16改成VGG19
        '''

        for param in loss_network.parameters():
            param.requires_grad = False# 屏蔽预训练模型的权重,只训练全连接层的权重

        self.loss_network = loss_network# 用于VGG19的特征提取网络

        self.mse_loss = nn.MSELoss() # MSE损失

        self.tv_loss = TVLoss() #正则化损失

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss 对抗性损失
        adversarial_loss = torch.mean(1 - out_labels)

        # Perception Loss 感知VGG损失
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # 在假高分图像和真高分图像进行MSE均方根误差的计算

        # MSE Loss MSE损失
        MSE_loss = self.mse_loss(out_images, target_images)# 直接使用MSE在假高分图像和真高分图像之间计算

        # TV Loss 正则化损失
        tv_loss = self.tv_loss(out_images)

        return MSE_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
        # return MSE_loss + 0.001 * adversarial_loss + 2e-8 * tv_loss
'''
        使用到SRGAN-VGG22或SRGAN-VGG54时
        用 return MSE_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss 返回
        
        未使用到SRGAN-VGG22或SRGAN-VGG54,只用到SRGAN-MSE时
        用return MSE_loss + 0.001 * adversarial_loss + 2e-8 * tv_loss 返回
'''

class TVLoss(nn.Module):# 正则化损失函数:这种正则化损失倾向于保存图像的光滑性,防止图像出来变得过于像素化。
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()# 继承
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]# 取x的第一个数
        h_x = x.size()[2]# 第三个数
        w_x = x.size()[3]# 最后一个数
        count_h = self.tensor_size(x[:, :, 1:, :])# H-X的最大值就是他
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()# 相减然后再平方,再求和
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size# 为什么要除以这个,而且tv_loss_weight是1

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]


if __name__ == "__main__":
    g_loss = GeneratorLoss()
    print(g_loss)

'''
        #以下是完整vgg16的所有层数打印结果:
        VGG16(
          (features): Sequential(
            (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace)
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (3): ReLU(inplace)
            (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (6): ReLU(inplace)
            (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (8): ReLU(inplace)
            (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (11): ReLU(inplace)
            (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (13): ReLU(inplace)
            (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (15): ReLU(inplace)
            (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (18): ReLU(inplace)
            (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (20): ReLU(inplace)
            (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (22): ReLU(inplace)
            (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (25): ReLU(inplace)
            (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (27): ReLU(inplace)
            (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (29): ReLU(inplace)
            (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          )

        #以下是完整vgg19的所有层数打印结果:
        VGG19(
          (features): Sequential(
            (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU(inplace=True)
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (3): ReLU(inplace=True)
            (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (6): ReLU(inplace=True)
            (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (8): ReLU(inplace=True)--------------------------------------------------------SRGAN-VGG22
            (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (11): ReLU(inplace=True)
            (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (13): ReLU(inplace=True)
            (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (15): ReLU(inplace=True)
            (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (17): ReLU(inplace=True)
            (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (20): ReLU(inplace=True)
            (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (22): ReLU(inplace=True)
            (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (24): ReLU(inplace=True)
            (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (26): ReLU(inplace=True)
            (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (29): ReLU(inplace=True)
            (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (31): ReLU(inplace=True)
            (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (33): ReLU(inplace=True)
            (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (35): ReLU(inplace=True)--------------------------------------------------------SRGAN-VGG54
            (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          )     

'''

data.utils.py

from os import listdir
from os.path import join

from PIL import Image
# torchvsion.transforms - 图像预处理包
# Compose - 把多个步骤整合一起
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
from torchvision.transforms import InterpolationMode

# 通过后缀检查是否为图片文件
def is_image_file(filename):
    # 判断文件后缀
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

'''
    any()函数用法:
    例:any(iterable)
    any() 函数用于判断给定的可迭代参数 iterable 是否全部为 False,则返回 False,如果有一个为 True,则返回 True。
    元素除了是 0、空、FALSE 外都算 TRUE。
    
    for in 语句用法
    说明:循环结构的一种,经常用于遍历字符串、列表,元组,字典等
    例:for x in y:
    执行流程:x依次表示y中的一个元素,遍历完所有元素循环结束。
    
    endswith()函数
    描述:判断字符串是否以指定字符或子字符串结尾。
    语法:str.endswith("suffix", start, end) 或 str[start,end].endswith("suffix")   
    用于判断字符串中某段字符串是否以指定字符或子字符串结尾。
    —> bool    返回值为布尔类型(True,False)

    suffix — 后缀,可以是单个字符,也可以是字符串,还可以是元组("suffix"中的引号要省略,常用于判断文件类型)。
    start —索引字符串的起始位置。
    end — 索引字符串的结束位置。
    str.endswith(suffix)  star默认为0,end默认为字符串的长度len(str)
    
    注意:空字符的情况。返回值通常为True
    
    例:filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']
    参数说明:
    filename: 被检测的字符串
    extension: 指定的字符或者子字符串(可以使用元组,会逐一匹配),这里extension代表是一个元组['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']
    总之:只要文件后缀名为['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']其中的一个,
    any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])就是True值
'''

# 计算有效的裁剪尺寸
def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)#有效裁剪尺寸计算公式

# 训练集高分辨率图预处理函数
def train_hr_transform(crop_size):# 输入有效裁剪尺寸
    return Compose([
        # 随机裁剪
        RandomCrop(crop_size),# 在随机位置上进行尺寸大小为 crop_size*crop_size 的裁剪
        # 变为张量
        ToTensor(),# 将图像转换为张量
    ])
'''
        Compose()函数:
        用法:把多个步骤进行整合
        例:Compose([RandomCrop(crop_size),ToTensor(),])
        把RandomCrop(crop_size)以及ToTensor()这两个步骤整合在一起

        RandomCrop()函数:
        对PIL 或者 Tensor类型图片进行随机裁剪,可以接受sequence或者int两种类型的参数
        如果参数类型是sequence:(H, W)。高H,宽W,则进行尺寸大小为 H*W 的裁剪
        如果参数类型是int: (int*int),则进行尺寸大小为 int*int 的裁剪
        例:
        RandomCrop(crop_size),crop_size是int类型的参数,所以进行尺寸大小为 crop_size*crop_size的裁剪
        
        ToTensor()函数:
        (1):将PIL类型或numpy类型的图片转换为tensor类型的图片
        (2):原始数据的shape是(H x W x C),通过ToTensor()后shape会变为(C x H x W)。
        (3):数据归一化,由[0,255] -> [0,1]。 
'''


# 训练集低分辨率图预处理函数
def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        # 变为图片
        ToPILImage(),# convert a tensor to PIL image

        # 整除下采样
        Resize(crop_size // upscale_factor, interpolation=InterpolationMode.BICUBIC),# 通过双三次插值法对图像进行下采样
        # Resize()函数:将输入图像的大小调整为给定的大小。
        # nterpolation:指定插值的方式,图像缩放之后,肯定像素要进行重新计算的,就靠这个参数来指定重新计算像素的方式,
        # 可按住Ctrl键然后鼠标移动到InterpolationMode位置点击进入查看有哪些插值方式
        # 注意Resize()函数返回的是裁剪后的图片

        ToTensor() # 将图像转换为张量
    ])


def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),# 将图片缩放至400*400大小,注意这里没有Crop,不是裁剪
        CenterCrop(400),# 在中心区域裁剪一个尺寸大小为400*400的图片,注意这里有Crop,是裁剪
        ToTensor()
    ])

# Train Dataset From Folder 翻译:从文件夹获取数据集
class TrainDatasetFromFolder(Dataset):# TrainDatasetFromFolder是定义的类,Dataset是继承的类
    def __init__(self, dataset_dir, crop_size, upscale_factor):# __init__是一个特殊方法用于在创建对象时进行初始化操作
    # 通过这个方法我们可以为TrainDatasetFromFolder对象绑定dataset_dir, crop_size, upscale_factor这三个属性,分别是数据集路径、裁剪大小、放大因子

        super(TrainDatasetFromFolder, self).__init__()
        '''
        super(Son, self).init()是指首先找到Son的父类(比如是类Father),然后把类Son的对象self转换为类Father的对象。
        然后“被转换”的类Father对象调用自己的init函数,其实简单理解就是子类把父类的__init__()放到自己的__init__()当中,
        这样子类就有了父类的__init__()的那些东西。
        例如:以上面的代码为例,
        TrainDatasetFromFolder类继承Dataset,
        super(TrainDatasetFromFolder, self).init()就是对继承自父类Dataset的属性进行初始化。
        而且是用Dataset的初始化方法来初始化继承的属性。
        '''
        # 获取图片列表
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]# 获得所有图像
        '''
        listdir()函数:【将文件夹下的所有东西变成一个列表】
        listdir函数在模块os中,os.listdir(path)用于获取path目录下,所有文件和文件夹的名称,并按照字母顺序存入返回值列表中。
        path可以为空,为空时默认为当前路径;可以返回隐藏文件名和隐藏文件夹名。
        路径可以是str类型或字节类型。如果路径是字节类型,返回的文件名也将是字节类型; 在其他情况下,它们将是str类型。
        例:
        listdir(dataset_dir)的意思是返回指定目录【dataset_dir】下的所有文件名,形成一个列表
        进一步listdir(dataset_dir) if is_image_file(x)的意思是返回指定目录【dataset_dir】下文件后缀名为
        ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']的所有文件名,形成一个列表
        
        join()函数:【把目录和文件名进行拼接得到一个路径】
        用来路径拼接,注意是os.path.join(),不是os.join(),不确定是哪个可以看最上面的包导入from os.path import join
        例:join(dataset_dir, x)
        把路径dataset_dir与文件名x进行拼接,得到一个路径       
        '''

        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)# 获得有效的裁剪尺寸

        # 随机裁剪原图像
        self.hr_transform = train_hr_transform(crop_size)# 高分辨率图预处理函数
        # 将裁剪好的图像处理成低分辨率的图片
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)# 低分辨率图预处理函数

    # 根据索引,迭代的读取路径和标签。因此我们需要有一个路径和标签的‘容器’,返回数据集和标签
    def __getitem__(self, index):# 传入一个索引,然后返回索引相对应图片的高清与低清图片
        # 获取该index的高清图像,同时转化得到低清图像
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))# 随机裁剪获得高清图片
        lr_image = self.lr_transform(hr_image)# 获得低分辨图(注意这里,低清图片是根据高清图片转换而来)
        return lr_image, hr_image

    #返回列表的长度
    def __len__(self):
        return len(self.image_filenames)

# 验证集
class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        # 获取图片列表
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        # 打开图片
        hr_image = Image.open(self.image_filenames[index])# 原始图片为高清图

        w, h = hr_image.size# 获取图片的长和宽

        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)# 获取有效的裁剪尺寸大小

        # 定义两种图片裁剪方式,一种针对高清图片,另一种针对低清图片,都是通过双三次插值法对图像进行下采样,
        # 不过两者裁剪尺寸不一样,高清图片裁剪尺寸为crop_size // self.upscale_factor,低清图片裁剪尺寸为crop_size
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=InterpolationMode.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=InterpolationMode.BICUBIC)

        hr_image = CenterCrop(crop_size)(hr_image)# 对图片hr_image的中间区域进行裁剪
        lr_image = lr_scale(hr_image)# 对图片hr_image使用lr_scale方式进行裁剪
        hr_restore_img = hr_scale(lr_image)# 对图片lr_image使用hr_scale方式进行裁剪

        # 注意返回的是图片hr_image、lr_image、hr_restore_img的张量形式
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)

# 测试集
class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TestDatasetFromFolder, self).__init__()
        # 有hr lr两个文件目录
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
        self.upscale_factor = upscale_factor
        self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
        self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]

    def __getitem__(self, index):#为了支持下标操作,既索引dataset[index]
        # 获取hr lr 图像
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index])
        w, h = lr_image.size
        hr_image = Image.open(self.hr_filenames[index])
        hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=InterpolationMode.BICUBIC)
        hr_restore_img = hr_scale(lr_image)
        return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):# 为了知道数据集中一共有多少样本 为了使用len(dataset)
        return len(self.lr_filenames)

你可能感兴趣的:(人工智能,生成对抗网络,超分辨率重建)