基于人工智能的【预测死亡-心力衰竭】患者模型建立

简介: 基于人工智能的【预测死亡-心力衰竭】患者模型建立

一、预测死亡-心力衰竭患者模型建立


image.png


1.数据集简介


  • 心血管疾病 (CVD) 是全球第一大死因,估计每年夺去 1790 万人的生命,占全球所有死亡人数的 31%。
  • 心力衰竭是由 CVD 引起的常见事件,该数据集包含 12 个可用于预测心力衰竭死亡率的特征。
  • 大多数心血管疾病可以通过使用全民策略解决烟草使用、不健康饮食和肥胖、缺乏身体活动和有害使用酒精等行为风险因素来预防。
  • 患有心血管疾病或处于高心血管风险(由于存在一种或多种风险因素,如高血压、糖尿病、高脂血症或已经确定的疾病)的人需要早期检测和管理,其中机器学习模型可以提供很大帮助。


2.scikiti-survival库的简介


  • scikit-survival 是一个 基于scikit-learn构建的用于生存分析的 Python 模块。它允许在利用 scikit-learn 的强大功能的同时进行生存分析,例如,用于预处理或进行交叉验证。
  • 生存分析(也称为事件发生时间或可靠性分析)的目标是在协变量和事件发生时间之间建立联系。生存分析与传统机器学习的不同之处在于,部分训练数据只能部分观察——它们被删减了。
  • 例如,在临床研究中,通常会在特定时间段内监测患者,并记录在该特定时间段内发生的事件。如果患者经历了事件,则可以记录事件的确切时间——患者的记录未经审查。相反,右截尾记录指的是在研究期间保持无事件的患者,并且不知道研究结束后事件是否发生。因此,生存分析需要考虑此类数据集的这一独特特征的模型。

文档: [User Guide — scikit-survival 0.20.1scikit-survival.readthedocs.io/en/latest/u… Guide — scikit-survival 0.20.1scikit-survival.readthedocs.io/en/latest/u…)


3.超参数调优框架optuna库的简介


optuna 是一个十分常用的超参数调优框架,具有操作简单,嵌入式强和动态调整参数空间等优点。


二、环境构设


from IPython.display import clear_output
!pip install scikit-survival
!pip install optuna
clear_output() # 清理很长的内容


三、数据处理


1.数据查看


