010-RNN-心脏病预测

  • 本文为365天深度学习训练营 中的学习记录博客
  • 参考文章地址:深度学习100例-循环神经网络(RNN)心脏病预测 | 第46天
  • 作者:K同学啊

难度:新手入门⭐

要求:

本地读取并加载数据。
了解循环神经网络(RNN)的构建过程
测试集accuracy到达87%

拔高:

测试集accuracy到达89%

import tensorflow        as tf

# 检查是否有GPU可用
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
gpus
[]

数据介绍

  • age:1) 年龄
  • sex:2) 性别
  • cp:3) 胸痛类型 (4 values)
  • trestbps:4) 静息血压
  • chol:5) 血清胆甾醇 (mg/dl
  • fbs:6) 空腹血糖 > 120 mg/dl
  • restecg:7) 静息心电图结果 (值 0,1 ,2)
  • thalach:8) 达到的最大心率
  • exang:9) 运动诱发的心绞痛
  • oldpeak:10) 相对于静止状态,运动引起的ST段压低
  • slope:11) 运动峰值 ST 段的斜率
  • ca:12) 荧光透视着色的主要血管数量 (0-3)
  • thal:13) 0 = 正常;1 = 固定缺陷;2 = 可逆转的缺陷
  • target:14) 0 = 心脏病发作的几率较小 1 = 心脏病发作的几率更大
import pandas as pd
import numpy as np

# 导入数据
df = pd.read_csv(r"D:\workspace_anaconda\data\heart.csv")
df
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 63 1 3 145 233 1 0 150 0 2.3 0 0 1 1
1 37 1 2 130 250 0 1 187 0 3.5 0 0 2 1
2 41 0 1 130 204 0 0 172 0 1.4 2 0 2 1
3 56 1 1 120 236 0 1 178 0 0.8 2 0 2 1
4 57 0 0 120 354 0 1 163 1 0.6 2 0 2 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
298 57 0 0 140 241 0 1 123 1 0.2 1 0 3 0
299 45 1 3 110 264 0 1 132 0 1.2 1 0 3 0
300 68 1 0 144 193 1 1 141 0 3.4 1 2 3 0
301 57 1 0 130 131 0 1 115 1 1.2 1 1 3 0
302 57 0 1 130 236 0 0 174 0 0.0 1 1 2 0

303 rows × 14 columns

# 检查是否有空值
df.isnull().sum()
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64

测试集与验证集的关系:

  1. 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  2. 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
  3. 我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集。
# 数据预处理
# 1、划分训练集与测试集
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

X = df.iloc[:,:-1]
y = df.iloc[:,-1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1, random_state = 1)
X_train.shape, y_train.shape

# 2、标准化
# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc      = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test  = sc.transform(X_test)

X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test  = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

SimpleRNN函数

tf.keras.layers.SimpleRNN(
units,activation=‘tanh’,
use_bias=True,
kernel_initializer=‘glorot_uniform’,
recurrent_initializer=‘orthogonal’,
bias_initializer=‘zeros’,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.0,
recurrent_dropout=0.0,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs)

  • units: 正整数,输出空间的维度。
  • activation: 要使用的激活函数。 默认:双曲正切(tanh)。 如果传入 None,则不使用激活函数 (即 线性激活:a(x) = x)。
  • use_bias: 布尔值,该层是否使用偏置向量。
  • kernel_initializer: kernel 权值矩阵的初始化器, 用于输入的线性转换 (详见 initializers)。
  • recurrent_initializer: recurrent_kernel 权值矩阵 的初始化器,用于循环层状态的线性转换 (详见 initializers)。
  • bias_initializer:偏置向量的初始化器 (详见initializers).
  • dropout: 在 0 和 1 之间的浮点数。 单元的丢弃比例,用于输入的线性转换
# 构建RNN模型
import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN

