vits复现gituhb项目--模型训练

在完成VITS论文学习后,对github上的官方仓库进行学习,帮助理解算法实现过程中的一些细节;仓库代码基于pytorch实现,链接为https://github.com/jaywalnut310/vits。论文和代码中都针对单speaker的数据集LJSpeech和多speaker的数据集VCTK进行了训练,本笔记主要针对多speaker设置下的训练代码进行注释解析,主要涉及仓库项目中的train_ms.py文件。

train_ms.py

VITS训练时,使用了混合精度训练,并且设置了对抗训练模式;其中判别器使用了多周期判别器,由多个子判别器组成,并且生成过程损失中还加上了feature_map损失。训练过程中,不是对完整的音频文件进行训练,而是提取一部分音频数据进行训练,进而在计算损失时,也要从ground truth中提取对应部分的数值进行计算。具体的训练代码及注释如下:

import os
import json
import argparse
import itertools
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler

import commons
import utils
from data_utils import (
    TextAudioSpeakerLoader,
    TextAudioSpeakerCollate,
    DistributedBucketSampler
)
from models import (
    SynthesizerTrn,
    MultiPeriodDiscriminator,
)
from losses import (
    generator_loss,
    discriminator_loss,
    feature_loss,
    kl_loss
)
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from text.symbols import symbols

torch.backends.cudnn.benchmark = True
global_step = 0


def main():
    """Assume Single Node Multi GPUs Training Only;只考虑单机多卡训练"""
    assert torch.cuda.is_available(), "CPU training is not allowed."

    n_gpus = torch.cuda.device_count()
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '80000'

    hps = utils.get_hparams()  # 获取参数超参数
    mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))


# run函数中是实际训练代码
def run(rank, n_gpus, hps):
    global global_step
    if rank == 0:
        logger = utils.get_logger(hps.model_dir)
        logger.info(hps)
        utils.check_git_hash(hps.model_dir)
        writer = SummaryWriter(log_dir=hps.model_dir)
        writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))

    dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
    torch.manual_seed(hps.train.seed)
    torch.cuda.set_device(rank)

    train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)  # 加载数据集
    # 分布式的基于桶的sampler
    train_sampler = DistributedBucketSampler(
        train_dataset,
        hps.train.batch_size,
        [32, 300, 400, 500, 600, 700, 800, 900, 1000],  # 桶排序的边界
        num_replicas=n_gpus,
        rank=rank,
        shuffle=True)
    collate_fn = TextAudioSpeakerCollate()
    # 构建训练数据
    train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True,
                              collate_fn=collate_fn, batch_sampler=train_sampler)
    if rank == 0:  # 在主机上进行验证,即此处是在主机上加载验证数据集
        eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
        eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False,
                                 batch_size=hps.train.batch_size, pin_memory=True,
                                 drop_last=False, collate_fn=collate_fn)
    # 生成器,表示文本到音频的整个模型
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model).cuda(rank)
    # 多周期的判别器
    net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
    # 生成器的优化器
    optim_g = torch.optim.AdamW(
        net_g.parameters(),
        hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps)
    # 判别器的优化器
    optim_d = torch.optim.AdamW(
        net_d.parameters(),
        hps.train.learning_rate,
        betas=hps.train.betas,
        eps=hps.train.eps)
    # 多卡分布式训练,使用DDP把生成器和判别器包裹起来
    net_g = DDP(net_g, device_ids=[rank])
    net_d = DDP(net_d, device_ids=[rank])

    try:  # 尝试加载可能存在的通过训练已经保存的模型参数
        _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
                                                   optim_g)
        _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
                                                   optim_d)
        global_step = (epoch_str - 1) * len(train_loader)
    except:
        epoch_str = 1
        global_step = 0

    # 定义生成器和判别器的学习率schedule
    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)

    scaler = GradScaler(enabled=hps.train.fp16_run)  # 混合精度训练

    for epoch in range(epoch_str, hps.train.epochs + 1):
        if rank == 0:  # 如果为主机,除了参入正常训练参数,还需要传入验证数据集、logger等其他参数
            train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
                               [train_loader, eval_loader], logger, [writer, writer_eval])
        else:
            train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler,
                               [train_loader, None], None, None)
        # 更新学习率
        scheduler_g.step()
        scheduler_d.step()


