bert4torch是一款简洁的训练框架,经过半年的维护和使用已经越发完善,近期的工作主要是增加了很多实战示例,拿来就用还是很香了。不了解bert4torch可以通过前述两篇文章来浅尝一下~
bert4torch(参考bert4keras的pytorch实现)15 赞同 · 9 评论文章
bert4torch快速上手16 赞同 · 3 评论文章
新增了xlnet和t5_pegasus两个预训练模型
比如在做实体提取时候,想尝试下加入token的词性来看看是否能提升模型效果,这个时候就需要增加额外的embedding,使用时候直接传入layer_add_embs
参数即可
build_transformer_model(
config_path=config_path, # 模型的config文件地址
checkpoint_path=checkpoint_path, # 模型文件地址,默认值None表示不加载预训练模型
model='bert', # 加载的模型结构,这里Model也可以基于nn.Module自定义后传入
application='encoder', # 模型应用,支持encoder,lm和unilm格式
segment_vocab_size=2, # type_token_ids数量,默认为2,如不传入segment_ids则需设置为0
with_pool=False, # 是否包含Pool部分
with_nsp=False, # 是否包含NSP部分
with_mlm=False, # 是否包含MLM部分
return_model_config=False, # 是否返回模型配置参数
output_all_encoded_layers=False, # 是否返回所有hidden_state层
layer_add_embs=nn.Embedding(2, 768), # 自定义额外的embedding输入
)
在训练过程中想打印一些指标来观测训练集上的指标(默认会打印loss),很多时候这些指标还需要自定义,参考keras的实现,目前bert4torch也支持了,使用方式如下
'''
定义使用的loss、optimizer和metrics,这里支持自定义
'''
def eval(y_pred, y_true):
# 仅做示意
return {'rouge-1': random.random(), 'rouge-2': random.random(), 'rouge-l': random.random(), 'bleu': random.random()}
def f1(y_pred, y_true):
# 仅做示意
return random.random()
model.compile(
loss=nn.CrossEntropyLoss(), # 可以自定义Loss
optimizer=optim.Adam(model.parameters(), lr=2e-5), # 可以自定义优化器
scheduler=None, # 可以自定义scheduler
adversarial_train={'name': 'fgm'}, # 训练trick方案设置,支持fgm, pgd, gradient_penalty, vat
metrics=['accuracy', eval, {'f1': f1}] # loss等默认打印的字段无需设置,可多种方式自定义回调函数
)
很多时候由于各种原因(显存不足,意外断电)等情况,虽然你保存的最优模型,但是你的优化器没保存,导致无法接着训练,bert4torch内置了简单的小函数,即可方便使用断点续训,训练进度条也会从上次断掉的地方重新开始记录
# =======断点续训========
# 在Callback中的on_epoch_end()或on_batch_end()保存需要的参数
model.save_weights(save_path, prefix=None) # 保存模型权重
model.save_steps_params(save_path) # 保存训练进度参数,当前的epoch和step,断点续训使用
torch.save(optimizer.state_dict(), save_path) # 保存优化器,断点续训使用
# 加载前序训练保存的参数
model.load_weights(save_path) # 加载模型权重
model.load_steps_params(save_path) # 加载训练进度参数,断点续训使用
state_dict = torch.load(save_path, map_location='cpu') # 加载优化器,断点续训使用
optimizer.load_state_dict(state_dict)
1. 每个Epoch会打印时间戳,方便查看训练的起止时间(仅仅记录训练时长总是需要换算)
打印Epoch同时记录时间戳
2. 句向量的获取简单配置即可get_pool_emb(hidden_state=None, pooler=None, attention_mask=None, pool_strategy='cls', custom_layer=None)
def get_pool_emb(hidden_state=None, pooler=None, attention_mask=None, pool_strategy='cls', custom_layer=None):
''' 获取句向量
'''
if pool_strategy == 'pooler':
return pooler
elif pool_strategy == 'cls':
if isinstance(hidden_state, (list, tuple)):
hidden_state = hidden_state[-1]
assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} strategy request tensor hidden_state'
return hidden_state[:, 0]
elif pool_strategy in {'last-avg', 'mean'}:
if isinstance(hidden_state, (list, tuple)):
hidden_state = hidden_state[-1]
assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} pooling strategy request tensor hidden_state'
hid = torch.sum(hidden_state * attention_mask[:, :, None], dim=1)
attention_mask = torch.sum(attention_mask, dim=1)[:, None]
return hid / attention_mask
elif pool_strategy in {'last-max', 'max'}:
if isinstance(hidden_state, (list, tuple)):
hidden_state = hidden_state[-1]
assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} pooling strategy request tensor hidden_state'
hid = hidden_state * attention_mask[:, :, None]
return torch.max(hid, dim=1)
elif pool_strategy == 'first-last-avg':
assert isinstance(hidden_state, list), f'{pool_strategy} pooling strategy request list hidden_state'
hid = torch.sum(hidden_state[1] * attention_mask[:, :, None], dim=1) # 这里不取0
hid += torch.sum(hidden_state[-1] * attention_mask[:, :, None], dim=1)
attention_mask = torch.sum(attention_mask, dim=1)[:, None]
return hid / (2 * attention_mask)
elif pool_strategy == 'custom':
# 取指定层
assert isinstance(hidden_state, list), f'{pool_strategy} pooling strategy request list hidden_state'
assert isinstance(custom_layer, (int, list, tuple)), f'{pool_strategy} pooling strategy request int/list/tuple custom_layer'
custom_layer = [custom_layer] if isinstance(custom_layer, int) else custom_layer
hid = 0
for i, layer in enumerate(custom_layer, start=1):
hid += torch.sum(hidden_state[layer] * attention_mask[:, :, None], dim=1)
attention_mask = torch.sum(attention_mask, dim=1)[:, None]
return hid / (i * attention_mask)
else:
raise ValueError('pool_strategy illegal')
3. 全局seed,固定随机种子一般会写几行简单的代码,但是太常用了,参考pytorch_lightning使用seed_everything(seed)
来固定随机数
def seed_everything(seed=None):
'''固定seed
'''
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min
if (seed is None) or not (min_seed_value <= seed <= max_seed_value):
random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max)
print(f"Global seed set to {seed}")
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed