随机森林算法demo python spark

简介:

关键参数

最重要的,常常需要调试以提高算法效果的有两个参数:numTrees,maxDepth。

  • numTrees(决策树的个数):增加决策树的个数会降低预测结果的方差,这样在测试时会有更高的accuracy。训练时间大致与numTrees呈线性增长关系。
  • maxDepth:是指森林中每一棵决策树最大可能depth,在决策树中提到了这个参数。更深的一棵树意味模型预测更有力,但同时训练时间更长,也更倾向于过拟合。但是值得注意的是,随机森林算法和单一决策树算法对这个参数的要求是不一样的。随机森林由于是多个的决策树预测结果的投票或平均而降低而预测结果的方差,因此相对于单一决策树而言,不容易出现过拟合的情况。所以随机森林可以选择比决策树模型中更大的maxDepth。 
    甚至有的文献说,随机森林的每棵决策树都最大可能地进行生长而不进行剪枝。但是不管怎样,还是建议对maxDepth参数进行一定的实验,看看是否可以提高预测的效果。 
    另外还有两个参数,subsamplingRate,featureSubsetStrategy一般不需要调试,但是这两个参数也可以重新设置以加快训练,但是值得注意的是可能会影响模型的预测效果(如果需要调试的仔细读下面英文吧)。

We include a few guidelines for using random forests by discussing the various parameters. We omit some decision tree parameters since those are covered in the decision tree guide. 
The first two parameters we mention are the most important, and tuning them can often improve performance: 
(1)numTrees: Number of trees in the forest. 
Increasing the number of trees will decrease the variance in predictions, improving the model’s test-time accuracy. 
Training time increases roughly linearly in the number of trees. 
(2)maxDepth: Maximum depth of each tree in the forest. 
Increasing the depth makes the model more expressive and powerful. However, deep trees take longer to train and are also more prone to overfitting. 
In general, it is acceptable to train deeper trees when using random forests than when using a single decision tree. One tree is more likely to overfit than a random forest (because of the variance reduction from averaging multiple trees in the forest). 
The next two parameters generally do not require tuning. However, they can be tuned to speed up training. 
(3)subsamplingRate: This parameter specifies the size of the dataset used for training each tree in the forest, as a fraction of the size of the original dataset. The default (1.0) is recommended, but decreasing this fraction can speed up training. 
(4)featureSubsetStrategy: Number of features to use as candidates for splitting at each tree node. The number is specified as a fraction or function of the total number of features. Decreasing this number will speed up training, but can sometimes impact performance if too low. 
We include a few guidelines for using random forests by discussing the various parameters. We omit some decision tree parameters since those are covered in the decision tree guide.

复制代码
"""
Random Forest Classification Example.
"""
from __future__ import print_function

from pyspark import SparkContext
# $example on$
from pyspark.mllib.tree import RandomForest, RandomForestModel
from pyspark.mllib.util import MLUtils
# $example off$

if __name__ == "__main__":
    sc = SparkContext(appName="PythonRandomForestClassificationExample")
    # $example on$
    # Load and parse the data file into an RDD of LabeledPoint.
    data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
    # Split the data into training and test sets (30% held out for testing)
    (trainingData, testData) = data.randomSplit([0.7, 0.3])

    # Train a RandomForest model.
    #  Empty categoricalFeaturesInfo indicates all features are continuous.
    #  Note: Use larger numTrees in practice.
    #  Setting featureSubsetStrategy="auto" lets the algorithm choose.
    model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                         numTrees=3, featureSubsetStrategy="auto",
                                         impurity='gini', maxDepth=4, maxBins=32)

    # Evaluate model on test instances and compute test error
    predictions = model.predict(testData.map(lambda x: x.features))
    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
    testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
    print('Test Error = ' + str(testErr))
    print('Learned classification forest model:')
    print(model.toDebugString())

    # Save and load model
    model.save(sc, "target/tmp/myRandomForestClassificationModel")
    sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel")
    # $example off$
复制代码

 模型样子:

复制代码
TreeEnsembleModel classifier with 3 trees

  Tree 0:
    If (feature 511 <= 0.0)
     If (feature 434 <= 0.0)
      Predict: 0.0
     Else (feature 434 > 0.0)
      Predict: 1.0
    Else (feature 511 > 0.0)
     Predict: 0.0
  Tree 1:
    If (feature 490 <= 31.0)
     Predict: 0.0
    Else (feature 490 > 31.0)
     Predict: 1.0
  Tree 2:
    If (feature 302 <= 0.0)
     If (feature 461 <= 0.0)
      If (feature 208 <= 107.0)
       Predict: 1.0
      Else (feature 208 > 107.0)
       Predict: 0.0
     Else (feature 461 > 0.0)
      Predict: 1.0
    Else (feature 302 > 0.0)
     Predict: 0.0
复制代码

 
