# 训练和验证函数
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
    net_g, net_d = nets  # 生成器和判别器
    optim_g, optim_d = optims
    scheduler_g, scheduler_d = schedulers
    train_loader, eval_loader = loaders
    if writers is not None:
        writer, writer_eval = writers

    train_loader.batch_sampler.set_epoch(epoch)  # 设置train_loader中桶排序的随机种子,随机种子是每次的epoch,用于打乱数据,但也可以复现
    global global_step

    net_g.train()
    net_d.train()
    for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(train_loader):
        x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
        spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
        y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
        speakers = speakers.cuda(rank, non_blocking=True)

        with autocast(enabled=hps.train.fp16_run):  # 模型计算部分进行半精度计算
            # 对整个音频序列采样进行训练,不是把整个音频序列送入进行训练,降低训练所需资源,ids_slice就对应采样后频谱的id
            # y_hat是预测的音频波形,l_length是时长预测器的损失,attn是对齐矩阵或时长信息
            y_hat, l_length, attn, ids_slice, x_mask, z_mask, \
            (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers)

            # 将线性谱转为mel谱图,便于后续计算L_recon
            mel = spec_to_mel_torch(
                spec,
                hps.data.filter_length,
                hps.data.n_mel_channels,
                hps.data.sampling_rate,
                hps.data.mel_fmin,
                hps.data.mel_fmax)
            # 以ids_slice作为指导,采样对应窗口的mel谱图作为target
            y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
            # 从生成的音频波形y_hat中提取对应的mel谱图
            y_hat_mel = mel_spectrogram_torch(
                y_hat.squeeze(1),
                hps.data.filter_length,
                hps.data.n_mel_channels,
                hps.data.sampling_rate,
                hps.data.hop_length,
                hps.data.win_length,
                hps.data.mel_fmin,
                hps.data.mel_fmax)
            # 从完整的音频数据中以ids_slice获取对应窗口部分的音频数据;判别器判别时需要真实波形数据
            y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size)  # slice

            # Discriminator;y_d_hat_r, y_d_hat_g记录所有子判别器对batch中真实波形y和生成波形y_hat的判别结果
            y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())

            with autocast(enabled=False):  # 损失的计算不进行半精度计算
                loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)  # 判别器的损失
                loss_disc_all = loss_disc
        # 判别器更新
        optim_d.zero_grad()
        scaler.scale(loss_disc_all).backward()
        scaler.unscale_(optim_d)  # 梯度剪裁前先进行unscale
        grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)  # 梯度剪裁
        scaler.step(optim_d)

        with autocast(enabled=hps.train.fp16_run):
            # Generator
            # 将生成的波形和真实波形分别送入到判别器中,希望两者在判别器的中间特征尽可能保持一致,即论文中的L_{fm},需要fmap_r, fmap_g进行计算
            y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
            with autocast(enabled=False):
                loss_dur = torch.sum(l_length.float())  # 时间预测器loss,直接求和
                loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel  # 重构loss,论文中系数c_mel为45
                # 计算模型基于文本学习到的先验分布和从音频线性谱图中学习到的后验分布之间的KL散度,系数c_kl为1
                loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl

                loss_fm = feature_loss(fmap_r, fmap_g)  # feature map 的loss
                loss_gen, losses_gen = generator_loss(y_d_hat_g)  # 生成器的对抗loss
                loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
        # 生成器更新
        optim_g.zero_grad()
        scaler.scale(loss_gen_all).backward()
        scaler.unscale_(optim_g)
        grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
        scaler.step(optim_g)
        scaler.update()
        # 主卡上进行loss打印、记录和模型验证、保存
        if rank == 0:
            if global_step % hps.train.log_interval == 0:
                lr = optim_g.param_groups[0]['lr']
                losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
                logger.info('Train Epoch: {} [{:.0f}%]'.format(
                    epoch,
                    100. * batch_idx / len(train_loader)))
                logger.info([x.item() for x in losses] + [global_step, lr])

                scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr,
                               "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}  # 记录损失和梯度
                scalar_dict.update(
                    {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl})

                scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
                scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
                scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
                # 以图像的形式记录mel谱图和对齐信息
                image_dict = {
                    "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
                    "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
                    "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
                    "all/attn": utils.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy())
                }
                # 调用定义的tensorboard的writer记录上述信息
                utils.summarize(
                    writer=writer,
                    global_step=global_step,
                    images=image_dict,
                    scalars=scalar_dict)

            if global_step % hps.train.eval_interval == 0:
                evaluate(hps, net_g, eval_loader, writer_eval)  # 验证
                # 保存生成器和判别器的参数
                utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,
                                      os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
                utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,
                                      os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
        global_step += 1

    if rank == 0:
        logger.info('====> Epoch: {}'.format(epoch))


