一、ML Olympiad - GOOD HEALTH AND WELL BEING
1.简介
利用机器学习知识诊断患者是否患有心脏病。通过大量的数据学习,就可以充分学习到相关临床经验,预测到潜在的心脏病患者,尽早发现,并及时治疗。
竞赛地址:www.kaggle.com/c/ml-olympi…
2.数据说明
1 个二进制目标变量 和 21 个特征变量,具体如下:
- HіghBP:被医生、护士或其他健康专业人士告知他们有高血脂的成年人
- HighChol:您有没有被医生、护士或其他健康专家告知您的 blооd 胆固醇高?
- ChоlChесk:Chоlеѕtеrоl检查wіthіnраѕt五个уеаrѕ。
- BMI:身体 Mаѕѕ 指数 (BMI)
- Smоkеr:你有没有在你的生活中至少吸烟 100 сіgаrеttеѕ?[注意:5 расkѕ = 100 сіgаrеttеѕ]
- Strоkе:(曾经说过)你有一个ѕtrоkе。
- Dіаbеtеѕ:0 表示没有糖尿病,1 表示糖尿病,2 表示糖尿病。
- PhуѕAсtіvіtу:报告在过去 30 天内进行身体锻炼或锻炼的成年人,而不是常规工作。
- Fruits:每天食用 1 个或更多水果
- Vеggіеѕ:Cоnѕumе Vеgеtаblеѕ 1оr more tіmеѕреr day
- HvуAlсоhоlCоnѕumр:酗酒者(成年男性每周饮酒超过 14 次,成年女性每周饮酒超过 7 次)
- AnуHеаlthсаrе:您是否有任何类型的健康保险,包括健康保险、预付费计划如 HMOѕ、оr 政府 рlаnѕѕuсh 如 Mеdісаrе、оr Indian HеаlthSеrvісе?
- NоDосbсCоѕt:有一次在 раѕt 12 mоnthѕ 时,уоu 需要 tо ѕее dосtоr 但不能因为оf成本?
- GеnHlth:你会说,一般来说,你的健康是:
- MentHlth:现在想想你的心理健康,哪些压力,dерrеѕѕіоn,以及情绪问题,在过去的 30 天里,你的心理健康有多少天?
- PhуѕHlth:现在想想你的 рhуѕісаl 健康,这 іnсludеѕрhуѕісаl 疾病和іnjurу,因为在第 30 天期间有多少天你的身体健康不好?
- DіffWаlk:你有困难吗?
- Sеx:指明性别。
- Agе:十四级 аgе саtеgоrу
- Education:你在学校的最高分是什么?
- Inсоmе:您的家庭年收入来自所有来源:(如果 раtіеnt 在аnуіnсоmеlеvеl拒绝аtаnуіnсоmеlеvеl,соdе“拒绝。”)
二、导入所需包
import numpy as np import pandas as pd import matplotlib.pyplot as plt %matplotlib inline import seaborn as sns sns.set_style('whitegrid') import warnings warnings.filterwarnings("ignore")
三、数据读取
1.解压缩数据
# 解压缩数据 # !unzip data/data127426/ml-olympiad-good-health-and-well-being.zip
2.pandas读取数据
# 读取数据 sub = pd.read_csv('sample_submission.csv') test = pd.read_csv('test.csv') train = pd.read_csv('train.csv')
3.查看数据
train.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 177576 entries, 0 to 177575 Data columns (total 23 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 PatientID 177576 non-null int64 1 HighBP 177576 non-null int64 2 HighChol 177576 non-null int64 3 CholCheck 177576 non-null int64 4 BMI 177576 non-null int64 5 Smoker 177576 non-null int64 6 Stroke 177576 non-null int64 7 Diabetes 177576 non-null int64 8 PhysActivity 177576 non-null int64 9 Fruits 177576 non-null int64 10 Veggies 177576 non-null int64 11 HvyAlcoholConsump 177576 non-null int64 12 AnyHealthcare 177576 non-null int64 13 NoDocbcCost 177576 non-null int64 14 GenHlth 177576 non-null int64 15 MentHlth 177576 non-null int64 16 PhysHlth 177576 non-null int64 17 DiffWalk 177576 non-null int64 18 Sex 177576 non-null int64 19 Age 177576 non-null int64 20 Education 177576 non-null int64 21 Income 177576 non-null int64 22 target 177576 non-null int64 dtypes: int64(23) memory usage: 31.2 MB
train.head() .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
PatientID | HighBP | HighChol | CholCheck | BMI | Smoker | Stroke | Diabetes | PhysActivity | Fruits | ... | NoDocbcCost | GenHlth | MentHlth | PhysHlth | DiffWalk | Sex | Age | Education | Income | target | |
0 | 42351 | 1 | 1 | 1 | 29 | 0 | 0 | 0 | 1 | 1 | ... | 0 | 3 | 0 | 0 | 0 | 0 | 13 | 5 | 8 | 0 |
1 | 135091 | 1 | 0 | 1 | 30 | 0 | 1 | 2 | 0 | 0 | ... | 0 | 2 | 0 | 0 | 0 | 0 | 9 | 5 | 6 | 0 |
2 | 201403 | 0 | 0 | 1 | 31 | 0 | 0 | 0 | 1 | 1 | ... | 0 | 2 | 0 | 7 | 0 | 0 | 10 | 6 | 8 | 0 |
3 | 72750 | 0 | 0 | 1 | 36 | 0 | 0 | 2 | 0 | 0 | ... | 0 | 2 | 0 | 0 | 0 | 0 | 11 | 5 | 6 | 0 |
4 | 133895 | 0 | 1 | 1 | 29 | 0 | 0 | 0 | 1 | 1 | ... | 0 | 4 | 0 | 0 | 1 | 1 | 10 | 6 | 7 | 0 |
5 rows × 23 columns
train.describe() .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
PatientID | HighBP | HighChol | CholCheck | BMI | Smoker | Stroke | Diabetes | PhysActivity | Fruits | ... | NoDocbcCost | GenHlth | MentHlth | PhysHlth | DiffWalk | Sex | Age | Education | Income | target | |
count | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | ... | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 | 177576.000000 |
mean | 126899.481523 | 0.429230 | 0.423914 | 0.962180 | 28.380001 | 0.443061 | 0.040287 | 0.298244 | 0.756335 | 0.634078 | ... | 0.084505 | 2.512597 | 3.195364 | 4.252681 | 0.169021 | 0.440690 | 8.032808 | 5.048672 | 6.048233 | 0.094185 |
std | 73166.055829 | 0.494968 | 0.494178 | 0.190762 | 6.578401 | 0.496749 | 0.196632 | 0.699622 | 0.429294 | 0.481689 | ... | 0.278144 | 1.069184 | 7.426860 | 8.736637 | 0.374771 | 0.496471 | 3.053915 | 0.986419 | 2.072959 | 0.292087 |
min | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 12.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | ... | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | 1.000000 | 0.000000 |
25% | 63655.750000 | 0.000000 | 0.000000 | 1.000000 | 24.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | ... | 0.000000 | 2.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 6.000000 | 4.000000 | 5.000000 | 0.000000 |
50% | 126805.500000 | 0.000000 | 0.000000 | 1.000000 | 27.000000 | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | ... | 0.000000 | 2.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 8.000000 | 5.000000 | 7.000000 | 0.000000 |
75% | 190268.500000 | 1.000000 | 1.000000 | 1.000000 | 31.000000 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | ... | 0.000000 | 3.000000 | 2.000000 | 3.000000 | 0.000000 | 1.000000 | 10.000000 | 6.000000 | 8.000000 | 0.000000 |
max | 253680.000000 | 1.000000 | 1.000000 | 1.000000 | 98.000000 | 1.000000 | 1.000000 | 2.000000 | 1.000000 | 1.000000 | ... | 1.000000 | 5.000000 | 30.000000 | 30.000000 | 1.000000 | 1.000000 | 13.000000 | 6.000000 | 8.000000 | 1.000000 |
8 rows × 23 columns
print(train.shape) print(test.shape)
(177576, 23) (76104, 22)
Check for missing values
train.isnull().sum()
PatientID 0 HighBP 0 HighChol 0 CholCheck 0 BMI 0 Smoker 0 Stroke 0 Diabetes 0 PhysActivity 0 Fruits 0 Veggies 0 HvyAlcoholConsump 0 AnyHealthcare 0 NoDocbcCost 0 GenHlth 0 MentHlth 0 PhysHlth 0 DiffWalk 0 Sex 0 Age 0 Education 0 Income 0 target 0 dtype: int64
test.isnull().sum()
PatientID 0 HighBP 0 HighChol 0 CholCheck 0 BMI 0 Smoker 0 Stroke 0 Diabetes 0 PhysActivity 0 Fruits 0 Veggies 0 HvyAlcoholConsump 0 AnyHealthcare 0 NoDocbcCost 0 GenHlth 0 MentHlth 0 PhysHlth 0 DiffWalk 0 Sex 0 Age 0 Education 0 Income 0 dtype: int64
4.缺失值检查
train.duplicated().any()
False
test.duplicated().any()
False
5.数据EDA
train.hist(figsize=(20,12));
train[['BMI','Age', 'Income']].hist(figsize=(20,12));
# 各类值数量 train.nunique()
PatientID 177576 HighBP 2 HighChol 2 CholCheck 2 BMI 81 Smoker 2 Stroke 2 Diabetes 3 PhysActivity 2 Fruits 2 Veggies 2 HvyAlcoholConsump 2 AnyHealthcare 2 NoDocbcCost 2 GenHlth 5 MentHlth 31 PhysHlth 31 DiffWalk 2 Sex 2 Age 13 Education 6 Income 8 target 2 dtype: int64
plt.figure(figsize=(20,8)) plt.subplot(2,4,1) sns.countplot(train.HighBP) plt.subplot(2,4,2) sns.countplot(train.HighChol) plt.subplot(2,4,3) sns.countplot(train.CholCheck) plt.subplot(2,4,4) sns.countplot(train.Smoker) plt.subplot(2,4,5) sns.countplot(train.Stroke) plt.subplot(2,4,6) sns.countplot(train.Sex) plt.subplot(2,4,7) sns.countplot(train.Education) plt.subplot(2,4,8) sns.countplot(train.target);
# 协相关 train.corr() .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
PatientID | HighBP | HighChol | CholCheck | BMI | Smoker | Stroke | Diabetes | PhysActivity | Fruits | ... | NoDocbcCost | GenHlth | MentHlth | PhysHlth | DiffWalk | Sex | Age | Education | Income | target | |
PatientID | 1.000000 | 0.001398 | -0.004094 | -0.005792 | 0.003198 | -0.012735 | -0.003059 | 0.000102 | -0.013725 | -0.009701 | ... | 0.004976 | 0.011615 | 0.004878 | 0.002281 | 0.006494 | 0.001922 | -0.016703 | -0.028613 | -0.031153 | 0.004587 |
HighBP | 0.001398 | 1.000000 | 0.298993 | 0.100240 | 0.213145 | 0.097167 | 0.128931 | 0.272010 | -0.124287 | -0.039989 | ... | 0.017138 | 0.302095 | 0.056623 | 0.161690 | 0.224106 | 0.052872 | 0.343887 | -0.144162 | -0.172669 | 0.209839 |
HighChol | -0.004094 | 0.298993 | 1.000000 | 0.086977 | 0.106773 | 0.092282 | 0.093440 | 0.210004 | -0.078590 | -0.040844 | ... | 0.011461 | 0.210064 | 0.060897 | 0.122438 | 0.145454 | 0.031610 | 0.272945 | -0.071393 | -0.086426 | 0.181495 |
CholCheck | -0.005792 | 0.100240 | 0.086977 | 1.000000 | 0.035259 | -0.009770 | 0.024557 | 0.067935 | 0.003132 | 0.024665 | ... | -0.057044 | 0.047976 | -0.008637 | 0.032951 | 0.041208 | -0.021306 | 0.091361 | 0.001194 | 0.014525 | 0.044727 |
BMI | 0.003198 | 0.213145 | 0.106773 | 0.035259 | 1.000000 | 0.012823 | 0.020246 | 0.226629 | -0.148226 | -0.086939 | ... | 0.060098 | 0.241322 | 0.086540 | 0.121914 | 0.197492 | 0.042101 | -0.036413 | -0.106241 | -0.102497 | 0.051915 |
Smoker | -0.012735 | 0.097167 | 0.092282 | -0.009770 | 0.012823 | 1.000000 | 0.060902 | 0.064119 | -0.087836 | -0.078450 | ... | 0.046644 | 0.164181 | 0.089985 | 0.116181 | 0.122357 | 0.095626 | 0.121167 | -0.161809 | -0.124294 | 0.114722 |
Stroke | -0.003059 | 0.128931 | 0.093440 | 0.024557 | 0.020246 | 0.060902 | 1.000000 | 0.106447 | -0.068302 | -0.013449 | ... | 0.037732 | 0.176839 | 0.069918 | 0.148983 | 0.173380 | 0.005267 | 0.124307 | -0.074506 | -0.127133 | 0.200142 |
Diabetes | 0.000102 | 0.272010 | 0.210004 | 0.067935 | 0.226629 | 0.064119 | 0.106447 | 1.000000 | -0.122705 | -0.041532 | ... | 0.036016 | 0.305061 | 0.076245 | 0.179073 | 0.226021 | 0.030879 | 0.183597 | -0.132493 | -0.174196 | 0.181464 |
PhysActivity | -0.013725 | -0.124287 | -0.078590 | 0.003132 | -0.148226 | -0.087836 | -0.068302 | -0.122705 | 1.000000 | 0.142944 | ... | -0.060159 | -0.266612 | -0.124298 | -0.220504 | -0.254840 | 0.031316 | -0.091499 | 0.201007 | 0.200626 | -0.085003 |
Fruits | -0.009701 | -0.039989 | -0.040844 | 0.024665 | -0.086939 | -0.078450 | -0.013449 | -0.041532 | 0.142944 | 1.000000 | ... | -0.045057 | -0.104774 | -0.066113 | -0.046548 | -0.049326 | -0.092977 | 0.065622 | 0.110978 | 0.080931 | -0.020491 |
Veggies | -0.001558 | -0.061588 | -0.040526 | 0.005866 | -0.062893 | -0.031075 | -0.042665 | -0.058449 | 0.153179 | 0.255295 | ... | -0.035244 | -0.125923 | -0.059795 | -0.065585 | -0.083168 | -0.065248 | -0.010327 | 0.155088 | 0.155928 | -0.039542 |
HvyAlcoholConsump | -0.005415 | -0.001964 | -0.011583 | -0.026005 | -0.047951 | 0.101023 | -0.015058 | -0.057922 | 0.013856 | -0.034479 | ... | 0.003698 | -0.036098 | 0.024243 | -0.028085 | -0.038485 | 0.006423 | -0.036105 | 0.024861 | 0.054354 | -0.029198 |
AnyHealthcare | 0.000923 | 0.039009 | 0.041390 | 0.120153 | -0.019125 | -0.023787 | 0.007640 | 0.014328 | 0.035964 | 0.032208 | ... | -0.231749 | -0.041575 | -0.053290 | -0.009113 | 0.006787 | -0.019041 | 0.138523 | 0.122380 | 0.158234 | 0.020135 |
NoDocbcCost | 0.004976 | 0.017138 | 0.011461 | -0.057044 | 0.060098 | 0.046644 | 0.037732 | 0.036016 | -0.060159 | -0.045057 | ... | 1.000000 | 0.167472 | 0.189433 | 0.149342 | 0.118942 | -0.046204 | -0.118938 | -0.102038 | -0.202447 | 0.031585 |
GenHlth | 0.011615 | 0.302095 | 0.210064 | 0.047976 | 0.241322 | 0.164181 | 0.176839 | 0.305061 | -0.266612 | -0.104774 | ... | 0.167472 | 1.000000 | 0.301532 | 0.525179 | 0.457259 | -0.005133 | 0.152558 | -0.286244 | -0.370260 | 0.258040 |
MentHlth | 0.004878 | 0.056623 | 0.060897 | -0.008637 | 0.086540 | 0.089985 | 0.069918 | 0.076245 | -0.124298 | -0.066113 | ... | 0.189433 | 0.301532 | 1.000000 | 0.354641 | 0.235289 | -0.080406 | -0.090970 | -0.100319 | -0.208794 | 0.063413 |
PhysHlth | 0.002281 | 0.161690 | 0.122438 | 0.032951 | 0.121914 | 0.116181 | 0.148983 | 0.179073 | -0.220504 | -0.046548 | ... | 0.149342 | 0.525179 | 0.354641 | 1.000000 | 0.479627 | -0.042447 | 0.098759 | -0.157118 | -0.267023 | 0.179600 |
DiffWalk | 0.006494 | 0.224106 | 0.145454 | 0.041208 | 0.197492 | 0.122357 | 0.173380 | 0.226021 | -0.254840 | -0.049326 | ... | 0.118942 | 0.457259 | 0.235289 | 0.479627 | 1.000000 | -0.069397 | 0.204378 | -0.193109 | -0.321384 | 0.210210 |
Sex | 0.001922 | 0.052872 | 0.031610 | -0.021306 | 0.042101 | 0.095626 | 0.005267 | 0.030879 | 0.031316 | -0.092977 | ... | -0.046204 | -0.005133 | -0.080406 | -0.042447 | -0.069397 | 1.000000 | -0.027869 | 0.018549 | 0.125373 | 0.085802 |
Age | -0.016703 | 0.343887 | 0.272945 | 0.091361 | -0.036413 | 0.121167 | 0.124307 | 0.183597 | -0.091499 | 0.065622 | ... | -0.118938 | 0.152558 | -0.090970 | 0.098759 | 0.204378 | -0.027869 | 1.000000 | -0.102786 | -0.128530 | 0.221841 |
Education | -0.028613 | -0.144162 | -0.071393 | 0.001194 | -0.106241 | -0.161809 | -0.074506 | -0.132493 | 0.201007 | 0.110978 | ... | -0.102038 | -0.286244 | -0.100319 | -0.157118 | -0.193109 | 0.018549 | -0.102786 | 1.000000 | 0.448643 | -0.098432 |
Income | -0.031153 | -0.172669 | -0.086426 | 0.014525 | -0.102497 | -0.124294 | -0.127133 | -0.174196 | 0.200626 | 0.080931 | ... | -0.202447 | -0.370260 | -0.208794 | -0.267023 | -0.321384 | 0.125373 | -0.128530 | 0.448643 | 1.000000 | -0.142447 |
target | 0.004587 | 0.209839 | 0.181495 | 0.044727 | 0.051915 | 0.114722 | 0.200142 | 0.181464 | -0.085003 | -0.020491 | ... | 0.031585 | 0.258040 | 0.063413 | 0.179600 | 0.210210 | 0.085802 | 0.221841 | -0.098432 | -0.142447 | 1.000000 |
23 rows × 23 columns
# 热力图 plt.figure(figsize=(20,12)) sns.heatmap(train.corr(), annot=True)
<matplotlib.axes._subplots.AxesSubplot at 0x7f845ce2eed0>
train.columns
Index(['PatientID', 'HighBP', 'HighChol', 'CholCheck', 'BMI', 'Smoker', 'Stroke', 'Diabetes', 'PhysActivity', 'Fruits', 'Veggies', 'HvyAlcoholConsump', 'AnyHealthcare', 'NoDocbcCost', 'GenHlth', 'MentHlth', 'PhysHlth', 'DiffWalk', 'Sex', 'Age', 'Education', 'Income', 'target'], dtype='object')
四、特征选择
1.特征选取
x = train.drop(['PatientID', 'target'], axis=1) y=train['target']
2.Train/test 数据集切分
from sklearn.model_selection import train_test_split x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=.3, random_state=42)
五、 模型训练
1.模型选择
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression(solver='liblinear')
2.模型训练
lr.fit(x_train,y_train)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, l1_ratio=None, max_iter=100, multi_class='auto', n_jobs=None, penalty='l2', random_state=None, solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
# Base from sklearn.metrics import classification_report, confusion_matrix, f1_score print('Training score:', lr.score(x_train, y_train)) print('*' *20) print('Test score:',lr.score(x_test, y_test)) print('*' *20) print('f1_score:', f1_score(y_test, lr.predict(x_test))) print('*' *20) print(confusion_matrix(y_test, lr.predict(x_test))) print('*' *20) print(classification_report(y_test, lr.predict(x_test)))
Training score: 0.9081679444582995 ******************** Test score: 0.9067820471908847 ******************** f1_score: 0.19695989650711515 ******************** [[47698 478] [ 4488 609]] ******************** precision recall f1-score support 0 0.91 0.99 0.95 48176 1 0.56 0.12 0.20 5097 accuracy 0.91 53273 macro avg 0.74 0.55 0.57 53273 weighted avg 0.88 0.91 0.88 53273
六、预测
testt = test.drop('PatientID', axis=1) pred = lr.predict(testt) subm = pd.DataFrame({'PatientID':sub.PatientID, 'target':pred}, index=None) # convert to dataframe subm.to_csv('first_submission.csv', index=False) # write it to a csv file