20 Newsgroup文本分类-基于sklearn

本项目聚焦于通过机器学习的方法来进行文本自动分类,采用的是有监督的学习,根据已经标注好类别的文本语料进行特征提取、建模、训练,进而对未知样本进行预测。可用于此场景的分类模型有很多,例如贝叶斯、决策树、SVM、深度学习等。本项目中将会重点尝试几个有代表性的模型,并对其应用效果进行对比分析。

使用的数据集来自于业内著名的20 Newsgroups 数据集,包含20类标注好的样本,数据量共计约2万条记录。该数据集每篇文档均不长,即使同时使用多个类目数据合并起来进行建模,在单机上也可快速完成,因此具有很好的学习训练价值。

本项目涉及的是一个多分类问题,故可供选择的评估指标有macro-F1, micro-F1两种。通常情况下,为规避样本量不均衡带来的问题,业界更多会采用micro-F1作为多分类问题的评估指标。本项目中样本量相对均衡,理论上两种方法均可,我将选择**micro-F1**来做为最终的评估指标。
此外,为考核模型性能,还可以将运行时间作为一个参考指标。

一、数据获取,获取目标类目的训练数据与测试数据

from sklearn.datasets import fetch_20newsgroups
sample_cate = ['alt.atheism', 'soc.religion.christian','comp.graphics', 'sci.med', 'rec.sport.baseball']
newsgroups_train = fetch_20newsgroups(subset='train',categories=sample_cate,shuffle=True, random_state=42,remove = ('headers', 'footers', 'quotes'))
newsgroups_test = fetch_20newsgroups(subset='test', categories=sample_cate,shuffle=True, random_state=42,remove = ('headers', 'footers', 'quotes'))

print(len(newsgroups_train.data), len(newsgroups_test.data))

二、向量化处理

from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer(stop_words='english',lowercase=True)
train_vector = vectorizer.fit_transform(newsgroups_train.data)
print(train_vector.shape)

test_vector = vectorizer.transform(newsgroups_test.data)
print(test_vector.shape)

三、分类算法

MultinomialNB

from time import time
b = time()
from sklearn.naive_bayes import MultinomialNB
mnb_clf = MultinomialNB(alpha=.01, fit_prior = False)
mnb_clf.fit(train_vector, newsgroups_train.target)
# 预测
pred = mnb_clf.predict(test_vector)
# 评分
from sklearn import metrics
print(metrics.f1_score(newsgroups_test.target, pred, average='micro'))
print("time spent %f" % (time()-b))

print(mnb_clf.coef_.shape)

使用Grid Search优化参数

from sklearn.model_selection import GridSearchCV
b = time()
parameters = {'fit_prior':(True, False), 'alpha':(0.01,0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0)}
gs_clf = GridSearchCV(mnb_clf,parameters,n_jobs=-1)
gs_clf = gs_clf.fit(train_vector, newsgroups_train.target)
gs_pred = gs_clf.predict(test_vector)
print(metrics.f1_score(newsgroups_test.target, gs_pred, average='micro'))
print("time spent %f" % (time()-b))

print("best score: %f" % gs_clf.best_score_)  
for param_name in sorted(parameters.keys()):
    print("%s: %r" % (param_name, gs_clf.best_params_[param_name]))

 

你可能感兴趣的:(文本挖掘)