环境 Python3
要求:从mysql数据库中读取两列数据进行线性回归预测,例如从经济数据表中读取GDP列和year列,进行GDP对year的线性回归分析,并根据线性回归系数预测下一年的GDP值。
首先从mysql中读取数据并打印
from pymysql import *
import pandas as pd
import numpy as np
from sklearn import linear_model
from sqlalchemy import create_engine
import matplotlib.pyplot as plt
def show_linear_line(X_parameters,Y_parameters):
regr = linear_model.LinearRegression()
regr.fit(X_parameters, Y_parameters)
plt.scatter(X_parameters,Y_parameters,color='blue')
plt.plot(X_parameters,regr.predict(X_parameters),color='red',linewidth=4)
plt.xticks(())
plt.yticks(())
plt.show()
'''
查询数据库某一列保存为list
'''
conn = connect(host='10.x.xx.xxx', port=3306, database='population',
user='root',
password='root', charset='utf8')
cs1 = conn.cursor()
#四个参数,表名,列名1,列名2,预测年份
input1 = 'economy'
input2 = 'year'
input3 = 'GDP'
input4 = 2019
#读取第一列year
sql1 = "select (%s) from (%s); " % (input2,input1)
cs1.execute(sql1)
datalist1 = []
alldata1 = cs1.fetchall()
for s in alldata1:
datalist1.append(s[0])
print(datalist1)
#读取第一列GDP
sql2 = "select (%s) from (%s); " % (input3,input1)
cs1.execute(sql2)
datalist2 = []
alldata2 = cs1.fetchall()
for s in alldata2:
datalist2.append(s[0])
print(datalist2)
[1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018]
[109.0, 120.0, 139.0, 142.0, 155.0, 183.0, 217.0, 257.0, 285.0, 327.0, 410.0, 456.0, 501.0, 599.0, 709.0, 886.0, 1145.0, 1508.0, 1805.0, 2097.0, 2406.0, 2713.0, 3213.0, 3500.0, 3770.0, 4396.0, 6165.0, 7141.0, 8313.0, 10072.0, 11392.0, 12419.0, 14442.0, 16628.0, 18350.0, 20330.0, 21944.0, 23686.0, 25669.0, 28015.0, 30320.0]
修改自变量(year)列表格式,从[1,2,3,]变成[[1],[2],[3]]样式,之后才能调用sklearn.linear_model
#新版的sklearn中,所有的数据都应该是二维矩阵,哪怕它只是单独一行或一列(比如前面做预测时,仅仅只用了一个样本数据),所以需要使用.reshape(1,-1)进行转换
#datalist11=np.array(datalist1).reshape(len(datalist1),-1)
datalist11=[]
for i in datalist1:
list1=[i]
datalist11.append(list1)
#datalist22=np.array(datalist2).reshape(1,-1)
print(datalist11)
[[1978], [1979], [1980], [1981], [1982], [1983], [1984], [1985], [1986], [1987], [1988], [1989], [1990], [1991], [1992], [1993], [1994], [1995], [1996], [1997], [1998], [1999], [2000], [2001], [2002], [2003], [2004], [2005], [2006], [2007], [2008], [2009], [2010], [2011], [2012], [2013], [2014], [2015], [2016], [2017], [2018]]
show_linear_line(datalist11,datalist2)
#回归预测
regr = linear_model.LinearRegression(fit_intercept=True,normalize=False)
regr.fit(datalist11,datalist2)
datalist12 = []
datalist12.append(input4)
datalist12=np.array(datalist12).reshape(-1,1)
output_predictvalue=regr.predict(datalist12)
print(output_predictvalue)
预测结果为[20861.57682927]