import pandas as pd 
data=pd.read_csv('data/data209679/heart_failure_clinical_records_dataset.csv')
data.info()
data.head()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 299 entries, 0 to 298
Data columns (total 13 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   age                       299 non-null    float64
 1   anaemia                   299 non-null    int64  
 2   creatinine_phosphokinase  299 non-null    int64  
 3   diabetes                  299 non-null    int64  
 4   ejection_fraction         299 non-null    int64  
 5   high_blood_pressure       299 non-null    int64  
 6   platelets                 299 non-null    float64
 7   serum_creatinine          299 non-null    float64
 8   serum_sodium              299 non-null    int64  
 9   sex                       299 non-null    int64  
 10  smoking                   299 non-null    int64  
 11  time                      299 non-null    int64  
 12  DEATH_EVENT               299 non-null    int64  
dtypes: float64(3), int64(10)
memory usage: 30.5 KB

    .dataframe tbody tr th:only-of-type {         vertical-align: middle;     } .dataframe tbody tr th {     vertical-align: top; } .dataframe thead th {     text-align: right; }

age anaemia creatinine_phosphokinase diabetes ejection_fraction high_blood_pressure platelets serum_creatinine serum_sodium sex smoking time DEATH_EVENT
0 75.0 0 582 0 20 1 265000.00 1.9 130 1 0 4 1
1 55.0 0 7861 0 38 0 263358.03 1.1 136 1 0 6 1
2 65.0 0 146 0 20 0 162000.00 1.3 129 1 1 7 1
3 50.0 1 111 0 20 0 210000.00 1.9 137 1 0 7 1
4 65.0 1 160 1 20 0 327000.00 2.7 116 0 0 8 1
  • 生存类数据,样本量小,使用交叉验证方法
  • ✔️构建预测模型则用scikit-survival文库,这里可以预测未发生死亡事件的人群的死亡时间(从随访起点算起)。


2.X,y构建


from sksurv.util import Surv
from sksurv.ensemble import RandomSurvivalForest
from sklearn.impute import SimpleImputer
data['DEATH_EVENT']=[True if x==1 else 0 for x in data['DEATH_EVENT']]
y=Surv.from_dataframe(event='DEATH_EVENT',time='time',data=data)
cat_cols=['anaemia','diabetes','high_blood_pressure','sex','smoking']
data[cat_cols]=data[cat_cols].astype('category')
X=data.drop(['DEATH_EVENT','time'],axis=1)
X.head()

    .dataframe tbody tr th:only-of-type {         vertical-align: middle;     } .dataframe tbody tr th {     vertical-align: top; } .dataframe thead th {     text-align: right; }

age anaemia creatinine_phosphokinase diabetes ejection_fraction high_blood_pressure platelets serum_creatinine serum_sodium sex smoking
0 75.0 0 582 0 20 1 265000.00 1.9 130 1 0
1 55.0 0 7861 0 38 0 263358.03 1.1 136 1 0
2 65.0 0 146 0 20 0 162000.00 1.3 129 1 1
3 50.0 1 111 0 20 0 210000.00 1.9 137 1 0
4 65.0 1 160 1 20 0 327000.00 2.7 116 0 0


四、模型构建和评价


1.超参数搜索


# pipe-line
from sklearn.pipeline import make_pipeline
from sksurv.ensemble import RandomSurvivalForest
from sklearn.preprocessing import RobustScaler,StandardScaler,MinMaxScaler,OneHotEncoder
from sklearn.model_selection import cross_val_score
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector as selector
import optuna
import numpy as np
def objective(trial):
    n_estimators=trial.suggest_int('n_estimators',100,1000,10)
    min_sample_split=trial.suggest_int('min_sample_split',1,29,2)
    min_sample_leaf=trial.suggest_int('min_sample_leaf',1,29,2)
    preprocessor=make_column_transformer((RobustScaler(),selector(dtype_include='number')))
    rsf=make_pipeline(preprocessor, RandomSurvivalForest(n_estimators=n_estimators,
                            min_samples_split=10,
                            min_samples_leaf=15,
                            n_jobs=-1,
                            random_state=0))
    scores=cross_val_score(rsf,X,y)
    return np.mean(scores)
study=optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
study.best_params


[I 2023-04-17 00:42:44,270] Trial 95 finished with value: 0.7495800673166991 and parameters: {'n_estimators': 230, 'min_sample_split': 15, 'min_sample_leaf': 5}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:47,104] Trial 96 finished with value: 0.7488762391145961 and parameters: {'n_estimators': 280, 'min_sample_split': 13, 'min_sample_leaf': 9}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:49,221] Trial 97 finished with value: 0.7502194611898544 and parameters: {'n_estimators': 200, 'min_sample_split': 13, 'min_sample_leaf': 11}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:51,100] Trial 98 finished with value: 0.7536458218120432 and parameters: {'n_estimators': 160, 'min_sample_split': 21, 'min_sample_leaf': 1}. Best is trial 37 with value: 0.7540914188972966.
[I 2023-04-17 00:42:52,665] Trial 99 finished with value: 0.7523008612365734 and parameters: {'n_estimators': 120, 'min_sample_split': 25, 'min_sample_leaf': 3}. Best is trial 37 with value: 0.7540914188972966.


2.模型训练


#best_model在后续预测中使用cindex=0.73
preprocessor=make_column_transformer((RobustScaler(),selector(dtype_include='number')))
rsf_best=make_pipeline(preprocessor, RandomSurvivalForest(n_estimators=170,
                            min_samples_split=15,
                            min_samples_leaf=25,
                            n_jobs=-1,
                            random_state=0))
rsf_best.fit(X,y)
import joblib
joblib.dump(rsf_best,'rsf_best.pkl')


['rsf_best.pkl']


3.模型预测


#限制累积风险为1,获得对应的时间。 
va_times=np.arange(4,241)
data_pre=data[data['DEATH_EVENT']!=True].drop(['DEATH_EVENT','time'],axis=1)
chf_funcs = rsf_best.predict_cumulative_hazard_function(data_pre)#产生对所有的test的风险函数,只需传入时间参数即可获得累积风险
outcome_period=[]
for fn in chf_funcs:#
    if fn(va_times[-1])<1:#在最后的预测时间内死亡全部累计概率不到0.6
        time_value=999
    else:
        for time in va_times:
            if fn(time)>1:
                time_value=time#发生结局的最短时间
                break
            # print(time)
    outcome_period.append(time_value)
outcome_predict=data_pre.copy()
outcome_predict['outcome_period']=outcome_period 
result=outcome_predict[outcome_predict['outcome_period']!=999]['outcome_period']


4.保存结果


patient_id=result.index
patient_surv_month=result.values
for i,x in zip(patient_id,patient_surv_month):
    print('{}号患者死亡的时间为{}个月时。'.format(i,x))
#这里的时间计算开始是从患者入组时间开始算起,不是当下日期。


