torchtext 用法

官方文档:
torchtext包含两部分:

  • 数据处理实用程序
  • 流行的自然语言数据集

torchtext.data 的函数列表如下

torchtext.data.__init__

基于torchtext的常见的数据预处理流程

1.定义Field:声明如何处理数据
2.定义Dataset:得到数据集,此时数据集里每一个样本是一个 经过 Field声明的预处理 预处理后的 wordlist
3.建立vocab:在这一步建立词汇表,词向量(word embeddings)
4.构造迭代器:构造迭代器,用来分批次训练模型

一、torchtext.data

现在torchtext.data,torchtext.dataset,torchtext.vocab都放在torchtext.legacy当中了。

1. torchtext.legacy.data.field
torchtext.data.field
1.1 RawField 所有Field类的基类
RawField 定义通用数据类型。

每个数据集都由一种或多种类型的数据组成。 例如,文本分类数据集包含句子及其类别,而机器翻译数据集包含两种语言文本的配对示例。
这些类型的数据都由一个 RawField 对象表示。

  • preprocessing :预处理, 在创建示例之前将应用于使用此字段的示例的管道。默认:None。
  • postprocessing:后处理, 在分配给批处理之前将应用于使用此字段的示例列表的管道。功能签名:(batch(list)) -> object。 默认值:None。
  • is_target:该字段是否为目标变量。 影响批次迭代。 默认值:False
  • RawField 源码
class RawField(object):
    def __init__(self, preprocessing=None, postprocessing=None, is_target=False):
        self.preprocessing = preprocessing
        self.postprocessing = postprocessing
        self.is_target = is_target

    def preprocess(self, x):
        """ 如果 指定preprocessing ,则返回预处理后的数据 """
        if self.preprocessing is not None:
            return self.preprocessing(x)
        else:
            return x

    def process(self, batch, *args, **kwargs):
        """ 基于指定的postprocessing,进行数据批处理,并返回批处理后的结果
        """
        if self.postprocessing is not None:
            batch = self.postprocessing(batch)
        return batch

1.2 Field

定义数据类型,将数据转换为tensor的类
Field 可以将常见文本处理成Tensor,可以用于模型的输入。
参数:

  • sequential:输入的数据是否是序列数据,默认是True. 如果sequential = False, 则 不可以输入词语切分参数 tokenization。
  • use_vocab:是否使用Vocab 对象,默认是True. 如果为 False,则该字段中的数据应该已经是数值化对象(word_to_idx)。
  • init_token:起始标记,默认是None。
  • eos_token:结束标记,默认是None。
  • fix_length:将序列填充至指定长度,默认是None。
  • dtype:数据类型,默认是torch.long。
  • preprocessing:前处理,数据进行分词化处理(tokenize)之后,数值化(word_to_idx)前,进行数据前处理,默认是None。许多数据集用自定义预处理方法替换此属性。
  • postprocessing:后处理,进行数值化(word_to_idx)之后的数据处理,基于构建好的vocab,对一个batch的数据进行数据处理,数据处理之后再将其转换为Tensor,默认是None。
  • lower:将文本字母全部转换为小写,默认是False.
  • tokenize:对序列数据进行分词化处理,默认是string.split。
    tokenize = 'spacy' 时,利用SpaCy 分词处理器处理文本。
  • tokenizer_language:要构建的tokenizer的语言。目前只有SpaCy支持各种语言。
  • include_lengths:是否返回填充的小批量的元组和包含每个示例长度的列表,或仅返回填充小批量。默认是False。
  • batch_first:是否先生成batch维度的张量,默认是False。
  • pad_token:填充字符串标记。 默认值:“”。
  • unk_token:表示 OOV 词的字符串标记。 默认值: ""。
  • pad_first:在文本开头以pad填充。
  • truncate_first:在开头做序列的截断。 默认值:False。
  • stop_words:预处理阶段,分词的时候丢掉指定的词,默认是None
  • is_target:该字段是否为目标变量。影响批次迭代。 默认值:False。
  • Field 源码