model = Sequential()
model.add(SimpleRNN(128, input_shape= (13,1),return_sequences=True,activation='relu'))
model.add(SimpleRNN(64,return_sequences=True, activation='relu'))
model.add(SimpleRNN(32,return_sequences=True,  activation='relu'))
model.add(SimpleRNN(16, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 simple_rnn_7 (SimpleRNN)    (None, 13, 128)           16640     
                                                                 
 simple_rnn_8 (SimpleRNN)    (None, 13, 64)            12352     
                                                                 
 simple_rnn_9 (SimpleRNN)    (None, 13, 32)            3104      
                                                                 
 simple_rnn_10 (SimpleRNN)   (None, 16)                784       
                                                                 
 dense_6 (Dense)             (None, 16)                272       
                                                                 
 dense_7 (Dense)             (None, 1)                 17        
                                                                 
=================================================================
Total params: 33,169
Trainable params: 33,169
Non-trainable params: 0
_________________________________________________________________
# 编译模型
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(loss='binary_crossentropy',
              optimizer=opt,
              metrics="accuracy")
# 训练模型
epochs = 100

history = model.fit(X_train, y_train, 
                    epochs=epochs, 
                    batch_size=128, 
                    validation_data=(X_test, y_test),
                    verbose=1)
Epoch 1/100
3/3 [==============================] - 2s 167ms/step - loss: 0.6832 - accuracy: 0.5551 - val_loss: 0.6836 - val_accuracy: 0.5484
Epoch 2/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6732 - accuracy: 0.5625 - val_loss: 0.6714 - val_accuracy: 0.5484
Epoch 3/100
3/3 [==============================] - 0s 17ms/step - loss: 0.6646 - accuracy: 0.5772 - val_loss: 0.6611 - val_accuracy: 0.6452
Epoch 4/100
3/3 [==============================] - 0s 17ms/step - loss: 0.6573 - accuracy: 0.5919 - val_loss: 0.6535 - val_accuracy: 0.6452
Epoch 5/100
3/3 [==============================] - 0s 17ms/step - loss: 0.6499 - accuracy: 0.6176 - val_loss: 0.6447 - val_accuracy: 0.7097
Epoch 6/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6425 - accuracy: 0.6544 - val_loss: 0.6351 - val_accuracy: 0.7097
Epoch 7/100
3/3 [==============================] - 0s 17ms/step - loss: 0.6361 - accuracy: 0.6838 - val_loss: 0.6267 - val_accuracy: 0.7419
Epoch 8/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6291 - accuracy: 0.7169 - val_loss: 0.6193 - val_accuracy: 0.8065
Epoch 9/100
3/3 [==============================] - 0s 17ms/step - loss: 0.6215 - accuracy: 0.7243 - val_loss: 0.6098 - val_accuracy: 0.8065
Epoch 10/100
3/3 [==============================] - 0s 16ms/step - loss: 0.6140 - accuracy: 0.7390 - val_loss: 0.5991 - val_accuracy: 0.8065
Epoch 11/100
3/3 [==============================] - 0s 18ms/step - loss: 0.6058 - accuracy: 0.7574 - val_loss: 0.5882 - val_accuracy: 0.8387
Epoch 12/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5977 - accuracy: 0.7721 - val_loss: 0.5763 - val_accuracy: 0.8387
Epoch 13/100
3/3 [==============================] - 0s 16ms/step - loss: 0.5884 - accuracy: 0.7757 - val_loss: 0.5646 - val_accuracy: 0.8387
Epoch 14/100
3/3 [==============================] - 0s 19ms/step - loss: 0.5790 - accuracy: 0.7757 - val_loss: 0.5537 - val_accuracy: 0.8065
Epoch 15/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5694 - accuracy: 0.7831 - val_loss: 0.5430 - val_accuracy: 0.8065
Epoch 16/100
3/3 [==============================] - 0s 19ms/step - loss: 0.5595 - accuracy: 0.7978 - val_loss: 0.5310 - val_accuracy: 0.8387
Epoch 17/100
3/3 [==============================] - 0s 19ms/step - loss: 0.5489 - accuracy: 0.8088 - val_loss: 0.5185 - val_accuracy: 0.8387
Epoch 18/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5387 - accuracy: 0.8125 - val_loss: 0.5062 - val_accuracy: 0.8387
Epoch 19/100
3/3 [==============================] - 0s 17ms/step - loss: 0.5284 - accuracy: 0.8088 - val_loss: 0.4951 - val_accuracy: 0.8387
Epoch 20/100
3/3 [==============================] - 0s 18ms/step - loss: 0.5182 - accuracy: 0.8162 - val_loss: 0.4832 - val_accuracy: 0.8387
Epoch 21/100
3/3 [==============================] - 0s 18ms/step - loss: 0.5089 - accuracy: 0.8199 - val_loss: 0.4712 - val_accuracy: 0.8387
Epoch 22/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4985 - accuracy: 0.8088 - val_loss: 0.4600 - val_accuracy: 0.8387
Epoch 23/100
3/3 [==============================] - 0s 18ms/step - loss: 0.4896 - accuracy: 0.8125 - val_loss: 0.4497 - val_accuracy: 0.8387
Epoch 24/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4799 - accuracy: 0.8088 - val_loss: 0.4385 - val_accuracy: 0.8387
Epoch 25/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4709 - accuracy: 0.8162 - val_loss: 0.4277 - val_accuracy: 0.8387
Epoch 26/100
3/3 [==============================] - 0s 20ms/step - loss: 0.4619 - accuracy: 0.8199 - val_loss: 0.4172 - val_accuracy: 0.8710
Epoch 27/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4524 - accuracy: 0.8162 - val_loss: 0.4089 - val_accuracy: 0.9032
Epoch 28/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4435 - accuracy: 0.8235 - val_loss: 0.4027 - val_accuracy: 0.9032
Epoch 29/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4360 - accuracy: 0.8272 - val_loss: 0.3966 - val_accuracy: 0.9032
Epoch 30/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4287 - accuracy: 0.8272 - val_loss: 0.3900 - val_accuracy: 0.9032
Epoch 31/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4213 - accuracy: 0.8272 - val_loss: 0.3832 - val_accuracy: 0.9032
Epoch 32/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4147 - accuracy: 0.8272 - val_loss: 0.3778 - val_accuracy: 0.9032
Epoch 33/100
3/3 [==============================] - 0s 17ms/step - loss: 0.4086 - accuracy: 0.8272 - val_loss: 0.3736 - val_accuracy: 0.9032
Epoch 34/100
3/3 [==============================] - 0s 16ms/step - loss: 0.4024 - accuracy: 0.8272 - val_loss: 0.3727 - val_accuracy: 0.9032
Epoch 35/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3982 - accuracy: 0.8199 - val_loss: 0.3673 - val_accuracy: 0.9032
Epoch 36/100
3/3 [==============================] - 0s 20ms/step - loss: 0.3914 - accuracy: 0.8235 - val_loss: 0.3579 - val_accuracy: 0.9032
Epoch 37/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3867 - accuracy: 0.8346 - val_loss: 0.3518 - val_accuracy: 0.8710
Epoch 38/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3872 - accuracy: 0.8235 - val_loss: 0.3455 - val_accuracy: 0.9032
Epoch 39/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3815 - accuracy: 0.8346 - val_loss: 0.3414 - val_accuracy: 0.9032
Epoch 40/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3744 - accuracy: 0.8346 - val_loss: 0.3414 - val_accuracy: 0.9032
Epoch 41/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3710 - accuracy: 0.8346 - val_loss: 0.3433 - val_accuracy: 0.9032
Epoch 42/100
3/3 [==============================] - 0s 18ms/step - loss: 0.3684 - accuracy: 0.8346 - val_loss: 0.3365 - val_accuracy: 0.9032
Epoch 43/100
3/3 [==============================] - 0s 21ms/step - loss: 0.3638 - accuracy: 0.8382 - val_loss: 0.3316 - val_accuracy: 0.9032
Epoch 44/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3603 - accuracy: 0.8456 - val_loss: 0.3357 - val_accuracy: 0.9032
Epoch 45/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3542 - accuracy: 0.8456 - val_loss: 0.3482 - val_accuracy: 0.9032
Epoch 46/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3591 - accuracy: 0.8419 - val_loss: 0.3534 - val_accuracy: 0.9032
Epoch 47/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3583 - accuracy: 0.8382 - val_loss: 0.3437 - val_accuracy: 0.9032
Epoch 48/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3496 - accuracy: 0.8493 - val_loss: 0.3341 - val_accuracy: 0.9032
Epoch 49/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3433 - accuracy: 0.8529 - val_loss: 0.3305 - val_accuracy: 0.9032
Epoch 50/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3399 - accuracy: 0.8640 - val_loss: 0.3199 - val_accuracy: 0.9032
Epoch 51/100
3/3 [==============================] - 0s 19ms/step - loss: 0.3373 - accuracy: 0.8640 - val_loss: 0.3152 - val_accuracy: 0.8710
Epoch 52/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3354 - accuracy: 0.8676 - val_loss: 0.3139 - val_accuracy: 0.8710
Epoch 53/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3310 - accuracy: 0.8713 - val_loss: 0.3152 - val_accuracy: 0.9032
Epoch 54/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3280 - accuracy: 0.8640 - val_loss: 0.3125 - val_accuracy: 0.8710
Epoch 55/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3254 - accuracy: 0.8640 - val_loss: 0.3068 - val_accuracy: 0.8387
Epoch 56/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3236 - accuracy: 0.8676 - val_loss: 0.3070 - val_accuracy: 0.8387
Epoch 57/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3192 - accuracy: 0.8713 - val_loss: 0.3110 - val_accuracy: 0.8710
Epoch 58/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3166 - accuracy: 0.8676 - val_loss: 0.3188 - val_accuracy: 0.9032
Epoch 59/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3142 - accuracy: 0.8640 - val_loss: 0.3129 - val_accuracy: 0.8387
Epoch 60/100
3/3 [==============================] - 0s 17ms/step - loss: 0.3100 - accuracy: 0.8787 - val_loss: 0.3092 - val_accuracy: 0.8387
Epoch 61/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3079 - accuracy: 0.8897 - val_loss: 0.3073 - val_accuracy: 0.8387
Epoch 62/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3046 - accuracy: 0.8897 - val_loss: 0.3112 - val_accuracy: 0.8387
Epoch 63/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3027 - accuracy: 0.8676 - val_loss: 0.3199 - val_accuracy: 0.9032
Epoch 64/100
3/3 [==============================] - 0s 16ms/step - loss: 0.3029 - accuracy: 0.8713 - val_loss: 0.3119 - val_accuracy: 0.9032
Epoch 65/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2957 - accuracy: 0.8750 - val_loss: 0.3008 - val_accuracy: 0.8387
Epoch 66/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2976 - accuracy: 0.8824 - val_loss: 0.2979 - val_accuracy: 0.8387
Epoch 67/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2928 - accuracy: 0.8934 - val_loss: 0.3039 - val_accuracy: 0.9032
Epoch 68/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2909 - accuracy: 0.8787 - val_loss: 0.2989 - val_accuracy: 0.8710
Epoch 69/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2868 - accuracy: 0.8824 - val_loss: 0.2895 - val_accuracy: 0.8387
Epoch 70/100
3/3 [==============================] - 0s 18ms/step - loss: 0.2841 - accuracy: 0.8971 - val_loss: 0.2878 - val_accuracy: 0.8387
Epoch 71/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2813 - accuracy: 0.8934 - val_loss: 0.2896 - val_accuracy: 0.8387
Epoch 72/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2785 - accuracy: 0.9007 - val_loss: 0.2878 - val_accuracy: 0.8387
Epoch 73/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2761 - accuracy: 0.8971 - val_loss: 0.2850 - val_accuracy: 0.8387
Epoch 74/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2745 - accuracy: 0.8934 - val_loss: 0.2941 - val_accuracy: 0.8710
Epoch 75/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2697 - accuracy: 0.8971 - val_loss: 0.3162 - val_accuracy: 0.9032
Epoch 76/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2820 - accuracy: 0.8787 - val_loss: 0.3146 - val_accuracy: 0.9032
Epoch 77/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2731 - accuracy: 0.8860 - val_loss: 0.2922 - val_accuracy: 0.8387
Epoch 78/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2643 - accuracy: 0.9044 - val_loss: 0.2846 - val_accuracy: 0.8387
Epoch 79/100
3/3 [==============================] - 0s 15ms/step - loss: 0.2688 - accuracy: 0.9044 - val_loss: 0.2876 - val_accuracy: 0.8387
Epoch 80/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2630 - accuracy: 0.9007 - val_loss: 0.2934 - val_accuracy: 0.8710
Epoch 81/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2573 - accuracy: 0.9081 - val_loss: 0.2981 - val_accuracy: 0.8710
Epoch 82/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2557 - accuracy: 0.9007 - val_loss: 0.3007 - val_accuracy: 0.8710
Epoch 83/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2534 - accuracy: 0.8971 - val_loss: 0.3000 - val_accuracy: 0.8710
Epoch 84/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2505 - accuracy: 0.9044 - val_loss: 0.3017 - val_accuracy: 0.8710
Epoch 85/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2494 - accuracy: 0.9044 - val_loss: 0.3089 - val_accuracy: 0.9032
Epoch 86/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2484 - accuracy: 0.9044 - val_loss: 0.3021 - val_accuracy: 0.8710
Epoch 87/100
3/3 [==============================] - 0s 20ms/step - loss: 0.2439 - accuracy: 0.9044 - val_loss: 0.2959 - val_accuracy: 0.8710
Epoch 88/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2454 - accuracy: 0.9154 - val_loss: 0.2916 - val_accuracy: 0.8387
Epoch 89/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2392 - accuracy: 0.9118 - val_loss: 0.3101 - val_accuracy: 0.9032
Epoch 90/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2451 - accuracy: 0.9044 - val_loss: 0.3084 - val_accuracy: 0.9032
Epoch 91/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2416 - accuracy: 0.9081 - val_loss: 0.2871 - val_accuracy: 0.8387
Epoch 92/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2335 - accuracy: 0.9154 - val_loss: 0.2807 - val_accuracy: 0.8387
Epoch 93/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2325 - accuracy: 0.9154 - val_loss: 0.2886 - val_accuracy: 0.9032
Epoch 94/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2312 - accuracy: 0.9081 - val_loss: 0.2951 - val_accuracy: 0.9032
Epoch 95/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2312 - accuracy: 0.9191 - val_loss: 0.2757 - val_accuracy: 0.8387
Epoch 96/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2305 - accuracy: 0.9154 - val_loss: 0.2772 - val_accuracy: 0.8387
Epoch 97/100
3/3 [==============================] - 0s 19ms/step - loss: 0.2248 - accuracy: 0.9118 - val_loss: 0.2982 - val_accuracy: 0.9032
Epoch 98/100
3/3 [==============================] - 0s 16ms/step - loss: 0.2242 - accuracy: 0.9044 - val_loss: 0.3163 - val_accuracy: 0.9032
Epoch 99/100
3/3 [==============================] - 0s 20ms/step - loss: 0.2360 - accuracy: 0.9081 - val_loss: 0.3202 - val_accuracy: 0.9032
Epoch 100/100
3/3 [==============================] - 0s 17ms/step - loss: 0.2318 - accuracy: 0.9044 - val_loss: 0.2915 - val_accuracy: 0.9032
# 模型评估
import matplotlib.pyplot as plt

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

010-RNN-心脏病预测_第1张图片

scores = model.evaluate(X_test, y_test, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
accuracy: 90.32%

你可能感兴趣的:(03,深度学习,rnn,人工智能,深度学习)