sklearn学习记录

from sklearn.datasets import fetch_20newsgroups
categories = ['alt.atheism','soc.religion.christian','comp.graphics','sci.med']
twenty_train = fetch_20newsgroups(subset ='train',categories = categories,shuffle=True,random_state=42)
print(twenty_train.target_names)
['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']
from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer()
X_train_counts = count_vect.fit_transform(twenty_train.data)
print('训练数据共有{0}篇,词汇技术为{1}个'.format(X_train_counts.shape[0],X_train_counts.shape[1]))
count =count_vect.vocabulary_.get('algorithm')
print('algorithm的出现次数为{0}'.format(count))
训练数据共有2257篇,词汇技术为35788个
algorithm的出现次数为4690
from sklearn.feature_extraction.text import TfidfTransformer
tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
print(X_train_tfidf.shape)
(2257, 35788)

训练分类器

from sklearn.naive_bayes import MultinomialNB
clf = MultinomialNB().fit(X_train_tfidf,twenty_train.target)
print('分类器的相关信息:' )
print(clf)
分类器的相关信息:
MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)
docs_new =['Nvidia is awesome!']
X_new_counts = count_vect.transform(docs_new)
X_new_tfidf = tfidf_transformer.transform(X_new_counts)
predicted = clf.predict(X_train_tfidf)
for doc, category in zip(docs_new,predicted):
    print('%r => %s' % (doc,twenty_train.target_names[category]))
'Nvidia is awesome!' => comp.graphics
predicted_proba=clf.predict_proba(X_train_tfidf)
print(predicted_proba)

predicted_log_proba=clf.predict_log_proba(X_train_tfidf)
print(predicted_log_proba)
[[ 0.03794718  0.81396819  0.07085396  0.07723068]
 [ 0.07027236  0.56938479  0.14892825  0.2114146 ]
 [ 0.04774093  0.00406403  0.06583678  0.88235825]
 ..., 
 [ 0.03827809  0.01056658  0.79105485  0.16010048]
 [ 0.02160756  0.03685754  0.85394925  0.08758564]
 [ 0.00548145  0.004625    0.98054771  0.00934584]]
[[-3.27156014 -0.20583399 -2.64713445 -2.56095856]
 [-2.65537672 -0.56319881 -1.90429065 -1.55393414]
 [-3.04196613 -5.50557911 -2.72057663 -0.12515712]
 ..., 
 [-3.26287771 -4.5500588  -0.23438797 -1.83195364]
 [-3.83471201 -3.30069502 -0.15788351 -2.43513819]
 [-5.20638618 -5.37627811 -0.01964397 -4.67282442]]

建立Pipeline

from sklearn.pipeline import Pipeline
text_clf = Pipeline([('vect',CountVectorizer()),('tfidf',TfidfTransformer()),('clf',MultinomialNB())])
text_clf = text_clf.fit(twenty_train.data,twenty_train.target)
print(text_clf)
Pipeline(memory=None,
     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        strip...inear_tf=False, use_idf=True)), ('clf', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))])
import numpy as np
twenty_test = fetch_20newsgroups(subset ='test',categories = categories,shuffle=True,random_state=42)
docs_test = twenty_test.data
predicted = text_clf.predict(docs_test)
print('准确率为:')
print(np.mean(predicted==twenty_test.target))
准确率为:
0.834886817577

对以上结果的改进,换用其他的分类算法

from sklearn.linear_model import SGDClassifier
text_clf = Pipeline([('vect',CountVectorizer()),
                     ('tfidf',TfidfTransformer()),
                     ('clf',SGDClassifier(loss='hinge',
                                          penalty ='l2',
                                          alpha = 1e-3,
                                          max_iter=5,
                                          random_state=42)),])
text_clf = text_clf.fit(twenty_train.data,twenty_train.target)
import numpy as np
predicted = text_clf.predict(docs_test)
print('准确率为:')
print(np.mean(predicted==twenty_test.target))
准确率为:
0.912782956059

对分类器的性能进行分析

from sklearn import metrics
print('打印分类器性能指标:')
print(metrics.classification_report(twenty_test.target,predicted,target_names = twenty_test.target_names))
打印分类器性能指标:
                        precision    recall  f1-score   support

           alt.atheism       0.95      0.81      0.87       319
         comp.graphics       0.88      0.97      0.92       389
               sci.med       0.94      0.90      0.92       396
soc.religion.christian       0.90      0.95      0.93       398

           avg / total       0.92      0.91      0.91      1502
print('打印混淆矩阵:')
print(metrics.confusion_matrix(twenty_test.target,predicted))
打印混淆矩阵:
[[258  11  15  35]
 [  4 379   3   3]
 [  5  33 355   3]
 [  5  10   4 379]]

使用网格搜索来进行参数优化,找到最合适的参数

from sklearn.grid_search import GridSearchCV
parameters = {'vect__ngram_range':[(1,1),(1,2)],
             'tfidf__use_idf':(True,False),
             'clf__alpha':(1e-2,1e-3)}
gs_clf = GridSearchCV(text_clf,parameters)
print(gs_clf)
GridSearchCV(cv=None, error_score='raise',
       estimator=Pipeline(memory=None,
     steps=[('vect', CountVectorizer(analyzer='word', binary=False, decode_error='strict',
        dtype=, encoding='utf-8', input='content',
        lowercase=True, max_df=1.0, max_features=None, min_df=1,
        ngram_range=(1, 1), preprocessor=None, stop_words=None,
        strip...ty='l2', power_t=0.5, random_state=42, shuffle=True,
       tol=None, verbose=0, warm_start=False))]),
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'clf__alpha': (0.01, 0.001), 'tfidf__use_idf': (True, False), 'vect__ngram_range': [(1, 1), (1, 2)]},
       pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)
gs_clf = gs_clf.fit(twenty_train.data[:400], twenty_train.target[:400])
print(gs_clf.predict(['An apple a day keeps doctor away']))
[2]
print(twenty_train.target_names[gs_clf.predict(['An apple a day keeps doctor away'])[0]])
sci.med

print(‘最佳准确率:%r’ % (gs_clf.best_score_))

sorted:内置排序函数,返回排序副本

for param_name in sorted(parameters.keys()):
print(‘%s : %r’ % (param_name,gs_clf.best_params_[param_name]))

* %r 和 %s 的区别:*
%s 用str()方法处理对象
%r 用rper()方法处理对象,打印时能够重现它所代表的对象(rper() unambiguously recreate the object it represents)

a='sunday'
print('Today is %s' %a)
print('Today is %r' %a)
Today is sunday
Today is 'sunday'

你可能感兴趣的:(机器学习)