Python 机器学习基石 作业2

实现简单的决策树(先找出合适的决策参量,再估计错误率)

import sys
import numpy as np
import math
from random import *

def read_input(path):
    x = []
    y = []
    for line in open(path).readlines():
        items = line.strip().split(' ')
        tmp_x = []
        for i in range(0, len(items)-1):
            tmp_x.append(float(items[i]))
        x.append(tmp_x)
        y.append(float(items[-1]))
    return np.array(x), np.array(y)

def calculate_Ein(x, y):
    thetas = np.array([float("-inf")]+[(x[i]+x[i+1])/2 for i in range(0, x.shape[0]-1)]+[float("inf")])
    target_theta = 0.0
    sign = 1
    Ein = x.shape[0]
    for theta in thetas:
        y_positive = np.where(x>theta, 1, -1)
        y_negative = np.where(x err_negative:
            if Ein > err_negative:
                Ein = err_negative
                sign = -1
                target_theta = theta
        else:
            if Ein > err_positive:
                Ein = err_positive
                sign = 1
                target_theta = theta
    return Ein, target_theta, sign
        

if __name__ == '__main__':
    x, y = read_input('train.txt')
    x_test, y_test = read_input('test.txt')
    Ein = x.shape[0]
    theta = 0.0
    sign = 1
    index = 0
    for i in range(0, x.shape[1]):
        cur_x = x[:,i]
        cur_data = np.transpose(np.array([cur_x, y]))
        cur_data = cur_data[np.argsort(cur_data[:,0])]
        cur_Ein, cur_theta, cur_sign = calculate_Ein(cur_data[:,0], cur_data[:,1])
        if cur_Ein < Ein:
            Ein = cur_Ein
            theta = cur_theta
            sign = cur_sign
            index = i
    print(Ein/x.shape[0])
    # test
    x_test = x_test[:, index]
    if sign == 1:
        y_predict = np.where(x_test > theta, 1.0, -1.0)
    else:
        y_predict = np.where(x_test < theta, 1.0, -1.0)
    Eout = sum(y_predict != y_test)
    print(Eout/x_test.shape[0])

注意:

  1. python中把数据每一行分开后得到的是一系列字符串,要先把它转化成浮点数再后续计算。
  2. np.array([cur_x, y])两个列向量拼接后得到了两个行向量,要转置变成想要的格式。

references:
http://www.cnblogs.com/xbf9xbf/p/4595990.html

你可能感兴趣的:(Python 机器学习基石 作业2)