本文转自张昺华-sky博客园博客,原文链接:http://www.cnblogs.com/bonelee/p/7204096.html,如需转载请自行联系原作者

相关文章
|
9天前
|
机器学习/深度学习 存储 算法
解锁文件共享软件背后基于 Python 的二叉搜索树算法密码
文件共享软件在数字化时代扮演着连接全球用户、促进知识与数据交流的重要角色。二叉搜索树作为一种高效的数据结构,通过有序存储和快速检索文件,极大提升了文件共享平台的性能。它依据文件名或时间戳等关键属性排序,支持高效插入、删除和查找操作,显著优化用户体验。本文还展示了用Python实现的简单二叉搜索树代码,帮助理解其工作原理,并展望了该算法在分布式计算和机器学习领域的未来应用前景。
|
25天前
|
监控 算法 安全
深度洞察内网监控电脑:基于Python的流量分析算法
在当今数字化环境中,内网监控电脑作为“守城卫士”,通过流量分析算法确保内网安全、稳定运行。基于Python的流量分析算法,利用`scapy`等工具捕获和解析数据包,提取关键信息,区分正常与异常流量。结合机器学习和可视化技术,进一步提升内网监控的精准性和效率,助力企业防范潜在威胁,保障业务顺畅。本文深入探讨了Python在内网监控中的应用,展示了其实战代码及未来发展方向。
|
1月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
眼疾识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了4种常见的眼疾图像数据集(白内障、糖尿病性视网膜病变、青光眼和正常眼睛) 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,实现用户上传一张眼疾图片识别其名称。
130 5
基于Python深度学习的眼疾识别系统实现~人工智能+卷积网络算法
|
2月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
347 55
|
2月前
|
存储 缓存 监控
局域网屏幕监控系统中的Python数据结构与算法实现
局域网屏幕监控系统用于实时捕获和监控局域网内多台设备的屏幕内容。本文介绍了一种基于Python双端队列(Deque)实现的滑动窗口数据缓存机制,以处理连续的屏幕帧数据流。通过固定长度的窗口,高效增删数据,确保低延迟显示和存储。该算法适用于数据压缩、异常检测等场景,保证系统在高负载下稳定运行。 本文转载自:https://www.vipshare.com
132 66
|
5天前
|
监控 算法 安全
内网桌面监控软件深度解析:基于 Python 实现的 K-Means 算法研究
内网桌面监控软件通过实时监测员工操作,保障企业信息安全并提升效率。本文深入探讨K-Means聚类算法在该软件中的应用,解析其原理与实现。K-Means通过迭代更新簇中心,将数据划分为K个簇类,适用于行为分析、异常检测、资源优化及安全威胁识别等场景。文中提供了Python代码示例,展示如何实现K-Means算法,并模拟内网监控数据进行聚类分析。
28 10
|
23天前
|
存储 算法 安全
控制局域网上网软件之 Python 字典树算法解析
控制局域网上网软件在现代网络管理中至关重要,用于控制设备的上网行为和访问权限。本文聚焦于字典树(Trie Tree)算法的应用,详细阐述其原理、优势及实现。通过字典树,软件能高效进行关键词匹配和过滤,提升系统性能。文中还提供了Python代码示例,展示了字典树在网址过滤和关键词屏蔽中的具体应用,为局域网的安全和管理提供有力支持。
50 17
|
1月前
|
存储 监控 算法
员工电脑监控屏幕场景下 Python 哈希表算法的探索
在数字化办公时代,员工电脑监控屏幕是保障信息安全和提升效率的重要手段。本文探讨哈希表算法在该场景中的应用,通过Python代码例程展示如何使用哈希表存储和查询员工操作记录,并结合数据库实现数据持久化,助力企业打造高效、安全的办公环境。哈希表在快速检索员工信息、优化系统性能方面发挥关键作用,为企业管理提供有力支持。
45 20
|
27天前
|
存储 人工智能 算法
深度解密:员工飞单需要什么证据之Python算法洞察
员工飞单是企业运营中的隐性风险,严重侵蚀公司利润。为应对这一问题,精准搜集证据至关重要。本文探讨如何利用Python编程语言及其数据结构和算法,高效取证。通过创建Transaction类存储交易数据,使用列表管理订单信息,结合排序算法和正则表达式分析交易时间和聊天记录,帮助企业识别潜在的飞单行为。Python的强大功能使得从交易流水和沟通记录中提取关键证据变得更加系统化和高效,为企业维权提供有力支持。
|
3月前
|
搜索推荐 Python
利用Python内置函数实现的冒泡排序算法
在上述代码中,`bubble_sort` 函数接受一个列表 `arr` 作为输入。通过两层循环,外层循环控制排序的轮数,内层循环用于比较相邻的元素并进行交换。如果前一个元素大于后一个元素,就将它们交换位置。
155 67

热门文章

最新文章