GitHub - roedoejet/FastSpeech2: An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"
python3 train.py -p config/AISHELL3/preprocess.yaml -m config/AISHELL3/model.yaml -t config/AISHELL3/train.yaml
通过 if __name__ == "__main__"运行整个py文件:
调用 “train.txt"和dataset.py加载数据,
调用utils文件夹下的model.py加载模型,声码器,
调用model文件夹下的loss.py中的FastSpeech2Loss class 设置损失函数,
用前面加载的模型和损失函数开始训练模型,导出结果并记录日志。
Step 0 : 定义可控训练参数, 调动main函数
if __name__ == "__main__":
#Define Args
parser = argparse.ArgumentParser()
parser.add_argument("--restore_step", type=int, default=0)
parser.add_argument(
"-p",
"--preprocess_config",
type=str,
required=True,
help="path to preprocess.yaml",
)
parser.add_argument(
"-m", "--model_config", type=str, required=True, help="path to model.yaml"
)
parser.add_argument(
"-t", "--train_config", type=str, required=True, help="path to train.yaml"
)
args = parser.parse_args() #args为可控训练参数
# Read Config
preprocess_config = yaml.load(
open(args.preprocess_config, "r"), Loader=yaml.FullLoader
)
model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
configs = (preprocess_config, model_config, train_config)
#Run _main_ function
main(args, configs)
Step 1 : 启动main函数,加载可控训练参数
def main(args, configs):
print("Prepare training ...")
#加载可控训练参数
preprocess_config, model_config, train_config = configs
Step 2 : 从train.txt加载数据,并经由dataset.py和torch里的Dataloader处理
def main(args, configs):
# Get dataset
dataset = Dataset(
"train.txt", preprocess_config, train_config, sort=True, drop_last=True
) #从 train.txt 中获取dataset
batch_size = train_config["optimizer"]["batch_size"]
group_size = 4 # Set this larger than 1 to enable sorting in Dataset,初始值为4
assert batch_size * group_size < len(dataset)
loader = DataLoader(
dataset,
batch_size=batch_size * group_size,
shuffle=True,
collate_fn=dataset.collate_fn,
)
Step 3 : 定义模型,声码器,损失函数
def main(args, configs):
# Prepare model
model, optimizer = get_model(args, configs, device, train=True) #设置优化器
# 将模型并行训练并移入计算设备中
model = nn.DataParallel(model) # Model Has Been Defined
# 计算模型参数量
num_param = get_param_num(model) # Number of TTS Parameters: num_param
print("Number of FastSpeech2 Parameters:", num_param)
# 设置损失函数
Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
# 加载声码器
vocoder = get_vocoder(model_config, device)
Step 4 : 加载日志,在"./output/log/AISHELL3"目录建立train, val两个文件夹来记录日志
def main(args, configs):
# Init logger
for p in train_config["path"].values():
os.makedirs(p, exist_ok=True)
train_log_path = os.path.join(train_config["path"]["log_path"], "train")
val_log_path = os.path.join(train_config["path"]["log_path"], "val")
os.makedirs(train_log_path, exist_ok=True)
os.makedirs(val_log_path, exist_ok=True)
train_logger = SummaryWriter(train_log_path)
val_logger = SummaryWriter(val_log_path)
Step 5 : 准备训练,加载可控训练参数
def main(args, configs):
# Training
step = args.restore_step + 1
epoch = 1
grad_acc_step = train_config["optimizer"]["grad_acc_step"]
grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
total_step = train_config["step"]["total_step"]
log_step = train_config["step"]["log_step"]
save_step = train_config["step"]["save_step"]
synth_step = train_config["step"]["synth_step"]
val_step = train_config["step"]["val_step"]
outer_bar = tqdm(total=total_step, desc="Training", position=0)
outer_bar.n = args.restore_step
outer_bar.update()
Step 6 : 准备训练,加载进度条,调动utils文件夹下tools.py中的to_device function来提取数据
while True:
inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
for batchs in loader:
for batch in batchs:
batch = to_device(batch, device)
Step 7 :开始训练,前向传播,计算损失,反向传播,梯度剪枝,更新模型权重参数
#Load Data
for batch in batchs:
batch = to_device(batch, device)
# Forward
output = model(*(batch[2:]))
# Cal Loss
losses = Loss(batch, output)
total_loss = losses[0]
# Backward
total_loss = total_loss / grad_acc_step
total_loss.backward()
if step % grad_acc_step == 0:
# Clipping gradients to avoid gradient explosion
nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
# Update weights
optimizer.step_and_update_lr()
optimizer.zero_grad()
Step 8 : 当训练步数到达预先设定的log_step时,调动utils文件夹下tool.py里的log function,记录loss和step
if step % log_step == 0:
losses = [l.item() for l in losses]
message1 = "Step {}/{}, ".format(step, total_step)
message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
*losses
)
with open(os.path.join(train_log_path, "log.txt"), "a") as f:
f.write(message1 + message2 + "\n")
outer_bar.write(message1 + message2)
log(train_logger, step, losses=losses)
Step 9 : 当训练步数到达预先设定的synth_step时,调动utils文件夹下tool.py里的log function 和 synth_one_sample function(具体用来干什么没看懂)
if step % synth_step == 0:
fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
batch,
output,
vocoder,
model_config,
preprocess_config,
)
log(
train_logger,
fig=fig,
tag="Training/step_{}_{}".format(step, tag),
)
sampling_rate = preprocess_config["preprocessing"]["audio"][
"sampling_rate"
]
log(
train_logger,
audio=wav_reconstruction,
sampling_rate=sampling_rate,
tag="Training/step_{}_{}_reconstructed".format(step, tag),
)
log(
train_logger,
audio=wav_prediction,
sampling_rate=sampling_rate,
tag="Training/step_{}_{}_synthesized".format(step, tag),
)
Step 10 : 当训练步数到达预先设定的val_step时,调动evaluate.py里的evaluate function来进行evaluation,并记录在log/AISHELL3/val/log.txt
if step % val_step == 0:
model.eval()
message = evaluate(model, step, configs, val_logger, vocoder)
with open(os.path.join(val_log_path, "log.txt"), "a") as f:
f.write(message + "\n")
outer_bar.write(message)
model.train()
Step 11 : 当训练步数到达预先设定的save_step时,保存训练模型
if step % save_step == 0:
torch.save(
{
"model": model.module.state_dict(),
"optimizer": optimizer._optimizer.state_dict(),
},
os.path.join(
train_config["path"]["ckpt_path"],
"{}.pth.tar".format(step),
),
)
Step 12 : 当训练步数到达预先设定的total_step时,退出训练
if step == total_step:
quit()
step += 1
outer_bar.update(1)
inner_bar.update(1)
epoch += 1
在train_log_path和val_log_path输出日志
在ckpt_path输出训练过程中按照save_step存储的模型