R语言逻辑回归、决策树、随机森林、神经网络预测患者心脏病数据混淆矩阵可视化(上):https://developer.aliyun.com/article/1498749
具有固定缺陷地中海贫血的人患心脏病的可能性更高
ggplot(heartDiseaseData,aes(target, fill=target)) + ... scale_fill_manual(values=c("#97BE11","#DC1E0B"))
可以观察到仅有少数参数,如胸痛类型、性别、运动诱发心绞痛、血管数量和ST段压低,对结果有显著影响。因此,可以舍弃其他参数。
log <- glm(...
显著特征的总结
d <- heartDiseaseDa...
逻辑回归
log <- glm(...=binomial) summary(log)
log.df <- tidy...
观察表明,如果个体患有2型或3型胸痛,患心脏病的可能性更高。随着血管数量、运动诱发心绞痛、ST段压低和男性性别数值的增加,患心脏病的可能性较低。
log.df %>% mutate(term=reorder(term,estimate)) %>% ... geom_hline(yintercept=0) + coord_flip()
随着ST段压低值的增加,患心脏病的可能性降低。随着血管数量的增加,女性患心脏病的可能性降低,而男性的可能性增加。
逻辑回归
data <- d set.seed(1237) train <- sample(nrow(data), .8*nrow(data), replace = FALSE) ... #调整参数 fitControl <- trainControl(method = "repeatedcv", ... TrainSet$target <- as.factor(TrainSet$target)
gbm.ada.1 <- caret::train(target ~ ., ... metric="ROC") gbm.ada.1
ST段压低是最重要的特征,其次是胸痛类型2等等。
varImp(gbm.ada.1)
pred <- predict(gbm.ada.1,ValidSet) .... res<-caret::confusionMatrix(t...
混淆矩阵
ggplot(data = t.df, aes(x = Var2, y = pred, label=Freq)) + ... ggtitle("Logistic Regression")
随机森林
gbm.ada.1 <- caret::train(target ~ ., ... metric="ROC") gbm.ada.1
变量重要性
varImp(gbm.ada.1)
pred <- predict(gbm.ada.1,ValidSet) ... res<-caret::confusionMatrix(t, positive="Heart Disease") res
混淆矩阵
ggplot(data = t.df, aes(x = Var1, y = pred, label=Freq)) + ... ggtitle("Random Forest")
绘制决策树
gbmGrid <- expand.grid(cp=c(0.01)) fitControl <- trainControl(method = "repeatedcv", ... summaryFunction = twoClassSummary) d$target<-make.names(d$target) system.time(gbm.ada.1 <- caret::train(target ~ ., ... tuneGrid=gbmGrid))
gbm.ada.1
varImp(gbm.ada.1)
rpart.plot(gbm.ada.1$finalModel, ... nn=TRUE)
神经网络
fitControl <- trainControl(method = "repeatedcv", ... summaryFunction = twoClassSummary) gbm.ada.1 <- caret::train(target ~ ., ... metric="ROC")
gbm.ada.1
变量重要性
varImp(gbm.ada.1)
pred <- predict(gbm.ada.1,ValidSet) ... res<-caret::confusionMa...
混淆矩阵
混淆矩阵(Confusion Matrix)是用于评估分类模型性能的一种表格。它以四个不同的指标来总结模型对样本的分类结果:真阳性(True Positive, TP)、真阴性(True Negative, TN)、假阳性(False Positive, FP)和假阴性(False Negative, FN)。
ggplot(data = t.df, aes(x = Var1, y = pred, label=Freq)) + ... ggtitle("Neural Network")