使用 torchtext + embedding 文件准备训练数据

一、原材料:

       1、预训练好的 embedding 文件,这里使用的embedding文件是 glove.6B.300d.txt,该文件可以从网上下载。
       2、语料,也就是将要被数字化的文本,这里使用的是 train.tsv、dev.tsv 和 test.tsv 三个文件,train.tsv 和 dev.tsv 这两个文件的数据格式如下:

sentence        label
hide new secretions from the parental units     0     
contains no wit , only labored gags     0     
that loves its characters and communicates something rather beautiful about human nature        1
remains utterly satisfied to remain the same throughout         0
on the worst revenge-of-the-nerds clichés the filmmakers could dredge up        0
that 's far too tragic to merit such superficial treatment      0     
demonstrates that the director of such hollywood blockbusters as patriot games can still turn out a small , personal film with an emotional wallop .    1
of saucy        1
a depressed fifteen-year-old 's suicidal poetry         0

test.tsv 文件的格式如下:

index   sentence
0       uneasy mishmash of styles and genres .
1       this film 's relationship to actual tension is the same as what christmas-tree flocking in a spray can is to actual snow : a poor -- if durable -- imitation .
2       by the end of no such thing the audience , like beatrice , has a watchful affection for the monster .
3       director rob marshall went out gunning to make a great one .
4       lathan and diggs have considerable personal charm , and their screen rapport makes the old story seem new .
5       a well-made and often lovely depiction of the mysteries of friendship .
6       none of this violates the letter of behan 's book , but missing is its spirit , its ribald , full-throated humor .
7       although it bangs a very cliched drum at times , this crowd-pleaser 's fresh dialogue , energetic music , and good-natured spunk are often infectious .

二、完整 demo

import torch

from torchtext import data
from torchtext import vocab
from tqdm import tqdm


embedding_file = '/home/jason/Desktop/data/embeddings/glove/glove.6B.300d.txt'
path = '/home/jason/Desktop/data/SST-2/'

cache_dir = '.cache/'
batch_size = 6
vectors = vocab.Vectors(embedding_file, cache_dir)

text_field = data.Field(tokenize='spacy',
                        lower=True,
                        include_lengths=True,
                        fix_length=10)
label_field = data.LabelField(dtype=torch.long)

train, dev, test = data.TabularDataset.splits(path=path,
                                              train='train.tsv',
                                              validation='dev.tsv',
                                              test='test.tsv',
                                              format='tsv',
                                              skip_header=True,
                                              fields=[('text', text_field), ('label', label_field)])

text_field.build_vocab(train,
                       dev,
                       test,
                       max_size=25000,
                       vectors=vectors,
                       unk_init=torch.Tensor.normal_)
label_field.build_vocab(train, dev, test)

pretrained_embeddings = text_field.vocab.vectors
labels = label_field.vocab.vectors

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_iter, dev_iter, test_iter = data.BucketIterator.splits((train, dev, test),
                                                             batch_sizes=(batch_size, len(dev), len(test)),
                                                             sort_key=lambda x: len(x.text),
                                                             sort_within_batch=True,
                                                             repeat=False,
                                                             shuffle=True,
                                                             device=device
                                                            )

for step, batch in enumerate(tqdm(train_iter, desc="Iteration")):
    print('>1: ', batch.text[1])
    print('>2: ', batch.label)
    break

三、执行结果如下:

Iteration:   0%|                                                                                         | 0/11225 [00:001:  (tensor([[   26,  8340,    62,     4,   677,   115],
        [  549,     3,  1621,   594,     5,   175],
        [    8,    85,    10,     3,   148,   484],
        [  217,     9,    56,   213,     9,     6],
        [  128,  1111, 11676,   545,  1942,   157],
        [  273,  1429,   538,     1,     1,     1],
        [    1,     1,     1,     1,     1,     1],
        [    1,     1,     1,     1,     1,     1],
        [    1,     1,     1,     1,     1,     1],
        [    1,     1,     1,     1,     1,     1]]), tensor([6, 6, 6, 5, 5, 5]))
>2:  tensor([1, 1, 1, 0, 0, 0])
Iteration:   0%|                                                                                         | 0/11225 [00:00

NOTE:因为文本被随机打乱,所以每次执行结果会不同!

你可能感兴趣的:(使用 torchtext + embedding 文件准备训练数据)