机器学习实战之AdaBoost元算法

简介: 今天学习的机器学习算法不是一个单独的算法,我们称之为元算法或集成算法(Ensemble)。其实就是对其他算法进行组合的一种方式。俗话说的好:“三个臭皮匠,赛过诸葛亮”。

今天学习的机器学习算法不是一个单独的算法,我们称之为元算法或集成算法(Ensemble)。其实就是对其他算法进行组合的一种方式。俗话说的好:“三个臭皮匠,赛过诸葛亮”。集成算法有多种形式:对同一数据集,使用多个算法,通过投票或者平均等方法获得最后的预测模型;同一算法在不同设置下的集成;同一算法在多个不同实例下的集成。本文着重讲解最后一种集成算法。

bagging

如果训练集有n个样本,我们随机抽取S次,每次有放回的获取m个样本,用某个单独的算法对S个数据集(每个数据集有m个样本)进行训练,这样就可以获得S个分类器。最后通过投票箱来获取最后的结果(少数服从多数的原则)。这就是bagging方法的核心思想,如图所示。

img_115fe5b2b408f311a9d19527c62e695c.png

bagging中有个常用的方法,叫随机森林(random forest),该算法基于决策树,不仅对数据随机化,也对特征随机化。

  • 数据的随机化:应用bootstrap方法有放回地随机抽取k个新的自助样本集。
  • 特征随机化:n个特征,每棵树随机选择m个特征划分数据集。

每棵树无限生长,最后依旧通过投票箱来获取最后的结果。

boosting

boosting方法在模型选择方面和bagging一样:选择单个机器学习算法。但boosting方法是先在原数据集中训练一个分类器,然后将前一个分类器没能完美分类的数据重新赋权重(weight),用新的权重数据再训练出一个分类器,以此循环,最终的分类结果由加权投票决定。
所以:boosting是串行算法(必须依赖上一个分类器),而bagging是并行算法(可以同时进行);boosting的分类器权重不同,bagging相同(下文中详细讲解)。

boosting也有很多版本,本文只讲解AdaBoost(自适应boosting)方法的原理和代码实践。
如图所示,为AdaBoost方法的原理示意图。

  • 首先,训练样本赋权重,构成向量D(初始值相等,如100个数据,那每个数据权重为1/100)。
  • 在该数据上训练一个弱分类器并计算错误率和该分类器的权重值(alpha)。
  • 基于该alpha值重新计算权重(分错的样本权重变大,分对的权重变小)。
  • 循环2,3步,但完成给定的迭代次数或者错误阈值时,停止循环。
  • 最终的分类结果由加权投票决定。
img_1e790d5763f4c6d053a911540b9b53ec.jpe

alpha和D的计算见下图(来源于机器学习实战):

img_9954d4a542e57e4012997fe34ac3a300.png

AdaBoost方法实践

数据来源

数据通过代码创建:

from numpy import *

def loadSimpData():
    dataArr = array([[1., 2.1], [2., 1.1], [1.3, 1.], [1., 1.], [2., 1.]])
    labelArr = [1.0, 1.0, -1.0, -1.0, 1.0]
    return dataArr, labelArr
弱决策树

该数据有两个特征,我们只用一个特征进行分类(弱分类器),然后选择精度最高的分类器。

def stumpClassify(dataMatrix, dimen, threshVal, threshIneq):
    retArray = ones((shape(dataMatrix)[0],1))
    if threshIneq == 'lt':
        retArray[dataMatrix[:,dimen] <= threshVal] = -1.0
    else:
        retArray[dataMatrix[:,dimen] > threshVal] = -1.0
    return retArray

def buildStump(dataArr, labelArr, D):
    dataMat = mat(dataArr)
    labelMat = mat(labelArr).T
    m, n = shape(dataMat)
    numSteps = 10.0
    bestStump = {}
    bestClasEst = mat(zeros((m, 1)))
    minError = inf
    for i in range(n):
        rangeMin = dataMat[:, i].min()
        rangeMax = dataMat[:, i].max()
        stepSize = (rangeMax-rangeMin)/numSteps
        for j in range(-1, int(numSteps)+1):
            for inequal in ['lt', 'gt']:
                threshVal = (rangeMin + float(j) * stepSize)
                predictedVals = stumpClassify(dataMat, i, threshVal, inequal)
                # print predictedVals
                errArr = mat(ones((m, 1)))
                errArr[predictedVals == labelMat] = 0
                weightedError = D.T*errArr
