# 梯度下降法
import decimal
decimal.getcontext().prec = 100
decimal.getcontext().rounding = getattr(decimal, 'ROUND_CEILING') # 总是趋向无穷大向上取整
def gradientdescent(learning_rate, a0, a1, a2, x1, x2, yi): # x1:房屋面积 单位:千feet**2; x2:房间数量; yi:房屋价格的真实值 单位:万$
count, a = 0, []
for i in range(len(yi)):
while count < 10000:
a0, a1, a2, x1[i], x2[i], yi[i], learning_rate = decimal.Decimal(a0), decimal.Decimal(a1), \
decimal.Decimal(a2), decimal.Decimal(x1[i]), \
decimal.Decimal(x2[i]), decimal.Decimal(yi[i]), \
decimal.Decimal(learning_rate)
J = decimal.Decimal(1 / 2) * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) ** 2 # J:损失函数
a0 = a0 - learning_rate * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) # 梯度下降计算a0
a1 = a1 - learning_rate * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x1[i] # 梯度下降计算a1
a2 = a2 - learning_rate * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x2[i] # 梯度下降计算a2
if 0 <= J - decimal.Decimal(1 / 2) * (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) ** 2 < 0.00001 and \
0 <= a0 + a1 * x1[i] + a2 * x2[i] - yi[i] <= 0.000001 and \
0 <= (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x1[i] <= 0.000001 and \
0 <= (a0 + a1 * x1[i] + a2 * x2[i] - yi[i]) * x2[i] <= 0.000001:
a.append([a0, a1, a2])
break
count += 1
print(f'第一组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[0]},'
f'第二组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[1]},'
f'第三组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[2]},'
f'第四组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[3]},'
f'第五组训练值得到的最小的损失函数对应的线性回归函数的系数为{a[4]}.')
return a
def verify(A, x1, x2):
Y_list = [] # 用梯度下降法得到的线性回归模型算出来的预测值,用来与真实值比较
for i in range(len(A)):
Y = A[i][0] + A[i][1] * decimal.Decimal(x1[i]) + A[i][2] * decimal.Decimal(x2[i])
Y_list.append(Y)
print(f'第一个预测值:{Y_list[0]}, 第二个预测值:{Y_list[1]}, 第三个预测值:{Y_list[2]}, 第四个预测值:{Y_list[3]}, 第五个预测值:{Y_list[4]}')
grad = gradientdescent(0.1, 0, 1, 1, [2.104, 1.600, 2.400, 1.416, 3.000], [3, 3, 3, 2, 4], [40.0, 33.0, 36.9, 23.2, 54.0])
verify(grad, [2.104, 1.600, 2.400, 1.416, 3.000], [3, 3, 3, 2, 4])
运行结果
第一组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.673876718510160019204096914391439045200052863069541805224781098532263420653296705003597778074836454’), Decimal(‘7.956852954170839274877895466755233374776639040226408202927314923289800060092752530733368662920025895’), Decimal(‘6.528301555304797803824627476965244426719948529777984312552428810084841757314553422970788265218600869’)]
第二组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.353446100142499098521371677672852926644124684676268631396267923277963110803997013853648209750483478’), Decimal(‘7.495432863721407526326806306254629650582593903127004441195724700753264064267482882924187465251742628’), Decimal(‘5.884620529127840587211435894661066871093601478411576870601004786757149856463281371312402722098662467’)]
第三组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.135494425141821066112276029498382388863205909527359943995108924971284162034302310654418104363015159’), Decimal(‘7.024657245719942996649057004873011295940022904084983333038654930770407022777703405028841245012233646’), Decimal(‘5.635109451587064370976472825653698091613144122668925474037909641308687565706506877418520949275162127’)]
第四组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘2.932664512997315719856367901580875346317498138265050518039295808522626058363133719158401876407202801’), Decimal(‘6.766170805682985398594123474790777912286411864689085906541723388972190721483751744936315466728731920’), Decimal(‘5.343218969545421141943483626620091223835148329903316963997779537164939132378513760247860835592288697’)]
第五组训练值得到的最小的损失函数对应的线性回归函数的系数为[Decimal(‘3.824120604491861713890478335961970487109515983399412363241708292688614716083923211794014256095138669’), Decimal(‘9.173102252718259567640495474325807705895776858862129039452930123936541738674159393254422123218779881’), Decimal(‘5.664143162483457537482490554983683650726319461018005602613825363893848802096210244702860349958431301’)].
第一个预测值:40.00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000,
第二个预测值:33.00000026948027356800674103850984781898315054905535011149023455240818981736569586449650965059031408,
第三个预测值:36.90000016963087674708453417429737639448552730722127339012969189134115341549318534685331979245893810,
第四个预测值:23.20000031293526482334899493166328343044307765994809569429857142169715129090130832355990228704099924,
第五个预测值:54.00000001258047056674192697887412820770212440405782189205580012007363514049124237036872202558520353