from torchtext.legacy.vocab import Vocab, SubwordVocab
class Field(RawField): # 继承RawField类
    vocab_cls = Vocab
    ignore = ['dtype', 'tokenize']
    # Field初始化
    def __init__(self, sequential=True, use_vocab=True, init_token=None,
                 eos_token=None, fix_length=None, dtype=torch.long,
                 preprocessing=None, postprocessing=None, lower=False,
                 tokenize=None, tokenizer_language='en', include_lengths=False,
                 batch_first=False, pad_token="", unk_token="",
                 pad_first=False, truncate_first=False, stop_words=None,
                 is_target=False):
        self.sequential = sequential
        self.use_vocab = use_vocab
        self.init_token = init_token
        self.eos_token = eos_token
        self.unk_token = unk_token
        self.fix_length = fix_length
        self.dtype = dtype
        self.preprocessing = preprocessing
        self.postprocessing = postprocessing
        self.lower = lower # 将所有文本字母转换成小写
        # tokenize, tokenizer_language 两个参数要一起指定,目前只有spacy支持各种语言。
        # store params to construct tokenizer for serialization
        # in case the tokenizer isn't picklable (e.g. spacy)
        self.tokenizer_args = (tokenize, tokenizer_language)
        self.tokenize = get_tokenizer(tokenize, tokenizer_language)
        self.include_lengths = include_lengths
        self.batch_first = batch_first
        self.pad_token = pad_token if self.sequential else None
        self.pad_first = pad_first
        self.truncate_first = truncate_first
        try:
            self.stop_words = set(stop_words) if stop_words is not None else None
        except TypeError:
            raise ValueError("Stop words must be convertible to a set")
        self.is_target = is_target

常见用法
TEXT = data.Field(tokenize='spacy',tokenizer_language='en_core_web_sm', dtype = torch.float)

TEXT = data.Field(batch_first=True, eos_token='')

self.TEXT = data.Field(init_token='', eos_token='', lower=True, tokenize='spacy', fix_length=16)
self.LABEL = data.Field(sequential=False, unk_token=None)

  • 一些对象类型(譬如,文件对象)不能进行 pickle。处理这种不能 pickle 的对象的实例属性时可以使用特殊的方法(__ getstate__() 和 __setstate__() )来修改类实例的状态。
  • __ getstate__ 与 __ setstate__ 两个魔法方法分别用于Python 对象的序列化(pickle.dump)与反序列化(pickle.load)
# 反序列化
    def __getstate__(self):
        str_type = dtype_to_attr(self.dtype)
        if is_tokenizer_serializable(*self.tokenizer_args):
            tokenize = self.tokenize
        else:
            # signal to restore in `__setstate__`
            tokenize = None
        attrs = {k: v for k, v in self.__dict__.items() if k not in self.ignore}
        attrs['dtype'] = str_type
        attrs['tokenize'] = tokenize
        return attrs
