数据集记录了泰坦尼克号一部分人的信息,以及其存活率等,众所周知,泰坦尼克号是一场海难,也就造成了人员信息难以调查,所以数据集中具有一些缺失的数据,数据集可以在 Kaggle 下载,也可以点击此处下载
pclass | survived | name | sex | age | sibsp | parch | ticket | fare | cabin | embarked | boat | body | home.dest | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | Allen, Miss. Elisabeth Walton | female | 29.0000 | 0 | 0 | 24160 | 211.3375 | B5 | S | 2 | NaN | St Louis, MO |
1 | 1 | 1 | Allison, Master. Hudson Trevor | male | 0.9167 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | 11 | NaN | Montreal, PQ / Chesterville, ON |
2 | 1 | 0 | Allison, Miss. Helen Loraine | female | 2.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
3 | 1 | 0 | Allison, Mr. Hudson Joshua Creighton | male | 30.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | 135.0 | Montreal, PQ / Chesterville, ON |
4 | 1 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | female | 25.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1304 | 3 | 0 | Zabour, Miss. Hileni | female | 14.5000 | 1 | 0 | 2665 | 14.4542 | NaN | C | NaN | 328.0 | NaN |
1305 | 3 | 0 | Zabour, Miss. Thamine | female | NaN | 1 | 0 | 2665 | 14.4542 | NaN | C | NaN | NaN | NaN |
1306 | 3 | 0 | Zakarian, Mr. Mapriededer | male | 26.5000 | 0 | 0 | 2656 | 7.2250 | NaN | C | NaN | 304.0 | NaN |
1307 | 3 | 0 | Zakarian, Mr. Ortin | male | 27.0000 | 0 | 0 | 2670 | 7.2250 | NaN | C | NaN | NaN | NaN |
1308 | 3 | 0 | Zimmerman, Mr. Leo | male | 29.0000 | 0 | 0 | 315082 | 7.8750 | NaN | S | NaN | NaN | NaN |
1309 rows × 14 columns
泰坦尼克标签说明
字段 | 字段说明 | 数据说明 |
---|---|---|
pclass | 舱等级 | 1头等舱、2二等舱、3三等舱 |
survival | 是否生存 | 0否、1是 |
name | 姓名 | None |
sex | 性别 | Female女性、male男性 |
age | 年龄 | None |
sibsp | siblings + parents | 兄弟姐妹或父母是否同船 |
parch | parents + children | parents父母、children孩子 |
ticked | 船票号码 | None |
fare | 船票费用 | None |
cabin | 舱位号码 | None |
embarked | 登船港口 | C=Cherbourg,Q=Queenstown,S=Southampton |
home.dest | 家、目的地 | home、destination |
利用了多元线性回归问题的预处理方式,在最后一层输出层划分为分类问题,有关多元线性回归与逻辑回归问题的原理,屈尊移驾这里
和那里,由于此问题属于逻辑回归中的二分类问题,所以激活函数选用了 Sigmod 函数,在上诉的“那里”连接中解释了 Sigmod 函数
在 data 文件夹下,新建一个 xls 文件,内容如下
int | char | float | |
---|---|---|---|
0 | 1 | NaN | 1.1 |
1 | 2 | NaN | 2.2 |
2 | 3 | NaN | 3.3 |
3 | 4 | d | NaN |
4 | 5 | e | NaN |
5 | 6 | NaN | 6.6 |
6 | 7 | NaN | 7.7 |
import pandas as pd
demo_data = pd.read_excel('./data/demo.xls')
demo_data
int | char | float | |
---|---|---|---|
0 | 1 | NaN | 1.1 |
1 | 2 | NaN | 2.2 |
2 | 3 | NaN | 3.3 |
3 | 4 | d | NaN |
4 | 5 | e | NaN |
5 | 6 | NaN | 6.6 |
6 | 7 | NaN | 7.7 |
isnull() 返回一个 bool 值的 dataframe
isnull().any() 判断特征列是否存在空值
isnull().sum() 获取特征列空值数目
demo_data.isnull()
int | char | float | |
---|---|---|---|
0 | False | True | False |
1 | False | True | False |
2 | False | True | False |
3 | False | False | True |
4 | False | False | True |
5 | False | True | False |
6 | False | True | False |
demo_data.isnull().any()
int False
char True
float True
dtype: bool
demo_data.isnull().sum()
int 0
char 5
float 2
dtype: int64
demo_data['char'] = demo_data['char'].fillna('A')
demo_data['float'] = demo_data['float'].fillna('5.5')
demo_data
int | char | float | |
---|---|---|---|
0 | 1 | A | 1.1 |
1 | 2 | A | 2.2 |
2 | 3 | A | 3.3 |
3 | 4 | d | 5.5 |
4 | 5 | e | 5.5 |
5 | 6 | A | 6.6 |
6 | 7 | A | 7.7 |
map 里面放入字典参数,将对应的键替换成其值,注意 map 需要将所有值全部替换,否则会报错
try:
demo_data['char'] = demo_data['char'].map({'A': 0, 'd': 1}).astype(int)
except ValueError:
print('Error rising:Cannot convert non-finite values (NA or inf) to integer')
finally:
demo_data['char'] = demo_data['char'].map({'A': 0, 'd': 1, 'e': 2}).astype(int)
print('Mapping finished')
Error rising:Cannot convert non-finite values (NA or inf) to integer
Mapping finished
demo_data
int | char | float | |
---|---|---|---|
0 | 1 | 0 | 1.1 |
1 | 2 | 0 | 2.2 |
2 | 3 | 0 | 3.3 |
3 | 4 | 1 | 5.5 |
4 | 5 | 2 | 5.5 |
5 | 6 | 0 | 6.6 |
6 | 7 | 0 | 7.7 |
sample 用于在原数据中提取数据,并进行洗牌操作,frac 代表提取的比例,为 1 表示 100 % 100% 100%
shuffle_data_1 = demo_data.sample(frac = 1)
shuffle_data_1
int | char | float | |
---|---|---|---|
1 | 2 | 0 | 2.2 |
0 | 1 | 0 | 1.1 |
2 | 3 | 0 | 3.3 |
3 | 4 | 1 | 5.5 |
5 | 6 | 0 | 6.6 |
4 | 5 | 2 | 5.5 |
6 | 7 | 0 | 7.7 |
shuffle_data_2 = demo_data.sample(frac = 5 / 7)
shuffle_data_2
int | char | float | |
---|---|---|---|
6 | 7 | 0 | 7.7 |
0 | 1 | 0 | 1.1 |
2 | 3 | 0 | 3.3 |
5 | 6 | 0 | 6.6 |
1 | 2 | 0 | 2.2 |
drop() 不改变原有数据,返回另一个 dataframe,使用 axis 可以指定行、列
demo_data = demo_data.drop(['char'], axis=1)
demo_data
int | float | |
---|---|---|
0 | 1 | 1.1 |
1 | 2 | 2.2 |
2 | 3 | 3.3 |
3 | 4 | 5.5 |
4 | 5 | 5.5 |
5 | 6 | 6.6 |
6 | 7 | 7.7 |
经过上面处理,demo_data 中 data_frame 全部变成数字,使用 .values 将 data_frame 转换为 ndarray
nd_array = demo_data.values
print('nd_array:\n', nd_array,
'\nnd_array type:', type(nd_array))
nd_array:
[[1 1.1]
[2 2.2]
[3 3.3]
[4 '5.5']
[5 '5.5']
[6 6.6]
[7 7.7]]
nd_array type:
import numpy
import pandas as pd
import tensorflow as tf
import urllib.request
from sklearn import preprocessing
import matplotlib.pyplot as plt
import os
import datetime
tf.__version__
'2.0.0'
data_url = 'http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic3.xls'
data_file = './data/titanic3.xls'
if not os.path.exists(data_file):
operation = urllib.request.urlretrieve(data_url, data_file)
print('downloading from %s' % data_url)
else:
print('titanic3.xls is exists in the data directory')
titanic3.xls is exists in the data directory
从数据摘要中发现 count 行的每一列数据不等,说明数据具有缺失项
dataframe = pd.read_excel(data_file)
dataframe.describe()
pclass | survived | age | sibsp | parch | fare | body | |
---|---|---|---|---|---|---|---|
count | 1309.000000 | 1309.000000 | 1046.000000 | 1309.000000 | 1309.000000 | 1308.000000 | 121.000000 |
mean | 2.294882 | 0.381971 | 29.881135 | 0.498854 | 0.385027 | 33.295479 | 160.809917 |
std | 0.837836 | 0.486055 | 14.413500 | 1.041658 | 0.865560 | 51.758668 | 97.696922 |
min | 1.000000 | 0.000000 | 0.166700 | 0.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 2.000000 | 0.000000 | 21.000000 | 0.000000 | 0.000000 | 7.895800 | 72.000000 |
50% | 3.000000 | 0.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 | 155.000000 |
75% | 3.000000 | 1.000000 | 39.000000 | 1.000000 | 0.000000 | 31.275000 | 256.000000 |
max | 3.000000 | 1.000000 | 80.000000 | 8.000000 | 9.000000 | 512.329200 | 328.000000 |
dataframe
pclass | survived | name | sex | age | sibsp | parch | ticket | fare | cabin | embarked | boat | body | home.dest | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | Allen, Miss. Elisabeth Walton | female | 29.0000 | 0 | 0 | 24160 | 211.3375 | B5 | S | 2 | NaN | St Louis, MO |
1 | 1 | 1 | Allison, Master. Hudson Trevor | male | 0.9167 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | 11 | NaN | Montreal, PQ / Chesterville, ON |
2 | 1 | 0 | Allison, Miss. Helen Loraine | female | 2.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
3 | 1 | 0 | Allison, Mr. Hudson Joshua Creighton | male | 30.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | 135.0 | Montreal, PQ / Chesterville, ON |
4 | 1 | 0 | Allison, Mrs. Hudson J C (Bessie Waldo Daniels) | female | 25.0000 | 1 | 2 | 113781 | 151.5500 | C22 C26 | S | NaN | NaN | Montreal, PQ / Chesterville, ON |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1304 | 3 | 0 | Zabour, Miss. Hileni | female | 14.5000 | 1 | 0 | 2665 | 14.4542 | NaN | C | NaN | 328.0 | NaN |
1305 | 3 | 0 | Zabour, Miss. Thamine | female | NaN | 1 | 0 | 2665 | 14.4542 | NaN | C | NaN | NaN | NaN |
1306 | 3 | 0 | Zakarian, Mr. Mapriededer | male | 26.5000 | 0 | 0 | 2656 | 7.2250 | NaN | C | NaN | 304.0 | NaN |
1307 | 3 | 0 | Zakarian, Mr. Ortin | male | 27.0000 | 0 | 0 | 2670 | 7.2250 | NaN | C | NaN | NaN | NaN |
1308 | 3 | 0 | Zimmerman, Mr. Leo | male | 29.0000 | 0 | 0 | 315082 | 7.8750 | NaN | S | NaN | NaN | NaN |
1309 rows × 14 columns
去掉了 ticked、cabin,将 age、fare,空值用其列均值代替,sex 用0,1代替,embarked 用 S 代替,在将其所在字符值转为数字
注意:你应使用 .copy() 函数来防止严重警告⚠
selected_dataframe = dataframe[selected_cols].copy()
ok
selected_dataframe = dataframe[selected_cols]
not recommand
selected_cols = ['survived', 'name', 'pclass', 'sex', 'age', 'sibsp', 'parch', 'fare', 'embarked']
selected_dataframe = dataframe[selected_cols].copy()
age_mean_value = selected_dataframe['age'].mean()
selected_dataframe['age'] = selected_dataframe['age'].fillna(age_mean_value)
fare_mean_value = selected_dataframe['fare'].mean()
selected_dataframe['fare'] = selected_dataframe['fare'].fillna(fare_mean_value)
selected_dataframe['embarked'] = selected_dataframe['embarked'].fillna('S')
selected_dataframe.describe()
survived | pclass | age | sibsp | parch | fare | |
---|---|---|---|---|---|---|
count | 1309.000000 | 1309.000000 | 1309.000000 | 1309.000000 | 1309.000000 | 1309.000000 |
mean | 0.381971 | 2.294882 | 29.881135 | 0.498854 | 0.385027 | 33.295479 |
std | 0.486055 | 0.837836 | 12.883199 | 1.041658 | 0.865560 | 51.738879 |
min | 0.000000 | 1.000000 | 0.166700 | 0.000000 | 0.000000 | 0.000000 |
25% | 0.000000 | 2.000000 | 22.000000 | 0.000000 | 0.000000 | 7.895800 |
50% | 0.000000 | 3.000000 | 29.881135 | 0.000000 | 0.000000 | 14.454200 |
75% | 1.000000 | 3.000000 | 35.000000 | 1.000000 | 0.000000 | 31.275000 |
max | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 9.000000 | 512.329200 |
selected_dataframe['sex'] = selected_dataframe['sex'].map({'female': 0, 'male': 1}).astype(int)
selected_dataframe['embarked'] = selected_dataframe['embarked'].map({'C': 0, 'Q': 1, 'S': 2}).astype(int)
删除 name 列
selected_dataframe = selected_dataframe.drop(['name'], axis=1)
selected_dataframe[:3]
survived | pclass | sex | age | sibsp | parch | fare | embarked | |
---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 0 | 29.0000 | 0 | 0 | 211.3375 | 2 |
1 | 1 | 1 | 1 | 0.9167 | 1 | 2 | 151.5500 | 2 |
2 | 0 | 1 | 0 | 2.0000 | 1 | 2 | 151.5500 | 2 |
features 代表特征,第 1 到最后一列
label 代表标签,第 0 列
ndarray_data = selected_dataframe.values
features = ndarray_data[:, 1:]
label = ndarray_data[:, 0]
print('features:\n', features,
'\nlabel:', label)
features:
[[ 1. 0. 29. ... 0. 211.3375 2. ]
[ 1. 1. 0.9167 ... 2. 151.55 2. ]
[ 1. 0. 2. ... 2. 151.55 2. ]
...
[ 3. 1. 26.5 ... 0. 7.225 0. ]
[ 3. 1. 27. ... 0. 7.225 0. ]
[ 3. 1. 29. ... 0. 7.875 2. ]]
label: [1. 1. 0. ... 0. 0. 0.]
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0, 1))
norm_features = minmax_scale.fit_transform(features)
print('norm_features:\n', norm_features,
'\nlabel:', label)
norm_features:
[[0. 0. 0.36116884 ... 0. 0.41250333 1. ]
[0. 1. 0.00939458 ... 0.22222222 0.2958059 1. ]
[0. 0. 0.0229641 ... 0.22222222 0.2958059 1. ]
...
[1. 1. 0.32985358 ... 0. 0.01410226 0. ]
[1. 1. 0.33611663 ... 0. 0.01410226 0. ]
[1. 1. 0.36116884 ... 0. 0.01537098 1. ]]
label: [1. 1. 0. ... 0. 0. 0.]
def prepare_data(df_data):
df = df_data.drop(['name'], axis=1)
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(fare_mean)
df['sex'] = df['sex'].map({'female':0, 'male':1}).astype(int)
df['embarked'] = df['embarked'].fillna('S')
df['embarked'] = df['embarked'].map({'C':0, 'Q':1, 'S':2}).astype(int)
ndarray_data = df.values
features = ndarray_data[:, 1:]
label = ndarray_data[:, 0]
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0, 1))
norm_features = minmax_scale.fit_transform(features)
return norm_features, label
dataframe = pd.read_excel('./data/titanic3.xls')
selected_cols= ['survived', 'name', 'pclass', 'sex', 'age', 'sibsp', 'parch', 'fare', 'embarked']
selected_dataframe = dataframe[selected_cols].copy()
selected_dataframe = selected_dataframe.sample(frac=1)
x_data, y_data = prepare_data(selected_dataframe)
train_size = int(len(x_data) * 0.8)
x_train = x_data[:train_size]
y_train = y_data[:train_size]
x_test = x_data[train_size:]
y_test = y_data[train_size:]
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=64,
input_dim=7,
use_bias=True,
kernel_initializer='uniform',
bias_initializer='zeros',
activation='relu'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=32, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=1, activation='sigmoid')
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 64) 512
_________________________________________________________________
dropout (Dropout) (None, 64) 0
_________________________________________________________________
dense_1 (Dense) (None, 32) 2080
_________________________________________________________________
dropout_1 (Dropout) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 1) 33
=================================================================
Total params: 2,625
Trainable params: 2,625
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(0.003),
loss='binary_crossentropy',
metrics=['accuracy'])
train_history = model.fit(
x=x_train,
y=y_train,
validation_split=0.2,
epochs=100,
batch_size=40,
verbose=1)
Train on 837 samples, validate on 210 samples
Epoch 1/100
837/837 [==============================] - 2s 2ms/sample - loss: 0.6968 - accuracy: 0.5412 - val_loss: 0.5988 - val_accuracy: 0.6571
Epoch 2/100
837/837 [==============================] - 0s 147us/sample - loss: 0.6342 - accuracy: 0.6404 - val_loss: 0.5504 - val_accuracy: 0.7571
Epoch 3/100
837/837 [==============================] - 0s 153us/sample - loss: 0.5555 - accuracy: 0.7276 - val_loss: 0.4904 - val_accuracy: 0.8143
Epoch 4/100
837/837 [==============================] - 0s 148us/sample - loss: 0.5088 - accuracy: 0.7766 - val_loss: 0.4638 - val_accuracy: 0.8095
Epoch 5/100
837/837 [==============================] - 0s 137us/sample - loss: 0.4948 - accuracy: 0.7814 - val_loss: 0.4526 - val_accuracy: 0.8095
Epoch 6/100
837/837 [==============================] - 0s 124us/sample - loss: 0.4832 - accuracy: 0.7897 - val_loss: 0.4526 - val_accuracy: 0.8000
Epoch 7/100
837/837 [==============================] - 0s 140us/sample - loss: 0.4672 - accuracy: 0.7861 - val_loss: 0.4508 - val_accuracy: 0.8000
Epoch 8/100
837/837 [==============================] - 0s 142us/sample - loss: 0.4707 - accuracy: 0.7897 - val_loss: 0.4431 - val_accuracy: 0.8190
Epoch 9/100
837/837 [==============================] - 0s 117us/sample - loss: 0.4760 - accuracy: 0.8005 - val_loss: 0.4452 - val_accuracy: 0.8000
Epoch 10/100
837/837 [==============================] - 0s 121us/sample - loss: 0.4568 - accuracy: 0.8017 - val_loss: 0.4414 - val_accuracy: 0.7952
Epoch 11/100
837/837 [==============================] - 0s 132us/sample - loss: 0.4533 - accuracy: 0.8100 - val_loss: 0.4473 - val_accuracy: 0.7952
Epoch 12/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4624 - accuracy: 0.7933 - val_loss: 0.4527 - val_accuracy: 0.8000
Epoch 13/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4452 - accuracy: 0.8088 - val_loss: 0.4455 - val_accuracy: 0.8000
Epoch 14/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4564 - accuracy: 0.7993 - val_loss: 0.4430 - val_accuracy: 0.7952
Epoch 15/100
837/837 [==============================] - 0s 161us/sample - loss: 0.4722 - accuracy: 0.8005 - val_loss: 0.4404 - val_accuracy: 0.8000
Epoch 16/100
837/837 [==============================] - 0s 141us/sample - loss: 0.4660 - accuracy: 0.8065 - val_loss: 0.4444 - val_accuracy: 0.8048
Epoch 17/100
837/837 [==============================] - 0s 124us/sample - loss: 0.4551 - accuracy: 0.8196 - val_loss: 0.4392 - val_accuracy: 0.8000
Epoch 18/100
837/837 [==============================] - 0s 118us/sample - loss: 0.4589 - accuracy: 0.8053 - val_loss: 0.4472 - val_accuracy: 0.8000
Epoch 19/100
837/837 [==============================] - 0s 116us/sample - loss: 0.4559 - accuracy: 0.8029 - val_loss: 0.4402 - val_accuracy: 0.8048
Epoch 20/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4479 - accuracy: 0.8124 - val_loss: 0.4398 - val_accuracy: 0.7952
Epoch 21/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4510 - accuracy: 0.8148 - val_loss: 0.4373 - val_accuracy: 0.8000
Epoch 22/100
837/837 [==============================] - 0s 126us/sample - loss: 0.4537 - accuracy: 0.8065 - val_loss: 0.4361 - val_accuracy: 0.8048
Epoch 23/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4591 - accuracy: 0.8088 - val_loss: 0.4423 - val_accuracy: 0.7905
Epoch 24/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4482 - accuracy: 0.8088 - val_loss: 0.4429 - val_accuracy: 0.7952
Epoch 25/100
837/837 [==============================] - 0s 121us/sample - loss: 0.4556 - accuracy: 0.8053 - val_loss: 0.4376 - val_accuracy: 0.8000
Epoch 26/100
837/837 [==============================] - 0s 128us/sample - loss: 0.4536 - accuracy: 0.8112 - val_loss: 0.4382 - val_accuracy: 0.8000
Epoch 27/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4537 - accuracy: 0.7969 - val_loss: 0.4483 - val_accuracy: 0.8000
Epoch 28/100
837/837 [==============================] - 0s 120us/sample - loss: 0.4420 - accuracy: 0.8148 - val_loss: 0.4442 - val_accuracy: 0.7857
Epoch 29/100
837/837 [==============================] - 0s 125us/sample - loss: 0.4462 - accuracy: 0.8053 - val_loss: 0.4371 - val_accuracy: 0.7905
Epoch 30/100
837/837 [==============================] - 0s 125us/sample - loss: 0.4550 - accuracy: 0.8124 - val_loss: 0.4387 - val_accuracy: 0.7905
Epoch 31/100
837/837 [==============================] - 0s 139us/sample - loss: 0.4421 - accuracy: 0.8088 - val_loss: 0.4406 - val_accuracy: 0.7857
Epoch 32/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4525 - accuracy: 0.8112 - val_loss: 0.4384 - val_accuracy: 0.7905
Epoch 33/100
837/837 [==============================] - 0s 126us/sample - loss: 0.4459 - accuracy: 0.8100 - val_loss: 0.4384 - val_accuracy: 0.7905
Epoch 34/100
837/837 [==============================] - 0s 133us/sample - loss: 0.4338 - accuracy: 0.8065 - val_loss: 0.4442 - val_accuracy: 0.7857
Epoch 35/100
837/837 [==============================] - 0s 143us/sample - loss: 0.4419 - accuracy: 0.8065 - val_loss: 0.4405 - val_accuracy: 0.7905
Epoch 36/100
837/837 [==============================] - 0s 137us/sample - loss: 0.4461 - accuracy: 0.8053 - val_loss: 0.4362 - val_accuracy: 0.7857
Epoch 37/100
837/837 [==============================] - 0s 118us/sample - loss: 0.4414 - accuracy: 0.8148 - val_loss: 0.4479 - val_accuracy: 0.7810
Epoch 38/100
837/837 [==============================] - 0s 120us/sample - loss: 0.4382 - accuracy: 0.8136 - val_loss: 0.4365 - val_accuracy: 0.7905
Epoch 39/100
837/837 [==============================] - 0s 125us/sample - loss: 0.4356 - accuracy: 0.8184 - val_loss: 0.4488 - val_accuracy: 0.7810
Epoch 40/100
837/837 [==============================] - 0s 143us/sample - loss: 0.4383 - accuracy: 0.8184 - val_loss: 0.4375 - val_accuracy: 0.7905
Epoch 41/100
837/837 [==============================] - 0s 150us/sample - loss: 0.4347 - accuracy: 0.8124 - val_loss: 0.4451 - val_accuracy: 0.7810
Epoch 42/100
837/837 [==============================] - 0s 143us/sample - loss: 0.4573 - accuracy: 0.8112 - val_loss: 0.4411 - val_accuracy: 0.7857
Epoch 43/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4358 - accuracy: 0.8136 - val_loss: 0.4379 - val_accuracy: 0.7905
Epoch 44/100
837/837 [==============================] - 0s 164us/sample - loss: 0.4453 - accuracy: 0.8160 - val_loss: 0.4458 - val_accuracy: 0.7810
Epoch 45/100
837/837 [==============================] - 0s 126us/sample - loss: 0.4401 - accuracy: 0.8076 - val_loss: 0.4405 - val_accuracy: 0.7952
Epoch 46/100
837/837 [==============================] - 0s 142us/sample - loss: 0.4364 - accuracy: 0.8160 - val_loss: 0.4465 - val_accuracy: 0.7810
Epoch 47/100
837/837 [==============================] - 0s 129us/sample - loss: 0.4311 - accuracy: 0.8184 - val_loss: 0.4386 - val_accuracy: 0.8000
Epoch 48/100
837/837 [==============================] - 0s 117us/sample - loss: 0.4377 - accuracy: 0.8160 - val_loss: 0.4444 - val_accuracy: 0.7857
Epoch 49/100
837/837 [==============================] - 0s 133us/sample - loss: 0.4546 - accuracy: 0.7957 - val_loss: 0.4432 - val_accuracy: 0.7810
Epoch 50/100
837/837 [==============================] - 0s 145us/sample - loss: 0.4403 - accuracy: 0.8208 - val_loss: 0.4436 - val_accuracy: 0.7905
Epoch 51/100
837/837 [==============================] - 0s 144us/sample - loss: 0.4259 - accuracy: 0.8148 - val_loss: 0.4374 - val_accuracy: 0.7952
Epoch 52/100
837/837 [==============================] - 0s 155us/sample - loss: 0.4300 - accuracy: 0.8160 - val_loss: 0.4411 - val_accuracy: 0.7857
Epoch 53/100
837/837 [==============================] - 0s 160us/sample - loss: 0.4381 - accuracy: 0.8136 - val_loss: 0.4432 - val_accuracy: 0.7905
Epoch 54/100
837/837 [==============================] - 0s 136us/sample - loss: 0.4290 - accuracy: 0.8256 - val_loss: 0.4414 - val_accuracy: 0.7810
Epoch 55/100
837/837 [==============================] - 0s 148us/sample - loss: 0.4360 - accuracy: 0.8160 - val_loss: 0.4385 - val_accuracy: 0.7952
Epoch 56/100
837/837 [==============================] - 0s 114us/sample - loss: 0.4364 - accuracy: 0.8232 - val_loss: 0.4415 - val_accuracy: 0.7810
Epoch 57/100
837/837 [==============================] - 0s 127us/sample - loss: 0.4364 - accuracy: 0.8076 - val_loss: 0.4397 - val_accuracy: 0.7952
Epoch 58/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4370 - accuracy: 0.8148 - val_loss: 0.4378 - val_accuracy: 0.7905
Epoch 59/100
837/837 [==============================] - 0s 139us/sample - loss: 0.4435 - accuracy: 0.8088 - val_loss: 0.4444 - val_accuracy: 0.7810
Epoch 60/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4354 - accuracy: 0.8172 - val_loss: 0.4372 - val_accuracy: 0.7952
Epoch 61/100
837/837 [==============================] - 0s 137us/sample - loss: 0.4375 - accuracy: 0.8112 - val_loss: 0.4420 - val_accuracy: 0.7857
Epoch 62/100
837/837 [==============================] - 0s 135us/sample - loss: 0.4307 - accuracy: 0.8136 - val_loss: 0.4392 - val_accuracy: 0.7905
Epoch 63/100
837/837 [==============================] - 0s 158us/sample - loss: 0.4362 - accuracy: 0.8160 - val_loss: 0.4406 - val_accuracy: 0.7952
Epoch 64/100
837/837 [==============================] - 0s 124us/sample - loss: 0.4405 - accuracy: 0.8124 - val_loss: 0.4440 - val_accuracy: 0.7810
Epoch 65/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4298 - accuracy: 0.8232 - val_loss: 0.4386 - val_accuracy: 0.8000
Epoch 66/100
837/837 [==============================] - 0s 108us/sample - loss: 0.4284 - accuracy: 0.8065 - val_loss: 0.4446 - val_accuracy: 0.7905
Epoch 67/100
837/837 [==============================] - 0s 129us/sample - loss: 0.4340 - accuracy: 0.8148 - val_loss: 0.4438 - val_accuracy: 0.7905
Epoch 68/100
837/837 [==============================] - 0s 114us/sample - loss: 0.4389 - accuracy: 0.8088 - val_loss: 0.4381 - val_accuracy: 0.8000
Epoch 69/100
837/837 [==============================] - 0s 124us/sample - loss: 0.4338 - accuracy: 0.8220 - val_loss: 0.4386 - val_accuracy: 0.8000
Epoch 70/100
837/837 [==============================] - 0s 115us/sample - loss: 0.4323 - accuracy: 0.8184 - val_loss: 0.4423 - val_accuracy: 0.7952
Epoch 71/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4237 - accuracy: 0.8232 - val_loss: 0.4406 - val_accuracy: 0.7905
Epoch 72/100
837/837 [==============================] - 0s 137us/sample - loss: 0.4355 - accuracy: 0.8208 - val_loss: 0.4437 - val_accuracy: 0.7905
Epoch 73/100
837/837 [==============================] - 0s 116us/sample - loss: 0.4362 - accuracy: 0.8196 - val_loss: 0.4364 - val_accuracy: 0.7905
Epoch 74/100
837/837 [==============================] - 0s 128us/sample - loss: 0.4293 - accuracy: 0.8196 - val_loss: 0.4445 - val_accuracy: 0.7810
Epoch 75/100
837/837 [==============================] - 0s 124us/sample - loss: 0.4252 - accuracy: 0.8184 - val_loss: 0.4400 - val_accuracy: 0.7905
Epoch 76/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4335 - accuracy: 0.8256 - val_loss: 0.4470 - val_accuracy: 0.7810
Epoch 77/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4284 - accuracy: 0.8184 - val_loss: 0.4384 - val_accuracy: 0.8000
Epoch 78/100
837/837 [==============================] - 0s 147us/sample - loss: 0.4398 - accuracy: 0.8136 - val_loss: 0.4412 - val_accuracy: 0.7905
Epoch 79/100
837/837 [==============================] - 0s 129us/sample - loss: 0.4339 - accuracy: 0.8160 - val_loss: 0.4454 - val_accuracy: 0.7810
Epoch 80/100
837/837 [==============================] - 0s 127us/sample - loss: 0.4286 - accuracy: 0.8160 - val_loss: 0.4397 - val_accuracy: 0.7905
Epoch 81/100
837/837 [==============================] - 0s 120us/sample - loss: 0.4315 - accuracy: 0.8220 - val_loss: 0.4393 - val_accuracy: 0.7905
Epoch 82/100
837/837 [==============================] - 0s 138us/sample - loss: 0.4263 - accuracy: 0.8184 - val_loss: 0.4415 - val_accuracy: 0.7905
Epoch 83/100
837/837 [==============================] - 0s 136us/sample - loss: 0.4298 - accuracy: 0.8208 - val_loss: 0.4405 - val_accuracy: 0.8048
Epoch 84/100
837/837 [==============================] - 0s 129us/sample - loss: 0.4341 - accuracy: 0.8112 - val_loss: 0.4377 - val_accuracy: 0.7952
Epoch 85/100
837/837 [==============================] - 0s 127us/sample - loss: 0.4325 - accuracy: 0.8100 - val_loss: 0.4432 - val_accuracy: 0.8000
Epoch 86/100
837/837 [==============================] - 0s 132us/sample - loss: 0.4277 - accuracy: 0.8124 - val_loss: 0.4415 - val_accuracy: 0.7857
Epoch 87/100
837/837 [==============================] - 0s 113us/sample - loss: 0.4274 - accuracy: 0.8196 - val_loss: 0.4427 - val_accuracy: 0.7905
Epoch 88/100
837/837 [==============================] - 0s 112us/sample - loss: 0.4243 - accuracy: 0.8280 - val_loss: 0.4400 - val_accuracy: 0.7905
Epoch 89/100
837/837 [==============================] - 0s 118us/sample - loss: 0.4280 - accuracy: 0.8220 - val_loss: 0.4418 - val_accuracy: 0.7952
Epoch 90/100
837/837 [==============================] - 0s 112us/sample - loss: 0.4340 - accuracy: 0.8208 - val_loss: 0.4409 - val_accuracy: 0.8000
Epoch 91/100
837/837 [==============================] - 0s 121us/sample - loss: 0.4298 - accuracy: 0.8136 - val_loss: 0.4403 - val_accuracy: 0.8000
Epoch 92/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4275 - accuracy: 0.8208 - val_loss: 0.4409 - val_accuracy: 0.7952
Epoch 93/100
837/837 [==============================] - 0s 136us/sample - loss: 0.4228 - accuracy: 0.8244 - val_loss: 0.4394 - val_accuracy: 0.8095
Epoch 94/100
837/837 [==============================] - 0s 119us/sample - loss: 0.4313 - accuracy: 0.8208 - val_loss: 0.4434 - val_accuracy: 0.8000
Epoch 95/100
837/837 [==============================] - 0s 130us/sample - loss: 0.4277 - accuracy: 0.8196 - val_loss: 0.4365 - val_accuracy: 0.8095
Epoch 96/100
837/837 [==============================] - 0s 118us/sample - loss: 0.4273 - accuracy: 0.8220 - val_loss: 0.4383 - val_accuracy: 0.8000
Epoch 97/100
837/837 [==============================] - 0s 113us/sample - loss: 0.4311 - accuracy: 0.8124 - val_loss: 0.4373 - val_accuracy: 0.8095
Epoch 98/100
837/837 [==============================] - 0s 123us/sample - loss: 0.4221 - accuracy: 0.8327 - val_loss: 0.4419 - val_accuracy: 0.7952
Epoch 99/100
837/837 [==============================] - 0s 124us/sample - loss: 0.4378 - accuracy: 0.8196 - val_loss: 0.4380 - val_accuracy: 0.8048
Epoch 100/100
837/837 [==============================] - 0s 122us/sample - loss: 0.4238 - accuracy: 0.8232 - val_loss: 0.4451 - val_accuracy: 0.7857
fig = plt.gcf()
fig.set_size_inches(10, 5)
ax1 = fig.add_subplot(111)
ax1.set_title('Train and Validation Picture')
ax1.set_ylabel('Loss value')
line1, = ax1.plot(train_history.history['loss'], color=(0.5, 0.5, 1.0), label='Loss train')
line2, = ax1.plot(train_history.history['val_loss'], color=(0.5, 1.0, 0.5), label='Loss valid')
ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy value')
line3, = ax2.plot(train_history.history['accuracy'], color=(0.5, 0.5, 0.5), label='Accuracy train')
line4, = ax2.plot(train_history.history['val_accuracy'], color=(1, 0, 0), label='Accuracy valid')
plt.legend(handles=(line1, line2, line3, line4), loc='best')
plt.show()
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print('test_loss:', test_loss,
'\ntest_acc:', test_acc,
'\nmetrics_names:', model.metrics_names)
262/1 - 0s - loss: 0.3581 - accuracy: 0.7672
test_loss: 0.48995060536242624
test_acc: 0.76717556
metrics_names: ['loss', 'accuracy']
Jack_info = [0, 'Jack', 3, 'male', 23, 1, 0, 5.0000, 'S']
Rose_info = [1, 'Rose', 1, 'female', 20, 1, 0, 100.0000, 'S']
new_passenger_pd = pd.DataFrame([Jack_info, Rose_info], columns=selected_cols)
all_passenger_pd = selected_dataframe.append(new_passenger_pd)
pred = model.predict(prepare_data(all_passenger_pd)[0])
print('Rose survived probability:', pred[-1:][0][0],
'\nJack survived probability:', pred[-2:][0][0])
Rose survived probability: 0.96711206
Jack survived probability: 0.12514974
all_passenger_pd.insert(len(all_passenger_pd.columns), 'surv_prob', pred)
all_passenger_pd
survived | name | pclass | sex | age | sibsp | parch | fare | embarked | surv_prob | |
---|---|---|---|---|---|---|---|---|---|---|
75 | 0 | Colley, Mr. Edward Pomeroy | 1 | male | 47.0 | 0 | 0 | 25.5875 | S | 0.221973 |
321 | 0 | Wright, Mr. George | 1 | male | 62.0 | 0 | 0 | 26.5500 | S | 0.194789 |
712 | 0 | Celotti, Mr. Francesco | 3 | male | 24.0 | 0 | 0 | 8.0500 | S | 0.130646 |
345 | 0 | Berriman, Mr. William John | 2 | male | 23.0 | 0 | 0 | 13.0000 | S | 0.211936 |
1298 | 0 | Wittevrongel, Mr. Camille | 3 | male | 36.0 | 0 | 0 | 9.5000 | S | 0.106074 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
10 | 0 | Astor, Col. John Jacob | 1 | male | 47.0 | 1 | 0 | 227.5250 | C | 0.213140 |
434 | 1 | Hart, Miss. Eva Miriam | 2 | female | 7.0 | 0 | 2 | 26.2500 | S | 0.881991 |
690 | 0 | Brobeck, Mr. Karl Rudolf | 3 | male | 22.0 | 0 | 0 | 7.7958 | S | 0.136104 |
0 | 0 | Jack | 3 | male | 23.0 | 1 | 0 | 5.0000 | S | 0.125150 |
1 | 1 | Rose | 1 | female | 20.0 | 1 | 0 | 100.0000 | S | 0.967112 |
1311 rows × 10 columns
form =pd.DataFrame(columns=[column for column in all_passenger_pd], data=all_passenger_pd)
form.to_excel('./data/result.xls', encoding='utf-8', index=None, header=True)
def prepare_data(df_data):
df = df_data.drop(['name'], axis=1)
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(fare_mean)
df['sex'] = df['sex'].map({'female':0, 'male':1}).astype(int)
df['embarked'] = df['embarked'].fillna('S')
df['embarked'] = df['embarked'].map({'C':0, 'Q':1, 'S':2}).astype(int)
ndarray_data = df.values
features = ndarray_data[:, 1:]
label = ndarray_data[:, 0]
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0, 1))
norm_features = minmax_scale.fit_transform(features)
return norm_features, label
dataframe = pd.read_excel('./data/titanic3.xls')
selected_cols= ['survived', 'name', 'pclass', 'sex', 'age', 'sibsp', 'parch', 'fare', 'embarked']
selected_dataframe = dataframe[selected_cols].copy()
selected_dataframe = selected_dataframe.sample(frac=1)
x_data, y_data = prepare_data(selected_dataframe)
train_size = int(len(x_data) * 0.8)
x_train = x_data[:train_size]
y_train = y_data[:train_size]
x_test = x_data[train_size:]
y_test = y_data[train_size:]
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=256,
input_dim=7,
use_bias=True,
kernel_initializer='uniform',
bias_initializer='zeros',
activation='relu'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=128, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=64, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=32, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=1, activation='sigmoid')
])
model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_3 (Dense) (None, 256) 2048
_________________________________________________________________
dropout_2 (Dropout) (None, 256) 0
_________________________________________________________________
dense_4 (Dense) (None, 128) 32896
_________________________________________________________________
dropout_3 (Dropout) (None, 128) 0
_________________________________________________________________
dense_5 (Dense) (None, 64) 8256
_________________________________________________________________
dropout_4 (Dropout) (None, 64) 0
_________________________________________________________________
dense_6 (Dense) (None, 32) 2080
_________________________________________________________________
dropout_5 (Dropout) (None, 32) 0
_________________________________________________________________
dense_7 (Dense) (None, 1) 33
=================================================================
Total params: 45,313
Trainable params: 45,313
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(0.003),
loss='binary_crossentropy',
metrics=['accuracy'])
log_dir = os.path.join(
'logs2.x',
'train',
'plugins',
'profile',
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
checkpoint_path = './checkpoint2.x/Titanic.{epoch:02d}.ckpt'
if not os.path.exists('./checkpoint2.x'):
os.mkdir('./checkpoint2.x')
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=log_dir,
histogram_freq=2),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1,
period=5)]
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.
train_history = model.fit(x=x_train, y=y_train,
validation_split=0.2,
epochs=100,
batch_size=40,
callbacks=callbacks,
verbose=1)
部分训练如下
Train on 837 samples, validate on 210 samples
Epoch 80/100
760/837 [==========================>...] - ETA: 0s - loss: 0.4277 - accuracy: 0.8132
Epoch 00080: saving model to ./checkpoint2.x/Titanic.80.h5
837/837 [==============================] - 0s 302us/sample - loss: 0.4353 - accuracy: 0.8112 - val_loss: 0.4639 - val_accuracy: 0.7810
Epoch 81/100
837/837 [==============================] - 0s 270us/sample - loss: 0.4455 - accuracy: 0.8017 - val_loss: 0.4768 - val_accuracy: 0.7810
Epoch 82/100
837/837 [==============================] - 0s 211us/sample - loss: 0.4376 - accuracy: 0.7993 - val_loss: 0.4654 - val_accuracy: 0.7905
Epoch 83/100
837/837 [==============================] - 0s 278us/sample - loss: 0.4377 - accuracy: 0.8065 - val_loss: 0.4703 - val_accuracy: 0.7810
Epoch 84/100
837/837 [==============================] - 0s 232us/sample - loss: 0.4368 - accuracy: 0.8160 - val_loss: 0.4631 - val_accuracy: 0.7952
Epoch 85/100
360/837 [===========>..................] - ETA: 0s - loss: 0.4669 - accuracy: 0.8056
Epoch 00085: saving model to ./checkpoint2.x/Titanic.85.h5
837/837 [==============================] - 0s 292us/sample - loss: 0.4437 - accuracy: 0.8124 - val_loss: 0.4627 - val_accuracy: 0.7810
Epoch 86/100
837/837 [==============================] - 0s 197us/sample - loss: 0.4365 - accuracy: 0.8017 - val_loss: 0.4686 - val_accuracy: 0.7905
Epoch 87/100
837/837 [==============================] - 0s 288us/sample - loss: 0.4500 - accuracy: 0.8148 - val_loss: 0.4689 - val_accuracy: 0.7857
Epoch 88/100
837/837 [==============================] - 0s 208us/sample - loss: 0.4356 - accuracy: 0.8029 - val_loss: 0.4794 - val_accuracy: 0.7905
Epoch 89/100
837/837 [==============================] - 0s 239us/sample - loss: 0.4283 - accuracy: 0.8148 - val_loss: 0.4621 - val_accuracy: 0.7857
Epoch 90/100
440/837 [==============>...............] - ETA: 0s - loss: 0.4083 - accuracy: 0.8295
Epoch 00090: saving model to ./checkpoint2.x/Titanic.90.h5
837/837 [==============================] - 0s 258us/sample - loss: 0.4359 - accuracy: 0.8172 - val_loss: 0.4736 - val_accuracy: 0.7905
Epoch 91/100
837/837 [==============================] - 0s 299us/sample - loss: 0.4365 - accuracy: 0.8053 - val_loss: 0.4658 - val_accuracy: 0.7905
Epoch 92/100
837/837 [==============================] - 0s 319us/sample - loss: 0.4376 - accuracy: 0.8148 - val_loss: 0.4696 - val_accuracy: 0.7905
Epoch 93/100
837/837 [==============================] - 0s 355us/sample - loss: 0.4375 - accuracy: 0.8005 - val_loss: 0.4698 - val_accuracy: 0.7952
Epoch 94/100
837/837 [==============================] - 0s 205us/sample - loss: 0.4384 - accuracy: 0.8005 - val_loss: 0.4682 - val_accuracy: 0.7905
Epoch 95/100
440/837 [==============>...............] - ETA: 0s - loss: 0.4514 - accuracy: 0.7909
Epoch 00095: saving model to ./checkpoint2.x/Titanic.95.h5
837/837 [==============================] - 0s 344us/sample - loss: 0.4392 - accuracy: 0.8005 - val_loss: 0.4620 - val_accuracy: 0.7952
Epoch 96/100
837/837 [==============================] - 0s 219us/sample - loss: 0.4347 - accuracy: 0.8053 - val_loss: 0.4643 - val_accuracy: 0.7857
Epoch 97/100
837/837 [==============================] - 0s 309us/sample - loss: 0.4410 - accuracy: 0.8005 - val_loss: 0.4772 - val_accuracy: 0.7905
Epoch 98/100
837/837 [==============================] - 0s 230us/sample - loss: 0.4325 - accuracy: 0.8076 - val_loss: 0.4629 - val_accuracy: 0.7857
Epoch 99/100
837/837 [==============================] - 0s 267us/sample - loss: 0.4308 - accuracy: 0.8005 - val_loss: 0.4658 - val_accuracy: 0.7857
Epoch 100/100
360/837 [===========>..................] - ETA: 0s - loss: 0.4338 - accuracy: 0.8139
Epoch 00100: saving model to ./checkpoint2.x/Titanic.100.h5
837/837 [==============================] - 0s 265us/sample - loss: 0.4314 - accuracy: 0.8124 - val_loss: 0.4623 - val_accuracy: 0.7857
fig = plt.gcf()
fig.set_size_inches(10, 5)
ax1 = fig.add_subplot(111)
ax1.set_title('Train and Validation Picture')
ax1.set_ylabel('Loss value')
line1, = ax1.plot(train_history.history['loss'], color=(0.5, 0.5, 1.0), label='Loss train')
line2, = ax1.plot(train_history.history['val_loss'], color=(0.5, 1.0, 0.5), label='Loss valid')
ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy value')
line3, = ax2.plot(train_history.history['accuracy'], color=(0.5, 0.5, 0.5), label='Accuracy train')
line4, = ax2.plot(train_history.history['val_accuracy'], color=(1, 0, 0), label='Accuracy valid')
plt.legend(handles=(line1, line2, line3, line4), loc='best')
plt.show()
Jack_info = [0, 'Jack', 3, 'male', 23, 1, 0, 5.0000, 'S']
Rose_info = [1, 'Rose', 1, 'female', 20, 1, 0, 100.0000, 'S']
new_passenger_pd = pd.DataFrame([Jack_info, Rose_info], columns=selected_cols)
all_passenger_pd = selected_dataframe.append(new_passenger_pd)
pred = model.predict(prepare_data(all_passenger_pd)[0])
print('Rose survived probability:', pred[-1:][0][0],
'\nJack survived probability:', pred[-2:][0][0])
Rose survived probability: 0.9700622
Jack survived probability: 0.12726058
由于只保存了网络参数,没有保存网络结构,需要重新定义网络结构(当然,由于 jupyter 的缓存效应,你大可不必重新定义,对于独立的 py 文件则需要这么做)
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=256,
input_dim=7,
use_bias=True,
kernel_initializer='uniform',
bias_initializer='zeros',
activation='relu'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=128, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=64, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=32, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=1, activation='sigmoid')
])
model.load_weights('./checkpoint2.x/Titanic.100.h5')
model.compile(optimizer=tf.keras.optimizers.Adam(0.003),
loss='binary_crossentropy',
metrics=['accuracy'])
loss, acc = model.evaluate(x_test, y_test, verbose=2)
print('Restore model accuracy:{:5.4f}%'.format(100 * acc))
262/1 - 0s - loss: 0.5042 - accuracy: 0.8511
Restore model accuracy:85.1145%
import numpy
import pandas as pd
import tensorflow as tf
import urllib.request
from sklearn import preprocessing
import matplotlib.pyplot as plt
import os
import datetime
tf.__version__
'1.15.2'
def prepare_data(df_data):
df = df_data.drop(['name'], axis=1)
age_mean = df['age'].mean()
df['age'] = df['age'].fillna(age_mean)
fare_mean = df['fare'].mean()
df['fare'] = df['fare'].fillna(fare_mean)
df['sex'] = df['sex'].map({'female':0, 'male':1}).astype(int)
df['embarked'] = df['embarked'].fillna('S')
df['embarked'] = df['embarked'].map({'C':0, 'Q':1, 'S':2}).astype(int)
ndarray_data = df.values
features = ndarray_data[:, 1:]
label = ndarray_data[:, 0]
minmax_scale = preprocessing.MinMaxScaler(feature_range=(0, 1))
norm_features = minmax_scale.fit_transform(features)
return norm_features, label
dataframe = pd.read_excel('./data/titanic3.xls')
selected_cols= ['survived', 'name', 'pclass', 'sex', 'age', 'sibsp', 'parch', 'fare', 'embarked']
selected_dataframe = dataframe[selected_cols].copy()
selected_dataframe = selected_dataframe.sample(frac=1)
x_data, y_data = prepare_data(selected_dataframe)
train_size = int(len(x_data) * 0.8)
x_train = x_data[:train_size]
y_train = y_data[:train_size]
x_test = x_data[train_size:]
y_test = y_data[train_size:]
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=64,
input_dim=7,
use_bias=True,
kernel_initializer='uniform',
bias_initializer='zeros',
activation='relu'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=32, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=1, activation='sigmoid')
])
WARNING:tensorflow:From e:\anaconda3\envs\tensorflow1.x\lib\site-packages\tensorflow_core\python\keras\initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From e:\anaconda3\envs\tensorflow1.x\lib\site-packages\tensorflow_core\python\ops\resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 64) 512
_________________________________________________________________
dropout (Dropout) (None, 64) 0
_________________________________________________________________
dense_1 (Dense) (None, 32) 2080
_________________________________________________________________
dropout_1 (Dropout) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 1) 33
=================================================================
Total params: 2,625
Trainable params: 2,625
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(0.003),
loss='binary_crossentropy',
metrics=['accuracy'])
WARNING:tensorflow:From e:\anaconda3\envs\tensorflow1.x\lib\site-packages\tensorflow_core\python\ops\nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
log_dir = os.path.join(
'logs1.x',
'train',
'plugins',
'profile',
datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
checkpoint_path = './checkpoint1.x/Titanic_{epoch:02d}-{val_loss:.2f}.ckpt'
callbacks = [tf.keras.callbacks.TensorBoard(log_dir=log_dir,
histogram_freq=2),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1,
period=5)]
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.
train_history = model.fit(x=x_train,
y=y_train,
validation_split=0.2,
epochs=100,
batch_size=40,
callbacks=callbacks,
verbose=2)
部分训练如下
Train on 837 samples, validate on 210 samples
837/837 - 0s - loss: 0.4398 - acc: 0.8124 - val_loss: 0.4671 - val_acc: 0.7857
Epoch 80/100
Epoch 00080: saving model to ./checkpoint1.x/Titanic_80-0.47.ckpt
837/837 - 0s - loss: 0.4360 - acc: 0.8076 - val_loss: 0.4673 - val_acc: 0.7857
Epoch 81/100
837/837 - 0s - loss: 0.4307 - acc: 0.8005 - val_loss: 0.4703 - val_acc: 0.7905
Epoch 82/100
837/837 - 0s - loss: 0.4401 - acc: 0.7981 - val_loss: 0.4666 - val_acc: 0.8000
Epoch 83/100
837/837 - 0s - loss: 0.4311 - acc: 0.8017 - val_loss: 0.4678 - val_acc: 0.7952
Epoch 84/100
837/837 - 0s - loss: 0.4296 - acc: 0.8172 - val_loss: 0.4673 - val_acc: 0.8000
Epoch 85/100
Epoch 00085: saving model to ./checkpoint1.x/Titanic_85-0.46.ckpt
837/837 - 0s - loss: 0.4384 - acc: 0.8029 - val_loss: 0.4634 - val_acc: 0.7857
Epoch 86/100
837/837 - 0s - loss: 0.4345 - acc: 0.8076 - val_loss: 0.4666 - val_acc: 0.7905
Epoch 87/100
837/837 - 0s - loss: 0.4307 - acc: 0.8053 - val_loss: 0.4650 - val_acc: 0.8000
Epoch 88/100
837/837 - 0s - loss: 0.4394 - acc: 0.8148 - val_loss: 0.4638 - val_acc: 0.8000
Epoch 89/100
837/837 - 0s - loss: 0.4355 - acc: 0.8053 - val_loss: 0.4648 - val_acc: 0.8000
Epoch 90/100
Epoch 00090: saving model to ./checkpoint1.x/Titanic_90-0.46.ckpt
837/837 - 0s - loss: 0.4326 - acc: 0.8100 - val_loss: 0.4623 - val_acc: 0.8000
Epoch 91/100
837/837 - 0s - loss: 0.4387 - acc: 0.8029 - val_loss: 0.4658 - val_acc: 0.7905
Epoch 92/100
837/837 - 0s - loss: 0.4285 - acc: 0.8065 - val_loss: 0.4613 - val_acc: 0.7905
Epoch 93/100
837/837 - 0s - loss: 0.4355 - acc: 0.8088 - val_loss: 0.4656 - val_acc: 0.7905
Epoch 94/100
837/837 - 0s - loss: 0.4318 - acc: 0.8136 - val_loss: 0.4629 - val_acc: 0.7952
Epoch 95/100
Epoch 00095: saving model to ./checkpoint1.x/Titanic_95-0.46.ckpt
837/837 - 0s - loss: 0.4386 - acc: 0.7981 - val_loss: 0.4639 - val_acc: 0.8000
Epoch 96/100
837/837 - 0s - loss: 0.4346 - acc: 0.8041 - val_loss: 0.4647 - val_acc: 0.7857
Epoch 97/100
837/837 - 0s - loss: 0.4256 - acc: 0.8160 - val_loss: 0.4608 - val_acc: 0.8048
Epoch 98/100
837/837 - 0s - loss: 0.4357 - acc: 0.8029 - val_loss: 0.4613 - val_acc: 0.8000
Epoch 99/100
837/837 - 0s - loss: 0.4265 - acc: 0.8041 - val_loss: 0.4614 - val_acc: 0.7952
Epoch 100/100
Epoch 00100: saving model to ./checkpoint1.x/Titanic_100-0.46.ckpt
837/837 - 0s - loss: 0.4243 - acc: 0.8148 - val_loss: 0.4611 - val_acc: 0.8000
fig = plt.gcf()
fig.set_size_inches(10, 5)
ax1 = fig.add_subplot(111)
ax1.set_title('Train and Validation Picture')
ax1.set_ylabel('Loss value')
line1, = ax1.plot(train_history.history['loss'], color=(0.5, 0.5, 1.0), label='Loss train')
line2, = ax1.plot(train_history.history['val_loss'], color=(0.5, 1.0, 0.5), label='Loss valid')
ax2 = ax1.twinx()
ax2.set_ylabel('Accuracy value')
line3, = ax2.plot(train_history.history['acc'], color=(0.5, 0.5, 0.5), label='Accuracy train')
line4, = ax2.plot(train_history.history['val_acc'], color=(1, 0, 0), label='Accuracy valid')
plt.legend(handles=(line1, line2, line3, line4), loc='best')
plt.show()
Jack_info = [0, 'Jack', 3, 'male', 23, 1, 0, 5.0000, 'S']
Rose_info = [1, 'Rose', 1, 'female', 20, 1, 0, 100.0000, 'S']
new_passenger_pd = pd.DataFrame([Jack_info, Rose_info], columns=selected_cols)
all_passenger_pd = selected_dataframe.append(new_passenger_pd)
pred = model.predict(prepare_data(all_passenger_pd)[0])
print('Rose survived probability:', pred[-1:][0][0],
'\nJack survived probability:', pred[-2:][0][0])
Rose survived probability: 0.9762004
Jack survived probability: 0.10789904
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=64,
input_dim=7,
use_bias=True,
kernel_initializer='uniform',
bias_initializer='zeros',
activation='relu'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=32, activation='sigmoid'),
tf.keras.layers.Dropout(rate=0.3),
tf.keras.layers.Dense(units=1, activation='sigmoid')
])
model.compile(optimizer=tf.keras.optimizers.Adam(0.003),
loss='binary_crossentropy',
metrics=['accuracy'])
checkpoint_dir = os.path.dirname(checkpoint_path)
latest = tf.train.latest_checkpoint(checkpoint_dir)
model.load_weights(latest)
loss, acc = model.evaluate(x_test, y_test)
262/262 [==============================] - 0s 244us/sample - loss: 0.4393 - acc: 0.7977