线性分类

线性分类

“查准率”、“查全率”

查准率(precision):算法挑出来的西瓜中有多少比例是好西瓜;

查全率(recall):所有的好西瓜中有多少比例被算法跳了出来。
查准率P与查全率R分别定义为:
线性分类_第1张图片
两者之间的关系:
查准率和查全率是一对矛盾的指标,一般说,当查准率高的时候,查全率一般很低;查全率高时,查准率一般很低。
在实际的模型评估中,单用Precision或者Recall来评价模型是不完整的,评价模型时必须用Precision/Recall两个值。有三种使用方法:平衡点(Break-Even Point,BEP)、F1度量、F1度量的一般化形式。(具体的可以上网具体了解)

F1-Score

F1-score: 是准确率与召回率的综合。 可以认为是平均效果。
公式:F1-score: 2TP/(2TP + FP + FN)

ROC

ROC全称是"受试者工作特征"曲线。
ROC是真正类率(召回率的另一种称呼)和假正类率(FPR)。FPR是被错误分为正类的负类实例比率。它等于1-真负类率(TNR)。
公式为:
线性分类_第2张图片

混淆矩阵

在人工智能中,混淆矩阵(confusion matrix)是可视化工具。

在机器学习领域,混淆矩阵(Confusion Matrix),又称为可能性矩阵或错误矩阵。混淆矩阵是可视化工具,特别用于监督学习,在无监督学习一般叫做匹配矩阵。在图像精度评价中,主要用于比较分类结果和实际测得值,可以把分类结果的精度显示在一个混淆矩阵里面。

混淆矩阵要表达的含义:

混淆矩阵的每一列代表了预测类别,每一列的总数表示预测为该类别的数据的数目;
每一行代表了数据的真实归属类别,每一行的数据总数表示该类别的数据实例的数目;每一列中的数值表示真实数据被预测为该类的数目。

Jupyter编程完成对手写体Mnist数据集中10个字符 (0-9)的分类识别

导入需求的库:

在这里插入代码片# 使用sklearn的函数来获取MNIST数据集、
from sklearn.datasets import fetch_openml
import numpy as np
import os
np.random.seed(42)
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False

定义导入数据集函数:


def sort_by_target(mnist):
    reorder_train=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[:60000])]))[:,1]
    reorder_test=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[60000:])]))[:,1]
    mnist.data[:60000]=mnist.data[reorder_train]
    mnist.target[:60000]=mnist.target[reorder_train]
    mnist.data[60000:]=mnist.data[reorder_test+60000]
    mnist.target[60000:]=mnist.target[reorder_test+60000]

导入数据集函数:

import time
a=time.time()
mnist=fetch_openml('mnist_784',version=1,cache=True)
mnist.target=mnist.target.astype(np.int8)
sort_by_target(mnist)
b=time.time()
print(b-a)

我的结果:

在这里插入图片描述
取出mnist数据集的数据:

X,y=mnist["data"],mnist["target"]

# 展示图片
def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap = mpl.cm.binary,
               interpolation="nearest")
    plt.axis("off")
some_digit = X[36000]
plot_digit(X[36000].reshape(28,28))

我的结果:
线性分类_第3张图片
定义mnist数据集中数字0-9展示功能函数:

def plot_digits(instances,images_per_row=10,**options):
    size=28
    image_pre_row=min(len(instances),images_per_row)
    images=[instances.reshape(size,size) for instances in instances]
    n_rows=(len(instances)-1) // image_pre_row+1
    row_images=[]
    n_empty=n_rows*image_pre_row-len(instances)
    images.append(np.zeros((size,size*n_empty)))
    for row in range(n_rows):
        rimages=images[row*image_pre_row:(row+1)*image_pre_row]
        row_images.append(np.concatenate(rimages,axis=1)
    image=np.concatenate(row_images,axis=0)
    plt.imshow(image,cmap=mpl.cm.binary,**options)
    plt.axis("off")

plt.figure(figsize=(9,9))
example_images=np.r_[X[:12000:600],X[13000:30600:600],X[30600:60000:590]]
plot_digits(example_images,images_per_row=10)
plt.show()

我的结果:
线性分类_第4张图片

你可能感兴趣的:(机器学习)