OLMo代码解析——train.py

train.py的运行代码

if __name__ == "__main__":
    # Initialize process group.
    dist.init_process_group(backend="nccl")

    prepare_cli_environment()

    try:
        yaml_path, args_list = sys.argv[1], sys.argv[2:]
    except IndexError:
        raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")

    cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])
    main(cfg)
  1. dist.init_process_group(backend=“nccl”)
    初始化分布式训练,使得多个GPU之前能够相互通信
  2. prepare_cli_environment()
    设置命令行环境
  3. yaml_path, args_list = sys.argv[1], sys.argv[2:]
    根据命令行输入,获取需要训练的模型配置文件路径以及剩余的参数
  4. cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])
    加载选择的模型配置文件,加载清理过的参数列表
  5. main(cfg)
    最后将加载好的参数传入训练函数

main()

part 1

    if cfg.run_name is None:
        raise OlmoConfigurationError("--run_name is required")
    log_extra_field("run_name", cfg.run_name)

    # Sanity check
    if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None:
        log.warning(
            "You want to reset the optimizer or trainer state, but we're not loading from the checkpoint. The"
            "setting has no effect."
        )

    barrier()

    # Set CUDA device.
    torch.cuda.set_device(f"cuda:{get_local_rank()}")
    device = torch.device("cuda")

    # Fill some configuration options.
    cfg.model.precision = cfg.precision
    cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
    assert cfg.device_train_batch_size is not None  # for mypy
    cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size
    if cfg.optimizer.no_decay_norm_and_bias is not None:
        log.warning(
            "You set the deprecated config option `no_decay_norm_and_bias`. For compatibility, this"
            "setting will take precedence over all other weight decay configurations. Please change"
            "your config to use `decay_norm_and_bias` and `decay_embeddings` instead."
        )
        cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias
        cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias
        cfg.optimizer.no_decay_norm_and_bias = None  # So nobody uses this by accident.
  1. if cfg.run_name is None:
    raise OlmoConfigurationError(“–run_name is required”)
    log_extra_field(“run_name”, cfg.run_name)
    确保run_name参数存在并且记录到log中

  2. if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None:
    log.warning(
    “You want to reset the optimizer or trainer state, but we’re not loading from the checkpoint. The”
    “setting has no effect.”
    )
    这是一个警告,检查是否设置了 reset_optimizer_state 或 reset_trainer_state 但没有提供加载路径 load_path。
    如果是这样,警告用户这些设置没有效果,因为没有从检查点加载。

  3. barrier()
    调用 barrier 函数,这通常是用于同步分布式训练中不同进程的同步点,确保所有进程都达到这一步骤。

  4. torch.cuda.set_device(f"cuda:{get_local_rank()}")
    device = torch.device(“cuda”)
    设置当前进程的 CUDA 设备,使用 get_local_rank() 获取当前进程的本地 GPU 排名。
    创建一个 PyTorch 设备对象 device,表示当前进程使用的 GPU。

  5. cfg.model.precision = cfg.precision
    cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
    assert cfg.device_train_batch_size is not None # for mypy

  6. cfg.model.precision = cfg.precision
    cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
    assert cfg.device_train_batch_size is not None # for mypy

将模型的精度(float32 or 64 )和设备上的训练批次大小设置为配置中的值。
计算每个设备上的训练批次大小,通过总体训练批次大小除以分布式训练的总进程数。
使用 assert 断言确保 cfg.device_train_batch_size 不为 None,这通常用于类型检查。
7. cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size
计算每个设备上的训练梯度累积步数,通过设备上的训练批次大小除以每个微批次的大小。

  1. if cfg.optimizer.no_decay_norm_and_bias is not None:
    log.warning(
    “You set the deprecated config option no_decay_norm_and_bias. For compatibility, this”
    “setting will take precedence over all other weight decay configurations. Please change”
    “your config to use decay_norm_and_bias and decay_embeddings instead.”
    )
    cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias
    cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias
    cfg.optimizer.no_decay_norm_and_bias = None # So nobody uses this by accident.
    如果配置中设置了 optimizer.no_decay_norm_and_bias,则发出一个警告,指示该选项已被弃用。
    然后,根据这个弃用的选项,设置相应的新选项,然后将弃用的选项设置为 None,以防止错误使用。
    权重衰减(weight decay)是一种正则化技术,通过在优化过程中对模型参数的更新中引入额外的惩罚,以防止过拟合。decay_norm_and_bias 和 decay_embeddings 通常是用于控制权重衰减的配置选项。

decay_norm_and_bias:
作用: 控制是否对权重衰减(weight decay)应用到模型中的偏置项和归一化层的缩放和偏置参数。
配置: 一般来说,如果设置为 True,则对这些参数应用权重衰减;如果设置为 False,则不应用权重衰减。
注意: 这个配置选项可能是为了提供更细粒度的控制,因为在一些情况下,对归一化层的参数进行权重衰减可能不是最佳选择。
decay_embeddings:
作用: 控制是否对权重衰减应用到模型中的嵌入层(embedding layers)的参数。
配置: 类似于 decay_norm_and_bias,如果设置为 True,则对嵌入层参数应用权重衰减;如果设置为 False,则不应用权重衰减。
注意: 在一些情况下,对嵌入层使用不同的权重衰减配置可能有助于调整模型的训练行为,特别是当嵌入层的参数数量相对较大且需要特殊处理时。
这两个配置选项的存在是为了允许用户更精细地控制权重衰减的应用范围,以便根据模型的结构和训练任务的需求进行调整。

part2

    # Display and save configuration.
    if get_global_rank() == 0:
        if cfg.data.paths is not None and len(cfg.data.paths) < 50:
            log.info("Configuration:")
            log.info(cfg)
        if not cfg.dry_run and (cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)):
            # Save config.
            save_path = Path(cfg.save_folder) / "config.yaml"
            if save_path.is_file() and not cfg.save_overwrite:
                raise OlmoConfigurationError(f"{save_path} already exists, use --save_overwrite to overwrite")
            else:
                log.info(f"Saving config to {save_path}")
                save_path.parent.mkdir(exist_ok=True, parents=True)
                cfg.save(save_path)
            del save_path

    barrier()

    # Maybe start W&B run.
    if cfg.wandb is not None and (get_global_rank() == 0 or not cfg.wandb.rank_zero_only):
        wandb_dir = Path(cfg.save_folder) / "wandb"
        wandb_dir.mkdir(parents=True, exist_ok=True)
        wandb.init(
            dir=wandb_dir,
            project=cfg.wandb.project,
            entity=cfg.wandb.entity,
            group=cfg.wandb.group,
            name=cfg.wandb.name,
            tags=cfg.wandb.tags,
            config=cfg.asdict(exclude=["wandb"]),
        )

    barrier()

    # Set seed.
    seed_all(cfg.seed)
  1. 显示和保存配置
# Display and save configuration.
if get_global_rank() == 0:
    if cfg.data.paths is not None and len(cfg.data.paths) < 50:
        log.info("Configuration:")
        log.info(cfg)
    if not cfg.dry_run and (cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)):
        # Save config.
        save_path = Path(cfg.save_folder) / "config.yaml"
        if save_path.is_file() and not cfg.save_overwrite:
            raise OlmoConfigurationError(f"{save_path} already exists, use --save_overwrite to overwrite")
        else:
            log.info(f"Saving config to {save_path}")
            save_path.parent.mkdir(exist_ok=True, parents=True)
            cfg.save(save_path)
        del save_path
barrier()

这部分代码在全局进程(get_global_rank() == 0)中负责显示和保存配置。
如果配置中存在数据路径且数量小于50,打印配置信息。
如果不是干运行(cfg.dry_run 为 False)且加载路径为空或加载路径的父目录与保存目录不同,保存配置文件到指定路径。
如果保存路径已存在且不允许覆盖保存(cfg.save_overwrite 为 False),则引发配置错误。
保存配置文件到save路径,然后删除变量
最后确保所有进程都达到某个同步点,以防止并发问题
干运行:通常指的是在执行程序时,仅仅是输出一些信息而不执行实际的操作。在深度学习训练中,有时候为了验证配置是否正确、数据是否加载正确,以及网络结构是否如预期那样,可以进行干运行而不进行实际的训练。
2. wandb配置

# Maybe start W&B run.
if cfg.wandb is not None and (get_global_rank() == 0 or not cfg.wandb.rank_zero_only):
    wandb_dir = Path(cfg.save_folder) / "wandb"
    wandb_dir.mkdir(parents=True, exist_ok=True)
    wandb.init(
        dir=wandb_dir,
        project=cfg.wandb.project,
        entity=cfg.wandb.entity,
        group=cfg.wandb.group,
        name=cfg.wandb.name,
        tags=cfg.wandb.tags,
        config=cfg.asdict(exclude=["wandb"]),
    )

如果配置中启用了 W&B(cfg.wandb 不为 None),并且当前进程是全局进程(get_global_rank() == 0)或者配置中不要求仅在全局进程运行(cfg.wandb.rank_zero_only 为 False)。
创建 W&B 目录并初始化 W&B 运行,传递配置信息。

  1. 随机种子设置
    seed_all(cfg.seed)
    设置随机种子为配置中指定的 seed 值,这有助于使训练过程具有可重复性。

part3

    # Construct data loader.
    train_loader = build_train_dataloader(cfg)

    # Construct evaluators.
    evaluators = build_evaluators(cfg, device)
    barrier()

    # Initialize the model.
    log.info("Building model...")
    olmo_model = Olmo(cfg.model)
    log.info(f"Total number of parameters: {olmo_model.num_params():,d}")
    log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}")
    log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}")

    olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)

    # Wrap the model in FSDP.
    log.info("Wrapping model with FDSP...")
    wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
    if version.parse(torch.__version__) >= version.parse("2.1.0"):
        # This prevents any parameters from being initialized twice
        def dummy_init_fn(module: torch.nn.Module) -> None:
            module.to_empty(device=get_default_device())

        param_init_fn = dummy_init_fn
    else:
        param_init_fn = None
    fsdp_model = FSDP(
        olmo_model,
        sharding_strategy=cfg.fsdp.sharding_strategy,
        mixed_precision=cfg.fsdp_precision,
        auto_wrap_policy=wrap_policy,
        use_orig_params=cfg.fsdp.use_orig_params,  # needed for compile and some of our optimizer/parameter metrics
        limit_all_gathers=True,
        device_id=get_local_rank(),
        param_init_fn=param_init_fn,
    )
    # when param_init_fn is None, FSDP will call reset_parameters() automatically
    if param_init_fn is not None:
        olmo_model.reset_parameters()

    log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}")
    log.info("Model:")
    log.info(fsdp_model)

    # Construct optimizer and learning rate scheduler.
    optim = build_optimizer(cfg, fsdp_model)
    scheduler = build_scheduler(cfg)

    # Data indices file.
    indices_file: Optional[TextIO] = None
    if cfg.save_data_indices:
        indices_file_path = Path(cfg.save_folder) / f"data-indices/rank{get_global_rank()}.tsv.gz"
        if indices_file_path.exists() and not cfg.save_overwrite:
            raise OlmoConfigurationError(f"{indices_file_path} already exists, use --save_overwrite to overwrite")
        indices_file_path.parent.mkdir(exist_ok=True, parents=True)
        indices_file = gzip.open(indices_file_path, "wt")
  1. train_loader = build_train_dataloader(cfg)
    创建训练数据加载器。build_train_dataloader 函数根据配置 cfg 中的参数构建一个用于训练的数据加载器,具体的实现在build_train_dataloader函数里
  2. evaluators = build_evaluators(cfg, device)
    创建评估器。build_evaluators 函数根据配置 cfg 和指定的设备 device 构建一个或多个用于模型评估的评估器。
  3. barrier()
    同步各个进程
  4. 创建模型。
# Initialize the model.
log.info("Building model...")
olmo_model = Olmo(cfg.model)
log.info(f"Total number of parameters: {olmo_model.num_params():,d}")
log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}")
log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}")

使用配置中的模型配置构建一个 Olmo 模型(类),并输出一些有关模型参数和 GPU 内存的信息。

  1. olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)
    根据配置中的 activation_checkpointing 参数设置激活检查点。**激活检查点(Activation Checkpointing)**是一种深度学习中的优化技术,旨在减少内存使用,特别是对于具有深层结构的模型。在深层神经网络中,随着信息传递到网络的深层,激活值的数量和内存需求也呈指数级增长。激活检查点通过在前向传播的某些位置存储中间激活值的检查点,从而在后续计算中减少内存的占用。
  2. 模型包装在 FSDP(Fully Sharded Data Parallelism)中
# Wrap the model in FSDP.
    log.info("Wrapping model with FDSP...")
    wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
    if version.parse(torch.__version__) >= version.parse("2.1.0"):
        # This prevents any parameters from being initialized twice
        def dummy_init_fn(module: torch.nn.Module) -> None:
            module.to_empty(device=get_default_device())

        param_init_fn = dummy_init_fn
    else:
        param_init_fn = None
    fsdp_model = FSDP(
        olmo_model,
        sharding_strategy=cfg.fsdp.sharding_strategy,
        mixed_precision=cfg.fsdp_precision,
        auto_wrap_policy=wrap_policy,
        use_orig_params=cfg.fsdp.use_orig_params,  # needed for compile and some of our optimizer/parameter metrics
        limit_all_gathers=True,
        device_id=get_local_rank(),
        param_init_fn=param_init_fn,
    )

