import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)
import seaborn as sns
color = sns.color_palette()
sns.set_style("whitegrid")
train_df = pd.read_csv('./used_car_train_20200313.csv', sep=' ')
print(train_df.shape)
train_df.describe()
(150000, 31)
SaleID | name | regDate | model | brand | bodyType | fuelType | gearbox | power | kilometer | regionCode | seller | offerType | creatDate | price | v_0 | v_1 | v_2 | v_3 | v_4 | v_5 | v_6 | v_7 | v_8 | v_9 | v_10 | v_11 | v_12 | v_13 | v_14 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 150000.000000 | 150000.000000 | 1.500000e+05 | 149999.000000 | 150000.000000 | 145494.000000 | 141320.000000 | 144019.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.0 | 1.500000e+05 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 | 150000.000000 |
mean | 74999.500000 | 68349.172873 | 2.003417e+07 | 47.129021 | 8.052733 | 1.792369 | 0.375842 | 0.224943 | 119.316547 | 12.597160 | 2583.077267 | 0.000007 | 0.0 | 2.016033e+07 | 5923.327333 | 44.406268 | -0.044809 | 0.080765 | 0.078833 | 0.017875 | 0.248204 | 0.044923 | 0.124692 | 0.058144 | 0.061996 | -0.001000 | 0.009035 | 0.004813 | 0.000313 | -0.000688 |
std | 43301.414527 | 61103.875095 | 5.364988e+04 | 49.536040 | 7.864956 | 1.760640 | 0.548677 | 0.417546 | 177.168419 | 3.919576 | 1885.363218 | 0.002582 | 0.0 | 1.067328e+02 | 7501.998477 | 2.457548 | 3.641893 | 2.929618 | 2.026514 | 1.193661 | 0.045804 | 0.051743 | 0.201410 | 0.029186 | 0.035692 | 3.772386 | 3.286071 | 2.517478 | 1.288988 | 1.038685 |
min | 0.000000 | 0.000000 | 1.991000e+07 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.500000 | 0.000000 | 0.000000 | 0.0 | 2.015062e+07 | 11.000000 | 30.451976 | -4.295589 | -4.470671 | -7.275037 | -4.364565 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | -9.168192 | -5.558207 | -9.639552 | -4.153899 | -6.546556 |
25% | 37499.750000 | 11156.000000 | 1.999091e+07 | 10.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 75.000000 | 12.500000 | 1018.000000 | 0.000000 | 0.0 | 2.016031e+07 | 1300.000000 | 43.135799 | -3.192349 | -0.970671 | -1.462580 | -0.921191 | 0.243615 | 0.000038 | 0.062474 | 0.035334 | 0.033930 | -3.722303 | -1.951543 | -1.871846 | -1.057789 | -0.437034 |
50% | 74999.500000 | 51638.000000 | 2.003091e+07 | 30.000000 | 6.000000 | 1.000000 | 0.000000 | 0.000000 | 110.000000 | 15.000000 | 2196.000000 | 0.000000 | 0.0 | 2.016032e+07 | 3250.000000 | 44.610266 | -3.052671 | -0.382947 | 0.099722 | -0.075910 | 0.257798 | 0.000812 | 0.095866 | 0.057014 | 0.058484 | 1.624076 | -0.358053 | -0.130753 | -0.036245 | 0.141246 |
75% | 112499.250000 | 118841.250000 | 2.007111e+07 | 66.000000 | 13.000000 | 3.000000 | 1.000000 | 0.000000 | 150.000000 | 15.000000 | 3843.000000 | 0.000000 | 0.0 | 2.016033e+07 | 7700.000000 | 46.004721 | 4.000670 | 0.241335 | 1.565838 | 0.868758 | 0.265297 | 0.102009 | 0.125243 | 0.079382 | 0.087491 | 2.844357 | 1.255022 | 1.776933 | 0.942813 | 0.680378 |
max | 149999.000000 | 196812.000000 | 2.015121e+07 | 247.000000 | 39.000000 | 7.000000 | 6.000000 | 1.000000 | 19312.000000 | 15.000000 | 8120.000000 | 1.000000 | 0.0 | 2.016041e+07 | 99999.000000 | 52.304178 | 7.320308 | 19.035496 | 9.854702 | 6.829352 | 0.291838 | 0.151420 | 1.404936 | 0.160791 | 0.222787 | 12.357011 | 18.819042 | 13.847792 | 11.147669 | 8.658418 |
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure()
sns.distplot(train_df['price'])
plt.figure()
train_df['price'].plot.box()
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-q2rsJKnU-1585058164601)(output_3_0.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rrUOYkdF-1585058164612)(output_3_1.png)]
# import gc
# test_df = pd.read_csv('./used_car_testA_20200313.csv', sep=' ')
# print(test_df.shape)
# df = pd.concat([train_df, test_df], axis=0, ignore_index=True)
# del train_df, test_df
# gc.collect()
# df.head()
## 通过 .info() 简要可以看到对应一些数据列名,以及NAN缺失信息
train_df.info()
RangeIndex: 150000 entries, 0 to 149999
Data columns (total 31 columns):
SaleID 150000 non-null int64
name 150000 non-null int64
regDate 150000 non-null int64
model 149999 non-null float64
brand 150000 non-null int64
bodyType 145494 non-null float64
fuelType 141320 non-null float64
gearbox 144019 non-null float64
power 150000 non-null int64
kilometer 150000 non-null float64
notRepairedDamage 150000 non-null object
regionCode 150000 non-null int64
seller 150000 non-null int64
offerType 150000 non-null int64
creatDate 150000 non-null int64
price 150000 non-null int64
v_0 150000 non-null float64
v_1 150000 non-null float64
v_2 150000 non-null float64
v_3 150000 non-null float64
v_4 150000 non-null float64
v_5 150000 non-null float64
v_6 150000 non-null float64
v_7 150000 non-null float64
v_8 150000 non-null float64
v_9 150000 non-null float64
v_10 150000 non-null float64
v_11 150000 non-null float64
v_12 150000 non-null float64
v_13 150000 non-null float64
v_14 150000 non-null float64
dtypes: float64(20), int64(10), object(1)
memory usage: 35.5+ MB
## 1) 查看每列的存在nan情况
train_df.isnull().sum()
SaleID 0
name 0
regDate 0
model 1
brand 0
bodyType 4506
fuelType 8680
gearbox 5981
power 0
kilometer 0
notRepairedDamage 0
regionCode 0
seller 0
offerType 0
creatDate 0
price 0
v_0 0
v_1 0
v_2 0
v_3 0
v_4 0
v_5 0
v_6 0
v_7 0
v_8 0
v_9 0
v_10 0
v_11 0
v_12 0
v_13 0
v_14 0
dtype: int64
##取出列名
columns = train_df.columns.values.tolist()
print(columns)
['SaleID', 'name', 'regDate', 'model', 'brand', 'bodyType', 'fuelType', 'gearbox', 'power', 'kilometer', 'notRepairedDamage', 'regionCode', 'seller', 'offerType', 'creatDate', 'price', 'v_0', 'v_1', 'v_2', 'v_3', 'v_4', 'v_5', 'v_6', 'v_7', 'v_8', 'v_9', 'v_10', 'v_11', 'v_12', 'v_13', 'v_14']
for i in columns:
print(i, train_df[i].value_counts())
# print('SaleID',train_df['SaleID'].value_counts())
# print('name', train_df['name'].value_counts())
# print('regDate', train_df['regDate'].value_counts())
# print('model', train_df['model'].value_counts())
# print('brand', train_df['brand'].value_counts())
# print('brand', train_df['brand'].value_counts())
SaleID 2047 1
113949 1
15661 1
13612 1
3371 1
..
8913 1
10960 1
53967 1
56014 1
0 1
Name: SaleID, Length: 150000, dtype: int64
name 708 282
387 282
55 280
1541 263
203 233
...
5074 1
7123 1
11221 1
13270 1
174485 1
Name: name, Length: 99662, dtype: int64
regDate 20000008 180
20000011 158
20000004 157
20000010 157
20000002 155
...
19910807 1
19910902 1
20151209 1
19911011 1
20151201 1
Name: regDate, Length: 3894, dtype: int64
model 0.0 11762
19.0 9573
4.0 8445
1.0 6038
29.0 5186
...
245.0 2
209.0 2
240.0 2
242.0 2
247.0 1
Name: model, Length: 248, dtype: int64
brand 0 31480
4 16737
14 16089
10 14249
1 13794
6 10217
9 7306
5 4665
13 3817
11 2945
3 2461
7 2361
16 2223
8 2077
25 2064
27 2053
21 1547
15 1458
19 1388
20 1236
12 1109
22 1085
26 966
30 940
17 913
24 772
28 649
32 592
29 406
37 333
2 321
31 318
18 316
36 228
34 227
33 218
23 186
35 180
38 65
39 9
Name: brand, dtype: int64
bodyType 0.0 41420
1.0 35272
2.0 30324
3.0 13491
4.0 9609
5.0 7607
6.0 6482
7.0 1289
Name: bodyType, dtype: int64
fuelType 0.0 91656
1.0 46991
2.0 2212
3.0 262
4.0 118
5.0 45
6.0 36
Name: fuelType, dtype: int64
gearbox 0.0 111623
1.0 32396
Name: gearbox, dtype: int64
power 0 12829
75 9593
150 6495
60 6374
140 5963
...
1597 1
1596 1
572 1
316 1
575 1
Name: power, Length: 566, dtype: int64
kilometer 15.0 96877
12.5 15722
10.0 6459
9.0 5257
8.0 4573
7.0 4084
6.0 3725
5.0 3144
4.0 2718
3.0 2501
2.0 2354
0.5 1840
1.0 746
Name: kilometer, dtype: int64
notRepairedDamage 0.0 111361
- 24324
1.0 14315
Name: notRepairedDamage, dtype: int64
regionCode 419 369
764 258
125 137
176 136
462 134
...
6414 1
7063 1
4239 1
5931 1
7267 1
Name: regionCode, Length: 7905, dtype: int64
seller 0 149999
1 1
Name: seller, dtype: int64
offerType 0 150000
Name: offerType, dtype: int64
creatDate 20160403 5848
20160404 5606
20160320 5485
20160312 5383
20160402 5382
...
20151227 1
20151217 1
20160131 1
20160130 1
20160115 1
Name: creatDate, Length: 96, dtype: int64
price 500 2337
1500 2158
1200 1922
1000 1850
2500 1821
...
25321 1
8886 1
8801 1
37920 1
8188 1
Name: price, Length: 3763, dtype: int64
v_0 45.349115 20
48.087217 16
47.568450 15
48.618150 15
47.840357 15
..
44.752849 1
47.710369 1
45.626634 1
43.795918 1
42.340691 1
Name: v_0, Length: 143997, dtype: int64
v_1 -3.245133 20
3.183323 16
1.942732 15
3.354949 15
2.796739 15
..
3.418050 1
-2.994782 1
-3.022811 1
-3.220458 1
-3.309574 1
Name: v_1, Length: 143998, dtype: int64
v_2 -0.349860 20
0.826577 16
0.887762 15
-0.158364 15
0.980616 15
..
-0.722279 1
0.214464 1
0.019754 1
-1.154275 1
-0.891405 1
Name: v_2, Length: 143997, dtype: int64
v_3 -0.218201 20
-1.312990 16
-2.006612 15
-1.619873 15
-1.612432 15
..
3.480942 1
-1.724122 1
3.100129 1
-3.044833 1
1.927354 1
Name: v_3, Length: 143998, dtype: int64
v_4 -1.626828 20
0.696775 16
-0.365900 15
-0.364986 15
-0.379815 15
..
0.243463 1
1.202133 1
-0.973844 1
-0.550933 1
-0.629790 1
Name: v_4, Length: 143998, dtype: int64
v_5 0.000000 4485
0.269406 20
0.256226 16
0.261082 15
0.277097 15
...
0.237881 1
0.228538 1
0.272297 1
0.243947 1
0.254857 1
Name: v_5, Length: 139624, dtype: int64
v_6 0.000000 35465
0.000053 20
0.093472 16
0.075212 15
0.087562 15
...
0.112769 1
0.000540 1
0.122172 1
0.111154 1
0.111825 1
Name: v_6, Length: 109766, dtype: int64
v_7 0.000000 5467
0.124213 20
0.130272 16
0.139667 15
0.051680 15
...
0.051634 1
0.037816 1
0.024212 1
0.077910 1
0.125709 1
Name: v_7, Length: 138709, dtype: int64
v_8 0.000000 1597
0.067358 20
0.074742 16
0.073268 15
0.075905 15
...
0.039385 1
0.083553 1
0.032876 1
0.052331 1
0.037786 1
Name: v_8, Length: 142451, dtype: int64
v_9 0.000000 3486
0.014867 20
0.082765 16
0.101150 15
0.051535 15
...
0.067019 1
0.030366 1
0.095158 1
0.028713 1
0.098763 1
Name: v_9, Length: 140617, dtype: int64
v_10 2.329386 20
-4.303481 16
-3.163236 15
-4.757359 15
-4.383929 15
..
3.076637 1
-4.074684 1
2.034086 1
2.212329 1
2.903413 1
Name: v_10, Length: 143997, dtype: int64
v_11 -2.255591 20
-0.330053 16
-1.107940 15
-0.802614 15
-1.436494 15
..
-2.418401 1
-0.622637 1
0.195040 1
-0.802083 1
0.920200 1
Name: v_11, Length: 143997, dtype: int64
v_12 0.847433 20
2.486297 16
2.375256 15
3.097963 15
2.104470 15
..
0.342035 1
-2.909754 1
0.756550 1
-2.862605 1
-1.811578 1
Name: v_12, Length: 143997, dtype: int64
v_13 -1.698497 20
-0.043463 16
0.548697 15
-0.834697 15
-0.499453 15
..
-2.415566 1
-1.634300 1
-1.192429 1
-0.265945 1
1.151207 1
Name: v_13, Length: 143998, dtype: int64
v_14 0.003015 20
-2.290344 16
-3.059444 15
1.027487 15
0.869586 15
..
0.273255 1
1.406113 1
0.530243 1
0.423112 1
0.025664 1
Name: v_14, Length: 143998, dtype: int64
train_df['notRepairedDamage'].value_counts()
0.0 111361
- 24324
1.0 14315
Name: notRepairedDamage, dtype: int64
train_df['notRepairedDamage'].replace('-', np.nan, inplace=True)
train_df['notRepairedDamage'].value_counts()
0.0 111361
1.0 14315
Name: notRepairedDamage, dtype: int64
###查找类别变量和数值变量
#### 看出saleID这些是树数值变量;
#######
date_cols = ['regDate', 'creatDate']
cate_cols = ['name', 'model', 'brand', 'bodyType', 'fuelType', 'gearbox', 'notRepairedDamage', 'regionCode', 'seller', 'offerType']
num_cols = ['power', 'kilometer'] + ['v_{}'.format(i) for i in range(15)]
cols = date_cols + cate_cols + num_cols
tmp = pd.DataFrame()
tmp['count'] = df[cols].count().values
tmp['missing_rate'] = (df.shape[0] - tmp['count']) / df.shape[0]
tmp['nunique'] = df[cols].nunique().values
tmp['max_value_counts'] = [df[f].value_counts().values[0] for f in cols]
tmp['max_value_counts_prop'] = tmp['max_value_counts'] / df.shape[0]
tmp['max_value_counts_value'] = [df[f].value_counts().index[0] for f in cols]
tmp.index = cols
tmp
count | missing_rate | nunique | max_value_counts | max_value_counts_prop | max_value_counts_value | |
---|---|---|---|---|---|---|
regDate | 200000 | 0.000000 | 3900 | 228 | 0.001140 | 20000008 |
creatDate | 200000 | 0.000000 | 101 | 7814 | 0.039070 | 20160403 |
name | 200000 | 0.000000 | 128466 | 378 | 0.001890 | 708 |
model | 199999 | 0.000005 | 248 | 15658 | 0.078290 | 0 |
brand | 200000 | 0.000000 | 40 | 41828 | 0.209140 | 0 |
bodyType | 194081 | 0.029595 | 8 | 55405 | 0.277025 | 0 |
fuelType | 188427 | 0.057865 | 7 | 122312 | 0.611560 | 0 |
gearbox | 192109 | 0.039455 | 2 | 148924 | 0.744620 | 0 |
notRepairedDamage | 200000 | 0.000000 | 3 | 148610 | 0.743050 | 0.0 |
regionCode | 200000 | 0.000000 | 8021 | 515 | 0.002575 | 419 |
seller | 200000 | 0.000000 | 2 | 199999 | 0.999995 | 0 |
offerType | 200000 | 0.000000 | 1 | 200000 | 1.000000 | 0 |
power | 200000 | 0.000000 | 623 | 17024 | 0.085120 | 0 |
kilometer | 200000 | 0.000000 | 13 | 129066 | 0.645330 | 15 |
v_0 | 200000 | 0.000000 | 189804 | 26 | 0.000130 | 45.3491 |
v_1 | 200000 | 0.000000 | 189805 | 26 | 0.000130 | -3.24513 |
v_2 | 200000 | 0.000000 | 189804 | 26 | 0.000130 | -0.34986 |
v_3 | 200000 | 0.000000 | 189805 | 26 | 0.000130 | -0.218201 |
v_4 | 200000 | 0.000000 | 189805 | 26 | 0.000130 | -1.62683 |
v_5 | 200000 | 0.000000 | 184098 | 5893 | 0.029465 | 0 |
v_6 | 200000 | 0.000000 | 144757 | 47171 | 0.235855 | 0 |
v_7 | 200000 | 0.000000 | 182826 | 7322 | 0.036610 | 0 |
v_8 | 200000 | 0.000000 | 187737 | 2163 | 0.010815 | 0 |
v_9 | 200000 | 0.000000 | 185365 | 4614 | 0.023070 | 0 |
v_10 | 200000 | 0.000000 | 189804 | 26 | 0.000130 | 2.32939 |
v_11 | 200000 | 0.000000 | 189804 | 26 | 0.000130 | -2.25559 |
v_12 | 200000 | 0.000000 | 189804 | 26 | 0.000130 | 0.847433 |
v_13 | 200000 | 0.000000 | 189805 | 26 | 0.000130 | -1.6985 |
v_14 | 200000 | 0.000000 | 189805 | 26 | 0.000130 | 0.00301494 |
from tqdm import tqdm
def date_proc(x):
m = int(x[4:6])
if m == 0:
m = 1
return x[:4] + '-' + str(m) + '-' + x[6:]
for f in tqdm(date_cols):
df[f] = pd.to_datetime(df[f].astype('str').apply(date_proc))
df[f + '_year'] = df[f].dt.year
df[f + '_month'] = df[f].dt.month
df[f + '_day'] = df[f].dt.day
df[f + '_dayofweek'] = df[f].dt.dayofweek
100%|██████████| 2/2 [00:01<00:00, 1.41it/s]
plt.figure()
plt.figure(figsize=(16, 6))
i = 1
for f in date_cols:
for col in ['year', 'month', 'day', 'dayofweek']:
plt.subplot(2, 4, i)
i += 1
v = df[f + '_' + col].value_counts()
fig = sns.barplot(x=v.index, y=v.values)
for item in fig.get_xticklabels():
item.set_rotation(90)
plt.title(f + '_' + col)
plt.tight_layout()
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-X6a7osK2-1585058164619)(output_25_1.png)]
cate_cols.remove('seller')
cate_cols.remove('offerType')
date_cols = ['regDate_year', 'regDate_month', 'regDate_day', 'regDate_dayofweek', 'creatDate_month', 'creatDate_day', 'creatDate_dayofweek']
corr1 = abs(df[~df['price'].isnull()][['price'] + date_cols + num_cols].corr())
plt.figure(figsize=(10, 10))
sns.heatmap(corr1, linewidths=0.1, cmap=sns.cm.rocket_r)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kuzAqtM9-1585058164619)(output_28_1.png)]
plt.figure()
plt.figure(figsize=(15, 15))
i = 1
for f in num_cols[2:]:
plt.subplot(5, 3, i)
i += 1
sns.distplot(df[~df['price'].isnull()][f], label='train', color='y', hist=False)
sns.distplot(df[df['price'].isnull()][f], label='test', color='g', hist=False)
plt.tight_layout()
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p7HWmJIz-1585058164620)(output_30_1.png)]
plt.figure()
plt.figure(figsize=(20, 18))
i = 1
for f in cate_cols + date_cols + num_cols:
if df[f].nunique() <= 50:
plt.subplot(5, 3, i)
i += 1
v = df[~df['price'].isnull()].groupby(f)['price'].agg({
f + '_price_mean': 'mean'}).reset_index()
fig = sns.barplot(x=f, y=f + '_price_mean', data=v)
for item in fig.get_xticklabels():
item.set_rotation(90)
plt.tight_layout()
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ra7fJKEo-1585058164626)(output_32_1.png)]
df['notRepairedDamage'] = df['notRepairedDamage'].astype('str').apply(lambda x: x if x != '-' else None).astype('float16')
from scipy.stats import entropy
feat_cols = []
### count编码
for f in tqdm([
'regDate', 'creatDate', 'regDate_year',
'model', 'brand', 'regionCode'
]):
df[f + '_count'] = df[f].map(df[f].value_counts())
feat_cols.append(f + '_count')
### 用数值特征对类别特征做统计刻画,随便挑了几个跟price相关性最高的匿名特征
for f1 in tqdm(['model', 'brand', 'regionCode']):
g = df.groupby(f1, as_index=False)
for f2 in tqdm(['v_0', 'v_3', 'v_8', 'v_12']):
feat = g[f2].agg({
'{}_{}_max'.format(f1, f2): 'max', '{}_{}_min'.format(f1, f2): 'min',
'{}_{}_median'.format(f1, f2): 'median', '{}_{}_mean'.format(f1, f2): 'mean',
'{}_{}_std'.format(f1, f2): 'std', '{}_{}_mad'.format(f1, f2): 'mad'
})
df = df.merge(feat, on=f1, how='left')
feat_list = list(feat)
feat_list.remove(f1)
feat_cols.extend(feat_list)
### 类别特征的二阶交叉
for f_pair in tqdm([
['model', 'brand'], ['model', 'regionCode'], ['brand', 'regionCode']
]):
### 共现次数
df['_'.join(f_pair) + '_count'] = df.groupby(f_pair)['SaleID'].transform('count')
### n unique、熵
df = df.merge(df.groupby(f_pair[0], as_index=False)[f_pair[1]].agg({
'{}_{}_nunique'.format(f_pair[0], f_pair[1]): 'nunique',
'{}_{}_ent'.format(f_pair[0], f_pair[1]): lambda x: entropy(x.value_counts() / x.shape[0])
}), on=f_pair[0], how='left')
df = df.merge(df.groupby(f_pair[1], as_index=False)[f_pair[0]].agg({
'{}_{}_nunique'.format(f_pair[1], f_pair[0]): 'nunique',
'{}_{}_ent'.format(f_pair[1], f_pair[0]): lambda x: entropy(x.value_counts() / x.shape[0])
}), on=f_pair[1], how='left')
### 比例偏好
df['{}_in_{}_prop'.format(f_pair[0], f_pair[1])] = df['_'.join(f_pair) + '_count'] / df[f_pair[1] + '_count']
df['{}_in_{}_prop'.format(f_pair[1], f_pair[0])] = df['_'.join(f_pair) + '_count'] / df[f_pair[0] + '_count']
feat_cols.extend([
'_'.join(f_pair) + '_count',
'{}_{}_nunique'.format(f_pair[0], f_pair[1]), '{}_{}_ent'.format(f_pair[0], f_pair[1]),
'{}_{}_nunique'.format(f_pair[1], f_pair[0]), '{}_{}_ent'.format(f_pair[1], f_pair[0]),
'{}_in_{}_prop'.format(f_pair[0], f_pair[1]), '{}_in_{}_prop'.format(f_pair[1], f_pair[0])
])
100%|██████████| 6/6 [00:01<00:00, 3.36it/s]
0%| | 0/3 [00:00, ?it/s]
0%| | 0/4 [00:00, ?it/s][A
25%|██▌ | 1/4 [00:00<00:00, 3.20it/s][A
50%|█████ | 2/4 [00:00<00:00, 3.89it/s][A
75%|███████▌ | 3/4 [00:00<00:00, 4.17it/s][A
100%|██████████| 4/4 [00:00<00:00, 4.32it/s][A
33%|███▎ | 1/3 [00:00<00:01, 1.02it/s]
0%| | 0/4 [00:00, ?it/s][A
25%|██▌ | 1/4 [00:00<00:01, 2.02it/s][A
50%|█████ | 2/4 [00:00<00:00, 3.35it/s][A
75%|███████▌ | 3/4 [00:00<00:00, 4.27it/s][A
100%|██████████| 4/4 [00:00<00:00, 4.95it/s][A
67%|██████▋ | 2/3 [00:01<00:00, 1.08it/s]
0%| | 0/4 [00:00, ?it/s][A
25%|██▌ | 1/4 [00:04<00:12, 4.10s/it][A
50%|█████ | 2/4 [00:07<00:07, 3.79s/it][A
75%|███████▌ | 3/4 [00:11<00:03, 3.87s/it][A
100%|██████████| 4/4 [00:15<00:00, 3.78s/it][A
100%|██████████| 3/3 [00:17<00:00, 5.72s/it]
100%|██████████| 3/3 [00:17<00:00, 5.98s/it]
from sklearn.model_selection import KFold
train_df = df[~df['price'].isnull()].reset_index(drop=True)
test_df = df[df['price'].isnull()].reset_index(drop=True)
### target encoding目标编码,回归场景相对来说做目标编码的选择更多,不仅可以做均值编码,还可以做标准差编码、中位数编码等
enc_cols = []
stats_default_dict = {
'max': train_df['price'].max(),
'min': train_df['price'].min(),
'median': train_df['price'].median(),
'mean': train_df['price'].mean(),
'sum': train_df['price'].sum(),
'std': train_df['price'].std(),
'skew': train_df['price'].skew(),
'kurt': train_df['price'].kurt(),
'mad': train_df['price'].mad()
}
### 暂且选择这三种编码
enc_stats = ['mean', 'std', 'mad']
skf = KFold(n_splits=5, shuffle=True, random_state=2020)
for f in tqdm(['model', 'brand', 'regionCode']):
enc_dict = {
}
for stat in enc_stats:
enc_dict['{}_target_{}'.format(f, stat)] = stat
train_df['{}_target_{}'.format(f, stat)] = 0
test_df['{}_target_{}'.format(f, stat)] = 0
enc_cols.append('{}_target_{}'.format(f, stat))
for i, (trn_idx, val_idx) in enumerate(skf.split(train_df, train_df['price'])):
trn_x, val_x = train_df.iloc[trn_idx].reset_index(drop=True), train_df.iloc[val_idx].reset_index(drop=True)
enc_df = trn_x.groupby(f, as_index=False)['price'].agg(enc_dict)
val_x = val_x[[f]].merge(enc_df, on=f, how='left')
test_x = test_df[[f]].merge(enc_df, on=f, how='left')
for stat in enc_stats:
val_x['{}_target_{}'.format(f, stat)] = val_x['{}_target_{}'.format(f, stat)].fillna(stats_default_dict[stat])
test_x['{}_target_{}'.format(f, stat)] = test_x['{}_target_{}'.format(f, stat)].fillna(stats_default_dict[stat])
train_df.loc[val_idx, '{}_target_{}'.format(f, stat)] = val_x['{}_target_{}'.format(f, stat)].values
test_df['{}_target_{}'.format(f, stat)] += test_x['{}_target_{}'.format(f, stat)].values / skf.n_splits
cols = cate_cols + date_cols + num_cols + feat_cols + enc_cols
sub = test_df[['SaleID']].copy()
test_df = test_df[cols]
labels = train_df['price'].values
train_df = train_df[cols]
print(train_df.shape)
train_df.head()
100%|██████████| 3/3 [00:23<00:00, 7.91s/it]
(150000, 140)
name | model | brand | bodyType | fuelType | gearbox | notRepairedDamage | regionCode | regDate_year | regDate_month | regDate_day | regDate_dayofweek | creatDate_month | creatDate_day | creatDate_dayofweek | power | kilometer | v_0 | v_1 | v_2 | v_3 | v_4 | v_5 | v_6 | v_7 | v_8 | v_9 | v_10 | v_11 | v_12 | v_13 | v_14 | regDate_count | creatDate_count | regDate_year_count | model_count | brand_count | regionCode_count | model_v_0_max | model_v_0_min | model_v_0_median | model_v_0_mean | model_v_0_std | model_v_0_mad | model_v_3_max | model_v_3_min | model_v_3_median | model_v_3_mean | model_v_3_std | model_v_3_mad | model_v_8_max | model_v_8_min | model_v_8_median | model_v_8_mean | model_v_8_std | model_v_8_mad | model_v_12_max | model_v_12_min | model_v_12_median | model_v_12_mean | model_v_12_std | model_v_12_mad | brand_v_0_max | brand_v_0_min | brand_v_0_median | brand_v_0_mean | brand_v_0_std | brand_v_0_mad | brand_v_3_max | brand_v_3_min | brand_v_3_median | brand_v_3_mean | brand_v_3_std | brand_v_3_mad | brand_v_8_max | brand_v_8_min | brand_v_8_median | brand_v_8_mean | brand_v_8_std | brand_v_8_mad | brand_v_12_max | brand_v_12_min | brand_v_12_median | brand_v_12_mean | brand_v_12_std | brand_v_12_mad | regionCode_v_0_max | regionCode_v_0_min | regionCode_v_0_median | regionCode_v_0_mean | regionCode_v_0_std | regionCode_v_0_mad | regionCode_v_3_max | regionCode_v_3_min | regionCode_v_3_median | regionCode_v_3_mean | regionCode_v_3_std | regionCode_v_3_mad | regionCode_v_8_max | regionCode_v_8_min | regionCode_v_8_median | regionCode_v_8_mean | regionCode_v_8_std | regionCode_v_8_mad | regionCode_v_12_max | regionCode_v_12_min | regionCode_v_12_median | regionCode_v_12_mean | regionCode_v_12_std | regionCode_v_12_mad | model_brand_count | model_brand_nunique | model_brand_ent | brand_model_nunique | brand_model_ent | model_in_brand_prop | brand_in_model_prop | model_regionCode_count | model_regionCode_nunique | model_regionCode_ent | regionCode_model_nunique | regionCode_model_ent | model_in_regionCode_prop | regionCode_in_model_prop | brand_regionCode_count | brand_regionCode_nunique | brand_regionCode_ent | regionCode_brand_nunique | regionCode_brand_ent | brand_in_regionCode_prop | regionCode_in_brand_prop | model_target_mean | model_target_std | model_target_mad | brand_target_mean | brand_target_std | brand_target_mad | regionCode_target_mean | regionCode_target_std | regionCode_target_mad | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 736 | 30.0 | 6 | 1.0 | 0.0 | 0.0 | 0.0 | 1046 | 2004 | 4 | 2 | 4 | 4 | 4 | 0 | 60 | 12.5 | 43.357796 | 3.966344 | 0.050257 | 2.159744 | 1.143786 | 0.235676 | 0.101988 | 0.129549 | 0.022816 | 0.097462 | -2.881803 | 2.804097 | -2.420821 | 0.795292 | 0.914762 | 100 | 7416 | 11899 | 3145.0 | 13719 | 47 | 46.342726 | 28.987024 | 42.918210 | 42.750103 | 2.039852 | 1.335433 | 7.375598 | -3.651126 | 1.361821 | 1.119369 | 2.000747 | 1.685237 | 0.132452 | 0.000000 | 0.041461 | 0.044344 | 0.028413 | 0.023923 | 9.825077 | -7.781645 | -1.560370 | -1.352894 | 2.292161 | 1.896074 | 49.216535 | 28.987024 | 43.518332 | 43.331114 | 2.282470 | 1.548165 | 9.121568 | -5.433129 | 0.953054 | 0.733410 | 2.075349 | 1.700596 | 0.153849 | 0.0 | 0.045827 | 0.048987 | 0.029004 | 0.024024 | 11.778541 | -8.293423 | -1.411952 | -1.041002 | 2.414547 | 1.953658 | 50.376809 | 33.511607 | 45.942332 | 45.445490 | 2.369732 | 1.507840 | 6.062231 | -4.480148 | -0.613237 | -0.180360 | 2.049067 | 1.627380 | 0.132026 | 0.000000 | 0.063441 | 0.061578 | 0.029357 | 0.024186 | 6.353017 | -4.362750 | 0.558707 | 0.699833 | 2.183664 | 1.773069 | 3145.0 | 1.0 | 0.000000 | 15 | 2.161420 | 0.229244 | 1.000000 | 1.0 | 1938.0 | 7.390204 | 27 | 3.062849 | 0.021277 | 0.000318 | 3 | 4553 | 8.091984 | 14 | 2.234541 | 0.063830 | 0.000219 | 2774.671714 | 3039.693856 | 2298.168884 | 3613.811755 | 4698.830289 | 3197.888310 | 8415.931034 | 10276.594200 | 6996.870392 |
1 | 2262 | 40.0 | 1 | 2.0 | 0.0 | 0.0 | NaN | 4366 | 2003 | 3 | 1 | 5 | 3 | 9 | 2 | 0 | 15.0 | 45.305273 | 5.236112 | 0.137925 | 1.380657 | -1.422165 | 0.264777 | 0.121004 | 0.135731 | 0.026597 | 0.020582 | -4.900482 | 2.096338 | -1.030483 | -1.722674 | 0.245522 | 116 | 6937 | 12074 | 5911.0 | 18326 | 10 | 49.517245 | 33.708559 | 45.717711 | 45.594824 | 1.643656 | 1.135157 | 7.849033 | -4.628179 | -0.051705 | 0.065820 | 1.824363 | 1.539744 | 0.129159 | 0.000000 | 0.056804 | 0.056247 | 0.026896 | 0.022472 | 10.572566 | -5.644524 | 0.735713 | 0.650309 | 2.031854 | 1.690149 | 50.735456 | 31.524342 | 46.038090 | 45.836925 | 2.083827 | 1.463073 | 9.201960 | -5.411330 | -0.611653 | -0.448543 | 1.993718 | 1.670792 | 0.153444 | 0.0 | 0.064640 | 0.064518 | 0.029553 | 0.024529 | 12.973057 | -6.589651 | 1.299017 | 1.201088 | 2.362357 | 1.941684 | 48.151271 | 37.210840 | 45.141957 | 44.408016 | 2.905189 | 1.872938 | 2.132799 | -3.210486 | -1.729640 | -0.641526 | 2.017365 | 1.808277 | 0.096599 | 0.020296 | 0.080059 | 0.066296 | 0.030225 | 0.026461 | 5.690846 | -2.829548 | 1.329232 | 1.163445 | 2.361794 | 1.661170 | 5911.0 | 1.0 | 0.000000 | 17 | 2.033879 | 0.322547 | 1.000000 | 2.0 | 3201.0 | 7.864350 | 7 | 1.886697 | 0.200000 | 0.000338 | 2 | 5432 | 8.273403 | 7 | 1.886697 | 0.200000 | 0.000109 | 6769.837131 | 6267.143090 | 4722.694793 | 9241.418072 | 9358.438119 | 6902.470132 | 5974.500000 | 3637.599987 | 2774.500000 |
2 | 14874 | 115.0 | 15 | 1.0 | 0.0 | 0.0 | 0.0 | 2806 | 2004 | 4 | 3 | 5 | 4 | 2 | 5 | 163 | 12.5 | 45.978359 | 4.823792 | 1.319524 | -0.998467 | -0.996911 | 0.251410 | 0.114912 | 0.165147 | 0.062173 | 0.027075 | -4.846749 | 1.803559 | 1.565330 | -0.832687 | -0.229963 | 102 | 7092 | 11899 | 1249.0 | 1969 | 19 | 49.023367 | 34.908390 | 46.523215 | 46.462299 | 1.329109 | 0.905913 | 4.719014 | -4.147652 | -1.652962 | -1.578260 | 1.176329 | 0.962648 | 0.136903 | 0.034065 | 0.075430 | 0.079243 | 0.021157 | 0.017590 | 11.064900 | -0.897979 | 2.176214 | 2.343856 | 1.380674 | 1.041263 | 49.023367 | 34.908390 | 46.311901 | 46.200290 | 1.506889 | 0.959293 | 6.473598 | -4.640231 | -1.590475 | -1.538817 | 1.149450 | 0.908954 | 0.138584 | 0.0 | 0.077238 | 0.080276 | 0.020360 | 0.016847 | 11.064900 | -5.134526 | 2.070449 | 2.267722 | 1.438300 | 1.042965 | 47.727832 | 32.526985 | 44.869182 | 44.612175 | 3.239949 | 1.850588 | 6.840967 | -3.751112 | -0.073282 | -0.168359 | 2.306767 | 1.516114 | 0.115919 | 0.001932 | 0.059694 | 0.061692 | 0.026073 | 0.017904 | 4.274930 | -2.240666 | 0.024260 | 0.574365 | 1.652139 | 1.319930 | 1249.0 | 1.0 | 0.000000 | 5 | 0.957182 | 0.634332 | 1.000000 | 3.0 | 948.0 | 6.751637 | 12 | 2.333197 | 0.157895 | 0.002402 | 3 | 1374 | 7.081183 | 9 | 2.104653 | 0.157895 | 0.001524 | 10902.680217 | 5741.391564 | 4742.415530 | 9826.330479 | 5330.449114 | 4283.976617 | 6625.615385 | 7177.660814 | 5590.236686 |
3 | 71865 | 109.0 | 10 | 0.0 | 0.0 | 1.0 | 0.0 | 434 | 1996 | 9 | 8 | 6 | 3 | 12 | 5 | 193 | 15.0 | 45.687478 | 4.492574 | -0.050616 | 0.883600 | -2.228079 | 0.274293 | 0.110300 | 0.121964 | 0.033395 | 0.000000 | -4.509599 | 1.285940 | -0.501868 | -2.438353 | -0.478699 | 32 | 7185 | 6411 | 504.0 | 19015 | 47 | 49.593459 | 34.331296 | 46.216312 | 46.256499 | 1.607789 | 1.054037 | 7.351431 | -4.903170 | -0.676848 | -0.904197 | 1.617832 | 1.346071 | 0.135339 | 0.000000 | 0.069205 | 0.071043 | 0.025596 | 0.021051 | 11.338190 | -4.235065 | 1.491440 | 1.833323 | 1.966263 | 1.589044 | 50.885397 | 32.521871 | 45.863752 | 45.759286 | 1.849069 | 1.268704 | 9.381599 | -5.500692 | -0.251426 | -0.303559 | 1.870598 | 1.543436 | 0.159710 | 0.0 | 0.060434 | 0.062277 | 0.028016 | 0.023270 | 13.562011 | -7.195385 | 0.709913 | 0.914852 | 2.163917 | 1.748443 | 47.498653 | 33.866864 | 44.567279 | 44.218650 | 2.545722 | 1.769000 | 4.006365 | -3.611927 | -0.159675 | -0.004396 | 2.082738 | 1.748414 | 0.113320 | 0.009842 | 0.063436 | 0.061249 | 0.030280 | 0.026340 | 7.710618 | -4.132363 | 0.089942 | 0.131581 | 2.538530 | 1.995444 | 504.0 | 1.0 | 0.000000 | 19 | 2.269429 | 0.026505 | 1.000000 | 1.0 | 409.0 | 5.934558 | 22 | 2.815751 | 0.021277 | 0.001984 | 3 | 4914 | 8.151037 | 15 | 2.520178 | 0.063830 | 0.000158 | 12845.219355 | 12622.845286 | 8920.376358 | 8473.354291 | 9048.727516 | 6463.659835 | 4756.730769 | 4136.174450 | 3471.263314 |
4 | 111080 | 110.0 | 5 | 1.0 | 0.0 | 0.0 | 0.0 | 6977 | 2012 | 1 | 3 | 1 | 3 | 13 | 6 | 68 | 5.0 | 44.383511 | 2.031433 | 0.572169 | -1.571239 | 2.246088 | 0.228036 | 0.073205 | 0.091880 | 0.078819 | 0.121534 | -1.896240 | 0.910783 | 0.931110 | 2.834518 | 1.923482 | 44 | 3450 | 5695 | 713.0 | 6234 | 6 | 44.959813 | 31.204724 | 41.477199 | 41.596730 | 1.881916 | 1.212541 | 7.058858 | -3.228070 | 1.819648 | 1.574585 | 1.837949 | 1.556478 | 0.107752 | 0.000000 | 0.035158 | 0.037542 | 0.025606 | 0.021604 | 6.219816 | -7.326931 | -3.054149 | -2.526677 | 2.115925 | 1.770558 | 47.778230 | 30.829302 | 43.597325 | 43.353802 | 1.916984 | 1.307947 | 8.205395 | -5.001960 | 0.270870 | 0.294369 | 1.799809 | 1.443789 | 0.148567 | 0.0 | 0.053159 | 0.053869 | 0.025691 | 0.020953 | 11.369898 | -7.326931 | -1.042476 | -0.908179 | 2.047795 | 1.585105 | 47.293305 | 44.383511 | 45.893775 | 45.910207 | 1.072447 | 0.870473 | 0.349628 | -3.807374 | -2.217431 | -1.990793 | 1.367304 | 0.919992 | 0.119831 | 0.038455 | 0.096406 | 0.088482 | 0.027794 | 0.019896 | 4.434116 | -0.495203 | 2.236023 | 2.066459 | 1.752542 | 1.394655 | 660.0 | 2.0 | 0.264708 | 7 | 1.563337 | 0.105871 | 0.925666 | 1.0 | 579.0 | 6.284667 | 5 | 1.560710 | 0.166667 | 0.001403 | 1 | 3051 | 7.795699 | 4 | 1.242453 | 0.166667 | 0.000160 | 1573.425581 | 1536.318881 | 1246.169984 | 3297.186928 | 3349.492435 | 2318.576826 | 14166.333333 | 10302.790900 | 7822.444444 |
from sklearn.model_selection import KFold
from lightgbm.sklearn import LGBMRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
import time
oof = np.zeros(train_df.shape[0])
sub['price'] = 0
feat_imp_df = pd.DataFrame({
'feat': cols, 'imp': 0})
skf = KFold(n_splits=5, shuffle=True, random_state=2020)
### 参数可以重新设置一下,天池这里面太慢了所以随便跑跑
clf = LGBMRegressor(
learning_rate=0.1,
n_estimators=1000,
num_leaves=31,
subsample=0.8,
colsample_bytree=0.8,
random_state=2020,
metric=None
)
for i, (trn_idx, val_idx) in enumerate(skf.split(train_df, labels)):
print('--------------------- {} fold ---------------------'.format(i))
t = time.time()
trn_x, trn_y = train_df.iloc[trn_idx].reset_index(drop=True), labels[trn_idx]
val_x, val_y = train_df.iloc[val_idx].reset_index(drop=True), labels[val_idx]
clf.fit(
trn_x, trn_y,
eval_set=[(val_x, val_y)],
categorical_feature=cate_cols,
eval_metric='mae',
early_stopping_rounds=100,
verbose=200
)
feat_imp_df['imp'] += clf.feature_importances_ / skf.n_splits
oof[val_idx] = clf.predict(val_x)
sub['price'] += clf.predict(test_df) / skf.n_splits
print('val mse:', mean_squared_error(val_y, oof[val_idx]))
print('runtime: {}\n'.format(time.time() - t))
mae = mean_absolute_error(labels, oof)
mse = mean_squared_error(labels, oof)
print('cv mae:', mae)
print('cv mse:', mse)
print('sub mean:', sub['price'].mean())
# sub.to_csv('sub_{}_{}_{}.csv'.format(mae, mse, sub['price'].mean()), index=False)
--------------------- 0 fold ---------------------
Training until validation scores don't improve for 100 rounds.
[200] valid_0's l1: 677.099
[400] valid_0's l1: 656.464
[600] valid_0's l1: 641.186
[800] valid_0's l1: 630.216
[1000] valid_0's l1: 621.368
Did not meet early stopping. Best iteration is:
[1000] valid_0's l1: 621.368
val mse: 1925185.46164
runtime: 101.27237272262573
--------------------- 1 fold ---------------------
Training until validation scores don't improve for 100 rounds.
[200] valid_0's l1: 671.094
[400] valid_0's l1: 645.48
[600] valid_0's l1: 631.551
[800] valid_0's l1: 621.686
[1000] valid_0's l1: 614.845
Did not meet early stopping. Best iteration is:
[1000] valid_0's l1: 614.845
val mse: 1568075.45505
runtime: 102.32736563682556
--------------------- 2 fold ---------------------
Training until validation scores don't improve for 100 rounds.
[200] valid_0's l1: 678.367
[400] valid_0's l1: 654.512
[600] valid_0's l1: 644.471
[800] valid_0's l1: 631.984
[1000] valid_0's l1: 623.846
Did not meet early stopping. Best iteration is:
[1000] valid_0's l1: 623.846
val mse: 1939169.62885
runtime: 96.74052476882935
--------------------- 3 fold ---------------------
Training until validation scores don't improve for 100 rounds.
[200] valid_0's l1: 681.855
[400] valid_0's l1: 657.384
[600] valid_0's l1: 641.93
[800] valid_0's l1: 629.722
[1000] valid_0's l1: 622.407
Did not meet early stopping. Best iteration is:
[1000] valid_0's l1: 622.407
val mse: 1761196.44532
runtime: 101.86149859428406
--------------------- 4 fold ---------------------
Training until validation scores don't improve for 100 rounds.
[200] valid_0's l1: 674.189
[400] valid_0's l1: 649.298
[600] valid_0's l1: 633.508
[800] valid_0's l1: 627.868
[1000] valid_0's l1: 619.026
Did not meet early stopping. Best iteration is:
[1000] valid_0's l1: 619.026
val mse: 1545660.45939
runtime: 102.83581805229187
cv mae: 620.3181929
cv mse: 1747857.49005
sub mean: 5929.64175126
sns.distplot(labels, label='train', color='y', hist=False)
sns.distplot(oof, label='oof', color='g', hist=False)
sns.distplot(sub['price'], label='test', color='r', hist=False)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CY8iw5bj-1585058164635)(output_38_1.png)]
plt.figure(figsize=(15, 30))
feat_imp_df = feat_imp_df.sort_values('imp').reset_index(drop=True)
sns.barplot(x='imp', y='feat', data=feat_imp_df)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-w8YTTaKO-1585058164636)(output_39_1.png)]