CMeKG代码解读(以项目为导向从零开始学习知识图谱)(三)

书接上文:
https://blog.csdn.net/chen_nnn/article/details/122814434https://blog.csdn.net/chen_nnn/article/details/122814434

目录

evaluate(): 

run_train():

load_model():

get_triples():


evaluate(): 

def evaluate(data, is_print, model4s, model4po):
    X, Y, Z = 1e-10, 1e-10, 1e-10
    for d in data:
        R = set([SPO(spo) for spo in extract_spoes(d['text'], model4s, model4po)])  # 模型提取出的三元组数目
        T = set([SPO(spo) for spo in d['spo_list']])  # 正确的三元组数目
        if is_print:
            print('text:', d['text'])
            print('R:', R)
            print('T:', T)
        X += len(R & T)  # 模型提取出的三元组数目中正确的个数
        Y += len(R)  # 模型提取出的三元组个数
        Z += len(T)  # 正确的三元组总数
    f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z

    return f1, precision, recall

该函数顾名思义是对模型的训练效果进行评价的函数,首先定义了三个变量,但是定义的这三个变量很有意思,不是0但是非常接近0,难道是考虑到这三个参数中有可能出现在分母,所以令其不能为0。同时这也印证了SPO类的返回得确实是三元组的形式。然后之后的代码就比较易懂了。R中保存了模型提取出的三元组的数目,T中保存了正确的三元组数目,并且都使用了set方法除去其中重复的情况。然后根据传入的参数,确定是否打印三元组的结果,并利用交集的方式计算出模型提取出的三元组数目中正确的个数。并由此计算出f1,precision(准确率)和recall(召回率)。

run_train():

