if __name__ == "__main__":
# Initialize process group.
dist.init_process_group(backend="nccl")
prepare_cli_environment()
try:
yaml_path, args_list = sys.argv[1], sys.argv[2:]
except IndexError:
raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")
cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])
main(cfg)
if cfg.run_name is None:
raise OlmoConfigurationError("--run_name is required")
log_extra_field("run_name", cfg.run_name)
# Sanity check
if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None:
log.warning(
"You want to reset the optimizer or trainer state, but we're not loading from the checkpoint. The"
"setting has no effect."
)
barrier()
# Set CUDA device.
torch.cuda.set_device(f"cuda:{get_local_rank()}")
device = torch.device("cuda")
# Fill some configuration options.
cfg.model.precision = cfg.precision
cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
assert cfg.device_train_batch_size is not None # for mypy
cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size
if cfg.optimizer.no_decay_norm_and_bias is not None:
log.warning(
"You set the deprecated config option `no_decay_norm_and_bias`. For compatibility, this"
"setting will take precedence over all other weight decay configurations. Please change"
"your config to use `decay_norm_and_bias` and `decay_embeddings` instead."
)
cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias
cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias
cfg.optimizer.no_decay_norm_and_bias = None # So nobody uses this by accident.
if cfg.run_name is None:
raise OlmoConfigurationError(“–run_name is required”)
log_extra_field(“run_name”, cfg.run_name)
确保run_name参数存在并且记录到log中
if (cfg.reset_optimizer_state or cfg.reset_trainer_state) and cfg.load_path is None:
log.warning(
“You want to reset the optimizer or trainer state, but we’re not loading from the checkpoint. The”
“setting has no effect.”
)
这是一个警告,检查是否设置了 reset_optimizer_state 或 reset_trainer_state 但没有提供加载路径 load_path。
如果是这样,警告用户这些设置没有效果,因为没有从检查点加载。
barrier()
调用 barrier 函数,这通常是用于同步分布式训练中不同进程的同步点,确保所有进程都达到这一步骤。
torch.cuda.set_device(f"cuda:{get_local_rank()}")
device = torch.device(“cuda”)
设置当前进程的 CUDA 设备,使用 get_local_rank() 获取当前进程的本地 GPU 排名。
创建一个 PyTorch 设备对象 device,表示当前进程使用的 GPU。
cfg.model.precision = cfg.precision
cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
assert cfg.device_train_batch_size is not None # for mypy
cfg.model.precision = cfg.precision
cfg.device_train_batch_size = cfg.global_train_batch_size // get_world_size()
assert cfg.device_train_batch_size is not None # for mypy
将模型的精度(float32 or 64 )和设备上的训练批次大小设置为配置中的值。
计算每个设备上的训练批次大小,通过总体训练批次大小除以分布式训练的总进程数。
使用 assert 断言确保 cfg.device_train_batch_size 不为 None,这通常用于类型检查。
7. cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size
计算每个设备上的训练梯度累积步数,通过设备上的训练批次大小除以每个微批次的大小。
no_decay_norm_and_bias
. For compatibility, this”decay_norm_and_bias
and decay_embeddings
instead.”decay_norm_and_bias:
作用: 控制是否对权重衰减(weight decay)应用到模型中的偏置项和归一化层的缩放和偏置参数。
配置: 一般来说,如果设置为 True,则对这些参数应用权重衰减;如果设置为 False,则不应用权重衰减。
注意: 这个配置选项可能是为了提供更细粒度的控制,因为在一些情况下,对归一化层的参数进行权重衰减可能不是最佳选择。
decay_embeddings:
作用: 控制是否对权重衰减应用到模型中的嵌入层(embedding layers)的参数。
配置: 类似于 decay_norm_and_bias,如果设置为 True,则对嵌入层参数应用权重衰减;如果设置为 False,则不应用权重衰减。
注意: 在一些情况下,对嵌入层使用不同的权重衰减配置可能有助于调整模型的训练行为,特别是当嵌入层的参数数量相对较大且需要特殊处理时。
这两个配置选项的存在是为了允许用户更精细地控制权重衰减的应用范围,以便根据模型的结构和训练任务的需求进行调整。
# Display and save configuration.
if get_global_rank() == 0:
if cfg.data.paths is not None and len(cfg.data.paths) < 50:
log.info("Configuration:")
log.info(cfg)
if not cfg.dry_run and (cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)):
# Save config.
save_path = Path(cfg.save_folder) / "config.yaml"
if save_path.is_file() and not cfg.save_overwrite:
raise OlmoConfigurationError(f"{save_path} already exists, use --save_overwrite to overwrite")
else:
log.info(f"Saving config to {save_path}")
save_path.parent.mkdir(exist_ok=True, parents=True)
cfg.save(save_path)
del save_path
barrier()
# Maybe start W&B run.
if cfg.wandb is not None and (get_global_rank() == 0 or not cfg.wandb.rank_zero_only):
wandb_dir = Path(cfg.save_folder) / "wandb"
wandb_dir.mkdir(parents=True, exist_ok=True)
wandb.init(
dir=wandb_dir,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
group=cfg.wandb.group,
name=cfg.wandb.name,
tags=cfg.wandb.tags,
config=cfg.asdict(exclude=["wandb"]),
)
barrier()
# Set seed.
seed_all(cfg.seed)
# Display and save configuration.
if get_global_rank() == 0:
if cfg.data.paths is not None and len(cfg.data.paths) < 50:
log.info("Configuration:")
log.info(cfg)
if not cfg.dry_run and (cfg.load_path is None or Path(cfg.load_path).parent != Path(cfg.save_folder)):
# Save config.
save_path = Path(cfg.save_folder) / "config.yaml"
if save_path.is_file() and not cfg.save_overwrite:
raise OlmoConfigurationError(f"{save_path} already exists, use --save_overwrite to overwrite")
else:
log.info(f"Saving config to {save_path}")
save_path.parent.mkdir(exist_ok=True, parents=True)
cfg.save(save_path)
del save_path
barrier()
这部分代码在全局进程(get_global_rank() == 0)中负责显示和保存配置。
如果配置中存在数据路径且数量小于50,打印配置信息。
如果不是干运行(cfg.dry_run 为 False)且加载路径为空或加载路径的父目录与保存目录不同,保存配置文件到指定路径。
如果保存路径已存在且不允许覆盖保存(cfg.save_overwrite 为 False),则引发配置错误。
保存配置文件到save路径,然后删除变量
最后确保所有进程都达到某个同步点,以防止并发问题
干运行:通常指的是在执行程序时,仅仅是输出一些信息而不执行实际的操作。在深度学习训练中,有时候为了验证配置是否正确、数据是否加载正确,以及网络结构是否如预期那样,可以进行干运行而不进行实际的训练。
2. wandb配置
# Maybe start W&B run.
if cfg.wandb is not None and (get_global_rank() == 0 or not cfg.wandb.rank_zero_only):
wandb_dir = Path(cfg.save_folder) / "wandb"
wandb_dir.mkdir(parents=True, exist_ok=True)
wandb.init(
dir=wandb_dir,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
group=cfg.wandb.group,
name=cfg.wandb.name,
tags=cfg.wandb.tags,
config=cfg.asdict(exclude=["wandb"]),
)
如果配置中启用了 W&B(cfg.wandb 不为 None),并且当前进程是全局进程(get_global_rank() == 0)或者配置中不要求仅在全局进程运行(cfg.wandb.rank_zero_only 为 False)。
创建 W&B 目录并初始化 W&B 运行,传递配置信息。
# Construct data loader.
train_loader = build_train_dataloader(cfg)
# Construct evaluators.
evaluators = build_evaluators(cfg, device)
barrier()
# Initialize the model.
log.info("Building model...")
olmo_model = Olmo(cfg.model)
log.info(f"Total number of parameters: {olmo_model.num_params():,d}")
log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}")
log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}")
olmo_model.set_activation_checkpointing(cfg.activation_checkpointing)
# Wrap the model in FSDP.
log.info("Wrapping model with FDSP...")
wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
if version.parse(torch.__version__) >= version.parse("2.1.0"):
# This prevents any parameters from being initialized twice
def dummy_init_fn(module: torch.nn.Module) -> None:
module.to_empty(device=get_default_device())
param_init_fn = dummy_init_fn
else:
param_init_fn = None
fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
device_id=get_local_rank(),
param_init_fn=param_init_fn,
)
# when param_init_fn is None, FSDP will call reset_parameters() automatically
if param_init_fn is not None:
olmo_model.reset_parameters()
log.info(f"Peak GPU Memory (MB) after FSDP: {int(peak_gpu_memory() or 0)}")
log.info("Model:")
log.info(fsdp_model)
# Construct optimizer and learning rate scheduler.
optim = build_optimizer(cfg, fsdp_model)
scheduler = build_scheduler(cfg)
# Data indices file.
indices_file: Optional[TextIO] = None
if cfg.save_data_indices:
indices_file_path = Path(cfg.save_folder) / f"data-indices/rank{get_global_rank()}.tsv.gz"
if indices_file_path.exists() and not cfg.save_overwrite:
raise OlmoConfigurationError(f"{indices_file_path} already exists, use --save_overwrite to overwrite")
indices_file_path.parent.mkdir(exist_ok=True, parents=True)
indices_file = gzip.open(indices_file_path, "wt")
# Initialize the model.
log.info("Building model...")
olmo_model = Olmo(cfg.model)
log.info(f"Total number of parameters: {olmo_model.num_params():,d}")
log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embedding=False):,d}")
log.info(f"Peak GPU Memory (MB) before FSDP: {int(peak_gpu_memory() or 0)}")
使用配置中的模型配置构建一个 Olmo 模型(类),并输出一些有关模型参数和 GPU 内存的信息。
# Wrap the model in FSDP.
log.info("Wrapping model with FDSP...")
wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
if version.parse(torch.__version__) >= version.parse("2.1.0"):
# This prevents any parameters from being initialized twice
def dummy_init_fn(module: torch.nn.Module) -> None:
module.to_empty(device=get_default_device())
param_init_fn = dummy_init_fn
else:
param_init_fn = None
fsdp_model = FSDP(
olmo_model,
sharding_strategy=cfg.fsdp.sharding_strategy,
mixed_precision=cfg.fsdp_precision,
auto_wrap_policy=wrap_policy,
use_orig_params=cfg.fsdp.use_orig_params, # needed for compile and some of our optimizer/parameter metrics
limit_all_gathers=True,
device_id=get_local_rank(),
param_init_fn=param_init_fn,
)
7.optim = build_optimizer(cfg, fsdp_model)
scheduler = build_scheduler(cfg)
创建优化器和学习率调度器。根据配置 cfg 中的参数和之前创建的 FSDP 模型,构建用于优化模型参数的优化器以及用于动态调整学习率的调度器。
8.数据索引
# Data indices file.
indices_file: Optional[TextIO] = None
if cfg.save_data_indices:
indices_file_path = Path(cfg.save_folder) / f"data-indices/rank{get_global_rank()}.tsv.gz"
if indices_file_path.exists() and not cfg.save_overwrite:
raise OlmoConfigurationError(f"{indices_file_path} already exists, use --save_overwrite to overwrite")
indices_file_path.parent.mkdir(exist_ok=True, parents=True)
indices_file = gzip.open(indices_file_path, "wt")
如果配置中设置了保存数据索引的选项 (cfg.save_data_indices),则创建一个文件用于保存数据索引。如果文件已存在且不允许覆盖保存,则引发配置错误。文件路径包括进程的全局排名,以避免多个进程写入相同的文件。
数据索引在深度学习中通常用于记录训练数据的索引或相关信息,具体的作用取决于使用它们的上下文。以下是一些可能的用途:
数据分析和可视化: 通过保存数据索引,你可以在训练后进行数据分析,了解哪些样本被用于训练,它们的分布情况等。这对于诊断模型在不同类型的数据上的表现以及检查数据集的质量非常有帮助。
复现实验: 在科学研究中,复现实验是非常重要的。通过保存数据索引,可以确保在不同时间或不同环境中能够重现相同的训练集。这对于验证模型在不同条件下的一致性和稳定性非常有用。
Debugging: 在调试模型时,有时需要检查特定样本的输入或输出。保存数据索引可以帮助你追踪具体的数据样本,并在需要时检查输入、输出或标签。
分布式训练中的同步: 在分布式训练环境中,每个进程可能处理不同的数据子集。通过保存每个进程处理的数据索引,可以确保每个进程都在同一时间处理相同的数据。
控制整个训练流程,包括保存和加载检查点,编译模型,以及执行实际的训练过程。
with Trainer(
cfg=cfg,
epoch=cfg.epoch,
model=olmo_model,
fsdp_model=fsdp_model,
optim=optim,
scheduler=scheduler,
train_loader=train_loader,
device=device,
evaluators=evaluators,
indices_file=indices_file,
) as trainer:
if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None:
checkpoint_type = (
CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
)
# We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever).
log.info("Saving pre-train checkpoint...")
checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(checkpoint_type=checkpoint_type)
log.info(f"Checkpoint saved to {checkpoint_path}")
# And they we verify that we can load it.
log.info("Attempting to load pre-train checkpoint...")
trainer.restore_checkpoint(
checkpoint_path, checkpoint_type=checkpoint_type, local_cache=local_checkpoint_cache
)
log.info("Checkpoint successfully loaded")
# NOTE: https://github.com/allenai/LLM/issues/233
# log.info("Removing pre-train checkpoint...")
# trainer.remove_checkpoint(checkpoint_type=checkpoint_type)
# log.info("Successfully removed checkpoint")
if cfg.load_path is not None:
log.info(f"Loading checkpoint from {cfg.load_path}...")
trainer.restore_checkpoint(
cfg.load_path,
load_optimizer_state=not cfg.reset_optimizer_state,
load_trainer_state=not cfg.reset_trainer_state,
sharded_checkpointer=cfg.load_path_sharded_checkpointer,
)
log.info("Checkpoint successfully loaded")
# If we have to, set a new scheduler:
if cfg.reset_optimizer_state and not cfg.reset_trainer_state:
trainer.scheduler = BoltOnWarmupScheduler.wrap(
trainer.scheduler,
trainer.global_step,
int(trainer.global_step + cfg.scheduler.t_warmup),
)
if cfg.force_save_unsharded:
log.info("Saving unsharded checkpoint...")
checkpoint_path, _ = trainer.save_checkpoint(checkpoint_type=CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
if cfg.compile is not None:
# TODO (epwalsh): trying to compile the whole train step results in a compile-time error from within
# the optimizer. We should investigate this further at some point.
# trainer.train_step = torch.compile(trainer.train_step, **cfg.compile.asdict())
trainer.train_batch = torch.compile(trainer.train_batch, **cfg.compile.asdict()) # type: ignore
# TODO (epwalsh): compiling the `eval_batch()` method is a little sketchy since the inputs will look
# different for different eval tasks. That might be okay, but it might not be.
# trainer.eval_batch = torch.compile(trainer.eval_batch, **cfg.compile.asdict()) # type: ignore
# Alternatively, could just do this:
# trainer.fsdp_model = torch.compile(trainer.fsdp_model, **cfg.compile.asdict())
if not cfg.dry_run:
log.info("Starting training...")
trainer.fit()
log.info("Training complete")
else:
log.info("Dry run complete")
if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None:
checkpoint_type = (
CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
)
# We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever).
log.info("Saving pre-train checkpoint...")
checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(checkpoint_type=checkpoint_type)
log.info(f"Checkpoint saved to {checkpoint_path}")
# And they we verify that we can load it.
log.info("Attempting to load pre-train checkpoint...")
trainer.restore_checkpoint(
checkpoint_path, checkpoint_type=checkpoint_type, local_cache=local_checkpoint_cache
)
log.info("Checkpoint successfully loaded")
# NOTE: https://github.com/allenai/LLM/issues/233
# log.info("Removing pre-train checkpoint...")
# trainer.remove_checkpoint(checkpoint_type=checkpoint_type)
# log.info("Successfully removed checkpoint")
如果不是干运行 (cfg.dry_run) 且允许保存预训练检查点且未指定加载路径,则执行以下操作:
根据配置的保存检查点数量来确定是分片保存还是整体保存。
提前保存一个检查点,以确保这个操作不会失败。
验证能够加载保存的检查点。
3. 继续训练
if cfg.load_path is not None:
log.info(f"Loading checkpoint from {cfg.load_path}...")
trainer.restore_checkpoint(
cfg.load_path,
load_optimizer_state=not cfg.reset_optimizer_state,
load_trainer_state=not cfg.reset_trainer_state,
sharded_checkpointer=cfg.load_path_sharded_checkpointer,
)
log.info("Checkpoint successfully loaded")
# If we have to, set a new scheduler:
if cfg.reset_optimizer_state and not cfg.reset_trainer_state:
trainer.scheduler = BoltOnWarmupScheduler.wrap(
trainer.scheduler,
trainer.global_step,
int(trainer.global_step + cfg.scheduler.t_warmup),
)
如果指定了加载路径 (cfg.load_path)(也就是说继续训练的状态下),则加载模型的检查点。根据配置,可以选择是否加载优化器和训练器的状态,并且可能使用分片的检查点加载器。如果需要重置优化器状态但不重置训练器状态,则设置一个新的调度器(scheduler)。
4. if cfg.force_save_unsharded:
log.info(“Saving unsharded checkpoint…”)
checkpoint_path, _ = trainer.save_checkpoint(checkpoint_type=CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
如果配置中强制保存未分片的检查点 (cfg.force_save_unsharded),则保存未分片的检查点。
5.trainer.train_batch = torch.compile(trainer.train_batch, **cfg.compile.asdict()) # type: ignore
如果配置中指定了编译选项 (cfg.compile),则尝试编译训练批次 (train_batch)。这可能是为了提高训练速度的一种优化手段
6.开始训练或干运行
if not cfg.dry_run:
log.info("Starting training...")
trainer.fit()
log.info("Training complete")
else:
log.info("Dry run complete")
如果不是干运行,则开始训练过程,调用 trainer.fit()。否则,打印干运行完成的信息。