7.optim = build_optimizer(cfg, fsdp_model)
scheduler = build_scheduler(cfg)
创建优化器和学习率调度器。根据配置 cfg 中的参数和之前创建的 FSDP 模型,构建用于优化模型参数的优化器以及用于动态调整学习率的调度器。
8.数据索引

# Data indices file.
indices_file: Optional[TextIO] = None
if cfg.save_data_indices:
    indices_file_path = Path(cfg.save_folder) / f"data-indices/rank{get_global_rank()}.tsv.gz"
    if indices_file_path.exists() and not cfg.save_overwrite:
        raise OlmoConfigurationError(f"{indices_file_path} already exists, use --save_overwrite to overwrite")
    indices_file_path.parent.mkdir(exist_ok=True, parents=True)
    indices_file = gzip.open(indices_file_path, "wt")

如果配置中设置了保存数据索引的选项 (cfg.save_data_indices),则创建一个文件用于保存数据索引。如果文件已存在且不允许覆盖保存,则引发配置错误。文件路径包括进程的全局排名,以避免多个进程写入相同的文件。
数据索引在深度学习中通常用于记录训练数据的索引或相关信息,具体的作用取决于使用它们的上下文。以下是一些可能的用途:
数据分析和可视化: 通过保存数据索引,你可以在训练后进行数据分析,了解哪些样本被用于训练,它们的分布情况等。这对于诊断模型在不同类型的数据上的表现以及检查数据集的质量非常有帮助。
复现实验: 在科学研究中,复现实验是非常重要的。通过保存数据索引,可以确保在不同时间或不同环境中能够重现相同的训练集。这对于验证模型在不同条件下的一致性和稳定性非常有用。
Debugging: 在调试模型时,有时需要检查特定样本的输入或输出。保存数据索引可以帮助你追踪具体的数据样本,并在需要时检查输入、输出或标签。
分布式训练中的同步: 在分布式训练环境中,每个进程可能处理不同的数据子集。通过保存每个进程处理的数据索引,可以确保每个进程都在同一时间处理相同的数据。

part4

控制整个训练流程,包括保存和加载检查点,编译模型,以及执行实际的训练过程。

    with Trainer(
        cfg=cfg,
        epoch=cfg.epoch,
        model=olmo_model,
        fsdp_model=fsdp_model,
        optim=optim,
        scheduler=scheduler,
        train_loader=train_loader,
        device=device,
        evaluators=evaluators,
        indices_file=indices_file,
    ) as trainer:
        if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None:
            checkpoint_type = (
                CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
            )

            # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever).
            log.info("Saving pre-train checkpoint...")
            checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(checkpoint_type=checkpoint_type)
            log.info(f"Checkpoint saved to {checkpoint_path}")

            # And they we verify that we can load it.
            log.info("Attempting to load pre-train checkpoint...")
            trainer.restore_checkpoint(
                checkpoint_path, checkpoint_type=checkpoint_type, local_cache=local_checkpoint_cache
            )
            log.info("Checkpoint successfully loaded")

            # NOTE: https://github.com/allenai/LLM/issues/233
            #  log.info("Removing pre-train checkpoint...")
            #  trainer.remove_checkpoint(checkpoint_type=checkpoint_type)
            #  log.info("Successfully removed checkpoint")

        if cfg.load_path is not None:
            log.info(f"Loading checkpoint from {cfg.load_path}...")
            trainer.restore_checkpoint(
                cfg.load_path,
                load_optimizer_state=not cfg.reset_optimizer_state,
                load_trainer_state=not cfg.reset_trainer_state,
                sharded_checkpointer=cfg.load_path_sharded_checkpointer,
            )
            log.info("Checkpoint successfully loaded")

            # If we have to, set a new scheduler:
            if cfg.reset_optimizer_state and not cfg.reset_trainer_state:
                trainer.scheduler = BoltOnWarmupScheduler.wrap(
                    trainer.scheduler,
                    trainer.global_step,
                    int(trainer.global_step + cfg.scheduler.t_warmup),
                )

        if cfg.force_save_unsharded:
            log.info("Saving unsharded checkpoint...")
            checkpoint_path, _ = trainer.save_checkpoint(checkpoint_type=CheckpointType.unsharded)
            log.info(f"Unsharded checkpoint saved to {checkpoint_path}")

        if cfg.compile is not None:
            # TODO (epwalsh): trying to compile the whole train step results in a compile-time error from within
            # the optimizer. We should investigate this further at some point.
            #  trainer.train_step = torch.compile(trainer.train_step, **cfg.compile.asdict())
            trainer.train_batch = torch.compile(trainer.train_batch, **cfg.compile.asdict())  # type: ignore
            # TODO (epwalsh): compiling the `eval_batch()` method is a little sketchy since the inputs will look
            # different for different eval tasks. That might be okay, but it might not be.
            #  trainer.eval_batch = torch.compile(trainer.eval_batch, **cfg.compile.asdict())  # type: ignore
            # Alternatively, could just do this:
            #  trainer.fsdp_model = torch.compile(trainer.fsdp_model, **cfg.compile.asdict())

        if not cfg.dry_run:
            log.info("Starting training...")
            trainer.fit()
            log.info("Training complete")
        else:
            log.info("Dry run complete")
  1. with Trainer(
    cfg=cfg,
    epoch=cfg.epoch,
    model=olmo_model,
    fsdp_model=fsdp_model,
    optim=optim,
    scheduler=scheduler,
    train_loader=train_loader,
    device=device,
    evaluators=evaluators,
    indices_file=indices_file,
    ) as trainer:
    使用 Trainer 对象进行训练。这里使用了 Python 中的 with 语句,确保在离开代码块时执行 trainer.close() 操作。Trainer 对象似乎负责整个训练过程的管理。
  2. 检查和验证能否运行
    if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None:
        checkpoint_type = (
            CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
        )

        # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever).
        log.info("Saving pre-train checkpoint...")
        checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(checkpoint_type=checkpoint_type)
        log.info(f"Checkpoint saved to {checkpoint_path}")

        # And they we verify that we can load it.
        log.info("Attempting to load pre-train checkpoint...")
        trainer.restore_checkpoint(
            checkpoint_path, checkpoint_type=checkpoint_type, local_cache=local_checkpoint_cache
        )
        log.info("Checkpoint successfully loaded")

        # NOTE: https://github.com/allenai/LLM/issues/233
        #  log.info("Removing pre-train checkpoint...")
        #  trainer.remove_checkpoint(checkpoint_type=checkpoint_type)
        #  log.info("Successfully removed checkpoint")

