泰酷辣!探索七种常用的机器学习图型

简介: 泰酷辣!探索七种常用的机器学习图型

一、引言

随着机器学习的快速发展,对于数据的可视化和解释变得越来越重要。机器学习图是一种以图形方式呈现数据、模型和结果的有效工具。它们不仅可以帮助我们更好地理解数据,而且可以支持决策制定和模型调优。

机器学习图通过直观的可视化呈现,帮助我们从数据中提取有价值的信息。它们可以揭示数据之间的关联、趋势和模式,帮助我们识别特征重要性、异常值和模型性能。此外,机器学习图还可以加深对机器学习算法和模型的理解,并促进与他人的交流和共享。

本文旨在探索九种常见的机器学习图,分别是:神经网络结构图、热力图(Heatmap)、散点图(Scatter Plot)、决策树(Decision Tree)、SHAP图、ROC曲线和特征重要性图(Feature Importance Plot)。这些图形在机器学习中具有广泛的应用,涵盖了数据分布、特征分析、模型评估和结果展示等方面。

二、神经网络结构图

「神经网络结构图」是一种直观的图形表示,用于展示神经网络模型的层次结构和连接方式。它由输入层、隐藏层和输出层组成,每个层都包含多个节点(或称为神经元),这些节点之间通过连接权重进行信息传递。

# 安装所需的包
install.packages("neuralnet")
# 导入所需的包
library(neuralnet)
library(ggplot2)
# 创建数据集
data <- iris
# 将类别变量编码为数值
data$Species <- as.numeric(factor(data$Species))
# 数据标准化
data <- scale(data)
# 训练神经网络模型
model <- neuralnet(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data = data, hidden = 3)
# 绘制神经网络结构图
plot(model, rep="best")

神经网络结构图的解读方法如下:

  • 层次结构:神经网络结构图按照从左到右的顺序显示各层。最左边的层是输入层,用来接收原始数据。中间的层是隐藏层,用于对数据进行转换和特征提取。最右边的层是输出层,用于产生模型的预测结果。
  • 节点和连接:每个节点代表一个特征或变量,它接收来自上一层的输入,并通过激活函数对输入进行处理后输出到下一层。节点之间的连接代表变量之间的加权连接,每个连接都有一个权重,表示该连接的重要性。
  • 节点数量:每个层中的节点数量反映了该层的复杂性和表达能力。通常情况下,隐藏层中的节点数量越多,神经网络的学习能力就越强,但也容易导致过拟合问题。
  • 连接权重:连接权重代表了每个连接的重要性和影响力。较大的权重表示该连接对输出结果的影响更大,而较小的权重则表示影响较小。
  • 激活函数:激活函数决定了节点输出的范围和非线性转换。常见的激活函数包括Sigmoid函数、ReLU函数等。

总结:神经网络结构图是一种展示神经网络模型层次结构和连接方式的图形表示。通过解读节点、连接和层次结构,可以获得对模型的理解和洞察。这种图形化的表达方式有助于直观地展示模型的复杂性、特征重要性和信息流动。

三、热力图(Heatmap)

在机器学习中,可以使用热力图来可视化特征相关性矩阵或混淆矩阵等。以下是一些示例代码和简要介绍:

  • 「可视化特征相关性矩阵」
# 创建数据集
rel <- iris[, 1:4]
# 计算特征相关性矩阵
cor_matrix <- cor(rel)
# 绘制热力图
heatmap(cor_matrix, col = colorRampPalette(c("blue", "white", "red"))(100))

热力图中的颜色强度和梯度变化来判断特征之间的相关程度。较暗或较亮的颜色表示较强的相关性,而中间色调则表示较弱的相关性。

  • 「可视化混淆矩阵」
