风格迁移篇--- stargan代码详解以及论文解读翻译

这里写目录标题

  • 引用
  • 说明
  • 一、data_loader.py文件
    • class CelebA
      • def preprocess(self)
      • def __getitem__()
      • def __len__(self)
      • def get_loader()
  • 二、logger.py文件
    • class Logger(object)
      • def __init__(self, log_dir)
      • def scalar_summary()
  • 三、main.py
  • 四、module.py文件
    • 残差块
    • 生成器
    • 鉴别器
  • 五、solver.py文件
    • 定义一个类class Solver(object)
      • def __init__(self, celeba_loader, rafd_loader, config):
      • def build_model(self)
      • def print_network()
      • def restore_model(self, resume_iters)
      • def build_tensorboard(self):
      • def update_lr(self, g_lr, d_lr):
      • def reset_grad(self):
      • def denorm(self, x):
      • def gradient_penalty(self, y, x):
      • def label2onehot(self, labels, dim)
      • def create_labels()
      • def classification_loss(self, logit, target, dataset='CelebA')
      • def train(self)
        • 1. Preprocess input data预处理输入数据
        • 2. Train the discriminator训练判别器
        • 3. Train the generator训练生成器
        • 4. Miscellaneous其他的
      • def train_multi(self):
        • 1. Preprocess input data预处理输入数据
        • 2. Train the discriminator
        • 3. Train the generator
        • 4. Miscellaneous
      • def test(self)
      • def test_multi(self)

引用

论文地址:https://ieeexplore.ieee.org/document/8579014/
论文代码(pytorch):https://github.com/yunjey/stargan
论文翻译与解读:https://blog.csdn.net/m0_61985580/article/details/125766783?spm=1001.2014.3001.5501

说明

1、此文仅作为学习笔记,注释可能会有一些偏差,如有注释错误欢迎留言更正。
2、如有使用此文,请标注出处。

一、data_loader.py文件

class CelebA

首先是定义一个CelebA的一个类。
在CelebA里面包含def preprocess(self)、def getitem(self, index)、def len(self)、
def get_loader()等函数

def preprocess(self)

    def preprocess(self):
        """Preprocess the CelebA attribute file.预处理 CelebA 属性文件"""
        # Python rstrip() 删除 string 字符串末尾的指定字符(默认为空格)
        lines = [line.rstrip() for line in open(self.attr_path, 'r')]  # 去掉路劲中的空格换行等 # txt文件是一行一行读取
        all_attr_names = lines[1].split()   # splot()通过指定分隔符对字符串进行切片
        # str.split(str="", num=string.count(str)). 通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则分隔 num+1 个子字符串
        # str -- 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等。num -- 分割次数。默认为 -1, 即分隔所有
        # 返回分割后的字符串列表。
        for i, attr_name in enumerate(all_attr_names):
            self.attr2idx[attr_name] = i           # 属性类别
            self.idx2attr[i] = attr_name            # 类别到属性

        lines = lines[2:]
        random.seed(1234)
        random.shuffle(lines)                # 打乱切片
        for i, line in enumerate(lines):
            split = line.split()
            filename = split[0]                # 图片名
            values = split[1:]                 # 图片队形的标签

            label = []
            for attr_name in self.selected_attrs:    # 创建训练选用的任务类别和索引的一一对应
                idx = self.attr2idx[attr_name]       # 得到索引
                label.append(values[idx] == '1')     # label如果是1则还是1,为-1是换成0

            if (i+1) < 2000:                       # 取2000张作为测试集数据
                self.test_dataset.append([filename, label])  # 把名和标签放进test_dataset
            else:
                self.train_dataset.append([filename, label]) # 把名和标签放进train_dataset

        print('Finished preprocessing the CelebA dataset...')

def getitem()

    def __getitem__(self, index):
        """Return one image and its corresponding attribute label.返回一张图片及其对应的属性标签"""
        dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
        filename, label = dataset[index]
        image = Image.open(os.path.join(self.image_dir, filename))
        return self.transform(image), torch.FloatTensor(label)

def len(self)

    def __len__(self):
        """Return the number of images."""
        return self.num_images

def get_loader()

构建并返回数据加载器,对数据进行水平翻转,裁剪更改图片大小等操作

