HW1主要是使用liner model 进行pm2.5的预测
作业连接:https://ntumlta.github.io/2017fall-ml-hw1/
由于部分ppf被强,所以这里加上一个百度网盘连接:
https://pan.baidu.com/s/1Ff-3zdzqMEi1W2qUf3Agdg 密码 ooqn
内容是这个作业的相关内容
作业要求:
1. 使用前9个小时的数据,预测出第十个小时的PM2.5的值是多少
2.提供2014年的12个月每个月的前20天的24小时数据作为train data
3.每小时有18组数据(so2 甲烷 之类的指标)
下面解析 作业的sample code
import xlrd
import numpy as np
data =
for i in range( 18):
data.append([]) #18 组数据
if __name__ == '__main__':
iFileDir = "./";
iFileName = iFileDir + "train.xlsx";
print('iFileName = %r'%iFileName)
try:
wb = xlrd.open_workbook(iFileName)
except:
print( "file %s is not exist" % (iFileName) )
for sheet_name in wb.sheet_names():
sheet = wb.sheet_by_name(sheet_name)
for row_num in range(1, sheet.nrows):
for i in range(3, 27): #3-27 是对应的24小时数据
data_tmp = sheet.cell_value(row_num, i) #将数据转换成浮点数
if data_tmp == 'NR': #NR是没有检测到数据
data_tmp = float(0)
data[(row_num-1)%18].append(float(data_tmp))
x = []
y = []
上面只是就相应的train数据吃进18个list
month = 12
#10hour as one data
data_len = 9 # 连续9个小时的数据作为输入
data_length = 20 * 24 # 一个月24天 每天20个数据
for i in range(12): # 12 个月
for j in range(480 - 9): #
x.append([]) #
for t in range(18): # 18组数据
for k in range(9): # 连续9个数据作为一个input
x[471*i +j].append(data[t][480*i + j+k])
y.append(data[9][480*i+j+9]) # 第9个list存放的是pm2.5的数据
x = np.array(x)
#得到x 每一维度是 18 * 9 个数据, 每个月会有471个维度 一共有471*12个维度
y = np.array(y)
y = b + w10*x10 + w20*x20 + ...... w90*x90
+ w11 *x11 +...............................+w91*x91
......
+ w117*x117..............................+w917*x917
x = np.concatenate((np.ones((x.shape[0],1)),x),axis=1) #每个维度加一个数据1 作为bias
w = np.zeros(len(x[0])) #选定一个起始点,这里做了维度为1 长度为18 *9 +1的 数值为0的矩阵
l_rate = 10 #初始learning rate
repeat = 10000
补 Loss function 的公式
x_t = x.transpose()
s_gra = np.zeros(len(x[0]))
for i in range(repeat): #重复计算10000次
hypo = np.dot(x,w) #得到一个一维矩阵 每个值对应一个y`
loss = hypo - y
cost = np.sum(loss**2) / len(x)
cost_a = math.sqrt(cost)
gra = np.dot(x_t,loss) #这里不懂,应该是矩阵的导数
s_gra += gra**2
ada = np.sqrt(s_gra)
w = w - l_rate * gra/ada #通过adagrad 更新初始点
print ('iteration: %d | Cost: %f ' % ( i,cost_a))