# 将数据集分为训练集和测试集
set.seed(123)
train_index <- sample(1:nrow(iris), size = round(0.7*nrow(iris)), replace = FALSE)
train_data <- iris[train_index, ]
test_data <- iris[-train_index, ]
# 构建随机森林模型
library(randomForest)
model_rf <- randomForest(Species ~ ., data = train_data, ntree = 500)
# 输出模型结果
print(model_rf)
# 使用测试集评估模型
predicted <- predict(model_rf, newdata = test_data)
actual <- test_data$Species
# 绘制混淆矩阵
library(ggplot2)
library(ggExtra)
confusion_matrix <- table(predicted, actual)
heatmap(confusion_matrix, colors = "Blues")

热力图中的颜色强度反映了相应类别的数量或比例。较高的颜色强度表示模型对该类别的预测效果较好,较低的颜色强度表示预测效果较差。通过解读混淆矩阵的热力图,我们可以了解模型在不同类别上的预测情况,识别出可能存在的误分类或模型表现不佳的问题。

四、散点图(Scatter Plot)

# 导入必要的库
library(ggplot2)
# 示例数据集 - Iris数据集
data(iris)
# 创建散点图
ggplot(iris, aes(x = Sepal.Length, y = Sepal.Width, color = Species)) +
  geom_point() +
  labs(title = "Scatter Plot of Iris Dataset",
       x = "Sepal Length",
       y = "Sepal Width")

用于显示数据集中的样本分布,特别是在二维或三维空间中展示不同类别的样本。

五、决策树(Decision Tree)

# 导入必要的库
library(rpart)
library(rpart.plot)
# 示例数据集 - Iris数据集
data(iris)
# 拟合决策树模型
model <- rpart(Species ~ ., data = iris)
# 绘制决策树
rpart.plot(model)

绘制的决策树图形将显示节点分裂的条件(例如,使用哪个特征进行分割),以及每个叶节点的类别标签和样本数量。通过观察决策树,可以获得以下信息:

  1. 特征的重要性:决策树在顶部初始分裂的特征通常是最具有区分度的特征。观察分裂节点所使用的特征可以了解它们对模型的贡献程度。
  2. 分类规则:从决策树的根节点到叶节点的路径表示了一系列的分类规则。每个节点上的条件分裂代表一个判断条件,通过遵循不同路径,可以按照决策树的分类规则对新样本进行分类。

对于决策树的解读方法,可以考虑以下几个方面:

  1. 根据特征重要性,确定哪些特征对于模型分类起到关键作用。
  2. 根据分类规则,解读决策树每个节点的条件分裂,理解模型如何基于特征进行分类决策。
  3. 通过观察叶节点的类别标签和样本数量,了解模型对不同类别的分类结果和置信度。

需要注意的是,决策树可能会过拟合训练数据,因此在实际应用中应该进行模型评估和调优,例如通过剪枝来减少过拟合。

六、SHAP图

install.packages("xgboost")
install.packages("SHAPforxgboost")
library(xgboost)
library(SHAPforxgboost)
data("iris")
X1 = as.matrix(iris[,-5])
mod1 = xgboost::xgboost(
  data = X1, label = iris$Species, gamma = 0, eta = 1,
  lambda = 0, nrounds = 1, verbose = FALSE)
shap_values <- shap.values(xgb_model = mod1, X_train = X1)
shap_values$mean_shap_score
shap_values_iris <- shap_values$shap_score
# shap.prep() returns the long-format SHAP data from either model or
shap_long_iris <- shap.prep(xgb_model = mod1, X_train = X1)
# is the same as: using given shap_contrib
shap_long_iris <- shap.prep(shap_contrib = shap_values_iris, X_train = X1)

七、ROC曲线

