机器学习基石笔记:Homework #2 Decision Stump相关习题

问题描述

机器学习基石笔记:Homework #2 Decision Stump相关习题_第1张图片
图1 16-18

机器学习基石笔记:Homework #2 Decision Stump相关习题_第2张图片
图2 19-20

程序实现

17-18

# coding: utf-8

import numpy as np
import matplotlib.pyplot as plt


def sign(n):
    if(n>0):
        return 1
    else:
        return -1

def gen_data():
    data_X=np.random.uniform(-1,1,(20,1))# [-1,1)
    data_Y=np.zeros((20,1))
    idArray=np.random.permutation([i for i in range(20)])
    for i in range(20):
        if(i<20*0.2):
            data_Y[idArray[i]][0]=-sign(data_X[idArray[i]][0])
        else:
            data_Y[idArray[i]][0] = sign(data_X[idArray[i]][0])
    data=np.concatenate((data_X,data_Y),axis=1)
    return data

def decision_stump(dataArray):
    minErrors=20
    min_s_theta_list=[]
    num_data=dataArray.shape[0]
    data=dataArray.tolist()
    data.sort(key=lambda x:x[0])
    for s in [-1.0,1.0]:
        for i in range(num_data):
            if(i==num_data-1):
                theta=(data[i][0]+1.0)/2
            else:
                theta=(data[i][0]+data[i+1][0])/2
            errors=0
            for i in range(20):
                pred=s*sign(data[i][0]-theta)
                if(pred!=data[i][1]):
                    errors+=1
            if(minErrors>errors):
                minErrors=errors
                min_s_theta_list=[]
            elif(minErrors

19-20

# coding: utf-8

import numpy as np

def read_data(dataFile):
    with open(dataFile, 'r') as file:
        data_list = []
        for line in file.readlines():
            line = line.strip().split()
            data_list.append([float(l) for l in line])
        data_array = np.array(data_list)
        return data_array

def predict(s,theta,dataX):
    num_data=dataX.shape[0]
    res=s*np.sign(dataX-theta)
    return res

def decision_stump(dataArray):
    min_s_theta_list=[]
    num_data=dataArray.shape[0]
    minErrors=num_data
    data=dataArray.tolist()
    data.sort(key=lambda x:x[0])
    dataArray=np.array(data)
    dataX=dataArray[:,0].reshape(num_data,1)
    dataY=dataArray[:,1].reshape(num_data,1)
    for s in [-1.0,1.0]:
        for i in range(num_data):
            if(i==num_data-1):
                theta=(dataX[i][0]*2+1)/2
            else:
                theta=(dataX[i][0]+dataX[i+1][0])/2
            pred=predict(s,theta,dataX)
            errors=np.sum(pred!=dataY)
            if(minErrors>errors):
                minErrors=errors
                min_s_theta_list=[]
            elif(minErrors

运行结果

17-18

图3 17-18结果1

机器学习基石笔记:Homework #2 Decision Stump相关习题_第3张图片
图4 17-18结果2

19-20

机器学习基石笔记:Homework #2 Decision Stump相关习题_第4张图片
图5 19-20结果

你可能感兴趣的:(机器学习基石笔记:Homework #2 Decision Stump相关习题)