# 验证
def evaluate(hps, generator, eval_loader, writer_eval):
    generator.eval()  # 验证模式
    with torch.no_grad():
        for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers) in enumerate(eval_loader):
            x, x_lengths = x.cuda(0), x_lengths.cuda(0)
            spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0)
            y, y_lengths = y.cuda(0), y_lengths.cuda(0)
            speakers = speakers.cuda(0)

            # remove else
            x = x[:1]
            x_lengths = x_lengths[:1]
            spec = spec[:1]
            spec_lengths = spec_lengths[:1]
            y = y[:1]
            y_lengths = y_lengths[:1]
            speakers = speakers[:1]
            break
        y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, max_len=1000)  # 基于文本生成音频
        y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length

        # 提取真实的mel谱图
        mel = spec_to_mel_torch(
            spec,
            hps.data.filter_length,
            hps.data.n_mel_channels,
            hps.data.sampling_rate,
            hps.data.mel_fmin,
            hps.data.mel_fmax)
        # 从预测的音频的提取mel谱图
        y_hat_mel = mel_spectrogram_torch(
            y_hat.squeeze(1).float(),
            hps.data.filter_length,
            hps.data.n_mel_channels,
            hps.data.sampling_rate,
            hps.data.hop_length,
            hps.data.win_length,
            hps.data.mel_fmin,
            hps.data.mel_fmax
        )
    image_dict = {
        "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy())
    }
    audio_dict = {
        "gen/audio": y_hat[0, :, :y_hat_lengths[0]]
    }
    if global_step == 0:
        image_dict.update({"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())})
        audio_dict.update({"gt/audio": y[0, :, :y_lengths[0]]})

    # 记录信息
    utils.summarize(
        writer=writer_eval,
        global_step=global_step,
        images=image_dict,
        audios=audio_dict,
        audio_sampling_rate=hps.data.sampling_rate
    )
    generator.train()


if __name__ == "__main__":
    main()

losses.py

从论文中可知,本模型训练过程中涉及很多的损失,对抗训练过程中,判别器是常规的判别器损失结构,但是使用的是多周期判别器,由多个子判别器组成;生成器的损失,包括mel重建损失、KL散度、时长预测器损失、对抗训练生成损失以及特征图损失,其中时长预测器损失在模型forward函数中直接计算、mel重建损失是直接计算L1损失,剩下的四种损失在losses.py文件中定义,代码如下:

import torch
from torch.nn import functional as F

import commons


# 计算对抗训练中生成波形和真实波形在判别器中间特征之间的距离损失
def feature_loss(fmap_r, fmap_g):
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):  # 遍历真实波形和预测波形在判别器每层的特征图
        for rl, gl in zip(dr, dg):
            rl = rl.float().detach()
            gl = gl.float()
            loss += torch.mean(torch.abs(rl - gl))  # 计算L1损失

    return loss * 2


# 判别器损失
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    loss = 0
    r_losses = []
    g_losses = []
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):  # 遍历多个子判别器的判别结果
        dr = dr.float()  # 一个子判别器对真实波形的判别结果
        dg = dg.float()  # 一个子判别器对生成波形的判别结果
        r_loss = torch.mean((1 - dr) ** 2)  # 真实波形的判别结果越接近于1越好
        g_loss = torch.mean(dg ** 2)  # 生成波形的判别结果越接近于0越好
        loss += (r_loss + g_loss)  # 累加当前子判别器的损失
        r_losses.append(r_loss.item())
        g_losses.append(g_loss.item())

    return loss, r_losses, g_losses


# 生成器的对抗损失,就是将生成器生成的波形经过判别器后的输出与1计算距离损失,L2损失
def generator_loss(disc_outputs):
    loss = 0
    gen_losses = []
    for dg in disc_outputs:
        dg = dg.float()
        l = torch.mean((1 - dg) ** 2)
        gen_losses.append(l)
        loss += l

    return loss, gen_losses


# 先验分布和后验分布之间的KL散度
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
    """
    z_p, logs_q: [b, h, t_t]
    m_p, logs_p: [b, h, t_t]
    """
    z_p = z_p.float()
    logs_q = logs_q.float()
    m_p = m_p.float()
    logs_p = logs_p.float()
    z_mask = z_mask.float()

    kl = logs_p - logs_q - 0.5
    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2. * logs_p)
    kl = torch.sum(kl * z_mask)
    l = kl / torch.sum(z_mask)
    return l

本笔记主要记录vits官方仓库中模型训练相关代码,其中涉及到的一些辅助函数,如果有必要后续会进行补充。本笔记主要是对代码进行详细的注释,读者若发现问题或错误,请评论指出,互相学习。

你可能感兴趣的:(github项目代码,TTS,1024程序员节)