梯度下降实现SVM多分类+最详细的数学推导+Python实战(鸢尾花数据集)! |
支持向量机(Support Vector Machine, SVM)的基本模型是在特征空间上找到最佳的分离超平面使得训练集上正负样本间隔最大。SVM的目标是寻找一个最优化超平面在空间中分割两类数据,这个最优化超平面需要满足的条件是:离其最近的点到其的距离最大化,这些点被称为支持向量。SVM是用来解决二分类问题的有监督学习算法,同时它可以通过one-vs-all策略应用到多分类问题中。本文主要介绍如何使用梯度下降法对SVM多分类问题进行优化。
假设数据集 X ∈ R k × n \mathbf{X} \in \mathrm{R}^{k \times n} X∈Rk×n, n n n 为训练样本的个数, k k k 为每个样本的维度。另外注意:下面使用的是L2-SVM!
L ( w c , b c ) = ∑ c C [ 1 2 w c T w c + λ ∑ i N max { 0 , 1 − y i c ( w c T x i + b c ) } 2 ] = 1 2 ∑ c C w c T w c + λ ∑ c C ∑ i N max { 0 , 1 − y i c ( w c T x i + b c ) } 2 (1) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) &=\sum_{c}^{C}\left[\frac{1}{2}\boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{i}^{N} \max \left\{0,1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right\}^{2}\right] \\ &=\frac{1}{2}\sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c}^{C} \sum_{i}^{N} \max \left\{0,1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right\}^2 \tag{1} \end{aligned} L(wc,bc)=c∑C[21wcTwc+λi∑Nmax{0,1−yic(wcTxi+bc)}2]=21c∑CwcTwc+λc∑Ci∑Nmax{0,1−yic(wcTxi+bc)}2(1)
首先,当 1 − y i ( w T x i + b ) < 0 1-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)<0 1−yi(wTxi+b)<0 的样本,此时相当于分类正确的情况,不需要加上 Hinge-Loss,因此我们有如下:
L ( w c , b c ) = 1 2 ∑ c C w c T w c = tr ( W T W ) (2) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) = \frac{1}{2}\sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}=\operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{W}\right) \tag{2} \end{aligned} L(wc,bc)=21c∑CwcTwc=tr(WTW)(2) ∂ L ( w c , b c ) ∂ w c = w c (3) \begin{aligned} \frac{\partial \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c})}{\partial \boldsymbol w_{c}} =\boldsymbol w_{c} \tag{3} \end{aligned} ∂wc∂L(wc,bc)=wc(3)
其次,当 1 − y i ( w T x i + b ) > 0 1-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)>0 1−yi(wTxi+b)>0 的样本,此时相当于分类不正确的情况,需要加上 Hinge-Loss,因此我们有如下:
L ( w c , b c ) = ∑ c C [ 1 2 w c T w c + λ ∑ i N max { 0 , 1 − y i c ( w c T x i + b c ) } 2 ] = 1 2 ∑ c C w c T w c + λ ∑ c C ∑ i N ∥ 1 − y i c ( w c T x i + b c ) ∥ 2 = 1 2 ∑ c C w c T w c + λ ∑ c C ∑ i N [ 1 + ( w c T x i + b c ) 2 − 2 y i c ( w c T x i + b c ) ] = 1 2 ∑ c w c T w c + λ ∑ c = 1 C ∑ i = 1 N [ 1 + w c T x i x i T w c + b c 2 + 2 w c T x i b c − 2 y i c w c T x i − 2 y i c b c ] = 1 2 ∑ c w c T w c + λ ∑ c = 1 c [ n ( 1 + b c 2 ) + w c T ( ∑ i = 1 N x i x i T ) w c + 2 b c w c T ( ∑ i = 1 N x i ) − 2 w c T ( ∑ i = 1 N x i y i c ) − 2 ( ∑ i = 1 N y i c ) b c ] (4) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) &=\sum_{c}^{C}\left[\frac{1}{2}\boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{i}^{N} \max \left\{0,1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right\}^{2}\right] \\ &=\frac{1}{2}\sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c}^{C} \sum_{i}^{N}\left\|1-y_{i}^{c}\left(\boldsymbol w_{c}^{T} x_{i}+b_{c}\right)\right\|^{2} \\ &=\frac{1}{2} \sum_{c}^{C} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c}^{C} \sum_{i}^{N}\left[1+\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)^{2}-2 y_{i}^{c}\left(\boldsymbol w_{c}^{T} \boldsymbol x_{i}+b_{c}\right)\right] \\& =\frac{1}{2}\sum_{c} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c=1}^{C} \sum_{i=1}^{N}\left[1+\boldsymbol w_{c}^{T} \boldsymbol x_{i} \boldsymbol x_{i}^{T} \boldsymbol w_{c}+b_{c}^{2}+2 \boldsymbol w_{c}^{T} \boldsymbol x_{i} b_{c}-2 y_{i}^{c} \boldsymbol w_{c}^{T} \boldsymbol x_{i}-2 y_{i}^{c} b_{c}\right] \\&=\frac{1}{2}\sum_{c} \boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{c=1}^{c}\left[n\left(1+b_{c}^{2}\right)+\boldsymbol w_{c}^{T}\left(\sum_{i=1}^{N}\boldsymbol x_{i}\boldsymbol x_{i}^{T}\right) \boldsymbol w_{c}+2 b_{c} \boldsymbol w_{c}^{T}\left(\sum_{i=1}^{N} \boldsymbol x_{i}\right)-2 \boldsymbol w_{c}^{T}\left(\sum_{i=1}^{N} \boldsymbol x_{i} y_{i}^{c}\right)-2\left(\sum_{i=1}^{N} y_{i}^{c}\right) b_{c}\right] \tag{4} \end{aligned} L(wc,bc)=c∑C[21wcTwc+λi∑Nmax{0,1−yic(wcTxi+bc)}2]=21c∑CwcTwc+λc∑Ci∑N∥∥1−yic(wcTxi+bc)∥∥2=21c∑CwcTwc+λc∑Ci∑N[1+(wcTxi+bc)2−2yic(wcTxi+bc)]=21c∑wcTwc+λc=1∑Ci=1∑N[1+wcTxixiTwc+bc2+2wcTxibc−2yicwcTxi−2yicbc]=21c∑wcTwc+λc=1∑c[n(1+bc2)+wcT(i=1∑NxixiT)wc+2bcwcT(i=1∑Nxi)−2wcT(i=1∑Nxiyic)−2(i=1∑Nyic)bc](4)
然后,整理上面可以得到:
L ( w c , b c ) = 1 2 ∑ i w c T w c + λ ∑ i [ n ( 1 + b c 2 ) + w c T X X T w c + 2 b c w c T X E − 2 w c T X y c − 2 b c E T y c ] (5) \begin{aligned} \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c}) &= \frac{1}{2}\sum_{i}\boldsymbol w_{c}^{T} \boldsymbol w_{c}+\lambda \sum_{i}\left[n\left(1+b_{c}^{2}\right)+\boldsymbol w_{c}^{T} \mathbf{X} \mathbf{X}^{\mathrm{T}} \boldsymbol w_{c}+2 b_{c} \boldsymbol w_{c}^{T} \mathbf{X} \mathbf{E}-2 \boldsymbol w_{c}^{T} \mathbf{X} \mathbf{y}_{c}-2 b_{c} \mathbf{E}^{T} \mathbf{y}_{c}\right] \tag{5} \end{aligned} L(wc,bc)=21i∑wcTwc+λi∑[n(1+bc2)+wcTXXTwc+2bcwcTXE−2wcTXyc−2bcETyc](5)
此外,上述公式 ( 5 ) (5) (5) 还可以继续化简,这里只是提供一个思路!
L ( W , b ) = 1 2 tr ( W T W ) + λ [ n ( c + b T b ) + tr ( W T X X T W ) + 2 b T W T X E − 2 tr ( W T X Y T ) − 2 b T Y E ] (6) \begin{aligned} \mathcal{L}(\mathbf{\mathbf{W}}, \mathbf{b}) =\frac{1}{2}\operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{W}\right)+\lambda\left[n\left(c+\mathbf{b}^{\mathrm{T}} \mathbf{b}\right)+\operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{X} \mathbf{X}^{\mathrm{T}} \mathbf{W}\right)+2 \mathbf{b}^{\mathrm{T}} \mathbf{W}^{\mathrm{T}} \mathbf{X} \mathbf{E}-2 \operatorname{tr}\left(\mathbf{W}^{\mathrm{T}} \mathbf{X} \mathbf{Y}^{\mathrm{T}}\right)-2 \mathbf{b}^{\mathrm{T}} \mathbf{Y} \mathbf{E}\right] \tag{6} \end{aligned} L(W,b)=21tr(WTW)+λ[n(c+bTb)+tr(WTXXTW)+2bTWTXE−2tr(WTXYT)−2bTYE](6)
目标函数 ( 5 ) (5) (5) 分别对 w c \boldsymbol w_{c} wc 和 b c b_c bc 求偏导数,可以得到如下:
∂ L ( w c , b c ) ∂ w c = w c + 2 X X T w c + 2 X E b c − 2 X y c (7) \frac{\partial \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c})}{\partial \boldsymbol w_{c}} = \boldsymbol w_{c}+2 \mathbf{X} \mathbf{X}^{\mathrm{T}} \boldsymbol w_{c}+2 \mathbf{X} \mathbf{E} b_{c}-2 \mathbf{X}{\mathbf y}_{c}\tag{7} ∂wc∂L(wc,bc)=wc+2XXTwc+2XEbc−2Xyc(7) ∂ L ( w c , b c ) ∂ b c = 2 n b c − 2 y c T E (8) \frac{\partial \mathcal{L}(\mathbf{\boldsymbol w_{c}}, b_{c})}{\partial b_{c}} = 2 n b_{c}-2{\mathbf y}_{c}^{T} \mathbf{E}\tag{8} ∂bc∂L(wc,bc)=2nbc−2ycTE(8)
X k × n = [ x 1 ( 1 ) x 1 ( 2 ) . x 1 ( n ) . . . . . . . . . . . . x k ( 1 ) x k ( 2 ) . x k ( n ) ] k × n {\mathbf{X} }_{k \times n}=\left[\begin{array}{cccc}{ x_{1}^{(1)}} & { x_{1}^{(2)}} & {.} & { x_{1}^{(n)}} \\ {.} & {.} & {.} & {.} \\ {.} & {.} & {.} & {.} \\ {.} & {.} & {.} & {.} \\ { x_{k}^{(1)}} & { x_{k}^{(2)}} & {.} & { x_{k}^{(n)}}\end{array}\right]_{k \times n} Xk×n=⎣⎢⎢⎢⎢⎡x1(1)...xk(1)x1(2)...xk(2).....x1(n)...xk(n)⎦⎥⎥⎥⎥⎤k×n | Y C × n = [ y 1 ( 1 ) y 2 ( 1 ) . y n ( 1 ) ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ y 1 ( C ) y 2 ( C ) ⋅ y n ( C ) ] C × n {\mathbf{Y}}_{{C \times n}}=\left[\begin{array}{cccc}{y_{1}^{(1)}} & {y_{2}^{(1)}} & {.} & {y_{n}^{(1)}} \\ {\cdot} & {\cdot} & {\cdot} & {\cdot} \\ {\cdot} & {\cdot} & {\cdot} & {\cdot} \\ {y_{1}^{(C)}} & {y_{2}^{(C)}} & {\cdot} & {y_{\mathrm{n}}^{(C)}}\end{array}\right]_{{C \times n}} YC×n=⎣⎢⎢⎡y1(1)⋅⋅y1(C)y2(1)⋅⋅y2(C).⋅⋅⋅yn(1)⋅⋅yn(C)⎦⎥⎥⎤C×n | E = [ 1 ⋅ ⋅ 1 ] n × 1 \mathbf{E}=\left[\begin{array}{l}{1} \\ {\cdot} \\ {\cdot} \\ {1}\end{array}\right]_{n \times 1} E=⎣⎢⎢⎡1⋅⋅1⎦⎥⎥⎤n×1 |
---|---|---|
W k × C = [ w 1 ( 1 ) w 1 ( 2 ) . w 1 ( C ) . . . . . . . . w k ( 1 ) w k ( 2 ) . w k ( C ) ] k × C \mathbf{W}_{k \times C}=\left[\begin{array}{lll}{w_{1}^{(1)}} & {w_{1}^{(2)}} & {.} & {w_{1}^{(C)}} \\ {.} & {.} & {.} & {.} \\ {.} & {.} & {.} & {.} \\ {w_{k}^{(1)}} & {w_{k}^{(2)}} & {.} & {w_{k}^{(C)}}\end{array}\right]_{k \times C} Wk×C=⎣⎢⎢⎡w1(1)..wk(1)w1(2)..wk(2)....w1(C)..wk(C)⎦⎥⎥⎤k×C | b = [ b 1 . . b C ] C × 1 \mathbf{b}=\left[\begin{array}{l}{b_{1}} \\ {.} \\ {.} \\ {b_{C}}\end{array}\right]_{C \times 1} b=⎣⎢⎢⎡b1..bC⎦⎥⎥⎤C×1 | y c = [ y 1 ( c ) ⋅ ⋅ y n ( c ) ] n × 1 \boldsymbol{y}_{\mathrm{c}}=\left[\begin{array}{c}{y_{1}^{(c)}} \\ {\cdot} \\ {\cdot} \\ {y_{n}^{(c)}}\end{array}\right]_{n \times 1} yc=⎣⎢⎢⎡y1(c)⋅⋅yn(c)⎦⎥⎥⎤n×1 |
整体迭代过程非常容易理解,主要分为以下两个模块(具体的过程看2.2章节的代码实现):
- ① 如果进来的样本不满足条件 1 − y i ( w T x i + b ) > 0 , 1-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)>0, 1−yi(wTxi+b)>0, 那么将尽可能能往满足条件的方向优化(此时使用 Hinge-Loss在SVM的原问题空间对问题进行优化)。
- ② 如果进来的样本符条件 1 − y i ( w T x i + b ) < 0 , 1-y_{i}\left(\boldsymbol w^{T} \boldsymbol x_{i}+b\right)<0, 1−yi(wTxi+b)<0, 那么参数保持不变。
Iris 鸢尾花数据集包含3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
- 这里就以鸢尾花数据集为例one.txt:
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica
- 程序代码如下:
import numpy as np
batchsz = 150
def obtain_w_via_gradient_descent(x, c, y, penalty_c, threshold = 1e-19, learn_rate = 1e-4):
""" 利用梯度下降法求解如下的SVM问题:min 1/2 * w^T * w + C * Σ_i=1:n(max(0, 1 - y_i * (w^T * x_i + b)))
:param x: 训练样本 x = [x_1, x_2, ..., x_i]
:param c: 类别数
:param y: 样本标签 y = [y_1, y_2, ..., y_c]
:param threshold: 梯度下降停止阈值
"""
data_num = np.shape(x)[1]
feature_dim = np.shape(x)[0]
w = np.ones([feature_dim, c], dtype=np.float32)
b = np.ones([c, 1], dtype=np.float32)
dl_dw = np.zeros([feature_dim, c], dtype=np.float)
dl_db = np.zeros([c, 1], dtype=np.float)
it = 1
th = 0.1
while it < 60000 and th > threshold:
a = np.tile(b, [1, data_num])
ksi = (np.transpose(w) @ x + np.tile(b, [1, data_num])) * y
index_martix = ksi < 1
for class_num in range(c):
index_vector = index_martix[class_num, :]
if True in index_vector:
x_c = x[:, index_vector]
data_num_c = np.shape(x_c)[1]
e = np.ones([data_num_c, 1], dtype=np.float)
y_c = np.reshape(y[class_num, index_vector], [data_num_c, 1])
w_c = np.reshape(w[:, class_num], [feature_dim, 1])
b_c = b[class_num]
dl_dw[:, class_num] = (w_c + 2 * penalty_c * (x_c @ np.transpose(x_c) @ w_c +
x_c @ e * b_c -
x_c @ y_c))[:, 0]
dl_db[class_num, 0] = 2 * penalty_c * (b_c * data_num_c +
np.transpose(w_c) @ x_c @ e -
np.transpose(y_c) @ e)
else:
w_c = np.reshape(w[:, class_num], [feature_dim, 1])
dl_dw[:, class_num] = w_c[:, 0]
dl_db[class_num, 0] = 0
w_ = w - learn_rate * (dl_dw / np.linalg.norm(dl_dw, ord=2))
b_ = b - learn_rate * dl_db
th = np.sum(np.square(w_ - w)) + np.sum(np.square(b_ - b))
it = it + 1
w = w_
b = b_
if it % 200 == 0:
y_predict = np.transpose(w) @ x + np.tile(b, [1, data_num])
correct_prediction = np.equal(np.argmax(y_predict, 0), np.argmax(y, 0))
accuracy = np.mean(correct_prediction.astype(np.float))
print("epoch:", it, "acc:", accuracy)
def iris_type(s):
it = {b'Iris-setosa': 0, b'Iris-versicolor': 1, b'Iris-virginica': 2} # b'Iris-virginica': 2
return it[s]
def normalize_data(data):
mean = np.mean(data, axis=0)
std = np.std(data, axis=0)
for i in range(data.shape[0]):
data[i, :] = (data[i, :] - mean) / std
return data
def convert_to_one_hot(y, C):
return np.eye(C)[y.reshape(-1)]
def main():
data = np.loadtxt('./one.txt', dtype=float, delimiter=',', converters={4: iris_type}) #
x = data[:, :4]
x = normalize_data(x) # 预处理数据
y = data[:, 4]
y = y.astype(np.int)
y_onehot = convert_to_one_hot(y, 3)
y_onehot[y_onehot == 0] = -1
x = np.transpose(x) # k*n k: 特征维度, n: 样本数
y_onehot = np.transpose(y_onehot) # c*n c: 类别数, n: 样本数
w = np.array([[1, 1, 1], [1, 1, 1]]) # k*c k: 特征维度, c: 类别数
b = np.array([[1],[1],[1]]) # c*1 c: 类别数
obtain_w_via_gradient_descent(x, 3, y_onehot, 0.5)
if __name__ == '__main__':
main()
- 程序最终运行结果如下:
ssh://zhangkf@192.168.136.64:22/home/zhangkf/anaconda3/envs/py1/bin/python -u /home/zhangkf/johnCodes/TF1/svm_test/SVM_grad.py
epoch: 200 acc: 0.6666666666666666
epoch: 400 acc: 0.3333333333333333
epoch: 600 acc: 0.3333333333333333
epoch: 800 acc: 0.3333333333333333
epoch: 1000 acc: 0.3333333333333333
epoch: 1200 acc: 0.3333333333333333
epoch: 1400 acc: 0.3333333333333333
epoch: 1600 acc: 0.3333333333333333
epoch: 1800 acc: 0.3333333333333333
epoch: 2000 acc: 0.3333333333333333
epoch: 2200 acc: 0.3333333333333333
epoch: 2400 acc: 0.3333333333333333
epoch: 2600 acc: 0.34
epoch: 2800 acc: 0.34
epoch: 3000 acc: 0.36
epoch: 3200 acc: 0.36666666666666664
epoch: 3400 acc: 0.38
epoch: 3600 acc: 0.3933333333333333
epoch: 3800 acc: 0.4
epoch: 4000 acc: 0.4266666666666667
epoch: 4200 acc: 0.43333333333333335
epoch: 4400 acc: 0.47333333333333333
epoch: 4600 acc: 0.5
epoch: 4800 acc: 0.5066666666666667
epoch: 5000 acc: 0.52
epoch: 5200 acc: 0.52
epoch: 5400 acc: 0.52
epoch: 5600 acc: 0.54
epoch: 5800 acc: 0.5533333333333333
epoch: 6000 acc: 0.5733333333333334
epoch: 6200 acc: 0.58
epoch: 6400 acc: 0.58
epoch: 6600 acc: 0.58
epoch: 6800 acc: 0.5866666666666667
epoch: 7000 acc: 0.5866666666666667
epoch: 7200 acc: 0.5933333333333334
epoch: 7400 acc: 0.5933333333333334
epoch: 7600 acc: 0.6066666666666667
epoch: 7800 acc: 0.6066666666666667
epoch: 8000 acc: 0.6266666666666667
epoch: 8200 acc: 0.6333333333333333
epoch: 8400 acc: 0.64
epoch: 8600 acc: 0.64
epoch: 8800 acc: 0.6466666666666666
epoch: 9000 acc: 0.6533333333333333
epoch: 9200 acc: 0.66
epoch: 9400 acc: 0.66
epoch: 9600 acc: 0.66
epoch: 9800 acc: 0.66
epoch: 10000 acc: 0.6466666666666666
epoch: 10200 acc: 0.6533333333333333
epoch: 10400 acc: 0.6533333333333333
epoch: 10600 acc: 0.6533333333333333
epoch: 10800 acc: 0.6533333333333333
epoch: 11000 acc: 0.6533333333333333
epoch: 11200 acc: 0.6533333333333333
epoch: 11400 acc: 0.66
epoch: 11600 acc: 0.66
epoch: 11800 acc: 0.66
epoch: 12000 acc: 0.6666666666666666
epoch: 12200 acc: 0.6733333333333333
epoch: 12400 acc: 0.6866666666666666
epoch: 12600 acc: 0.6866666666666666
epoch: 12800 acc: 0.6866666666666666
epoch: 13000 acc: 0.6866666666666666
epoch: 13200 acc: 0.6866666666666666
epoch: 13400 acc: 0.6933333333333334
epoch: 13600 acc: 0.7133333333333334
epoch: 13800 acc: 0.72
epoch: 14000 acc: 0.7333333333333333
epoch: 14200 acc: 0.74
epoch: 14400 acc: 0.7466666666666667
epoch: 14600 acc: 0.7533333333333333
epoch: 14800 acc: 0.76
epoch: 15000 acc: 0.76
epoch: 15200 acc: 0.7666666666666667
epoch: 15400 acc: 0.7666666666666667
epoch: 15600 acc: 0.7666666666666667
epoch: 15800 acc: 0.7666666666666667
epoch: 16000 acc: 0.78
epoch: 16200 acc: 0.78
epoch: 16400 acc: 0.7866666666666666
epoch: 16600 acc: 0.7933333333333333
epoch: 16800 acc: 0.7933333333333333
epoch: 17000 acc: 0.7933333333333333
epoch: 17200 acc: 0.7933333333333333
epoch: 17400 acc: 0.8066666666666666
epoch: 17600 acc: 0.8066666666666666
epoch: 17800 acc: 0.82
epoch: 18000 acc: 0.8266666666666667
epoch: 18200 acc: 0.82
epoch: 18400 acc: 0.82
epoch: 18600 acc: 0.8266666666666667
epoch: 18800 acc: 0.8266666666666667
epoch: 19000 acc: 0.8266666666666667
epoch: 19200 acc: 0.8266666666666667
epoch: 19400 acc: 0.8333333333333334
epoch: 19600 acc: 0.8333333333333334
epoch: 19800 acc: 0.8333333333333334
epoch: 20000 acc: 0.8333333333333334
epoch: 20200 acc: 0.8333333333333334
epoch: 20400 acc: 0.8466666666666667
epoch: 20600 acc: 0.8533333333333334
epoch: 20800 acc: 0.86
epoch: 21000 acc: 0.8666666666666667
epoch: 21200 acc: 0.8666666666666667
epoch: 21400 acc: 0.8666666666666667
epoch: 21600 acc: 0.8666666666666667
epoch: 21800 acc: 0.8666666666666667
epoch: 22000 acc: 0.8666666666666667
epoch: 22200 acc: 0.8666666666666667
epoch: 22400 acc: 0.8666666666666667
epoch: 22600 acc: 0.8666666666666667
epoch: 22800 acc: 0.8666666666666667
epoch: 23000 acc: 0.8666666666666667
epoch: 23200 acc: 0.8666666666666667
epoch: 23400 acc: 0.8666666666666667
epoch: 23600 acc: 0.8666666666666667
epoch: 23800 acc: 0.8666666666666667
epoch: 24000 acc: 0.8666666666666667
epoch: 24200 acc: 0.8666666666666667
epoch: 24400 acc: 0.8666666666666667
epoch: 24600 acc: 0.8666666666666667
epoch: 24800 acc: 0.8666666666666667
epoch: 25000 acc: 0.8666666666666667
epoch: 25200 acc: 0.8666666666666667
epoch: 25400 acc: 0.8666666666666667
epoch: 25600 acc: 0.8666666666666667
epoch: 25800 acc: 0.8666666666666667
epoch: 26000 acc: 0.8666666666666667
epoch: 26200 acc: 0.8666666666666667
epoch: 26400 acc: 0.8666666666666667
epoch: 26600 acc: 0.86
epoch: 26800 acc: 0.86
epoch: 27000 acc: 0.86
epoch: 27200 acc: 0.86
epoch: 27400 acc: 0.86
epoch: 27600 acc: 0.86
epoch: 27800 acc: 0.86
epoch: 28000 acc: 0.86
epoch: 28200 acc: 0.8666666666666667
epoch: 28400 acc: 0.86
epoch: 28600 acc: 0.86
epoch: 28800 acc: 0.86
epoch: 29000 acc: 0.86
epoch: 29200 acc: 0.86
epoch: 29400 acc: 0.86
epoch: 29600 acc: 0.86
epoch: 29800 acc: 0.86
epoch: 30000 acc: 0.8533333333333334
epoch: 30200 acc: 0.8533333333333334
epoch: 30400 acc: 0.86
epoch: 30600 acc: 0.86
epoch: 30800 acc: 0.8666666666666667
epoch: 31000 acc: 0.8666666666666667
epoch: 31200 acc: 0.8666666666666667
epoch: 31400 acc: 0.8666666666666667
epoch: 31600 acc: 0.86
epoch: 31800 acc: 0.86
epoch: 32000 acc: 0.86
epoch: 32200 acc: 0.86
epoch: 32400 acc: 0.86
epoch: 32600 acc: 0.86
epoch: 32800 acc: 0.86
epoch: 33000 acc: 0.86
epoch: 33200 acc: 0.86
epoch: 33400 acc: 0.86
epoch: 33600 acc: 0.86
epoch: 33800 acc: 0.86
epoch: 34000 acc: 0.8666666666666667
epoch: 34200 acc: 0.8666666666666667
epoch: 34400 acc: 0.8666666666666667
epoch: 34600 acc: 0.8666666666666667
epoch: 34800 acc: 0.8666666666666667
epoch: 35000 acc: 0.8666666666666667
epoch: 35200 acc: 0.8866666666666667
epoch: 35400 acc: 0.8866666666666667
epoch: 35600 acc: 0.8933333333333333
epoch: 35800 acc: 0.9
epoch: 36000 acc: 0.9
epoch: 36200 acc: 0.9
epoch: 36400 acc: 0.9
epoch: 36600 acc: 0.9
epoch: 36800 acc: 0.9
epoch: 37000 acc: 0.9
epoch: 37200 acc: 0.9
epoch: 37400 acc: 0.9
epoch: 37600 acc: 0.9
epoch: 37800 acc: 0.9
epoch: 38000 acc: 0.9
epoch: 38200 acc: 0.9066666666666666
epoch: 38400 acc: 0.9066666666666666
epoch: 38600 acc: 0.9133333333333333
epoch: 38800 acc: 0.9133333333333333
epoch: 39000 acc: 0.9133333333333333
epoch: 39200 acc: 0.9133333333333333
epoch: 39400 acc: 0.9133333333333333
epoch: 39600 acc: 0.9133333333333333
epoch: 39800 acc: 0.9133333333333333
epoch: 40000 acc: 0.9133333333333333
epoch: 40200 acc: 0.9133333333333333
epoch: 40400 acc: 0.9133333333333333
epoch: 40600 acc: 0.9133333333333333
epoch: 40800 acc: 0.9133333333333333
epoch: 41000 acc: 0.9133333333333333
epoch: 41200 acc: 0.9133333333333333
epoch: 41400 acc: 0.9133333333333333
epoch: 41600 acc: 0.9266666666666666
epoch: 41800 acc: 0.9266666666666666
epoch: 42000 acc: 0.9266666666666666
epoch: 42200 acc: 0.9333333333333333
epoch: 42400 acc: 0.9333333333333333
epoch: 42600 acc: 0.9333333333333333
epoch: 42800 acc: 0.9333333333333333
epoch: 43000 acc: 0.9333333333333333
epoch: 43200 acc: 0.9333333333333333
epoch: 43400 acc: 0.9333333333333333
epoch: 43600 acc: 0.9333333333333333
epoch: 43800 acc: 0.9333333333333333
epoch: 44000 acc: 0.9333333333333333
epoch: 44200 acc: 0.9333333333333333
epoch: 44400 acc: 0.9333333333333333
epoch: 44600 acc: 0.9333333333333333
epoch: 44800 acc: 0.94
epoch: 45000 acc: 0.94
epoch: 45200 acc: 0.94
epoch: 45400 acc: 0.94
epoch: 45600 acc: 0.9466666666666667
epoch: 45800 acc: 0.9466666666666667
epoch: 46000 acc: 0.9466666666666667
epoch: 46200 acc: 0.9466666666666667
epoch: 46400 acc: 0.9466666666666667
epoch: 46600 acc: 0.9466666666666667
epoch: 46800 acc: 0.9466666666666667
epoch: 47000 acc: 0.9466666666666667
epoch: 47200 acc: 0.9466666666666667
epoch: 47400 acc: 0.9466666666666667
epoch: 47600 acc: 0.9466666666666667
epoch: 47800 acc: 0.9466666666666667
epoch: 48000 acc: 0.9466666666666667
epoch: 48200 acc: 0.9466666666666667
epoch: 48400 acc: 0.9466666666666667
epoch: 48600 acc: 0.9466666666666667
epoch: 48800 acc: 0.9466666666666667
epoch: 49000 acc: 0.9466666666666667
epoch: 49200 acc: 0.9466666666666667
epoch: 49400 acc: 0.9466666666666667
epoch: 49600 acc: 0.9466666666666667
epoch: 49800 acc: 0.9466666666666667
epoch: 50000 acc: 0.9466666666666667
epoch: 50200 acc: 0.9466666666666667
epoch: 50400 acc: 0.9466666666666667
epoch: 50600 acc: 0.9466666666666667
epoch: 50800 acc: 0.9466666666666667
epoch: 51000 acc: 0.9466666666666667
epoch: 51200 acc: 0.9466666666666667
epoch: 51400 acc: 0.9466666666666667
epoch: 51600 acc: 0.9533333333333334
epoch: 51800 acc: 0.9533333333333334
epoch: 52000 acc: 0.96
epoch: 52200 acc: 0.96
epoch: 52400 acc: 0.96
epoch: 52600 acc: 0.96
epoch: 52800 acc: 0.96
epoch: 53000 acc: 0.96
epoch: 53200 acc: 0.96
epoch: 53400 acc: 0.96
epoch: 53600 acc: 0.96
epoch: 53800 acc: 0.96
epoch: 54000 acc: 0.96
epoch: 54200 acc: 0.96
epoch: 54400 acc: 0.96
epoch: 54600 acc: 0.96
epoch: 54800 acc: 0.96
epoch: 55000 acc: 0.96
epoch: 55200 acc: 0.96
epoch: 55400 acc: 0.96
epoch: 55600 acc: 0.96
epoch: 55800 acc: 0.96
epoch: 56000 acc: 0.96
epoch: 56200 acc: 0.96
epoch: 56400 acc: 0.96
epoch: 56600 acc: 0.96
epoch: 56800 acc: 0.96
epoch: 57000 acc: 0.96
epoch: 57200 acc: 0.96
epoch: 57400 acc: 0.96
epoch: 57600 acc: 0.96
epoch: 57800 acc: 0.96
epoch: 58000 acc: 0.96
epoch: 58200 acc: 0.96
epoch: 58400 acc: 0.96
epoch: 58600 acc: 0.96
epoch: 58800 acc: 0.96
epoch: 59000 acc: 0.96
epoch: 59200 acc: 0.96
epoch: 59400 acc: 0.96
epoch: 59600 acc: 0.96
epoch: 59800 acc: 0.96
epoch: 60000 acc: 0.96
Process finished with exit code 0
import numpy as np
from time import *
batchsz = 500
np.random.seed(0)
# 0. 定义函数实现mini_batch
def mini_batches(X, Y, mini_batch_size=batchsz, seed=0):
np.random.seed(seed)
m = X.shape[0] # m是样本数
mini_batches = [] # 用来存放一个一个的mini_batch
num_complete_minibatches = int(m // mini_batch_size) # 样本总数除以每个batch的样本数量
for i in range(num_complete_minibatches):
mini_batch_X = X[i * mini_batch_size:(i + 1) * mini_batch_size, :]
mini_batch_Y = Y[i * mini_batch_size:(i + 1) * mini_batch_size, :]
mini_batch = (mini_batch_X, mini_batch_Y)
mini_batches.append(mini_batch)
if m % mini_batch_size != 0:
# 如果样本数不能被整除,取余下的部分
mini_batch_X = X[num_complete_minibatches * mini_batch_size:, :]
mini_batch_Y = Y[num_complete_minibatches * mini_batch_size, :]
mini_batch = (mini_batch_X, mini_batch_Y)
mini_batches.append(mini_batch)
return mini_batches
# mini_batches = mini_batches(X_train, y_train, mini_batch_size=64, seed=0)
#
# mini_batches[780][0].shape
# (64, 32, 32, 3)
# 1. 随机梯度下降法实现优化SVM
def obtain_w_via_gradient_descent(x, c, y, penalty_c, x_test, y_test_onehot, threshold = 1e-19, learn_rate = 1e-4):
""" 利用梯度下降法求解如下的SVM问题:min 1/2 * w^T * w + C * Σ_i=1:n(max(0, 1 - y_i * (w^T * x_i + b)))
:param x: 训练样本 x = [x_1, x_2, ..., x_i]
:param c: 类别数
:param y: 样本标签 y = [y_1, y_2, ..., y_c]
:param threshold: 梯度下降停止阈值
"""
data_num = np.shape(x)[1]
feature_dim = np.shape(x)[0]
w = np.ones([feature_dim, c], dtype=np.float32)
b = np.ones([c, 1], dtype=np.float32)
dl_dw = np.zeros([feature_dim, c], dtype=np.float)
dl_db = np.zeros([c, 1], dtype=np.float)
epoch = 1
th = 0.1
iterations = mini_batches(x.T, y.T, batchsz, seed=0) # mini_batchs
print(iterations[0][0].shape)
begin_time = time()
while epoch < 100000 and th > threshold:
for x_y in iterations:
x = x_y[0].T
y = x_y[1].T
a = np.tile(b, [1, batchsz])
ksi = (np.transpose(w) @ x + np.tile(b, [1, batchsz])) * y
index_martix = ksi < 1
for class_num in range(c):
index_vector = index_martix[class_num, :]
if True in index_vector:
x_c = x[:, index_vector]
data_num_c = np.shape(x_c)[1]
e = np.ones([data_num_c, 1], dtype=np.float)
y_c = np.reshape(y[class_num, index_vector], [data_num_c, 1])
w_c = np.reshape(w[:, class_num], [feature_dim, 1])
b_c = b[class_num]
dl_dw[:, class_num] = (w_c + 2 * penalty_c * (x_c @ np.transpose(x_c) @ w_c +
x_c @ e * b_c -
x_c @ y_c))[:, 0]
dl_db[class_num, 0] = 2 * penalty_c * (b_c * data_num_c +
np.transpose(w_c) @ x_c @ e -
np.transpose(y_c) @ e)
else:
w_c = np.reshape(w[:, class_num], [feature_dim, 1])
dl_dw[:, class_num] = w_c[:, 0]
dl_db[class_num, 0] = 0
w_ = w - learn_rate * (dl_dw / np.linalg.norm(dl_dw, ord=2))
b_ = b - learn_rate * dl_db
th = np.sum(np.square(w_ - w)) + np.sum(np.square(b_ - b))
epoch = epoch + 1
w = w_
b = b_
#############################################################################
if epoch % 100 == 0: # 训练过程中准确率打印
y_predict = np.transpose(w) @ x + np.tile(b, [1, batchsz])
correct_prediction = np.equal(np.argmax(y_predict, 0), np.argmax(y, 0))
accuracy = np.mean(correct_prediction.astype(np.float))
print("epoch:", epoch, "acc:", accuracy)
end_time = time()
run_time = end_time - begin_time
print('Run time:', run_time) # 该循环程序运行时间
########################################## 测试集结果 ############################
data_num = np.shape(x_test)[1]
y_predict = np.transpose(w) @ x_test + np.tile(b, [1, data_num])
correct_prediction = np.equal(np.argmax(y_predict, 0), np.argmax(y_test_onehot, 0))
accuracy = np.mean(correct_prediction.astype(np.float))
print("Test_acc:", accuracy)
# 2. 归一化数据
def normalize_data(data):
mean = np.mean(data, axis=0)
std = np.std(data, axis=0)
for i in range(data.shape[0]):
data[i, :] = (data[i, :] - mean) / std
return data
# 3. 转化为one_hot编码
def convert_to_one_hot(y, C):
return np.eye(C)[y.reshape(-1)]
# 4. 随机打散训练数据和相应的标签
def random_scattered(data):
index = np.arange(data.shape[0])
np.random.shuffle(index)
data = data[index,:]
return data
def main():
# 1. 数据集加载
data = np.loadtxt('one2.txt', dtype=float, delimiter=',')
# 2. 随机打散训练数据和相应的标签
data = random_scattered(data)
# 3. 拆分训练数据和测试数据;
train_num = int(0.75 * data.shape[0])
data_train_label = data[:train_num, :] # 训练集75%
data_test_lable = data[train_num + 1:, :] # 测试集25%
################################### 训练集 ###################################
x = data_train_label[:, :4]
# 3. 归一化
x = normalize_data(x)
y = data_train_label[:, 4]
y = y.astype(np.int)-1
# 4. 转换one_hot编码
y_onehot = convert_to_one_hot(y, 2)
y_onehot[y_onehot == 0] = -1
x = np.transpose(x) # k*n k: 特征维度, n: 样本数
y_onehot = np.transpose(y_onehot) # c*n c: 类别数, n: 样本数
w = np.array([[1, 1, 1], [1, 1, 1]]) # k*c k: 特征维度, c: 类别数
b = np.array([[1],[1],[1]]) # c*1 c: 类别数
################################### 测试集 ###################################
x_test = data_test_lable[:, :4]
# 3. 归一化
x_test = normalize_data(x_test)
y_test = data_test_lable[:, 4]
y_test = y_test.astype(np.int)-1
# 4. 转换one_hot编码
y_test_onehot = convert_to_one_hot(y_test, 2)
y_test_onehot[y_test_onehot == 0] = -1
x_test = np.transpose(x_test) # k*n k: 特征维度, n: 样本数
y_test_onehot = np.transpose(y_test_onehot) # c*n c: 类别数, n: 样本数
obtain_w_via_gradient_descent(x, 2, y_onehot, 0.5, x_test, y_test_onehot)
if __name__ == '__main__':
main()
one2.txt
部分6.64E+01,3.53E+02,5.35E+03,9.50E+02,1
6.18E+01,2.34E+02,4.77E+03,9.50E+02,1
6.71E+01,2.80E+02,5.44E+03,9.32E+02,1
6.30E+01,2.06E+02,4.89E+03,9.27E+02,1
6.67E+01,1.75E+02,5.37E+03,9.15E+02,1
7.54E+01,8.78E+01,6.60E+03,9.14E+02,1
6.83E+01,2.16E+02,5.58E+03,9.14E+02,1
6.58E+01,1.79E+02,5.24E+03,9.12E+02,1
6.93E+01,2.53E+02,5.71E+03,9.10E+02,1
6.73E+01,2.69E+02,5.44E+03,9.08E+02,1
6.20E+01,7.08E+01,4.75E+03,9.06E+02,1
6.70E+01,1.03E+02,5.39E+03,9.04E+02,1
6.43E+01,1.08E+02,5.04E+03,9.03E+02,1
6.73E+01,2.90E+02,5.43E+03,9.01E+02,1
6.64E+01,2.08E+02,5.31E+03,8.98E+02,1
6.52E+01,9.82E+01,5.15E+03,8.91E+02,1
6.45E+01,2.13E+02,5.05E+03,8.91E+02,1
6.21E+01,1.50E+02,4.75E+03,8.90E+02,1
6.72E+01,2.05E+02,5.40E+03,8.84E+02,1
6.33E+01,2.74E+02,4.89E+03,8.83E+02,1
6.46E+01,1.53E+02,5.05E+03,8.83E+02,1
6.48E+01,1.33E+02,4.94E+03,7.45E+02,2
7.55E+01,1.34E+02,6.44E+03,7.45E+02,2
7.04E+01,3.49E+02,5.70E+03,7.45E+02,2
7.03E+01,1.74E+02,5.69E+03,7.45E+02,2
6.70E+01,1.48E+02,5.23E+03,7.45E+02,2
7.25E+01,1.25E+02,6.00E+03,7.45E+02,2
6.72E+01,2.20E+02,5.26E+03,7.45E+02,2
7.56E+01,9.81E+01,6.46E+03,7.45E+02,2
7.54E+01,2.08E+02,6.43E+03,7.45E+02,2
6.83E+01,1.29E+02,5.41E+03,7.45E+02,2
6.35E+01,1.56E+02,4.78E+03,7.45E+02,2
6.45E+01,2.58E+02,4.90E+03,7.45E+02,2
6.44E+01,2.17E+02,4.89E+03,7.45E+02,2
6.42E+01,1.02E+02,4.87E+03,7.45E+02,2
6.09E+01,1.53E+02,4.45E+03,7.45E+02,2
6.66E+01,2.02E+02,5.18E+03,7.45E+02,2
6.53E+01,1.17E+02,5.01E+03,7.45E+02,2
new.txt
1000025,5,1,1,1,2,1,3,1,1,2
1002945,5,4,4,5,7,10,3,2,1,2
1151734,10,8,7,4,3,10,7,9,1,4
1156017,3,1,1,1,2,1,2,1,1,2
1158247,1,1,1,1,1,1,1,1,1,2
1238021,1,1,1,1,2,1,2,1,1,2
1238464,1,1,1,1,1,?,2,1,1,2
1238633,10,10,10,6,8,4,8,5,1,4
1295186,10,10,10,1,6,1,2,8,1,4
527337,4,1,1,1,2,1,1,1,1,2
558538,4,1,3,3,2,1,1,1,1,2
1266124,5,1,2,1,2,1,1,1,1,2
1296025,4,1,2,1,2,1,1,1,1,2
1296263,4,1,1,1,2,1,1,1,1,2
1296593,5,2,1,1,2,1,1,1,1,2
1299161,4,8,7,10,4,10,7,5,1,4
1301945,5,1,1,1,1,1,1,1,1,2
1302428,5,3,2,4,2,1,1,1,1,2
1318169,9,10,10,10,10,5,10,10,10,4
1113061,5,1,1,1,2,1,3,1,1,2
1116192,5,1,2,1,2,1,3,1,1,2
1135090,4,1,1,1,2,1,2,1,1,2
1145420,6,1,1,1,2,1,2,1,1,2
1158157,5,1,1,1,2,2,2,1,1,2
1171578,3,1,1,1,2,1,1,1,1,2
1174841,5,3,1,1,2,1,1,1,1,2
1184586,4,1,1,1,2,1,2,1,1,2
1186936,2,1,3,2,2,1,2,1,1,2
1197527,5,1,1,1,2,1,2,1,1,2
1222464,6,10,10,10,4,10,7,10,1,4
1240603,2,1,1,1,1,1,1,1,1,2
import numpy as np
import pandas as pd
def main():
data = pd.read_csv('new.txt') # 数据集用逗号分隔,直接用txt,也可以读取CSV格式的。
data = data.values # DataFrame类型转换成Numpy中array类型,并把表头去掉;
data
if __name__ == '__main__':
main()