AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
DBPedia: https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz
Sogou news: https://s3.amazonaws.com/fast-ai-nlp/sogou_news_csv.tgz
Yelp Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz
Yelp Review Full: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz
Yahoo! Answers: https://s3.amazonaws.com/fast-ai-nlp/yahoo_answers_csv.tgz
Amazon Review Full: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_full_csv.tgz
Amazon Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz
1.数据集加载
2.读取标签和数据
3.创建word2id
3.1统计词频
3.2加入 pad:0,unk:1创建word2id
4.将数据转化成id
from torch.utils import data
import os
import csv
import nltk
import numpy as np
class AG_Data(data.DataLoader):
def __init__(self,data_path,min_count,max_length,n_gram=1,word2id = None,uniwords_num=0):
self.path = os.path.abspath(".")
if "data" not in self.path:
self.path += "/data"
self.n_gram = n_gram
self.load(data_path)# 数据集加载,读取标签和数据
if word2id==None:
self.get_word2id(self.data,min_count)# 得到word2id
else:
self.word2id = word2id
self.uniwords_num = uniwords_num
self.data = self.convert_data2id(self.data,max_length)# 将文本中的词都转化成id
self.data = np.array(self.data)
self.y = np.array(self.y)
# 数据集加载,读取标签和数据
def load(self, data_path,lowercase=True):
self.label = []
self.data = []
with open(self.path+data_path,"r") as f:
datas = list(csv.reader(f,delimiter=',', quotechar='"'))
for row in datas:
self.label.append(int(row[0]) - 1)
txt = " ".join(row[1:])
if lowercase:
txt = txt.lower()
txt = nltk.word_tokenize(txt) # 将句子转化成词
new_txt = []
for i in range(0, len(txt)):
for j in range(self.n_gram): # 添加n-gram词
if j <= i:
new_txt.append(" ".join(txt[i - j:i + 1]))
self.data.append(new_txt)
self.y = self.label
# 得到word2id
def get_word2id(self, datas, min_count=3):
word_freq = {}
for data in datas: # 首先统计词频,后续通过词频过滤低频词
for word in data:
if word_freq.get(word) != None:
word_freq[word] += 1
else:
word_freq[word] = 1
word2id = {"" : 0, "" : 1}
for word in word_freq: # 首先构建uni-gram词,因为不需要hash
if word_freq[word] < min_count or " " in word:
continue
word2id[word] = len(word2id)
self.uniwords_num = len(word2id)
for word in word_freq: # 构建2-gram以上的词,需要hash
if word_freq[word] < min_count or " " not in word:
continue
word2id[word] = len(word2id)
self.word2id = word2id
# 将文本中的词都转化成id
def convert_data2id(self, datas, max_length):
for i, data in enumerate(datas):
for j, word in enumerate(data):
if " " not in word:
datas[i][j] = self.word2id.get(word, 1)
else:
datas[i][j] = self.word2id.get(word, 1) % 100000 + self.uniwords_num # hash函数
datas[i] = datas[i][0:max_length] + [0] * (max_length - len(datas[i]))
return datas
def __getitem__(self, idx):
X = self.data[idx]
y = self.y[idx]
return X, y
def __len__(self):
return len(self.label)
if __name__=="__main__":
ag_data = AG_Data("/AG/train.csv",3,100)
print (ag_data.data.shape)
print (ag_data.data[-20:])
print (ag_data.y.shape)
print (len(ag_data.word2id))
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Fasttext(nn.Module):
def __init__(self,vocab_size,embedding_size,max_length,label_num):
super(Fasttext,self).__init__()
self.embedding =nn.Embedding(vocab_size,embedding_size)
self.avg_pool = nn.AvgPool1d(kernel_size=max_length,stride=1)
self.fc = nn.Linear(embedding_size, label_num)
def forward(self, x):
out = self.embedding(x) # batch_size*length*embedding_size bs*100*200
out = out.transpose(1, 2).contiguous() # batch_size*embedding_size*length bs*200*100
out = self.avg_pool(out).squeeze() # batch_size*embedding_size*1
out = self.fc(out) # batch_size*label_num
return out
if __name__=="__main__":
fasttext = Fasttext(100,200,100,4)
x = torch.Tensor(np.zeros([64,100])).long()
out = fasttext(x)
print (out.size())
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from model import Fasttext
from data import AG_Data
import numpy as np
from tqdm import tqdm
import config as argumentparser
config = argumentparser.ArgumentParser()
torch.manual_seed(config.seed)
if config.cuda and torch.cuda.is_available():
torch.cuda.set_device(config.gpu)
def get_test_result(data_iter,data_set):
# 生成测试结果
model.eval()
true_sample_num = 0
for data, label in data_iter:
if config.cuda and torch.cuda.is_available():
data = data.cuda()
label = label.cuda()
else:
data = torch.autograd.Variable(data).long()
out = model(data)
true_sample_num += np.sum((torch.argmax(out, 1) == label).cpu().numpy())
acc = true_sample_num / data_set.__len__()
return acc
training_set = AG_Data("/AG/train.csv",min_count=config.min_count,
max_length=config.max_length,n_gram=config.n_gram)
training_iter = torch.utils.data.DataLoader(dataset=training_set,
batch_size=config.batch_size,
shuffle=True,
num_workers=0)
test_set = AG_Data(data_path="/AG/test.csv",min_count=config.min_count,
max_length=config.max_length,n_gram=config.n_gram,word2id=training_set.word2id,
uniwords_num=training_set.uniwords_num)
test_iter = torch.utils.data.DataLoader(dataset=test_set,
batch_size=config.batch_size,
shuffle=False,
num_workers=0)
model = Fasttext(vocab_size=training_set.uniwords_num+100000,embedding_size=config.embed_size,
max_length=config.max_length,label_num=config.label_num)
if config.cuda and torch.cuda.is_available():
model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
loss = -1
for epoch in range(config.epoch):
model.train()
process_bar = tqdm(training_iter)
for data, label in process_bar:
if config.cuda and torch.cuda.is_available():
data = data.cuda()
label = label.cuda()
else:
data = torch.autograd.Variable(data).long()
label = torch.autograd.Variable(label).squeeze()
out = model(data)
loss_now = criterion(out, autograd.Variable(label.long()))
if loss == -1:
loss = loss_now.data.item()
else:
loss = 0.95*loss+0.05*loss_now.data.item()
process_bar.set_postfix(loss=loss_now.data.item())
process_bar.update()
optimizer.zero_grad()
loss_now.backward()
optimizer.step()
test_acc = get_test_result(test_iter, test_set)
print("The test acc is: %.5f" % test_acc)
输出测试集准确率: