一、原材料:
1、预训练好的 embedding 文件,这里使用的embedding文件是 glove.6B.300d.txt,该文件可以从网上下载。
2、语料,也就是将要被数字化的文本,这里使用的是 train.tsv、dev.tsv 和 test.tsv 三个文件,train.tsv 和 dev.tsv 这两个文件的数据格式如下:
sentence label
hide new secretions from the parental units 0
contains no wit , only labored gags 0
that loves its characters and communicates something rather beautiful about human nature 1
remains utterly satisfied to remain the same throughout 0
on the worst revenge-of-the-nerds clichés the filmmakers could dredge up 0
that 's far too tragic to merit such superficial treatment 0
demonstrates that the director of such hollywood blockbusters as patriot games can still turn out a small , personal film with an emotional wallop . 1
of saucy 1
a depressed fifteen-year-old 's suicidal poetry 0
test.tsv 文件的格式如下:
index sentence
0 uneasy mishmash of styles and genres .
1 this film 's relationship to actual tension is the same as what christmas-tree flocking in a spray can is to actual snow : a poor -- if durable -- imitation .
2 by the end of no such thing the audience , like beatrice , has a watchful affection for the monster .
3 director rob marshall went out gunning to make a great one .
4 lathan and diggs have considerable personal charm , and their screen rapport makes the old story seem new .
5 a well-made and often lovely depiction of the mysteries of friendship .
6 none of this violates the letter of behan 's book , but missing is its spirit , its ribald , full-throated humor .
7 although it bangs a very cliched drum at times , this crowd-pleaser 's fresh dialogue , energetic music , and good-natured spunk are often infectious .
二、完整 demo
import torch
from torchtext import data
from torchtext import vocab
from tqdm import tqdm
embedding_file = '/home/jason/Desktop/data/embeddings/glove/glove.6B.300d.txt'
path = '/home/jason/Desktop/data/SST-2/'
cache_dir = '.cache/'
batch_size = 6
vectors = vocab.Vectors(embedding_file, cache_dir)
text_field = data.Field(tokenize='spacy',
lower=True,
include_lengths=True,
fix_length=10)
label_field = data.LabelField(dtype=torch.long)
train, dev, test = data.TabularDataset.splits(path=path,
train='train.tsv',
validation='dev.tsv',
test='test.tsv',
format='tsv',
skip_header=True,
fields=[('text', text_field), ('label', label_field)])
text_field.build_vocab(train,
dev,
test,
max_size=25000,
vectors=vectors,
unk_init=torch.Tensor.normal_)
label_field.build_vocab(train, dev, test)
pretrained_embeddings = text_field.vocab.vectors
labels = label_field.vocab.vectors
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_iter, dev_iter, test_iter = data.BucketIterator.splits((train, dev, test),
batch_sizes=(batch_size, len(dev), len(test)),
sort_key=lambda x: len(x.text),
sort_within_batch=True,
repeat=False,
shuffle=True,
device=device
)
for step, batch in enumerate(tqdm(train_iter, desc="Iteration")):
print('>1: ', batch.text[1])
print('>2: ', batch.label)
break
三、执行结果如下:
Iteration: 0%| | 0/11225 [00:00, ?it/s]>1: (tensor([[ 26, 8340, 62, 4, 677, 115],
[ 549, 3, 1621, 594, 5, 175],
[ 8, 85, 10, 3, 148, 484],
[ 217, 9, 56, 213, 9, 6],
[ 128, 1111, 11676, 545, 1942, 157],
[ 273, 1429, 538, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1],
[ 1, 1, 1, 1, 1, 1]]), tensor([6, 6, 6, 5, 5, 5]))
>2: tensor([1, 1, 1, 0, 0, 0])
Iteration: 0%| | 0/11225 [00:00, ?it/s]