深度学习 Day21——利用RNN实现心脏病预测

一、前言

深度学习 Day21——利用RNN实现心脏病预测-LMLPHP

大概快两个星期没有更新这个专栏的续集了,实在是抱歉,最近发生了很多事,一直都没有时间来写,上周忙着复习期末考试,结果学校提前放假了,让我们尽快回家,然后就开始了漫长的回家之旅,路途艰苦,差点回不了家,还好经过两天的休息现在已经恢复的很好了,一刻也不能懈怠,调整过来了就要抓紧学习了。

好啦废话不多说,我们开始新的一轮学习吧,从本期博客开始,我们将从原来的CNN转变到RNN来,学习一个新的神经网络,流程都是类似的,我们主要来学习如何使用RNN来解决问题,本期博客我们就来使用RNN来实现心脏病的预测。

二、我的环境

  • 电脑系统:Windows 11
  • 语言环境:Python 3.8.5
  • 编译器:DataSpell 2022.2
  • 深度学习环境:TensorFlow 2.4.0
  • 显卡及显存:RTX 3070 8G

三、什么是RNN

在编程开始之前,我们首先来了解一下什么是RNN,循环神经网络( RNN ) 是一类人工神经网络,其中节点之间的连接可以创建一个循环,允许来自某些节点的输出影响对相同节点的后续输入。这允许它展示时间动态行为。源自前馈神经网络,RNN 可以使用其内部状态(记忆)来处理可变长度的输入序列。这使得它们适用于未分段的、连接的手写识别或语音识别等任务。

简单的RNN网络结构包括一个输入层、一个隐藏层和一个输出层:

深度学习 Day21——利用RNN实现心脏病预测-LMLPHP

如果我们将上面的图片按照时间线展开也可以画成这样:

深度学习 Day21——利用RNN实现心脏病预测-LMLPHP

四、前期工作

1、设置GPU

和之前一样,如果你GPU很好就只使用GPU进行训练,如果GPU不行就推荐使用CPU训练加GPU加速。

只使用GPU:

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

使用CPU+GPU:

import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

2、导入数据

在导入数据之前我们来查看一下我们的CSV数据:

深度学习 Day21——利用RNN实现心脏病预测-LMLPHP

其中每个数据的标签含义为:

  • age:年龄
  • sex:性别
  • cp:胸痛类型 (4 values)
  • trestbps:静息血压
  • chol:血清胆甾醇 (mg/dl)
  • fbs:空腹血糖 > 120 mg/dl
  • restecg:静息心电图结果 (值 0,1 ,2)
  • thalach:达到的最大心率
  • exang:运动诱发的心绞痛
  • oldpeak:相对于静止状态,运动引起的ST段压低
  • slope:运动峰值 ST 段的斜率
  • ca:荧光透视着色的主要血管数量 (0-3)
  • thal:0 = 正常;1 = 固定缺陷;2 = 可逆转的缺陷
  • target:0 = 心脏病发作的几率较小 1 = 心脏病发作的几率更大

介绍完数据标签,我们将数据进行导入:

import pandas as pd
import numpy as np

df = pd.read_csv("E:\DL_data\Day21\heart.csv")
df

深度学习 Day21——利用RNN实现心脏病预测-LMLPHP

一共303行,4列数据。

3、检查数据

在进行数据预处理之前我们还需要对我们的数据进行检查,确保每一个标签内的数据没有空值:

df.isnull().sum()
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64

五、数据预处理

1、划分数据集

检查完数据之后,我们来对数据集进行划分,划分出训练集、验证集和测试集:

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

x = df.iloc[:,:-1]
y = df.iloc[:,-1]

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1)
x_train.shape, y_train.shape
((272, 13), (272,))

2、数据标准化

# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test  = sc.transform(x_test)

x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
x_test  = x_test.reshape(x_test.shape[0], x_test.shape[1], 1)

这里我们用到了 StandardScaler函数,它的作用是去均值和方差归一化。且是针对每一个特征维度来做的,而不是针对样本。

六、构建RNN模型

这里我们构建模型要用到tf.keras.layers.SimpleRNN()函数,这个函数的模型如下:

tf.keras.layers.SimpleRNN(
    units,
    activation='tanh',
    use_bias=True,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='orthogonal',
    bias_initializer='zeros',
    kernel_regularizer=None,
    recurrent_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    recurrent_constraint=None,
    bias_constraint=None,
    dropout=0.0,
    recurrent_dropout=0.0,
    return_sequences=False,
    return_state=False,
    go_backwards=False,
    stateful=False,
    unroll=False,
    **kwargs
)

这里对函数的参数介绍我就直接截图官网的解释:

深度学习 Day21——利用RNN实现心脏病预测-LMLPHP

介绍完函数模型,现在我们来构建模型:

import tensorflow
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN

model = Sequential()
model.add(SimpleRNN(128, input_shape= (13,1),return_sequences=True,activation='relu'))
model.add(SimpleRNN(64,return_sequences=True, activation='relu'))
model.add(SimpleRNN(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()
Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn_31 (SimpleRNN)    (None, 13, 128)           16640     
_________________________________________________________________
simple_rnn_32 (SimpleRNN)    (None, 13, 64)            12352     
_________________________________________________________________
simple_rnn_33 (SimpleRNN)    (None, 32)                3104      
_________________________________________________________________
dense_22 (Dense)             (None, 64)                2112      
_________________________________________________________________
dense_23 (Dense)             (None, 1)                 65        
=================================================================
Total params: 34,273
Trainable params: 34,273
Non-trainable params: 0
_________________________________________________________________

七、编译模型

opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(loss='binary_crossentropy', optimizer=opt,metrics=['accuracy'])

八、训练模型

epochs = 50

history = model.fit(x_train, y_train,
                    epochs=epochs,
                    batch_size=128,
                    validation_data=(x_test, y_test),
                    verbose=1)

训练的结果是:

Epoch 1/50
3/3 [==============================] - 0s 154ms/step - loss: 0.7026 - accuracy: 0.3272 - val_loss: 0.7091 - val_accuracy: 0.3548
Epoch 2/50
3/3 [==============================] - 0s 17ms/step - loss: 0.6947 - accuracy: 0.4485 - val_loss: 0.7001 - val_accuracy: 0.4839
Epoch 3/50
3/3 [==============================] - 0s 21ms/step - loss: 0.6884 - accuracy: 0.5772 - val_loss: 0.6914 - val_accuracy: 0.5806
Epoch 4/50
3/3 [==============================] - 0s 14ms/step - loss: 0.6831 - accuracy: 0.6691 - val_loss: 0.6834 - val_accuracy: 0.5806
Epoch 5/50
3/3 [==============================] - 0s 15ms/step - loss: 0.6779 - accuracy: 0.6801 - val_loss: 0.6759 - val_accuracy: 0.5806
Epoch 6/50
3/3 [==============================] - 0s 13ms/step - loss: 0.6729 - accuracy: 0.6765 - val_loss: 0.6692 - val_accuracy: 0.6129
Epoch 7/50
3/3 [==============================] - 0s 14ms/step - loss: 0.6677 - accuracy: 0.6912 - val_loss: 0.6627 - val_accuracy: 0.6129
Epoch 8/50
3/3 [==============================] - 0s 14ms/step - loss: 0.6623 - accuracy: 0.7022 - val_loss: 0.6558 - val_accuracy: 0.6129
Epoch 9/50
3/3 [==============================] - 0s 14ms/step - loss: 0.6568 - accuracy: 0.7132 - val_loss: 0.6492 - val_accuracy: 0.6774
Epoch 10/50
3/3 [==============================] - 0s 15ms/step - loss: 0.6511 - accuracy: 0.7169 - val_loss: 0.6418 - val_accuracy: 0.7419
Epoch 11/50
3/3 [==============================] - 0s 13ms/step - loss: 0.6450 - accuracy: 0.7390 - val_loss: 0.6342 - val_accuracy: 0.8065
Epoch 12/50
3/3 [==============================] - 0s 14ms/step - loss: 0.6386 - accuracy: 0.7463 - val_loss: 0.6260 - val_accuracy: 0.8065
Epoch 13/50
3/3 [==============================] - 0s 13ms/step - loss: 0.6325 - accuracy: 0.7463 - val_loss: 0.6170 - val_accuracy: 0.8065
Epoch 14/50
3/3 [==============================] - 0s 16ms/step - loss: 0.6261 - accuracy: 0.7463 - val_loss: 0.6083 - val_accuracy: 0.8065
Epoch 15/50
3/3 [==============================] - 0s 14ms/step - loss: 0.6192 - accuracy: 0.7574 - val_loss: 0.5989 - val_accuracy: 0.8065
Epoch 16/50
3/3 [==============================] - 0s 15ms/step - loss: 0.6122 - accuracy: 0.7684 - val_loss: 0.5888 - val_accuracy: 0.8387
Epoch 17/50
3/3 [==============================] - 0s 14ms/step - loss: 0.6046 - accuracy: 0.7610 - val_loss: 0.5781 - val_accuracy: 0.8387
Epoch 18/50
3/3 [==============================] - 0s 14ms/step - loss: 0.5967 - accuracy: 0.7684 - val_loss: 0.5666 - val_accuracy: 0.8387
Epoch 19/50
3/3 [==============================] - 0s 13ms/step - loss: 0.5880 - accuracy: 0.7684 - val_loss: 0.5543 - val_accuracy: 0.8710
Epoch 20/50
3/3 [==============================] - 0s 14ms/step - loss: 0.5784 - accuracy: 0.7794 - val_loss: 0.5413 - val_accuracy: 0.8710
Epoch 21/50
3/3 [==============================] - 0s 13ms/step - loss: 0.5680 - accuracy: 0.8015 - val_loss: 0.5282 - val_accuracy: 0.8710
Epoch 22/50
3/3 [==============================] - 0s 13ms/step - loss: 0.5572 - accuracy: 0.8015 - val_loss: 0.5149 - val_accuracy: 0.8710
Epoch 23/50
3/3 [==============================] - 0s 14ms/step - loss: 0.5462 - accuracy: 0.8051 - val_loss: 0.4997 - val_accuracy: 0.8710
Epoch 24/50
3/3 [==============================] - 0s 17ms/step - loss: 0.5360 - accuracy: 0.7904 - val_loss: 0.4845 - val_accuracy: 0.8710
Epoch 25/50
3/3 [==============================] - 0s 16ms/step - loss: 0.5256 - accuracy: 0.7868 - val_loss: 0.4675 - val_accuracy: 0.8710
Epoch 26/50
3/3 [==============================] - 0s 16ms/step - loss: 0.5142 - accuracy: 0.7831 - val_loss: 0.4489 - val_accuracy: 0.8710
Epoch 27/50
3/3 [==============================] - 0s 16ms/step - loss: 0.5028 - accuracy: 0.7794 - val_loss: 0.4294 - val_accuracy: 0.8710
Epoch 28/50
3/3 [==============================] - 0s 16ms/step - loss: 0.4881 - accuracy: 0.7904 - val_loss: 0.4108 - val_accuracy: 0.8710
Epoch 29/50
3/3 [==============================] - 0s 15ms/step - loss: 0.4783 - accuracy: 0.8051 - val_loss: 0.3924 - val_accuracy: 0.9032
Epoch 30/50
3/3 [==============================] - 0s 14ms/step - loss: 0.4673 - accuracy: 0.8051 - val_loss: 0.3738 - val_accuracy: 0.9032
Epoch 31/50
3/3 [==============================] - 0s 14ms/step - loss: 0.4555 - accuracy: 0.8088 - val_loss: 0.3546 - val_accuracy: 0.9032
Epoch 32/50
3/3 [==============================] - 0s 14ms/step - loss: 0.4448 - accuracy: 0.8125 - val_loss: 0.3376 - val_accuracy: 0.9032
Epoch 33/50
3/3 [==============================] - 0s 14ms/step - loss: 0.4350 - accuracy: 0.8125 - val_loss: 0.3235 - val_accuracy: 0.9032
Epoch 34/50
3/3 [==============================] - 0s 14ms/step - loss: 0.4248 - accuracy: 0.8125 - val_loss: 0.3131 - val_accuracy: 0.9032
Epoch 35/50
3/3 [==============================] - 0s 15ms/step - loss: 0.4159 - accuracy: 0.8162 - val_loss: 0.3055 - val_accuracy: 0.9032
Epoch 36/50
3/3 [==============================] - 0s 14ms/step - loss: 0.4075 - accuracy: 0.8125 - val_loss: 0.3003 - val_accuracy: 0.9032
Epoch 37/50
3/3 [==============================] - 0s 15ms/step - loss: 0.4000 - accuracy: 0.8199 - val_loss: 0.2925 - val_accuracy: 0.9032
Epoch 38/50
3/3 [==============================] - 0s 16ms/step - loss: 0.3931 - accuracy: 0.8272 - val_loss: 0.2859 - val_accuracy: 0.9032
Epoch 39/50
3/3 [==============================] - 0s 14ms/step - loss: 0.3865 - accuracy: 0.8382 - val_loss: 0.2806 - val_accuracy: 0.9032
Epoch 40/50
3/3 [==============================] - 0s 14ms/step - loss: 0.3789 - accuracy: 0.8309 - val_loss: 0.2800 - val_accuracy: 0.9032
Epoch 41/50
3/3 [==============================] - 0s 13ms/step - loss: 0.3736 - accuracy: 0.8346 - val_loss: 0.2874 - val_accuracy: 0.9032
Epoch 42/50
3/3 [==============================] - 0s 13ms/step - loss: 0.3745 - accuracy: 0.8272 - val_loss: 0.2857 - val_accuracy: 0.9032
Epoch 43/50
3/3 [==============================] - 0s 13ms/step - loss: 0.3655 - accuracy: 0.8419 - val_loss: 0.2740 - val_accuracy: 0.9032
Epoch 44/50
3/3 [==============================] - 0s 13ms/step - loss: 0.3560 - accuracy: 0.8456 - val_loss: 0.2719 - val_accuracy: 0.9032
Epoch 45/50
3/3 [==============================] - 0s 14ms/step - loss: 0.3504 - accuracy: 0.8456 - val_loss: 0.2714 - val_accuracy: 0.9032
Epoch 46/50
3/3 [==============================] - 0s 13ms/step - loss: 0.3444 - accuracy: 0.8566 - val_loss: 0.2723 - val_accuracy: 0.9032
Epoch 47/50
3/3 [==============================] - 0s 12ms/step - loss: 0.3391 - accuracy: 0.8640 - val_loss: 0.2725 - val_accuracy: 0.9032
Epoch 48/50
3/3 [==============================] - 0s 13ms/step - loss: 0.3353 - accuracy: 0.8566 - val_loss: 0.2706 - val_accuracy: 0.9032
Epoch 49/50
3/3 [==============================] - 0s 12ms/step - loss: 0.3307 - accuracy: 0.8566 - val_loss: 0.2685 - val_accuracy: 0.9032
Epoch 50/50
3/3 [==============================] - 0s 13ms/step - loss: 0.3254 - accuracy: 0.8713 - val_loss: 0.2657 - val_accuracy: 0.9032
model.evaluate(x_test,y_test)
1/1 [==============================] - 0s 999us/step - loss: 0.2657 - accuracy: 0.9032
[0.26569074392318726, 0.9032257795333862]

九、模型评估

import matplotlib.pyplot as plt

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

深度学习 Day21——利用RNN实现心脏病预测-LMLPHP

我们打印出最后的准确率看看:

scores = model.evaluate(x_test, y_test, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
accuracy: 90.32%

可以看出还不错,不过模型还可以进行改进的。

十、最后我想说

RNN相比于其他的神经网络,RNN是具有记忆性的,输入数据之间通常是分离的,它只能单独的去处理一个个的输入,前后输入之间并没有相关性,但是在实际应用中,某些任务需要能够处理序列的信息,而且前后输入存在一定的联系,RNN也有一定的局限性,那就是RNN具有长距离依赖,很难处理长序列的数据,而且由于其模型的特性,它比起神经网络更容易出现梯度消失和梯度爆炸的问题。

12-10 06:38