def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, 
               batch_size=16, dataset='CelebA', mode='train', num_workers=1):
    """Build and return a data loader."""
    transform = []
    if mode == 'train':
        transform.append(T.RandomHorizontalFlip())    # 数据随机水平翻转
    transform.append(T.CenterCrop(crop_size))         # 从中间裁剪
    transform.append(T.Resize(image_size))            # 更改图片大小
    transform.append(T.ToTensor())
    transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))      # 正则化
    transform = T.Compose(transform)

    if dataset == 'CelebA':       # 选择CelebA或者是RaFD
        dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode)
    elif dataset == 'RaFD':
        dataset = ImageFolder(image_dir, transform)

    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=(mode=='train'),
                                  num_workers=num_workers)
    return data_loader

二、logger.py文件

这个主要是用来加载TensorBord

class Logger(object)

object为加载对象

def init(self, log_dir)

    def __init__(self, log_dir):
        """Initialize summary writer."""
        self.writer = tf.summary.FileWriter(log_dir)

def scalar_summary()

    def scalar_summary(self, tag, value, step):
        """Add scalar summary."""
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
        self.writer.    add_summary(summary, step)

三、main.py

这个主要是用来调参,设置配置参数

def str2bool(v):
    return v.lower() in ('true')

def main(config):
    # For fast training.用于快速训练。
    cudnn.benchmark = True

    # Create directories if not exist.如果不存在则创建目录。
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)
    if not os.path.exists(config.sample_dir):
        os.makedirs(config.sample_dir)
    if not os.path.exists(config.result_dir):
        os.makedirs(config.result_dir)

    # Data loader.数据加载器。
    celeba_loader = None
    rafd_loader = None

    if config.dataset in ['CelebA', 'Both']:
        celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs,
                                   config.celeba_crop_size, config.image_size, config.batch_size,
                                   'CelebA', config.mode, config.num_workers)
    if config.dataset in ['RaFD', 'Both']:
        rafd_loader = get_loader(config.rafd_image_dir, None, None,
                                 config.rafd_crop_size, config.image_size, config.batch_size,
                                 'RaFD', config.mode, config.num_workers)

    # Solver for training and testing StarGAN.用于训练和测试 StarGAN 的求解器

    solver = Solver(celeba_loader, rafd_loader, config)

    if config.mode == 'train':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.train()
        elif config.dataset in ['Both']:
            solver.train_multi()
    elif config.mode == 'test':
        if config.dataset in ['CelebA', 'RaFD']:
            solver.test()
        elif config.dataset in ['Both']:
            solver.test_multi()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Model configuration.模型配置参数
    parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)')
    parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)')
    parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset')
    parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset')
    parser.add_argument('--image_size', type=int, default=128, help='image resolution')
    parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G')
    parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D')
    parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G')
    parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D')
    parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss')
    parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss')
    parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty')
    
    # Training configuration.训练的配置参数
    parser.add_argument('--dataset', type=str, default='CelebA', choices=['CelebA', 'RaFD', 'Both'])
    parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size')
    parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D')
    parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr')
    parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G')
    parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D')
    parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer')
    parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
    parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step')
    # 表示读取的命令行参数的个数, ‘+’表示读取一个或多个, ‘*’表示0个或多个
    parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset',
                        default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])

    # Test configuration.测试配置参数
    parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step')

    # Miscellaneous.
    parser.add_argument('--num_workers', type=int, default=1)
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'])
    parser.add_argument('--use_tensorboard', type=str2bool, default=True)

    # Directories.其他的
    parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images')
    parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt')
    parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train')
    parser.add_argument('--log_dir', type=str, default='stargan/logs')
    parser.add_argument('--model_save_dir', type=str, default='stargan/models')
    parser.add_argument('--sample_dir', type=str, default='stargan/samples')
    parser.add_argument('--result_dir', type=str, default='stargan/results')

    # Step size.
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=1000)
    parser.add_argument('--model_save_step', type=int, default=10000)
    parser.add_argument('--lr_update_step', type=int, default=1000)

    config = parser.parse_args()
    print(config)
    main(config)

四、module.py文件

这个文件主要是生成器与鉴别器的网络结构以及两者的具体参数

残差块

这个是残差块的定义

