词典构建方法

一. 注意
此文章只是解释了hanlp的代码,纯属应用绝不应用于商业用途

二. 代码粘贴

# -*- coding:utf-8 -*-
# user: wbb
# Date: 2020-06-11
# 功能:词典
from typing import List, Dict, Union, Iterable

from model.common.structure import Serializable
from model.common.constant import PAD, UNK
import tensorflow as tf
from tensorflow.python.ops.lookup_ops import index_table_from_tensor, index_to_string_table_from_tensor


class Vocab(Serializable):##仿照字典的写法,同时增加了额外内容
    def __init__(self, idx_to_token: List[str] = None, token_to_idx: Dict = None, mutable=True, pad_token=PAD,
                 unk_token=UNK) -> None:
        super().__init__()
        if idx_to_token:## 如果存在id到token的转换词典
            ##获取token到id的转换词典
            t2i = dict((token, idx) for idx, token in enumerate(idx_to_token))
            ##如果已有token到id的转换词典则只需要更新词典即可
            if token_to_idx:
                t2i.update(token_to_idx)
            token_to_idx = t2i
        if token_to_idx is None:## 如果token到id的转换词典不存在则自定义如下词典
            token_to_idx = {}
            if pad_token:
                token_to_idx[pad_token] = len(token_to_idx)
            if unk_token:
                token_to_idx[unk_token] = len(token_to_idx)
        self.token_to_idx = token_to_idx
        self.idx_to_token: list = None
        self.mutable = mutable## 该词典是否可变
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.token_to_idx_table: tf.lookup.StaticHashTable = None
        self.idx_to_token_table = None

    def __setitem__(self, token: str, idx: int):## 只能更改可变词典
        assert self.mutable, 'Update an immutable Vocab object is not allowed'
        self.token_to_idx[token] = idx

    def __getitem__(self, key: Union[str, int, List]) -> Union[int, str, List]:
        if isinstance(key, str):# 已知value获取key
            return self.get_idx(key)
        elif isinstance(key, int):# 已知key获取value
            return self.get_token(key)
        elif isinstance(key, list):# 按list获取
            if len(key) == 0:
                return []
            elif isinstance(key[0], str):
                return [self.get_idx(x) for x in key]
            elif isinstance(key[0], int):
                return [self.get_token(x) for x in key]

    def __contains__(self, key: Union[str, int]):
        if isinstance(key, str):# 判断是否包含此value
            return key in self.token_to_idx
        elif isinstance(key, int):# 判断是否包含此索引
            return 0 <= key < len(self.idx_to_token)
        else:
            return False

    def add(self, token: str) -> int:
        assert self.mutable, 'It is not allowed to call add on an immutable Vocab'
        assert isinstance(token, str), f'Token type must be str but got {type(token)} from {token}'
        assert token, 'Token must not be None or length 0'
        idx = self.token_to_idx.get(token, None)##首先获取此token,如果不存在则增加并返回对应的索引
        if idx is None:
            idx = len(self.token_to_idx)
            self.token_to_idx[token] = idx
        return idx

    def update(self, tokens: Iterable[str]) -> None:## 不断增加token
        """
        Update the vocab with these tokens by adding them to vocab one by one.
        Parameters
        ----------
        tokens
        """
        assert self.mutable, 'It is not allowed to update an immutable Vocab'
        for token in tokens:
            self.add(token)

    def get_idx(self, token: str) -> int:# 获取该token对应的索引
        idx = self.token_to_idx.get(token, None)
        if idx is None:
            if self.mutable:
                idx = len(self.token_to_idx)
                self.token_to_idx[token] = idx
            else:
                idx = self.token_to_idx.get(self.unk_token, None)
        return idx

    def get_idx_without_add(self, token: str) -> int:# 没有的话直接返回unk_token的值
        idx = self.token_to_idx.get(token, None)
        if idx is None:
            idx = self.token_to_idx.get(self.unk_token, None)
        return idx

    def get_token(self, idx: int) -> str:
        if self.idx_to_token:
            return self.idx_to_token[idx]

        if self.mutable:
            for token in self.token_to_idx:
                if self.token_to_idx[token] == idx:
                    return token

    def has_key(self, token):
        return token in self.token_to_idx

    def __len__(self):
        return len(self.token_to_idx)

    def lock(self):
        if self.locked:
            return self
        self.mutable = False
        self.build_idx_to_token()
        self.build_lookup_table()
        return self

    def build_idx_to_token(self):## 依据token_to_idx构建idx_to_token
        max_idx = max(self.token_to_idx.values())
        self.idx_to_token = [None] * (max_idx + 1)
        for token, idx in self.token_to_idx.items():
            self.idx_to_token[idx] = token

    def build_lookup_table(self):
        tensor = tf.constant(self.idx_to_token, dtype=tf.string)
        self.token_to_idx_table = index_table_from_tensor(tensor, num_oov_buckets=1 if self.unk_idx is None else 0,
                                                          default_value=-1 if self.unk_idx is None else self.unk_idx)
        # self.idx_to_token_table = index_to_string_table_from_tensor(self.idx_to_token, self.safe_unk_token)

    def unlock(self):
        if not self.locked:
            return
        self.mutable = True
        self.idx_to_token = None
        self.idx_to_token_table = None
        self.token_to_idx_table = None
        return self

    @property
    def locked(self):
        return not self.mutable

    @property
    def unk_idx(self):
        if self.unk_token is None:
            return None
        else:
            return self.token_to_idx.get(self.unk_token, None)

    @property
    def pad_idx(self):
        if self.pad_token is None:
            return None
        else:
            return self.token_to_idx.get(self.pad_token, None)

    @property
    def tokens(self):
        return self.token_to_idx.keys()

    def __str__(self) -> str:
        return self.token_to_idx.__str__()

    def summary(self, verbose=True) -> str:
        # report = 'Length: {}\n'.format(len(self))
        # report += 'Samples: {}\n'.format(str(list(self.token_to_idx.keys())[:min(50, len(self))]))
        # report += 'Mutable: {}'.format(self.mutable)
        # report = report.strip()
        report = '[{}] = '.format(len(self))
        report += str(list(self.token_to_idx.keys())[:min(50, len(self))])
        if verbose:
            print(report)
        return report

    def __call__(self, some_token: Union[str, List[str]]) -> Union[int, List[int]]:
        if isinstance(some_token, list):
            indices = []
            for token in some_token:
                indices.append(self.get_idx(token))
            return indices
        else:
            return self.get_idx(some_token)

    def lookup(self, token_tensor: tf.Tensor) -> tf.Tensor:
        if self.mutable:
            self.lock()
        return self.token_to_idx_table.lookup(token_tensor)

    def to_dict(self) -> dict:
        idx_to_token = self.idx_to_token
        pad_token = self.pad_token
        unk_token = self.unk_token
        mutable = self.mutable
        items = locals().copy()
        items.pop('self')
        return items

    def copy_from(self, item: dict):
        for key, value in item.items():
            setattr(self, key, value)
        self.token_to_idx = {k: v for v, k in enumerate(self.idx_to_token)}
        if not self.mutable:
            self.build_lookup_table()

    def lower(self):
        self.unlock()
        token_to_idx = self.token_to_idx
        self.token_to_idx = {}
        for token in token_to_idx.keys():
            self.add(token.lower())
        return self

    @property
    def first_token(self):
        if self.idx_to_token:
            return self.idx_to_token[0]
        if self.token_to_idx:
            return next(iter(self.token_to_idx))
        return None

    def merge(self, other):
        for word, idx in other.token_to_idx.items():
            self.get_idx(word)

    @property
    def safe_pad_token(self) -> str:
        """
        Get the pad token safely. It always returns a pad token, which is the token
        closest to pad if not presented in the vocab.

        Returns
        -------
            str pad token
        """
        if self.pad_token:
            return self.pad_token
        if self.first_token:
            return self.first_token
        return PAD

    @property
    def safe_pad_token_idx(self) -> int:
        return self.token_to_idx.get(self.safe_pad_token, 0)

    @property
    def safe_unk_token(self) -> str:
        """
        Get the unk token safely. It always returns a unk token, which is the token
        closest to unk if not presented in the vocab.

        Returns
        -------
            str pad token
        """
        if self.unk_token:
            return self.unk_token
        if self.first_token:
            return self.first_token
        return UNK


def create_label_vocab() -> Vocab:
    return Vocab(pad_token=None, unk_token=None)

三. 重要操作
1.index_table_from_tensor

import tensorflow as tf
from tensorflow.python.ops.lookup_ops import index_table_from_tensor,index_to_string_table_from_tensor
mapping_strings = tf.constant(["emerson", "lake", "palmer"])
table = tf.contrib.lookup.index_table_from_tensor(
    mapping=mapping_strings, num_oov_buckets=1, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer"])
ids = table.lookup(features)
print(ids)

你可能感兴趣的:(机器学习)