泰酷辣!探索七种常用的机器学习图型

简介: 泰酷辣!探索七种常用的机器学习图型

一、引言

随着机器学习的快速发展,对于数据的可视化和解释变得越来越重要。机器学习图是一种以图形方式呈现数据、模型和结果的有效工具。它们不仅可以帮助我们更好地理解数据,而且可以支持决策制定和模型调优。

机器学习图通过直观的可视化呈现,帮助我们从数据中提取有价值的信息。它们可以揭示数据之间的关联、趋势和模式,帮助我们识别特征重要性、异常值和模型性能。此外,机器学习图还可以加深对机器学习算法和模型的理解,并促进与他人的交流和共享。

本文旨在探索九种常见的机器学习图,分别是:神经网络结构图、热力图(Heatmap)、散点图(Scatter Plot)、决策树(Decision Tree)、SHAP图、ROC曲线和特征重要性图(Feature Importance Plot)。这些图形在机器学习中具有广泛的应用,涵盖了数据分布、特征分析、模型评估和结果展示等方面。

二、神经网络结构图

「神经网络结构图」是一种直观的图形表示,用于展示神经网络模型的层次结构和连接方式。它由输入层、隐藏层和输出层组成,每个层都包含多个节点(或称为神经元),这些节点之间通过连接权重进行信息传递。

# 安装所需的包
install.packages("neuralnet")
# 导入所需的包
library(neuralnet)
library(ggplot2)
# 创建数据集
data <- iris
# 将类别变量编码为数值
data$Species <- as.numeric(factor(data$Species))
# 数据标准化
data <- scale(data)
# 训练神经网络模型
model <- neuralnet(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, data = data, hidden = 3)
# 绘制神经网络结构图
plot(model, rep="best")

神经网络结构图的解读方法如下:

  • 层次结构:神经网络结构图按照从左到右的顺序显示各层。最左边的层是输入层,用来接收原始数据。中间的层是隐藏层,用于对数据进行转换和特征提取。最右边的层是输出层,用于产生模型的预测结果。
  • 节点和连接:每个节点代表一个特征或变量,它接收来自上一层的输入,并通过激活函数对输入进行处理后输出到下一层。节点之间的连接代表变量之间的加权连接,每个连接都有一个权重,表示该连接的重要性。
  • 节点数量:每个层中的节点数量反映了该层的复杂性和表达能力。通常情况下,隐藏层中的节点数量越多,神经网络的学习能力就越强,但也容易导致过拟合问题。
  • 连接权重:连接权重代表了每个连接的重要性和影响力。较大的权重表示该连接对输出结果的影响更大,而较小的权重则表示影响较小。
  • 激活函数:激活函数决定了节点输出的范围和非线性转换。常见的激活函数包括Sigmoid函数、ReLU函数等。

总结:神经网络结构图是一种展示神经网络模型层次结构和连接方式的图形表示。通过解读节点、连接和层次结构,可以获得对模型的理解和洞察。这种图形化的表达方式有助于直观地展示模型的复杂性、特征重要性和信息流动。

三、热力图(Heatmap)

在机器学习中,可以使用热力图来可视化特征相关性矩阵或混淆矩阵等。以下是一些示例代码和简要介绍:

  • 「可视化特征相关性矩阵」
# 创建数据集
rel <- iris[, 1:4]
# 计算特征相关性矩阵
cor_matrix <- cor(rel)
# 绘制热力图
heatmap(cor_matrix, col = colorRampPalette(c("blue", "white", "red"))(100))

热力图中的颜色强度和梯度变化来判断特征之间的相关程度。较暗或较亮的颜色表示较强的相关性,而中间色调则表示较弱的相关性。

  • 「可视化混淆矩阵」
# 将数据集分为训练集和测试集
set.seed(123)
train_index <- sample(1:nrow(iris), size = round(0.7*nrow(iris)), replace = FALSE)
train_data <- iris[train_index, ]
test_data <- iris[-train_index, ]
# 构建随机森林模型
library(randomForest)
model_rf <- randomForest(Species ~ ., data = train_data, ntree = 500)
# 输出模型结果
print(model_rf)
# 使用测试集评估模型
predicted <- predict(model_rf, newdata = test_data)
actual <- test_data$Species
# 绘制混淆矩阵
library(ggplot2)
library(ggExtra)
confusion_matrix <- table(predicted, actual)
heatmap(confusion_matrix, colors = "Blues")

热力图中的颜色强度反映了相应类别的数量或比例。较高的颜色强度表示模型对该类别的预测效果较好,较低的颜色强度表示预测效果较差。通过解读混淆矩阵的热力图,我们可以了解模型在不同类别上的预测情况,识别出可能存在的误分类或模型表现不佳的问题。

四、散点图(Scatter Plot)

# 导入必要的库
library(ggplot2)
# 示例数据集 - Iris数据集
data(iris)
# 创建散点图
ggplot(iris, aes(x = Sepal.Length, y = Sepal.Width, color = Species)) +
  geom_point() +
  labs(title = "Scatter Plot of Iris Dataset",
       x = "Sepal Length",
       y = "Sepal Width")

用于显示数据集中的样本分布,特别是在二维或三维空间中展示不同类别的样本。

五、决策树(Decision Tree)

# 导入必要的库
library(rpart)
library(rpart.plot)
# 示例数据集 - Iris数据集
data(iris)
# 拟合决策树模型
model <- rpart(Species ~ ., data = iris)
# 绘制决策树
rpart.plot(model)