def run_train():
    load_schema(config.PATH_SCHEMA)
    train_path = config.PATH_TRAIN
    all_data = load_data(train_path)
    random.shuffle(all_data)

    # 8:2划分训练集、验证集
    idx = int(len(all_data) * 0.8)
    train_data = all_data[:idx]
    valid_data = all_data[idx:]

    # train
    train_data_loader = IterableDataset(train_data, True)
    num_train_data = len(train_data)
    checkpoint = torch.load(config.PATH_MODEL)

    model4s = Model4s()
    model4s.load_state_dict(checkpoint['model4s_state_dict'])
    # model4s.cuda()

    model4po = Model4po()
    model4po.load_state_dict(checkpoint['model4po_state_dict'])
    # model4po.cuda()

    param_optimizer = list(model4s.named_parameters()) + list(model4po.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    lr = config.learning_rate
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    checkpoint = train(train_data_loader, model4s, model4po, optimizer)

    del train_data
    gc.collect()
    # save
    model_path = config.PATH_SAVE
    torch.save(checkpoint, model_path)
    print('saved!')

    # valid
    model4s.eval()
    model4po.eval()
    f1, precision, recall = evaluate(valid_data, True, model4s, model4po)
    print('f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall))

首先是加载训练函数所需使用的各项数据和路径,并将all_data打乱。然后根据所有数据的百分之八十为界限,前百分之八十用来训练模型,后百分之八十用来验证模型效果。然后对主体模型和谓语-客体模型进行训练,用load_state_dict方法(源码如下)进行训练后的模型的匹配。

r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.

Args:
    state_dict (dict): a dict containing parameters and
        persistent buffers.
    strict (bool, optional): whether to strictly enforce that the keys
        in :attr:`state_dict` match the keys returned by this module's
        :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

Returns:
    ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
        * **missing_keys** is a list of str containing the missing keys
        * **unexpected_keys** is a list of str containing the unexpected keys

Note:
    If a parameter or buffer is registered as ``None`` and its corresponding key
    exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
    ``RuntimeError``.
"""

然后是构建一个param_optimizer列表,该列表利用named_parameters()(源码如下)方法将里面的内容对应的键值对取出。

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
Example:
>>> for name, param in self.named_parameters():
>>>    if name in ['bias']:
>>>        print(param.size()) 

根据条件构建一个optimizer_grouped_parameters列表,列表是两个字典元素,第一个字典中包含两个键值对,键名:params,存储的值是p,存储条件是从param_optimizer中取出的n和p,n要满足no_decay中的数据都不会出现在n中,才将p保存,另一个键名:weight_decay,值是0.1,第二个字典中也包含两个键值对,键名:params,存储的值是p,存储条件是从param_optimizer中取出的n和p,n要满足no_decay中的数据只要其中一个出现在n中,就将p保存,另一个键名:weight_decay,值是0。

加载设定好的学习率,并根据刚才创建好的列表optimizer_grouped_parameters,开始学习。对学习后的模型使用load_state_dict()方法进行匹配。然后对将训练数据,主体模型,客体模型和优化器放到train()函数中开始训练。删除 train_data释放空间,gc.collect()命令可以回收没有被使用的空间。然后对训练好的模型进行保存。保存好之后对模型的训练效果进行评估,首先将两个模型都调整到评价模式,然后使用之前已经书写好的评价函数,进行计算并打印计算结果。

load_model():

def load_model():
    load_schema(config.PATH_SCHEMA)
    checkpoint = torch.load(config.PATH_MODEL, map_location='cpu')

    model4s = Model4s()
    model4s.load_state_dict(checkpoint['model4s_state_dict'])
    # model4s.cuda()

    model4po = Model4po()
    model4po.load_state_dict(checkpoint['model4po_state_dict'])
    # model4po.cuda()

    return model4s, model4po

 该函数的作用顾名思义是用来加载已经写好的模型,并匹配加载的checkpoint。torch.load使用方式如下:

load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
Loads an object saved with torch.save from a file.
torch.load uses Python's unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn't have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the map_location argument.
If map_location is a callable, it will be called once for each serialized storage with two arguments: storage and location. The storage argument will be the initial deserialization of the storage, residing on the CPU. Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to map_location. The builtin location tags are 'cpu' for CPU tensors and 'cuda:device_id' (e.g. 'cuda:2') for CUDA tensors. map_location should return either None or a storage. If map_location returns a storage, it will be used as the final deserialized object, already moved to the right device. Otherwise, torch.load will fall back to the default behavior, as if map_location wasn't specified.
If map_location is a torch.device object or a string containing a device tag, it indicates the location where all tensors should be loaded.
Otherwise, if map_location is a dict, it will be used to remap location tags appearing in the file (keys), to ones that specify where to put the storages (values).

翻译之后:加载(f,map_location=None,pickle_module=pickle,**pickle_load_args)

加载用torch保存的对象。从文件中保存。torch.load使用Python的解压功能,但特别处理张量下面的存储。它们首先在CPU上反序列化,然后移动到保存它们的设备。如果失败(例如,因为运行时系统没有特定的设备),将引发异常。但是,可以使用map_location参数将存储动态地重新映射到一组替代设备。

如果map_location是可调用的,则将为每个具有两个参数的序列化存储调用一次:storage和location。存储参数将是驻留在CPU上的存储的初始反序列化。每个序列化存储都有一个与之关联的位置标记,该标记标识保存它的设备,该标记是传递给map_location的第二个参数。内置的位置标签是cpu张量的“cpu”和cuda张量的“cuda:device_id”(例如“cuda:2”)。map_位置应返回None或storage。如果map_location返回一个存储,它将被用作最终的反序列化对象,并已移动到正确的设备上。否则,torch.load将返回默认行为,就像没有指定map_location一样。

如果map_location是torch。设备对象或包含设备标记的字符串,它指示应加载所有张量的位置。否则,如果map_location是dict,它将用于将文件中出现的位置标记(键)重新映射到指定存储位置(值)的位置标记。

get_triples():

def get_triples(content, model4s, model4po):
    if len(content) == 0:
        return []
    text_list = content.split('。')[:-1]
    res = []
    for text in text_list:
        if len(text) > 128:
            text = text[:128]
        triples = extract_spoes(text, model4s, model4po)
        res.append({
            'text': text,
            'triples': triples
        })
    return res

 首先判断,如果传入的content文本为空,则直接返回空列表,否则将content从一开始使用spilt方法根据“. ”进行切片。然后进入循环,取出其中一个元素,并取出128个字,根据前文中书写好的函数,判断找到返回text中存在的三元组的情况,并保存到res列表中,最后返回。

这个medical_re.py完结啦啊哈哈哈,撒花!!!!!!!!!!明天开一个新的文件,椰丝!!! 

你可能感兴趣的:(笔记,知识图谱,人工智能,python)