21. 日月光华 Python数据分析 - 机器学习 - 多元线性回归

数据描述: 数据集包含了 200 个不同市场的产品销售额, 每个销售额对应 3 种广告媒体,分别是 TV, radio 和 newspaper

任务描述:分析广告媒体与销售额之间的关系,基于广告媒体预算,预测销售额

评价指标:销售额为连续值,为回归问题,可以采用均方误差作为评价指标 开发环境

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

plt.style.use('ggplot')
data = pd.read_csv('Advertising.csv')
data.head()

Unnamed: 0 TV radio newspaper sales

0 1 230.1 37.8 69.2 22.1

1 2 44.5 39.3 45.1 10.4

2 3 17.2 45.9 69.3 9.3

3 4 151.5 41.3 58.5 18.5

4 5 180.8 10.8 58.4 12.9

查看TV媒体投入与销售额之间的关系

plt.scatter(data.TV, data.sales)

![image.png](https://upload-images.jianshu.io/upload_images/3968643-552ffc024602ca65.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)

查看radio媒体投入与销售额之间的关系

plt.scatter(data.radio, data.sales)

![image.png](https://upload-images.jianshu.io/upload_images/3968643-a882d926f4ebae33.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)

查看newspaper媒体投入与销售额之间的关系

plt.scatter(data.newspaper, data.sales)

![image.png](https://upload-images.jianshu.io/upload_images/3968643-bb822cc9dff772b9.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)
# 如何客观的评价我们的模型

x = data[['TV','radio','newspaper']]
y = data.sales
x_train,x_test,y_train,y_test = train_test_split(x, y)
len(x_train),len(y_train)

(150, 150)

len(x_test)

50

model = LinearRegression()
model.fit(x_train, y_train)
model.coef_

array([ 0.04371177, 0.19392061, -0.00030117])

for i in zip(x_train.columns, model.coef_): # zip打包
print(i)

('TV', 0.04371177253650498)

('radio', 0.193920608289641)

('newspaper', -0.0003011660943684182) # 发现 newspaper投入对销量影响最小

mean_squared_error(model.predict(x_test), y_test)

5.22399779076229

# 模型的改进

newspaper对结果影响不大,因此去掉该特征,反而是预测效果更好

x = data[['TV','radio']]
y = data.sales
x_train,x_test,y_train,y_test = train_test_split(x, y)
model2 = LinearRegression()
model2.fit(x_train,y_train)
model2.coef_

array([0.04603784, 0.18428908])

mean_squared_error(model2.predict(x_test),y_test)

1.7138986453138052

你可能感兴趣的:(21. 日月光华 Python数据分析 - 机器学习 - 多元线性回归)