#                 print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError))
                if weightedError < minError:
                    minError = weightedError
                    bestClasEst = predictedVals.copy()
                    bestStump['dim'] = i
                    bestStump['thresh'] = threshVal
                    bestStump['ineq'] = inequal
    return bestStump, minError, bestClasEst
AdaBoost算法

该函数用于构造多棵树,并保存每棵树的信息。

def adaBoostTrainDS(dataArr,classLabels, numIt=40):
    weakClassArr = []
    m = shape(dataArr)[0]
    D = mat(ones((m,1))/m)
    aggClassEst = mat(zeros((m,1)))
    for i in range(numIt):
        bestStump,error,classEst = buildStump(dataArr, classLabels, D)
        print('D:',D.T)
        alpha = float(0.5*log((1.0-error)/max(error,1e-16)))
        bestStump['alpha'] = alpha
        weakClassArr.append(bestStump)
        print('classEst:',classEst.T)
        expon = multiply(-1*alpha*mat(classLabels).T,classEst)
        D = multiply(D, exp(expon))
        D = D/D.sum()
        aggClassEst += alpha*classEst
        print('aggClassEst:',aggClassEst.T)
        aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T, ones((m,1)))
        errorRate = aggErrors.sum()/m
        print('total error:',errorRate,'\n')
        if errorRate == 0:break
    return weakClassArr

算法优缺点

  • 优点:精度高
  • 缺点:容易过拟合
相关文章
|
前端开发 Java 测试技术
java通用分页(后端)
1.通用分页是什么? Java通用分页是指在Java编程语言中实现的一种通用分页功能。它通常用于在Java Web应用中展示大量数据或查询结果,并将其分页显示给用户。
388 1
|
7月前
|
人工智能 Python
2025自学编程实操指南第一课面向AI编程
2025自学编程实操指南第一课面向AI编程,第一个实践案例:贪吃蛇游戏
|
测试技术 微服务 负载均衡
微服务部署:蓝绿部署、滚动部署、灰度发布、金丝雀发布
在项目迭代的过程中,不可避免需要”上线“。上线对应着部署,或者重新部署;部署对应着修改;修改则意味着风险。 目前有很多用于部署的技术,有的简单,有的复杂;有的得停机,有的不需要停机即可完成部署。
3019 0
|
12月前
|
Python
Python 三方库下载安装
Python 三方库下载安装
156 1
|
Kubernetes 安全 开发工具
Kubernetes系统安全-准入控制(admission control)
文章详细介绍了Kubernetes中的准入控制机制,包括各种准入控制器的功能、如何创建和使用LimitRange和ResourceQuota资源,以及PodSecurityPolicy和准入控制器扩展的使用方法。
145 1
Kubernetes系统安全-准入控制(admission control)
|
存储
【 uniapp - 黑马优购 | 搜索框 】如何实现自定义搜索组件、搜索建议、搜索历史
【 uniapp - 黑马优购 | 搜索框 】如何实现自定义搜索组件、搜索建议、搜索历史
1069 0
|
前端开发 UED
深入理解CSS中的文本对齐方式:水平对齐与垂直对齐
深入理解CSS中的文本对齐方式:水平对齐与垂直对齐
717 5
|
网络协议 安全 Linux
在IntelliJ IDEA中使用固定公网地址远程SSH连接服务器环境进行开发
在IntelliJ IDEA中使用固定公网地址远程SSH连接服务器环境进行开发
370 2
|
JavaScript 前端开发 开发者
从零到一:教你如何发布自己的npm插件包
从零到一:教你如何发布自己的npm插件包
|
负载均衡 安全 网络安全