20号患者死亡的时间为235个月时。
38号患者死亡的时间为198个月时。
89号患者死亡的时间为235个月时。
96号患者死亡的时间为235个月时。
98号患者死亡的时间为235个月时。
100号患者死亡的时间为235个月时。
102号患者死亡的时间为235个月时。
112号患者死亡的时间为235个月时。
117号患者死亡的时间为198个月时。
131号患者死亡的时间为198个月时。
137号患者死亡的时间为193个月时。
155号患者死亡的时间为235个月时。
157号患者死亡的时间为235个月时。
173号患者死亡的时间为235个月时。
190号患者死亡的时间为180个月时。
198号患者死亡的时间为235个月时。
203号患者死亡的时间为196个月时。
210号患者死亡的时间为235个月时。
223号患者死亡的时间为235个月时。
224号患者死亡的时间为235个月时。
226号患者死亡的时间为235个月时。
228号患者死亡的时间为196个月时。
229号患者死亡的时间为235个月时。
247号患者死亡的时间为180个月时。
281号患者死亡的时间为207个月时。
282号患者死亡的时间为198个月时。


6.预测个案


加载存储的模型,然后进行预测

def survival_time(model,patient):
    chf_funcs=model.predict_cumulative_hazard_function(patient)
    for fn in chf_funcs:#
        if fn(va_times[-1])<1:#在最后的预测时间内死亡全部累计概率不到0.6
            time_value=999
            print('该患者在241个月内未预测到因疾病原因的死亡')
        else:
            for time in va_times:
                if fn(time)>1:
                    time_value=time#发生结局的最短时间
                    break
            print('该患者预测在{}月时因疾病原因死亡'.format(time))


#加载储存的模型
model=joblib.load('rsf_best.pkl')
#输入患者数据,我们这里加载了20号患者,可以看到和前面的批量预测是一致的。
patient=data_pre[data_pre.index==20]
print(patient)
#预测死亡时间
survival_time(model,patient)


age anaemia  creatinine_phosphokinase diabetes  ejection_fraction  \
20  65.0       1                        52        0                 25   
   high_blood_pressure  platelets  serum_creatinine  serum_sodium sex smoking  
20                   1   276000.0               1.3           137   0       0  
该患者预测在235月时因疾病原因死亡

image.png


目录
相关文章
|
6天前
|
机器学习/深度学习 人工智能 边缘计算
大模型:引领人工智能新纪元的引擎
大模型:引领人工智能新纪元的引擎
|
6天前
|
人工智能 安全 网络安全
欧盟《人工智能法案》对通用AI模型的监管要求
【2月更文挑战第24天】欧盟《人工智能法案》对通用AI模型的监管要求
111 2
欧盟《人工智能法案》对通用AI模型的监管要求
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
人工智能与文本生成:基于Transformer的文本生成模型
人工智能与文本生成:基于Transformer的文本生成模型
140 0
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
人工智能基础——模型部分:模型介绍、模型训练和模型微调 !!
人工智能基础——模型部分:模型介绍、模型训练和模型微调 !!
172 0
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
人工智能大模型引领智能时代的革命
随着AI技术的飞速发展,人工智能大模型正成为推动社会进步和经济发展的重要力量,比如GPT-3、BERT和其他深度学习架构,正在开启一个全新的智能时代。在人机交互、计算范式和认知协作三个领域,大模型带来了深刻的变革。那么本文就来分享一下关于大模型如何提升人机交互的自然性和智能化程度,以及它们如何影响现有的计算模式并推动新一代计算技术的演进,并探讨这些变革对未来的意义。
50 1
人工智能大模型引领智能时代的革命
|
1天前
|
机器学习/深度学习 人工智能 算法
食物识别系统Python+深度学习人工智能+TensorFlow+卷积神经网络算法模型
食物识别系统采用TensorFlow的ResNet50模型,训练了包含11类食物的数据集,生成高精度H5模型。系统整合Django框架,提供网页平台,用户可上传图片进行食物识别。效果图片展示成功识别各类食物。[查看演示视频、代码及安装指南](https://www.yuque.com/ziwu/yygu3z/yhd6a7vai4o9iuys?singleDoc#)。项目利用深度学习的卷积神经网络(CNN),其局部感受野和权重共享机制适于图像识别,广泛应用于医疗图像分析等领域。示例代码展示了一个使用TensorFlow训练的简单CNN模型,用于MNIST手写数字识别。
14 3
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
什么是人工智能模型的泛化能力
什么是人工智能模型的泛化能力
14 2
|
6天前
|
存储 人工智能 算法
【论文阅读-问答】人工智能生成内容增强的甲状腺结节计算机辅助诊断模型:CHATGPT风格的助手
【论文阅读-问答】人工智能生成内容增强的甲状腺结节计算机辅助诊断模型:CHATGPT风格的助手
16 6
|
6天前
|
机器学习/深度学习 存储 人工智能
人工智能平台PAI产品使用合集之是否可以在模型部署发布后以http接口形式提供给业务开发人员使用
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
6天前
|
机器学习/深度学习 人工智能 NoSQL
人工智能平台PAI产品使用合集之机器学习PAI EasyRec训练时,怎么去除没有意义的辅助任务的模型,用于部署
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。

热门文章

最新文章