class ResidualBlock(nn.Module):
    """Residual Block with instance normalization."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True))

    def forward(self, x):
        return x + self.main(x)

生成器

生成器用的是cycleGAN里面的生成器参数。

class Generator(nn.Module):
    """Generator network."""
    def __init__(self, conv_dim=64, c_dim=5, repeat_num=6):
        super(Generator, self).__init__()
        """第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度"""
        layers = []
        layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
        layers.append(nn.ReLU(inplace=True))

        # Down-sampling layers.
        curr_dim = conv_dim
        for i in range(2):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim * 2

        # Bottleneck layers.
        for i in range(repeat_num):
            layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))

        # Up-sampling layers.
        for i in range(2):
            layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
            layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
            layers.append(nn.ReLU(inplace=True))
            curr_dim = curr_dim // 2

        layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
        layers.append(nn.Tanh())
        self.main = nn.Sequential(*layers)

    def forward(self, x, c):
        # Replicate spatially and concatenate domain information.在空间上复制并连接域信息
        # Note that this type of label conditioning does not work at all if we use reflection padding in Conv2d.
        # This is because instance normalization ignores the shifting (or bias) effect.
        # 在generator的forward时, 把c扩展到四个维度(记为c_expand),3 4维度值和x一样
        c = c.view(c.size(0), c.size(1), 1, 1)         # view 相当于Numpy中的reshape
        c = c.repeat(1, 1, x.size(2), x.size(3))       # 沿着指定的维度重复tensor
        x = torch.cat([x, c], dim=1)                   # 将输入图像x,label向量c,串联
        return self.main(x)

鉴别器


class Discriminator(nn.Module):
    """Discriminator network with PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()
        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01))
            curr_dim = curr_dim * 2

        kernel_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)     # 将层加入到神经网络
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)  # D判读图像的真假
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False)         # 判别输入图像的label.
        
    def forward(self, x):
        h = self.main(x)           # 这里的X表示训练时的图像,经过main()后生成2048维数据
        out_src = self.conv1(h)    # out_src 表示图像的真假
        out_cls = self.conv2(h)    # out_cls 表示图像的标签
        return out_src, out_cls.view(out_cls.size(0), out_cls.size(1))

五、solver.py文件

定义一个类class Solver(object)

“”“Solver for training and testing StarGAN.”“”

def init(self, celeba_loader, rafd_loader, config):

对这些配置参数进行实例化,具体参数可以看main.py

    def __init__(self, celeba_loader, rafd_loader, config):
        """Initialize configurations."""

        # Data loader.数据加载
        self.celeba_loader = celeba_loader
        self.rafd_loader = rafd_loader

        # Model configurations.模型配置
        self.c_dim = config.c_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.训练配置
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters
        self.selected_attrs = config.selected_attrs

        # Test configurations.测试配置
        self.test_iters = config.test_iters

        # Miscellaneous.其他的
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir

        # Step size.
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

def build_model(self)

    def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 
        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')       # 把生成器打印到屏幕上
        self.print_network(self.D, 'D')       # 把判别器打印到屏幕上
            
        self.G.to(self.device)
        self.D.to(self.device)

def print_network()

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

def restore_model(self, resume_iters)

    def restore_model(self, resume_iters):
        """恢复Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
        #  If :attr:`strict` is ``True``,那么state_dict 的键必须与返回的键完全匹配。
        #  将参数和缓冲区从 :attr:`state_dict` 复制到这个模块及其后代

def build_tensorboard(self):

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

def update_lr(self, g_lr, d_lr):

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator.生成器和判别器的衰减学习率"""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

def reset_grad(self):

    def reset_grad(self):
        """Reset the gradient buffers.重置梯度缓冲区"""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