绘制的决策树图形将显示节点分裂的条件(例如,使用哪个特征进行分割),以及每个叶节点的类别标签和样本数量。通过观察决策树,可以获得以下信息:

  1. 特征的重要性:决策树在顶部初始分裂的特征通常是最具有区分度的特征。观察分裂节点所使用的特征可以了解它们对模型的贡献程度。
  2. 分类规则:从决策树的根节点到叶节点的路径表示了一系列的分类规则。每个节点上的条件分裂代表一个判断条件,通过遵循不同路径,可以按照决策树的分类规则对新样本进行分类。

对于决策树的解读方法,可以考虑以下几个方面:

  1. 根据特征重要性,确定哪些特征对于模型分类起到关键作用。
  2. 根据分类规则,解读决策树每个节点的条件分裂,理解模型如何基于特征进行分类决策。
  3. 通过观察叶节点的类别标签和样本数量,了解模型对不同类别的分类结果和置信度。

需要注意的是,决策树可能会过拟合训练数据,因此在实际应用中应该进行模型评估和调优,例如通过剪枝来减少过拟合。

六、SHAP图

install.packages("xgboost")
install.packages("SHAPforxgboost")
library(xgboost)
library(SHAPforxgboost)
data("iris")
X1 = as.matrix(iris[,-5])
mod1 = xgboost::xgboost(
  data = X1, label = iris$Species, gamma = 0, eta = 1,
  lambda = 0, nrounds = 1, verbose = FALSE)
shap_values <- shap.values(xgb_model = mod1, X_train = X1)
shap_values$mean_shap_score
shap_values_iris <- shap_values$shap_score
# shap.prep() returns the long-format SHAP data from either model or
shap_long_iris <- shap.prep(xgb_model = mod1, X_train = X1)
# is the same as: using given shap_contrib
shap_long_iris <- shap.prep(shap_contrib = shap_values_iris, X_train = X1)

七、ROC曲线

# 导入库
library(randomForest)
library(pROC)
library(survival)
set.seed(1234)
trainIndex <- sample(1:nrow(gbsg), 0.8 * nrow(gbsg))
train <- gbsg[trainIndex,]
test <- gbsg[-trainIndex,]
train$status <- as.factor(train$status)
# 构建随机森林模型
rf_mod <- randomForest(status ~ ., data = train)
# 获取模型预测的概率
pred_prob <- predict(rf_mod, newdata = test, type = "prob")
# 计算真阳性率和假阳性率
roc <- pROC::roc(test$status, pred_prob[, 1])
# 绘制ROC曲线
plot(roc, main = "ROC Curve", print.auc = TRUE, auc.polygon = TRUE, grid = TRUE, legacy.axes = TRUE,col="blue")

八、特征重要性图

# 导入库
library(randomForest)
library(pROC)
library(survival)
set.seed(1234)
trainIndex <- sample(1:nrow(gbsg), 0.8 * nrow(gbsg))
train <- gbsg[trainIndex,]
test <- gbsg[-trainIndex,]
train$status <- as.factor(train$status)
# 构建随机森林模型
rf_mod <- randomForest(status ~ ., data = train)
varImpPlot(rf_mod, main = "variable importance")

目录
相关文章
|
机器学习/深度学习 数据可视化 算法
机器学习-可解释性机器学习:随机森林与fastshap的可视化模型解析
机器学习-可解释性机器学习:随机森林与fastshap的可视化模型解析
1399 1
|
机器学习/深度学习 人工智能 项目管理
【机器学习】集成学习——Stacking模型融合(理论+图解)
【机器学习】集成学习——Stacking模型融合(理论+图解)
5909 1
【机器学习】集成学习——Stacking模型融合(理论+图解)
|
机器学习/深度学习 算法 数据可视化
JAMA | 机器学习中的可解释性:SHAP分析图像复刻与解读
JAMA | 机器学习中的可解释性:SHAP分析图像复刻与解读
2886 1
|
机器学习/深度学习 数据采集 监控
机器学习-特征选择:如何使用递归特征消除算法自动筛选出最优特征?
机器学习-特征选择:如何使用递归特征消除算法自动筛选出最优特征?
1836 0
|
11月前
|
机器学习/深度学习 监控 算法
机器学习在图像识别中的应用:解锁视觉世界的钥匙
机器学习在图像识别中的应用:解锁视觉世界的钥匙
1423 95
|
机器学习/深度学习 人工智能 自然语言处理
简述人工智能,及其三大学派:符号主义、连接主义、行为主义
简述人工智能,及其三大学派:符号主义、连接主义、行为主义
7441 0
简述人工智能,及其三大学派:符号主义、连接主义、行为主义
|
12月前
|
机器学习/深度学习 数据采集 算法
机器学习在医疗诊断中的前沿应用,包括神经网络、决策树和支持向量机等方法,及其在医学影像、疾病预测和基因数据分析中的具体应用
医疗诊断是医学的核心,其准确性和效率至关重要。本文探讨了机器学习在医疗诊断中的前沿应用,包括神经网络、决策树和支持向量机等方法,及其在医学影像、疾病预测和基因数据分析中的具体应用。文章还讨论了Python在构建机器学习模型中的作用,面临的挑战及应对策略,并展望了未来的发展趋势。
761 1
|
10月前
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
3192 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
机器学习/深度学习 并行计算 算法
机器学习算法原理:详细介绍各种机器学习算法的原理、优缺点和适用场景
机器学习算法原理:详细介绍各种机器学习算法的原理、优缺点和适用场景
4381 0
|
机器学习/深度学习 人工智能 文字识别
【学习打卡03】可解释机器学习笔记之CAM类激活热力图
【学习打卡03】可解释机器学习笔记之CAM类激活热力图