LightGBM 可视化调参

简介: LightGBM 可视化调参

大家好,在100天搞定机器学习|Day63 彻底掌握 LightGBM一文中,我介绍了LightGBM 的模型原理和一个极简实例。最近我发现Huggingface与Streamlit好像更配,所以就开发了一个简易的 LightGBM 可视化调参的小工具,旨在让大家可以更深入地理解 LightGBM


网址:

https://huggingface.co/spaces/beihai/LightGBM-parameter-tuning


640.gif


只随便放了几个参数,调整这些参数可以实时看到模型评估指标的变化。代码我也放到文章中了,大家有好的优化思路可以留言。


下面就详细介绍一下实现过程:


LightGBM 的参数


在完成模型构建之后,必须对模型的效果进行评估,根据评估结果来继续调整模型的参数、特征或者算法,以达到满意的结果。


LightGBM,有核心参数,学习控制参数,IO参数,目标参数,度量参数,网络参数,GPU参数,模型参数,这里我常修改的便是核心参数,学习控制参数,度量参数等。


Control Parameters 含义 用法
max_depth 树的最大深度 当模型过拟合时,可以考虑首先降低 max_depth
min_data_in_leaf 叶子可能具有的最小记录数 默认20,过拟合时用
feature_fraction 例如 为0.8时,意味着在每次迭代中随机选择80%的参数来建树 boosting 为 random forest 时用
bagging_fraction 每次迭代时用的数据比例 用于加快训练速度和减小过拟合
early_stopping_round 如果一次验证数据的一个度量在最近的early_stopping_round 回合中没有提高,模型将停止训练 加速分析,减少过多迭代
lambda 指定正则化 0~1
min_gain_to_split 描述分裂的最小 gain 控制树的有用的分裂
max_cat_group 在 group 边界上找到分割点 当类别数量很多时,找分割点很容易过拟合时

CoreParameters 含义 用法
Task 数据的用途 选择 train 或者 predict
application 模型的用途 选择 regression: 回归时,binary: 二分类时,multiclass: 多分类时
boosting 要用的算法 gbdt, rf: random forest, dart: Dropouts meet Multiple Additive Regression Trees, goss: Gradient-based One-Side Sampling
num_boost_round 迭代次数 通常 100+
learning_rate 如果一次验证数据的一个度量在最近的 early_stopping_round 回合中没有提高,模型将停止训练 常用 0.1, 0.001, 0.003…
num_leaves
默认 31
device
cpu 或者 gpu
metric
mae: mean absolute error , mse: mean squared error , binary_logloss: loss for binary classification , multi_logloss: loss for multi classification

Faster Speed better accuracy over-fitting
将 max_bin 设置小一些 用较大的 max_bin max_bin 小一些
num_leaves 大一些 num_leaves 小一些
用 feature_fraction 来做 sub-sampling
用 feature_fraction
用 bagging_fraction 和 bagging_freq
设定 bagging_fraction 和 bagging_freq
training data 多一些 training data 多一些
用 save_binary 来加速数据加载 直接用 categorical feature 用 gmin_data_in_leaf 和 min_sum_hessian_in_leaf
用 parallel learning 用 dart 用 lambda_l1, lambda_l2 ,min_gain_to_split 做正则化
num_iterations 大一些,learning_rate 小一些 用 max_depth 控制树的深度


模型评估指标


以分类模型为例,常见的模型评估指标有一下几种:


混淆矩阵


混淆矩阵是能够比较全面的反映模型的性能,从混淆矩阵能够衍生出很多的指标来。


ROC曲线


ROC曲线,全称The Receiver Operating Characteristic Curve,译为受试者操作特性曲线。这是一条以不同阈值 下的假正率FPR为横坐标,不同阈值下的召回率Recall为纵坐标的曲线。让我们衡量模型在尽量捕捉少数类的时候,误伤多数类的情况如何变化的。


AUC


AUC(Area Under the ROC Curve)指标是在二分类问题中,模型评估阶段常被用作最重要的评估指标来衡量模型的稳定性。ROC曲线下的面积称为AUC面积,AUC面积越大说明ROC曲线越靠近左上角,模型越优;


Streamlit 实现


Streamlit我就不再多做介绍了,老读者应该都特别熟悉了。就再列一下之前开发的几个小东西:



核心代码如下,完整代码我放到Github,欢迎大家给个Star


https://github.com/tjxj/visual-parameter-tuning-with-streamlit

from definitions import *
st.set_option('deprecation.showPyplotGlobalUse', False)
st.sidebar.subheader("请选择模型参数:sunglasses:")
# 加载数据
breast_cancer = load_breast_cancer()
data = breast_cancer.data
target = breast_cancer.target
# 划分训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)
# 转换为Dataset数据格式
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
# 模型训练
params = {'num_leaves': num_leaves, 'max_depth': max_depth,
            'min_data_in_leaf': min_data_in_leaf, 
            'feature_fraction': feature_fraction,
            'min_data_per_group': min_data_per_group, 
            'max_cat_threshold': max_cat_threshold,
            'learning_rate':learning_rate,'num_leaves':num_leaves,
            'max_bin':max_bin,'num_iterations':num_iterations
            }