def denorm(self, x):


    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1].将范围从 [-1, 1] 转换为 [0, 1]。"""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

def gradient_penalty(self, y, x):

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

def label2onehot(self, labels, dim)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors.将标签索引转换为one-hot向量"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1      # long() 函数将数字或字符串转换为一个长整型。
        return out

def create_labels()

    def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.获取头发颜色指数
        if dataset == 'CelebA':
            hair_color_indices = []
            for i, attr_name in enumerate(selected_attrs):
                if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                    hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            if dataset == 'CelebA':
                c_trg = c_org.clone()
                if i in hair_color_indices:    # 将一种头发颜色设置为 1,其余设置为 0。
                    c_trg[:, i] = 1
                    for j in hair_color_indices:
                        if j != i:
                            c_trg[:, j] = 0
                else:
                    c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.反转属性值
            elif dataset == 'RaFD':
                c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

def classification_loss(self, logit, target, dataset=‘CelebA’)

    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""      # 分类loss并不都是交叉熵损失
        if dataset == 'CelebA':                                  # CelebA的标签是多属性的,不是一个onehot,所以使用了一个多个二分类的形式
            return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
        elif dataset == 'RaFD':                                  # RaFD则是一个onehot
            return F.cross_entropy(logit, target)

def train(self)

    def train(self):
        """Train StarGAN within a single dataset.在单个数据集中训练 StarGAN。"""
        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)    # x_fixed表示图像像素值  c_org表示真实标签值tensor([[ 1.,  0.,  0.,  1.,  1.]])
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
    # print(c_fixed_list)
    # [tensor([[ 1.,  0.,  0.,  1.,  1.]]), tensor([[ 0.,  1.,  0.,  1.,  1.]]), tensor([[ 0.,  0.,  1.,  1.,  1.]]),
    # tensor([[ 1.,  0.,  0.,  0.,  1.]]), tensor([[ 1.,  0.,  0.,  1.,  0.]])]

        # Learning rate cache for decaying.
        # Learning rate cache for decaying.
        g_lr = self.g_lr         # 生成器的学习率
        d_lr = self.d_lr         # 鉴别器的学习率

        # Start training from scratch or resume training.从头开始训练或恢复训练
        start_iters = 0
        if self.resume_iters:   # 参数resume_iters设置为none
            start_iters = self.resume_iters     # 可以不连续训练,从之前训练好后的结果处开始
            self.restore_model(self.resume_iters)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

1. Preprocess input data预处理输入数据

            # =================================================================================== #
            #                             1. Preprocess input data预处理输入数据                     #
            # =================================================================================== #

            # Fetch real images and labels.获取真实图像和标签
            try:
                x_real, label_org = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_org = next(data_iter)

            # Generate target domain labels randomly.随机生成目标域标签
            rand_idx = torch.randperm(label_org.size(0))   # tensor([ 0])
            label_trg = label_org[rand_idx]                # tensor([[ 1.,  0.,  0.,  1.,  1.]]) 真实label,从数据中取出

            if self.dataset == 'CelebA':
                c_org = label_org.clone()
                c_trg = label_trg.clone()
            elif self.dataset == 'RaFD':
                c_org = self.label2onehot(label_org, self.c_dim)
                c_trg = self.label2onehot(label_trg, self.c_dim)

            x_real = x_real.to(self.device)           # Input images.输入图像
            c_org = c_org.to(self.device)             # Original domain labels.原始域标签
            # print(c_org) tensor([[ 1.,  0.,  0.,  1.,  1.]]
            c_trg = c_trg.to(self.device)             # Target domain labels.目标域标签
            # print(c_trg) tensor([[ 1.,  0.,  0.,  1.,  1.]]
            label_org = label_org.to(self.device)     # Labels for computing classification loss.计算分类损失的标签
            label_trg = label_trg.to(self.device)     # Labels for computing classification loss.计算分类损失的标签

2. Train the discriminator训练判别器

            # =================================================================================== #
            #                             2. Train the discriminator训练判别器                      #
            # =================================================================================== #
            # 判别器以一个batch(16)的真实图片为输入,输出out_src[16, 1, 2, 2]和用来判断图片真假的out_cls[16, 5],得到图片的标签估计。
            # Compute loss with real images.用真实图像计算损失
            out_src, out_cls = self.D(x_real)    # out_src 表示图像的真假  # out_cls 表示图像的标签
            d_loss_real = - torch.mean(out_src)   # 判定越接近为真,损失越小  # d_loss_real最小,那么 out_src 最大==1 (针对图像)
            # d_loss_real = tensor(1.00000e-04 * 3.8965)
            d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset)  # 衡量真实标签与标签估计
            # d_loss_cls = tensor(3.4666)
            # Compute loss with fake images.用假图像计算损失
            # 将真实图像输入x_real和假的标签c_trg输入生成网络,得到生成图像x_fake,
            x_fake = self.G(x_real, c_trg)         # 输入一个batch的真实图片和目标标签,生成假的图
            out_src, out_cls = self.D(x_fake.detach())    # 梯度截断//
            d_loss_fake = torch.mean(out_src)      # 判定越接近为假,损失越小 # tensor(1.00000e-05 *-1.0045)

            """
                        out_src
                        tensor(1.00000e-03 *
                        [[[[-1.5289,  0.8110],
                           [ 0.2153,  0.4624]]]])
                        out_cls
                        tensor(1.00000e-03 *
                           [[ 1.4681,  1.9497,  1.2743, -1.1915,  0.7609]])
                        """

            # Compute loss for gradient penalty.计算梯度惩罚的损失
            # 计算梯度惩罚因子alpha,根据alpha结合x_real,x_fake,输入判别网络,计算梯度,得到梯度损失函数,
            # alpha是一个随机数 tensor([[[[ 0.7610]]]])
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
            # x_hat是一个图像大小的张量数据,随着alpha的改变而变化
            out_src, _ = self.D(x_hat)       # x_hat 表示梯度惩罚因子
            d_loss_gp = self.gradient_penalty(out_src, x_hat)     # 最终d_loss_gp 在0.99540.9956 波动

            # Backward and optimize.向后并优化
            # 损失包含4项:
            # 1.真实图像判定为真
            # 2.真实图像+错误标签记过G网络生成的图像判定为假
            # 3.真实图像经过D网络的生成的标签与真实标签之间的差异损失
            # 4.真实图像和 真实图像+错误标签记过G网络生成的图像 融合的梯度惩罚因子
            d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
            self.reset_grad()
            d_loss.backward()
            self.d_optimizer.step()

            # Logging.记录
            loss = {}
            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_cls'] = d_loss_cls.item()
            loss['D/loss_gp'] = d_loss_gp.item()

3. Train the generator训练生成器

            # =================================================================================== #
            #                               3. Train the generator训练生成器                                #
            # =================================================================================== #
            # 生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建)
            if (i+1) % self.n_critic == 0:   # 每更新5次判别器再更新一次生成器
                # Original-to-target domain.原始到目标域
                x_fake = self.G(x_real, c_trg)    # 输入一个batch的真实图片和目标标签,生成假的图片
                out_src, out_cls = self.D(x_fake)  # 得到假图的判别概率和估计标签
                g_loss_fake = - torch.mean(out_src)  # 估计标签越接近为真,损失越小。#这里是对抗损失,希望生成的假图像为1
                g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset)   # 估计越标签越接近目标标签,损失越小

                # Target-to-original domain.目标到原始域
                x_reconst = self.G(x_fake, c_org)    # 输入假图和原始标签,重建假图对应的原图
                g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))   # 重建损失--得到的重建图越像原图,损失越小

                # Backward and optimize.向后并优化
                g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls  # 计算生成器的损失值
                self.reset_grad()    # 梯度清零
                g_loss.backward()    # 将损失值返回
                self.g_optimizer.step()   # 优化

                # Logging.记录
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_rec'] = g_loss_rec.item()
                loss['G/loss_cls'] = g_loss_cls.item()

4. Miscellaneous其他的

           # =================================================================================== #
            #                                 4. Miscellaneous其他的                               #
            # =================================================================================== #

            # Print out training information.打印训练信息
            if (i+1) % self.log_step == 0:    # 每10次更新一次
                et = time.time() - start_time   # 所需的时间
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)

                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging.翻译固定图像以进行调试,用来存效果图的代码
            if (i+1) % self.sample_step == 0:     # 1000
                with torch.no_grad():             # x_fixed表示图像像素值
                    x_fake_list = [x_fixed]       # x_fixed放到x_fake_list里面
                    for c_fixed in c_fixed_list:    # 遍历c_fixed_list
                        x_fake_list.append(self.G(x_fixed, c_fixed))  # c_fixed标签
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.保存模型checkpoints
            if (i+1) % self.model_save_step == 0:           # 迭代10000保存1次
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))    # 迭代10000保存一次G_path权重
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))    # 迭代10000保存一次D_path权重
                torch.save(self.G.state_dict(), G_path)           # 保存G_path
                torch.save(self.D.state_dict(), D_path)           # 保存D_path
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.衰减学习率
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))   # num_iters_decay——衰减 lr 的迭代次数
                d_lr -= (self.d_lr / float(self.num_iters_decay))   # num_iters___训练D的总迭代次数
                self.update_lr(g_lr, d_lr)                          # 更新学习率
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))  # 将学习率打印

def train_multi(self):

“”“Train StarGAN with multiple datasets.使用多个数据集训练 StarGAN”“”

    def train_multi(self):
        """Train StarGAN with multiple datasets.使用多个数据集训练 StarGAN"""
        # Data iterators.数据迭代器。
        celeba_iter = iter(self.celeba_loader)    # celeba数据集迭代
        rafd_iter = iter(self.rafd_loader)        # rafd数据集迭代

        # Fetch fixed inputs for debugging.获取固定输入以进行调试
        x_fixed, c_org = next(celeba_iter)     # next() 返回迭代器的下一个项目。
        x_fixed = x_fixed.to(self.device)
        c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)    # celeba的标签列表
        c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')          # rafd的标签列表
        zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device)   # Zero vector for CelebA. CelebA 的零向量
        zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device)    # Zero vector for RaFD.    RaFD 的零向量
        mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device)  # Mask vector: [1, 0].
        mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device)     # Mask vector: [0, 1].

        # Learning rate cache for decaying.用于衰减的学习率缓存
        g_lr = self.g_lr
        d_lr = self.d_lr

        # Start training from scratch or resume training.从头开始训练或恢复训练
        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        # Start training.开始训练
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            for dataset in ['CelebA', 'RaFD']:

1. Preprocess input data预处理输入数据

               # =================================================================================== #
                #                             1. Preprocess input data预处理输入数据                     #
                # =================================================================================== #
                
                # Fetch real images and labels.获取真实图像和标签。
                data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter
                
                try:
                    x_real, label_org = next(data_iter)
                except:
                    if dataset == 'CelebA':
                        celeba_iter = iter(self.celeba_loader)
                        x_real, label_org = next(celeba_iter)  # celeba 数据迭代  x_real是真实图像,label_org是对应的标签
                    elif dataset == 'RaFD':
                        rafd_iter = iter(self.rafd_loader)
                        x_real, label_org = next(rafd_iter)    # rafd 数据迭代  x_real是真实图像,label_org是对应的标签

                # Generate target domain labels randomly.随机生成目标域标签。
                rand_idx = torch.randperm(label_org.size(0))   # torch.randperm(n):将0~n-1(包括0和n-1)随机打乱后获得的数字序列
                label_trg = label_org[rand_idx]                # 目标域标签
                # 标签追加一个mask
                # 在多数据集训练时,我们需要mask向量,mask向量的形成按如下形式进行拼接,前面是celebA的label后面是RaFD的label,最后是onehot,代表了哪个数据集的标签是已知的。
                if dataset == 'CelebA':
                    c_org = label_org.clone()     # 将label_org复制一份给c_org
                    c_trg = label_trg.clone()     # 将目标域标签label_trg复制一份给c_trg
                    zero = torch.zeros(x_real.size(0), self.c2_dim)
                    mask = self.label2onehot(torch.zeros(x_real.size(0)), 2)
                    c_org = torch.cat([c_org, zero, mask], dim=1)
                    c_trg = torch.cat([c_trg, zero, mask], dim=1)
                elif dataset == 'RaFD':
                    c_org = self.label2onehot(label_org, self.c2_dim)
                    c_trg = self.label2onehot(label_trg, self.c2_dim)
                    zero = torch.zeros(x_real.size(0), self.c_dim)
                    mask = self.label2onehot(torch.ones(x_real.size(0)), 2)
                    c_org = torch.cat([zero, c_org, mask], dim=1)
                    c_trg = torch.cat([zero, c_trg, mask], dim=1)

                x_real = x_real.to(self.device)             # Input images.输入图像
                c_org = c_org.to(self.device)               # Original domain labels.原始域标签
                c_trg = c_trg.to(self.device)               # Target domain labels.目标域标签
                label_org = label_org.to(self.device)  # Labels for computing classification loss.计算分类损失的标签
                label_trg = label_trg.to(self.device)  # Labels for computing classification loss.计算分类损失的标签

2. Train the discriminator

                # =================================================================================== #
                #                             2. Train the discriminator                              #
                # =================================================================================== #

                # Compute loss with real images.用真图像计算损失
                out_src, out_cls = self.D(x_real)
                out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]  # 这行不太懂
                d_loss_real = - torch.mean(out_src)
                d_loss_cls = self.classification_loss(out_cls, label_org, dataset)

                # Compute loss with fake images.用假图像计算损失
                x_fake = self.G(x_real, c_trg)          # 将真实图像和目标域标签传入G生成X_fake
                out_src, _ = self.D(x_fake.detach())    # 梯度截断
                d_loss_fake = torch.mean(out_src)       # 均值

                # Compute loss for gradient penalty.计算梯度惩罚的损失。
                alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)   # alpha是一个随机数
                x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True)
                # x_hat是一个图像大小的张量数据,随着alpha的改变而变化
                out_src, _ = self.D(x_hat)   # x_hat 表示梯度惩罚因子
                d_loss_gp = self.gradient_penalty(out_src, x_hat)

                # Backward and optimize.
                # 损失包含4项:
                # 1.真实图像判定为真
                # 2.真实图像+错误标签记过G网络生成的图像判定为假
                # 3.真实图像经过D网络的生成的标签与真实标签之间的差异损失
                # 4.真实图像和 真实图像+错误标签记过G网络生成的图像 融合的梯度惩罚因子
                d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging.
                loss = {}
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_cls'] = d_loss_cls.item()
                loss['D/loss_gp'] = d_loss_gp.item()

3. Train the generator

                # =================================================================================== #
                #                               3. Train the generator                                #
                # =================================================================================== #
                # 生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像(重建)
                if (i+1) % self.n_critic == 0:      # 每更新5次判别器再更新一次生成器
                    # Original-to-target domain.原始到目标域
                    x_fake = self.G(x_real, c_trg)      # 输入一个batch的真实图片和目标标签,生成假的图片
                    out_src, out_cls = self.D(x_fake)   # 得到假图的判别概率和估计标签
                    out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:]
                    g_loss_fake = - torch.mean(out_src)  # 估计标签越接近为真,损失越小。#这里是对抗损失,希望生成的假图像为1
                    g_loss_cls = self.classification_loss(out_cls, label_trg, dataset)

                    # Target-to-original domain.
                    x_reconst = self.G(x_fake, c_org)    # 输入假图和原始标签,重建假图对应的原图
                    g_loss_rec = torch.mean(torch.abs(x_real - x_reconst))     # 重建损失--得到的重建图越像原图,损失越小

                    # Backward and optimize.
                    g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
                    self.reset_grad()
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Logging.
                    loss['G/loss_fake'] = g_loss_fake.item()
                    loss['G/loss_rec'] = g_loss_rec.item()
                    loss['G/loss_cls'] = g_loss_cls.item()

4. Miscellaneous


                # =================================================================================== #
                #                                 4. Miscellaneous                                    #
                # =================================================================================== #

                # Print out training info.打印训练信息
                if (i+1) % self.log_step == 0:
                    et = time.time() - start_time
                    et = str(datetime.timedelta(seconds=et))[:-7]
                    log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset)
                    for tag, value in loss.items():
                        log += ", {}: {:.4f}".format(tag, value)
                    print(log)

                    if self.use_tensorboard:
                        for tag, value in loss.items():
                            self.logger.scalar_summary(tag, value, i+1)

            # Translate fixed images for debugging.翻译固定图像以进行调试
            if (i+1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_celeba_list:   # 遍历celeba标签列表
                        c_trg = torch.cat([c_fixed, zero_rafd, mask_celeba], dim=1)
                        x_fake_list.append(self.G(x_fixed, c_trg))
                    for c_fixed in c_rafd_list:
                        c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1)
                        x_fake_list.append(self.G(x_fixed, c_trg))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Save model checkpoints.  保存模型权重
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay learning rates.学习率的衰减
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

def test(self)

    def test(self):
        """Translate images using StarGAN trained on a single dataset."""   # 使用在单个数据集上训练的 StarGAN 翻译图像。
        # Load the trained generator.加载训练好的生成器
        self.restore_model(self.test_iters)
        
        # Set data loader.数据加载
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader
        
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.准备输入图像和目标域标签
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)  # 目标标签

                # Translate images.翻译图像
                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G(x_real, c_trg))   # 生成的假图像存在x_fake_list

                # Save the translated images.保存翻译的图像
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))

def test_multi(self)

    def test_multi(self):
        """Translate images using StarGAN trained on multiple datasets.使用在多个数据集上训练的 StarGAN 翻译图像"""
        # Load the trained generator.加载训练好的生成器
        self.restore_model(self.test_iters)
        
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(self.celeba_loader):

                # Prepare input images and target domain labels.准备输入图像和目标域标签
                x_real = x_real.to(self.device)
                c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
                c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
                zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device)           # Zero vector for CelebA
                zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device)             # Zero vector for RaFD.
                mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device)  # Mask vector: [1, 0].
                mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device)     # Mask vector: [0, 1].

                # Translate images.
                x_fake_list = [x_real]
                for c_celeba in c_celeba_list:
                    c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))
                for c_rafd in c_rafd_list:
                    c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))

你可能感兴趣的:(GAN,深度学习,pytorch,人工智能)