# 导入库
library(randomForest)
library(pROC)
library(survival)
set.seed(1234)
trainIndex <- sample(1:nrow(gbsg), 0.8 * nrow(gbsg))
train <- gbsg[trainIndex,]
test <- gbsg[-trainIndex,]
train$status <- as.factor(train$status)
# 构建随机森林模型
rf_mod <- randomForest(status ~ ., data = train)
# 获取模型预测的概率
pred_prob <- predict(rf_mod, newdata = test, type = "prob")
# 计算真阳性率和假阳性率
roc <- pROC::roc(test$status, pred_prob[, 1])
# 绘制ROC曲线
plot(roc, main = "ROC Curve", print.auc = TRUE, auc.polygon = TRUE, grid = TRUE, legacy.axes = TRUE,col="blue")

八、特征重要性图

# 导入库
library(randomForest)
library(pROC)
library(survival)
set.seed(1234)
trainIndex <- sample(1:nrow(gbsg), 0.8 * nrow(gbsg))
train <- gbsg[trainIndex,]
test <- gbsg[-trainIndex,]
train$status <- as.factor(train$status)
# 构建随机森林模型
rf_mod <- randomForest(status ~ ., data = train)
varImpPlot(rf_mod, main = "variable importance")

目录
相关文章
|
机器学习/深度学习 人工智能 算法
机器学习算法竞赛实战--1,初见竞赛
在时代的洪流之下,各行各业都在寻求生存之道利用先进的技术完成转型则是一个很好的办法,有些企业就开始寻求人工智能的助力开始向社会征求优秀的算法解决方案,此外,在学术领域的研究者们也渴望获得企业的场景和数据用于算法研究这就催生出了各种竞赛平台。对于有志于进军机器学习相关领域从事研究或者相关工作的初学者来说竞赛是性价比极高的一个实战选择,可以说是0门槛,任何人都可以参加。
128 0
机器学习算法竞赛实战--1,初见竞赛
|
机器学习/深度学习 算法 索引
机器学习实战运用:速刷牛客5道机器学习题目
机器学习实战运用:速刷牛客5道机器学习题目
186 0
机器学习实战运用:速刷牛客5道机器学习题目
|
机器学习/深度学习
漫画:什么是机器学习?
机器学习按照方式不同主要分为三大类,有监督学习(Supervised learning)、无监督学习(Unsupervised learning)以及半监督学习(Semi-supervised learning)。
261 0
漫画:什么是机器学习?
|
机器学习/深度学习 人工智能 编解码
天池读书会又来啦,五月场,数据分析、机器学习、深度学习、神经网络通吃!
天池读书会又来啦,五月场分享主题多样,包含了数据分析、机器学习、深度学习、神经网络等方面,相信总有你想看的。
652 0
天池读书会又来啦,五月场,数据分析、机器学习、深度学习、神经网络通吃!
|
机器学习/深度学习 人工智能 算法
机器学习:从入门到晋级
什么是机器学习,为什么学习机器学习,如何学习机器学习,这篇文章都告诉给你。
1931 0
|
机器学习/深度学习 算法 数据挖掘
白话机器学习
机器学习是什么 一段程序可以看作一连串从输入到输出的过程,无论是工程师还是程序员,我们都想通过设计来完成某种功能。以做一个网页为例,要画视觉图、UI 图,以及前端后端交互图等,我们要给计算机设计一套解决具体问题的流程。
1629 0
|
机器学习/深度学习 人工智能 算法
|
机器学习/深度学习 人工智能 算法
AutoML破解深度学习寒冬论,夏粉教小白5分钟搞定机器学习建模
昨天,国内AutoML领域创业公司智铀科技发布了自动化机器学习产品“小智”,据公开数据显示,这是国内首款可私有部署的AutoML商用产品。新智元创始人杨静女士作为特邀嘉宾,在智铀科技产品发布会上对AI软硬件发展现状和趋势以及AutoML应用做了主题演讲。
1717 0
|
机器学习/深度学习 算法 数据挖掘
在一头扎进机器学习前应该知道的那些事儿
本文简单总结了机器学习的几大任务及其对应的方法,方便初学者根据自己的任务选择合适的方法。当掌握机器学习基本知识以及清楚自己所要处理的任务后,应用机器学习就不会那么难了。
6210 0