gbm = lgb.train(params, lgb_train, num_boost_round=2000, valid_sets=lgb_eval, early_stopping_rounds=500)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)  
probs = gbm.predict(X_test, num_iteration=gbm.best_iteration)  # 输出的是概率结果  
fpr, tpr, thresholds = roc_curve(y_test, probs)
st.write('------------------------------------')
st.write('Confusion Matrix:')
st.write(confusion_matrix(y_test, np.where(probs > 0.5, 1, 0)))
st.write('------------------------------------')
st.write('Classification Report:')
report = classification_report(y_test, np.where(probs > 0.5, 1, 0), output_dict=True)
report_matrix = pd.DataFrame(report).transpose()
st.dataframe(report_matrix)
st.write('------------------------------------')
st.write('ROC:')
plot_roc(fpr, tpr)


上传Huggingface


Huggingface 前一篇文章(腾讯的这个算法,我搬到了网上,随便玩!)我已经介绍过了,这里就顺便再讲一下步骤吧。


step1:注册Huggingface账号

step2:创建Space,SDK记得选择Streamlit


640.png


step3:克隆新建的space代码,然后将改好的代码push上去


git lfs install 
git add .
git commit -m "commit from $beihai"
git push


push的时候会让输入用户名(就是你的注册邮箱)和密码,解决git总输入用户名和密码的问题:git config --global credential.helper store


push完成就大功告成了,回到你的space页对应项目,就可以看到效果了。

640.gif


https://huggingface.co/spaces/beihai/LightGBM-parameter-tuning


相关文章
|
人工智能 Linux 测试技术
NexaAI, 一行命令运行魔搭社区模型,首次在设备上运行 Qwen2-Audio
Qwen2-Audio是一个 70亿参数量 SOTA 多模态模型,可处理音频和文本输入。
1130 8
|
存储 设计模式 前端开发
Streamlit应用中构建多页面(三):两种方案(上)
Streamlit应用中构建多页面(三):两种方案
4493 0
|
9月前
|
人工智能 并行计算 算法
企业内训|智能驾驶与智能座舱技术——某汽车厂商
本课程系统讲解智能汽车两大核心领域技术架构与实现路径。课程涵盖智能驾驶感知层(激光雷达/毫米波雷达/视觉融合)、决策规划(A*/RRT算法与端到端模型)及高精地图定位(SLAM与无图方案),解析智能座舱系统演化(IVI/AR-HUD多屏交互)及硬件软件架构(高通芯片选型/QNX/鸿蒙车机)。
273 15
|
11月前
|
数据可视化 固态存储 图形学
解锁3D创作新姿势!Autodesk 3ds Max 2022中文版安装教程(附官方下载渠道)
Autodesk 3ds Max 2022 是一款专业三维建模、动画和渲染软件,广泛应用于影视、游戏、建筑等领域。其特点包括智能建模工具、高效Arnold渲染引擎、跨平台协作及多语言支持。安装需满足Win10/11系统、i5以上处理器、8GB内存等要求。正版安装流程包括下载官方程序、配置组件、激活许可证并验证功能。常见问题如安装失败、中文乱码等提供了解决方案。扩展学习资源推荐Forest Pack、V-Ray等插件,助力用户深入掌握软件功能。
2694 24
|
12月前
|
监控 Java 计算机视觉
Python图像处理中的内存泄漏问题:原因、检测与解决方案
在Python图像处理中,内存泄漏是常见问题,尤其在处理大图像时。本文探讨了内存泄漏的原因(如大图像数据、循环引用、外部库使用等),并介绍了检测工具(如memory_profiler、objgraph、tracemalloc)和解决方法(如显式释放资源、避免循环引用、选择良好内存管理的库)。通过具体代码示例,帮助开发者有效应对内存泄漏挑战。
609 1
|
人工智能 vr&ar
TRELLIS:微软联合清华和中科大推出的高质量 3D 生成模型,支持局部控制和多种输出格式
TRELLIS 是由微软、清华大学和中国科学技术大学联合推出的高质量 3D 生成模型,能够根据文本或图像提示生成多样化的 3D 资产,支持多种输出格式和灵活编辑。
1054 3
TRELLIS:微软联合清华和中科大推出的高质量 3D 生成模型,支持局部控制和多种输出格式
|
前端开发 NoSQL JavaScript
Websocket 替代方案:如何使用 Firestore 监听实时事件
Websocket 替代方案:如何使用 Firestore 监听实时事件
|
机器学习/深度学习 人工智能 开发者
谷歌推世界首个AI游戏引擎,2000亿游戏产业恐颠覆!0代码生成游戏,老黄预言成真
【9月更文挑战第22天】谷歌近日推出的AI游戏引擎GameNGen,作为全球首款神经模型驱动的游戏引擎,引发了广泛关注。该引擎使用户无需编写代码即可生成游戏,并实现了与复杂环境的实时交互,显著提升了模拟质量。在单TPU上,GameNGen能以超20帧/秒的速度流畅模拟经典游戏《DOOM》。这项技术不仅简化了游戏开发流程,降低了成本,还为游戏设计带来了更多可能性。然而,它也可能改变游戏产业的商业模式和创意多样性。无论如何,GameNGen标志着游戏开发领域的一次重大革新。
405 2
|
存储 传感器 物联网
MQTT 客户端和代理连接如何工作?
MQTT 客户端和代理连接如何工作?
650 2
MQTT 客户端和代理连接如何工作?
|
机器学习/深度学习 存储 人工智能
手推公式:LSTM单元梯度的详细的数学推导
手推公式:LSTM单元梯度的详细的数学推导
473 0
手推公式:LSTM单元梯度的详细的数学推导