# 序列化
    def __setstate__(self, state):
        state['dtype'] = getattr(torch, state['dtype'])
        if not state['tokenize']:
            state['tokenize'] = get_tokenizer(*state['tokenizer_args'])
        self.__dict__.update(state)
  • 数据预处理方法
    def preprocess(self, x):
        # 如果数据是序列数据,就需要进行分词处理
        if self.sequential and isinstance(x, str):
            x = self.tokenize(x.rstrip('\n'))
        # 是否将全部字母变成小写
        if self.lower:
            x = Pipeline(str.lower)(x)
        # 删除输入数据的 stop_words
        if self.sequential and self.use_vocab and self.stop_words is not None:
            x = [w for w in x if w not in self.stop_words]
        # 用户自定义的数据预处理方法
        if self.preprocessing is not None:
            return self.preprocessing(x)
        else:
            return x
  • 数据处理
    主要读取每个batch的数据进行两个操作:
  1. 每个batch的pad填充
  2. 对pad之后的数据进行数字化(word_to_idx)
    def process(self, batch, device=None):
        padded = self.pad(batch)
        tensor = self.numericalize(padded, device=device)
        return tensor
  • pad 函数
  • 常用参数:fix_length ,pad_first ,init_token ,eos_token
    def pad(self, minibatch):
        minibatch = list(minibatch)
        # 如果不是序列化的数据,不做任何操作
        if not self.sequential:
            return minibatch
        # 如果有提供fix_length ,则最大长度为fix_length ,否则最大长度为该batch中的最大序列长度
        if self.fix_length is None:
            max_len = max(len(x) for x in minibatch)
        else:
            max_len = self.fix_length + (
                self.init_token, self.eos_token).count(None) - 2
        # 创建两个空列表,padded 用于存放pad处理之后的数据,lengths 用于存放pad处理之后的数据的长度。
        padded, lengths = [], []
        # 对该batch中的每个数据进行pad填充
        for x in minibatch:
            # 如果pad_first = True, 则填充后的结构为 pad + 起始词 + 数据 + 终止词
            if self.pad_first:
                padded.append(
                    [self.pad_token] * max(0, max_len - len(x))
                    + ([] if self.init_token is None else [self.init_token])
                    + list(x[-max_len:] if self.truncate_first else x[:max_len])
                    + ([] if self.eos_token is None else [self.eos_token]))
            # 如果pad_first = False, 则填充后的结构为  起始词 + 数据 + 终止词 + pad
            else:
                padded.append(
                    ([] if self.init_token is None else [self.init_token])
                    + list(x[-max_len:] if self.truncate_first else x[:max_len])
                    + ([] if self.eos_token is None else [self.eos_token])
                    + [self.pad_token] * max(0, max_len - len(x)))
            lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
        # 如果include_lengths = True,则返回pad后的数据和lengths的结果,否则,只返回pad后的数据
        if self.include_lengths:
            return (padded, lengths)
        return padded
  • 构建词典vocab
    一般在NLP任务中,会先将text出现过的词构建成一个vocab, vocab主要实现三个功能:counter(x), 对text中的单词出现次数统计;idx_to_word, text中出现的词的列表;word_to_idx, 将word转换为vocab中的idx(vocab中的词编号)
    def build_vocab(self, *args, **kwargs):
        # 实例化一个counter计数器
        counter = Counter()
        sources = []
        for arg in args:
            if isinstance(arg, Dataset):
                sources += [getattr(arg, name) for name, field in
                            arg.fields.items() if field is self]
            else:
                sources.append(arg)
        for data in sources:
            for x in data:
                if not self.sequential:
                    x = [x]
                try:
                    counter.update(x)
                except TypeError:
                    counter.update(chain.from_iterable(x))
        specials = list(OrderedDict.fromkeys(
            tok for tok in [self.unk_token, self.pad_token, self.init_token,
                            self.eos_token] + kwargs.pop('specials', [])
            if tok is not None))
        self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
  • 将text中的单词转换成vocab中的idx
  • 计算机只能处理数字,而不能处理单词,因此需要将单词转化为数字(word_to_idx)
    def numericalize(self, arr, device=None):
        if self.include_lengths and not isinstance(arr, tuple):
            raise ValueError("Field has include_lengths set to True, but "
                             "input data is not a tuple of "
                             "(data batch, batch lengths).")
        if isinstance(arr, tuple):
            arr, lengths = arr
            lengths = torch.tensor(lengths, dtype=self.dtype, device=device)

        if self.use_vocab:
        """基于vocab,执行vocab.stoi[x],将word 转化为 idx"""
            if self.sequential:
                arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
            else:
                arr = [self.vocab.stoi[x] for x in arr]

            if self.postprocessing is not None:
                arr = self.postprocessing(arr, self.vocab)
        else:
            if self.dtype not in self.dtypes:
                raise ValueError(
                    "Specified Field dtype {} can not be used with "
                    "use_vocab=False because we do not know how to numericalize it. "
                    "Please raise an issue at "
                    "https://github.com/pytorch/text/issues".format(self.dtype))
            numericalization_func = self.dtypes[self.dtype]
            # It doesn't make sense to explicitly coerce to a numeric type if
            # the data is sequential, since it's unclear how to coerce padding tokens
            # to a numeric type.
            if not self.sequential:
                arr = [numericalization_func(x) if isinstance(x, str)
                       else x for x in arr]
            if self.postprocessing is not None:
                arr = self.postprocessing(arr, None)

        var = torch.tensor(arr, dtype=self.dtype, device=device)

        if self.sequential and not self.batch_first:
            var.t_()
        if self.sequential:
            var = var.contiguous()

        if self.include_lengths:
            return var, lengths
        return var
  • Field类源码主要功能如下所示:


    Field类总结
1.3 LabelField 源码
"""LabelField 继承Field类,
并对sequential,unk_token,is_target 参数进行重写,
用于标签数据处理。"""
class LabelField(Field):
    def __init__(self, **kwargs):
        # whichever value is set for sequential, unk_token, and is_target
        # will be overwritten
        kwargs['sequential'] = False
        kwargs['unk_token'] = None
        kwargs['is_target'] = True

        super(LabelField, self).__init__(**kwargs)
2. torchtext.legacy.vocab
  • 可以拿来用的预训练词向量
pretrained_aliases

你可能感兴趣的:(torchtext 用法)