这里是sklearn与tf机器学习实用指南第一章节的示例代码分析。主要是一个线性模型,分析gdp和幸福指数的关系。
源码和数据
首先是幸福指数的数据
oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
oecd_bli.head(2)
目标是提取出来我们想要的数据,我们新获取的表示以country为行,indicator为列的表,示例数据如下:
out:
Indicator Air pollution Assault rate Consultation on rule-making \
Country
Australia 13.0 2.1 10.5
Austria 27.0 3.4 7.1
Indicator Dwellings without basic facilities Educational attainment \
Country
Australia 1.1 76.0
Austria 1.0 83.0
Indicator Employees working very long hours Employment rate Homicide rate \
Country
Australia 14.02 72.0 0.8
Austria 7.61 72.0 0.4
Indicator Household net adjusted disposable income \
Country
Australia 31588.0
Austria 31173.0
Indicator Household net financial wealth ... \
Country ...
Australia 47657.0 ...
Austria 49887.0 ...
Indicator Long-term unemployment rate Personal earnings \
Country
Australia 1.08 50449.0
Austria 1.19 45199.0
Indicator Quality of support network Rooms per person Self-reported health \
Country
Australia 92.0 2.3 85.0
Austria 89.0 1.6 69.0
Indicator Student skills Time devoted to leisure and personal care \
Country
Australia 512.0 14.41
Austria 500.0 14.46
Indicator Voter turnout Water quality Years in education
Country
Australia 93.0 91.0 19.4
Austria 75.0 94.0 17.0
[2 rows x 24 columns]
In [11]: oecd_bli["Life satisfaction"].head()
Out[11]:
Country
Australia 7.3
Austria 6.9
Belgium 6.9
Brazil 7.0
Canada 7.3
Name: Life satisfaction, dtype: float64
接着处理gdp的数据
gdp_per_capita = pd.read_csv(datapath+"gdp_per_capita.csv", thousands=',', delimiter='\t',
encoding='latin1', na_values="n/a")
gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
gdp_per_capita.set_index("Country", inplace=True)
gdp_per_capita.head(2)
这边我就不放出输出了,大家可以直接写写看看结果是怎么样,主要就是替换了一下column的名字,让城市成为index,目的是为了融合这两个表。
下面是融合表的代码
full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)
full_country_stats.sort_values(by="GDP per capita", inplace=True)
full_country_stats
输出在ipython里面显示比较诡异,还是暂时不放了。
In [19]: full_country_stats[["GDP per capita", 'Life satisfaction']].loc["United States"]
Out[19]:
GDP per capita 55805.204
Life satisfaction 7.200
Name: United States, dtype: float64
这里就看看出我们表的目的啦
我们合并的这个表country有36个
简单的划分一下数据
remove_indices = [0, 1, 6, 8, 33, 34, 35]
keep_indices = list(set(range(36)) - set(remove_indices))
sample_data = full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]
missing_data = full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[remove_indices]
下面画一下图
sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(5,3))
plt.axis([0, 60000, 0, 10])
position_text = {
"Hungary": (5000, 1),
"Korea": (18000, 1.7),
"France": (29000, 2.4),
"Australia": (40000, 3.0),
"United States": (52000, 3.8),
}
for country, pos_text in position_text.items():
pos_data_x, pos_data_y = sample_data.loc[country]
country = "U.S." if country == "United States" else country
plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,
arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))
plt.plot(pos_data_x, pos_data_y, "ro")
save_fig('money_happy_scatterplot')
plt.show()
代码过得差不多,放上完整的代码
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
def prepare_country_stats(oecd_bli, gdp_per_capita):
oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
gdp_per_capita.set_index("Country", inplace=True)
full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,
left_index=True, right_index=True)
full_country_stats.sort_values(by="GDP per capita", inplace=True)
remove_indices = [0, 1, 6, 8, 33, 34, 35]
keep_indices = list(set(range(36)) - set(remove_indices))
return full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]
# Load the data
oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv(datapath + "gdp_per_capita.csv",thousands=',',delimiter='\t',
encoding='latin1', na_values="n/a")
# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]
# Visualize the data
country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
plt.show()
# Select a linear model
model = sklearn.linear_model.LinearRegression()
# Train the model
model.fit(X, y)
# Make a prediction for Cyprus
X_new = [[22587]] # Cyprus' GDP per capita
print(model.predict(X_new)) # outputs [[ 5.96242338]]
这边主要是处理数据比较麻烦。