如果不是干运行 (cfg.dry_run) 且允许保存预训练检查点且未指定加载路径,则执行以下操作:
根据配置的保存检查点数量来确定是分片保存还是整体保存。
提前保存一个检查点,以确保这个操作不会失败。
验证能够加载保存的检查点。
3. 继续训练

    if cfg.load_path is not None:
        log.info(f"Loading checkpoint from {cfg.load_path}...")
        trainer.restore_checkpoint(
            cfg.load_path,
            load_optimizer_state=not cfg.reset_optimizer_state,
            load_trainer_state=not cfg.reset_trainer_state,
            sharded_checkpointer=cfg.load_path_sharded_checkpointer,
        )
        log.info("Checkpoint successfully loaded")

        # If we have to, set a new scheduler:
        if cfg.reset_optimizer_state and not cfg.reset_trainer_state:
            trainer.scheduler = BoltOnWarmupScheduler.wrap(
                trainer.scheduler,
                trainer.global_step,
                int(trainer.global_step + cfg.scheduler.t_warmup),
            )

如果指定了加载路径 (cfg.load_path)(也就是说继续训练的状态下),则加载模型的检查点。根据配置,可以选择是否加载优化器和训练器的状态,并且可能使用分片的检查点加载器。如果需要重置优化器状态但不重置训练器状态,则设置一个新的调度器(scheduler)。
4. if cfg.force_save_unsharded:
log.info(“Saving unsharded checkpoint…”)
checkpoint_path, _ = trainer.save_checkpoint(checkpoint_type=CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
如果配置中强制保存未分片的检查点 (cfg.force_save_unsharded),则保存未分片的检查点。
5.trainer.train_batch = torch.compile(trainer.train_batch, **cfg.compile.asdict()) # type: ignore
如果配置中指定了编译选项 (cfg.compile),则尝试编译训练批次 (train_batch)。这可能是为了提高训练速度的一种优化手段
6.开始训练或干运行

    if not cfg.dry_run:
        log.info("Starting training...")
        trainer.fit()
        log.info("Training complete")
    else:
        log.info("Dry run complete")

如果不是干运行,则开始训练过程,调用 trainer.fit()。否则,打印干运行完成的信息。

你可能感兴趣的:(10天学完OLMo,transformer,人工智能,pytorch)