Python: Soft_max 分类器

这是用Python实现的soft max 分类器,基于Python 2.7.9, numpy, matplotlib.
代码来源于斯坦福大学的课程: http://cs231n.github.io/neural-networks-case-study/
基本是照搬过来,通过这个程序有助于了解python的语法。

import numpy as np
import matplotlib.pyplot as plt

N = 100  # number of points per class
D = 2    # dimensionality
K = 3    # number of classes
X = np.zeros((N*K,D))    #data matrix (each row = single example)
y = np.zeros(N*K, dtype='uint8')  # class labels

for j in xrange(K):
  ix = range(N*j,N*(j+1))
  r = np.linspace(0.0,1,N)            # radius
  t = np.linspace(j*4,(j+1)*4,N) + np.random.randn(N)*0.2 # theta
  X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
  y[ix] = j

# print y

# lets visualize the data:
plt.scatter(X[:,0], X[:,1], s=40, c=y, alpha=0.5)
plt.show()
#Train a Linear Classifier

# initialize parameters randomly
W = 0.01 * np.random.randn(D,K)
b = np.zeros((1,K))

# some hyperparameters
step_size = 1e-0
reg = 1e-3 # regularization strength

# gradient descent loop
num_examples = X.shape[0]

for i in xrange(200):

  # evaluate class scores, [N x K]
  scores = np.dot(X, W) + b 

  # compute the class probabilities
  exp_scores = np.exp(scores)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) # [N x K]

  # compute the loss: average cross-entropy loss and regularization
  corect_logprobs = -np.log(probs[range(num_examples),y])
  data_loss = np.sum(corect_logprobs)/num_examples
  reg_loss = 0.5*reg*np.sum(W*W)
  loss = data_loss + reg_loss
  if i % 10 == 0:
    print "iteration %d: loss %f" % (i, loss)

  # compute the gradient on scores
  dscores = probs
  dscores[range(num_examples),y] -= 1
  dscores /= num_examples

  # backpropate the gradient to the parameters (W,b)
  dW = np.dot(X.T, dscores)
  db = np.sum(dscores, axis=0, keepdims=True)

  dW += reg*W     #regularization gradient

  # perform a parameter update
  W += -step_size * dW
  b += -step_size * db

# evaluate training set accuracy
scores = np.dot(X, W) + b
predicted_class = np.argmax(scores, axis=1)
print 'training accuracy: %.2f' % (np.mean(predicted_class == y))

生成的随机数据

Python: Soft_max 分类器_第1张图片

运行结果

Python: Soft_max 分类器_第2张图片

你可能感兴趣的:(python)