# 将源码命名为 ml.py
__all__ = ['Classifier']
import numpy as np
import matplotlib.pyplot as plt
class Classifier:
def __init__(self, num_classes: int, max_iter: int=200, lr: float=0.1):
self.max_iter = max_iter # max iteration
self.lr = lr # learning rate
self.num_classes = num_classes # categories
self.scores = []
def __data_matrix(self, X):
'''
Parameters
----------
X : numpy.array
input matrix
Returns
-------
numpy.array
augmented matrix
'''
ones = np.ones(X.shape[0])
return np.insert(X, 0, ones, axis=1)
def __softmax(self, horizon):
'''
Parameters
----------
horizon : numpy.array
one line of data-set.
Returns
-------
numpy.array
softmax result.
'''
return np.exp(horizon) / np.sum(np.exp(horizon))
def fit(self, X, y) -> None:
'''
Parameters
----------
X : numpy.array
data-set to be trained
y : numpy.array
correct labels
Returns
-------
None
'''
augmented = self.__data_matrix(X)
self.weights = np.zeros((augmented.shape[1], self.num_classes), dtype=np.float64)
for step in range(self.max_iter):
for index in range(augmented.shape[0]):
res = self.__softmax(np.dot(augmented[index], self.weights))
obj = np.eye(self.num_classes)[int(y[index])]
err = res - obj
self.weights -= self.lr * np.transpose([augmented[index]]) * err
score = self.score(X_test, y_test) # working environment store the two values: X_test, y_test
self.scores.append(score)
if step % 20 == 0:
print("Training Error: {0:<}, Testing Score: {1:<}".format(np.linalg.norm(err), score))
def score(self, X, y) -> float:
'''
Parameters
----------
X : numpy.array
data-set to be tested
y : numpy.array
correct labels
Returns
-------
float
correct rate
'''
X = self.__data_matrix(X)
corr = 0
multiply = np.dot(X, self.weights)
predicted = np.argmax(multiply, axis=1)
corr += (predicted == y).sum()
return corr / X.shape[0]
def predict(self, X):
'''
Parameters
----------
X : numpy.array
data-set to be predicted
Returns
-------
numpy.array
predicted result
'''
X = self.__data_matrix(X)
multiply = np.dot(X, self.weights)
return np.argmax(multiply, axis=1)
def plot(self, color: str="slateblue", mark: str='o', style: str='dashed') -> None:
'''
Parameters
----------
color : str
The color of plot line.
mark : str
The mark of points.
style : str
The styple of plot.
Returns
-------
None
'''
axis_x = [num for num in range(1, self.max_iter + 1)]
# plt.xlabel, plt.ylabel, plt.title
plt.plot(axis_x, self.scores, c=color, marker=mark, linestyle=style)
plt.show()
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio # 用于解析.mat格式数据
from ml import Classifier # ml.py文件存储着分类器源码
下载实验所需的数据集:textretrieval.mat
数据集维度是:2866*6603
。
def process_data(url) -> tuple:
data = sio.loadmat(url)
return data['X'], data['class']
features, labels = process_data('../data/textretrieval.mat') # 数据集路径由环境部署结构确定
def pretreat(features, labels) -> tuple:
labels_mo = np.argwhere(labels == 1)[:, 1] # 对标签进行处理
return features[:2000, :], features[2000:, :], labels_mo[:2000], labels_mo[2000:]
X_train, X_test, y_train, y_test = pretreat(features, labels)
model = Classifier(num_classes=10, max_iter=1000, lr=0.1)
model.fit(X_train, y_train)
model.score(X_test, y_test)
model.plot(color="slateblue", mark='o', style='dashed')
Train Loss | Test Score |
---|---|
0.9408422674049407 | 0.14665127020785218 |
0.912061033842642 | 0.35219399538106233 |
0.8803304469065384 | 0.6189376443418014 |
0.8466429683180274 | 0.7274826789838337 |
0.8119254918344274 | 0.7621247113163973 |
0.7769214437087549 | 0.7875288683602771 |
0.7421752752868741 | 0.7979214780600462 |
0.7080714015970305 | 0.8175519630484989 |
0.6748833415377617 | 0.8233256351039261 |
0.6428083590006325 | 0.8290993071593533 |
0.6119867581256945 | 0.8371824480369515 |
0.5825126784879504 | 0.8418013856812933 |
0.5544415616129514 | 0.8429561200923787 |
0.5277965860484851 | 0.8418013856812933 |
0.50257469637334 | 0.8441108545034642 |
0.4787522109951752 | 0.8452655889145496 |
0.45628985777781966 | 0.8464203233256351 |
0.4351371279645393 | 0.8475750577367206 |
0.4152359142178181 | 0.8498845265588915 |
0.39652345984224263 | 0.851039260969977 |
0.3789346841487715 | 0.851039260969977 |
0.3624039669447191 | 0.8521939953810623 |
0.3468664793907125 | 0.8533487297921478 |
0.33225914434725806 | 0.8556581986143187 |
0.31852130079255164 | 0.8579676674364896 |
0.3055951365513337 | 0.8579676674364896 |
0.29342594303711267 | 0.859122401847575 |
0.2819622358734153 | 0.859122401847575 |
0.2711557765560219 | 0.8602771362586605 |
0.2609615228918985 | 0.8614318706697459 |
0.2513375297795615 | 0.8602771362586605 |
0.2422448168706154 | 0.8614318706697459 |
0.23364721562579857 | 0.8602771362586605 |
0.22551120509581035 | 0.8602771362586605 |
0.21780574326964636 | 0.8614318706697459 |
0.2105020989096215 | 0.8637413394919169 |
0.20357368731934128 | 0.8637413394919169 |
0.1969959123742666 | 0.8637413394919169 |
0.19074601630665308 | 0.8648960739030023 |
0.18480293811527693 | 0.8660508083140878 |
0.17914718101563548 | 0.8683602771362586 |
0.17376068901913053 | 0.8695150115473441 |
0.1686267324993392 | 0.8706697459584296 |
0.16372980244602603 | 0.8706697459584296 |
0.15905551300459905 | 0.8695150115473441 |
0.1545905118362361 | 0.8683602771362586 |
0.15032239780087622 | 0.8683602771362586 |
0.1462396454537189 | 0.8695150115473441 |
0.14233153584943295 | 0.8706697459584296 |
0.13858809316228626 | 0.8706697459584296 |