GBDT原理

article/2025/8/21 9:01:43

梯度提升树的使用

GBDT算法流程

GBDT流程

输入:训练数据集 D = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , … , ( x N , y N ) } D=\left\{\left(x_{1}, y_{1}\right),\left(x_{2}, y_{2}\right), \ldots,\left(x_{N}, y_{N}\right)\right\} D={(x1,y1),(x2,y2),,(xN,yN)}

1.初始化 f 0 ( x ) = 0 f_{0}(x) = 0 f0(x)=0

2.For m = 1 , 2 , … , M m=1,2, \ldots, M m=1,2,,M

3.针对每一个样本 ( x i , y i ) \left(x_{i}, y_{i}\right) (xi,yi),计算残差
r m , i = y i − f m − 1 ( x i ) , i = 1 , 2 , … , N r_{m, i}=y_{i}-f_{m-1}\left(x_{i}\right), i=1,2, \ldots, N rm,i=yifm1(xi),i=1,2,,N4.利用 { ( x i , r m , i ) } i = 1 , 2 , … , N \left\{\left(x_{i}, r_{m, i}\right)\right\}_{i=1,2, \ldots, N} {(xi,rm,i)}i=1,2,,N训练一个决策树(回归树),>得到 T ( x ; Θ m ) T\left(x ; \Theta_{m}\right) T(x;Θm)

5.更新 f m ( x ) = f m − 1 ( x ) + T ( x ; Θ m ) f_{m}(x)=f_{m-1}(x)+T\left(x ; \Theta_{m}\right) fm(x)=fm1(x)+T(x;Θm)

6.完成以上迭代,得到提升树 f M ( x ) = ∑ m = 1 M T ( x ; Θ m ) f_{M}(x)=\sum_{m=1}^{M} T\left(x ; \Theta_{m}\right) fM(x)=m=1MT(x;Θm)


负梯度和残差

GBDT全称:Gradient Boosting Decision Tree,即梯度提升决策树,理解为梯度提升 + 决策树。Friedman提出了利用最速下降的近似方法,利用损失函数的负梯度拟合集学习器:
− [ ∂ L ( y i , F ( x i ) ) ∂ F ( x i ) ] F ( x ) = F t − 1 ( x ) -\left[\frac{\partial L\left(y_{i}, F\left(\mathbf{x}_{\mathbf{i}}\right)\right)}{\partial F\left(\mathbf{x}_{\mathbf{i}}\right)}\right]_{F(\mathbf{x})=F_{t-1}(\mathbf{x})} [F(xi)L(yi,F(xi))]F(x)=Ft1(x)怎么理解这个近似,我们通过平方损失函数来给大家进行介绍

为了求导方便,在损失函数前面乘以1/2
L ( y i , F ( x i ) ) = 1 2 ( y i − F ( x i ) ) 2 L\left(y_{i}, F\left(\mathbf{x}_{\mathbf{i}}\right)\right)=\frac{1}{2}\left(y_{i}-F\left(\mathbf{x}_{\mathbf{i}}\right)\right)^{2} L(yi,F(xi))=21(yiF(xi))2 F ( X i ) F(X_{i}) F(Xi)求导,则有:
∂ L ( y i , F ( x i ) ) ∂ F ( x i ) = F ( x i ) − y i \frac{\partial L\left(y_{i}, F\left(\mathbf{x}_{\mathbf{i}}\right)\right)}{\partial F\left(\mathbf{x}_{\mathbf{i}}\right)}={F}\left(\mathbf{x}_{\mathbf{i}}\right)-y_{i} F(xi)L(yi,F(xi))=F(xi)yi残差是梯度的相反数,即:
r t i = y i − F t − 1 ( x ) = − [ ∂ L ( y i , F ( x i ) ) ∂ F ( x i ) ] F ( x ) = F t − 1 ( x ) r_{t i}=y_{i}-F_{t-1}(\mathbf{x})=-\left[\frac{\partial L\left(y_{i}, F\left(\mathbf{x}_{\mathbf{i}}\right)\right)}{\partial F\left(\mathbf{x}_{\mathbf{i}}\right)}\right]_{F(\mathbf{x})=F_{t-1}(\mathbf{x})} rti=yiFt1(x)=[F(xi)L(yi,F(xi))]F(x)=Ft1(x)在GBDT中使用负梯度作为残差进行拟合。


GBDT流程(回归)

GBDT是使用梯度提升的决策树(CART),CART树回归将空间划分为K个不相交的区域,并确定每个区域的输出 c k c_{k} ck,数学表达如下:
f ( X ) = ∑ k = 1 K c k I ( X ∈ R k ) f(\mathbf{X})=\sum_{k=1}^{K} c_{k} I\left(\mathbf{X} \in R_{k}\right) f(X)=k=1KckI(XRk)
在这里插入图片描述
输入:训练数据集 D = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , … , ( x N , y N ) } D=\left\{\left(x_{1}, y_{1}\right),\left(x_{2}, y_{2}\right), \ldots,\left(x_{N}, y_{N}\right)\right\} D={(x1,y1),(x2,y2),,(xN,yN)}
1.初始化: F 0 ( x ) = arg ⁡ min ⁡ h 0 ∑ i = 1 N L ( y i , h δ ( x ) ) = arg ⁡ min ⁡ c ∑ i = 1 N L ( y i , c ) ) \left.F_{0}(\mathbf{x})=\arg \min _{h_{0}} \sum_{i=1}^{N} L\left(y_{i}, h_{\delta}(\mathbf{x})\right)=\arg \min _{c} \sum_{i=1}^{N} L\left(y_{i}, c\right)\right) F0(x)=argminh0i=1NL(yi,hδ(x))=argminci=1NL(yi,c))

2.for t = 1 to T do

2.1 计算负梯度
y ~ i = − [ ∂ L ( y i , F ( x i ) ) ∂ F ( x i ) ] F ( x ) = F t − 1 ( x ) , i = 1 , 2 , ⋯ , N \tilde{y}_{i}=-\left[\frac{\partial L\left(y_{i}, F\left(\mathbf{x}_{\mathrm{i}}\right)\right)}{\partial F\left(\mathbf{x}_{\mathrm{i}}\right)}\right]_{F(\mathbf{x})=F_{t-1}(\mathbf{x})}, i=1,2, \cdots, N y~i=[F(xi)L(yi,F(xi))]F(x)=Ft1(x),i=1,2,,N2.2 拟合残差得到回归树,得到第t棵树的叶节点区域:
h t ( x ) = ∑ k = 1 K c k I ( X ∈ R t k ) h_{t}(\mathbf{x})=\sum_{k=1}^{K} c_{k} I\left(\mathbf{X} \in R_{t k}\right) ht(x)=k=1KckI(XRtk)2.3更新
F t ( x ) = F t − 1 ( x i ) + h t ( x ) = F t − 1 ( x i ) + ∑ k = 1 K c k I ( X ∈ R t k ) F_{t}(\mathbf{x})=F_{t-1}\left(\mathbf{x}_{\mathbf{i}}\right)+h_{t}(\mathbf{x})=F_{t-1}\left(\mathbf{x}_{\mathbf{i}}\right)+\sum_{k=1}^{K} c_{k} I\left(\mathbf{X} \in R_{t k}\right) Ft(x)=Ft1(xi)+ht(x)=Ft1(xi)+k=1KckI(XRtk)

3.得到加法模型: F ( x ) = ∑ t = 1 T h t ( x ) \boldsymbol{F}(\mathbf{x})=\sum_{t=1}^{T} h_{t}(\mathbf{x}) F(x)=t=1Tht(x)


GBDT流程(分类)

GBDT用于分类仍然使用CART回归树,使用softmax进行概率的映射,然后对概率的残差进行拟合
在这里插入图片描述
1.针对每个类别都先训练一个回归书,如三个类别,训练三棵树。就是比如对于样本 x i x_{i} xi为第二类,则输入三棵树分别为:( x i x_{i} xi,0),( x i x_{i} xi,1),( x i x_{i} xi,0)这其实是典型的OvR的多分类训练方式。而每棵树的训练过程就是CART的训练过程。这样,对于样本 x i x_{i} xi就得出了三棵树的预测值 F 1 ( x i ) F_{1}(x_{i}) F1(xi) F 2 ( x i ) F_{2}(x_{i}) F2(xi) F 3 ( x i ) F_{3}(x_{i}) F3(xi),模仿多分类的逻辑回归,用softmax来产生概率,以类别1为例: p 1 ( x i ) = exp ⁡ ( F 1 ( x i ) ) / ∑ l = 1 3 exp ⁡ ( F l ( x i ) ) p_{1}\left(\mathbf{x}_{\mathbf{i}}\right)=\exp \left(F_{1}\left(\mathbf{x}_{\mathbf{i}}\right)\right) / \sum_{l=1}^{3} \exp \left(F_{l}\left(\mathbf{x}_{\mathbf{i}}\right)\right) p1(xi)=exp(F1(xi))/l=13exp(Fl(xi))

2.对每个类别分别计算残差,如类别1: y ~ i 1 = 0 − p 1 ( x i ) \tilde{y}_{i 1}=0-p_{1}\left(\mathbf{x}_{\mathbf{i}}\right) y~i1=0p1(xi),类别2: y ~ i 2 = 1 − p 2 ( x i ) \tilde{y}_{i 2}=1-p_{2}\left(\mathbf{x}_{\mathbf{i}}\right) y~i2=1p2(xi),类别3: y ~ i 3 = 0 − p 3 ( x i ) \tilde{y}_{i 3}=0-p_{3}\left(\mathbf{x}_{\mathbf{i}}\right) y~i3=0p3(xi)

3.开始第二轮的训练,针对第一类输入为 ( x i , y ~ i 1 ) \left(\mathbf{x}_{\mathbf{i}}, \tilde{y}_{i 1}\right) (xi,y~i1),针对第二类输入为 ( x i , y ~ i 2 ) \left(\mathbf{x}_{\mathbf{i}}, \tilde{y}_{i 2}\right) (xi,y~i2),针对第三类输入为 ( x i , y ~ i 3 ) \left(\mathbf{x}_{\mathbf{i}}, \tilde{y}_{i 3}\right) (xi,y~i3),继续训练出三棵树。

4.重复3直到迭代M轮,就得到了最后的模型。预测的时候只要找出概率最高的即为对应的类别

GBDT原理案例举例

import numpy as np
import matplotlib.pyplot as plt
#回归时分类的极限思想
#分类的类别多到一定程度,那么就是回归
from sklearn.ensemble import GradientBoostingClassifier,GradientBoostingRegressor
from sklearn import tree# X数据:购物金额和上网时间
# y目标:14(高一),16(高三),24(大学毕业),26(工作两年)
X = np.array([[800,3],[1200,1],[1800,4],[2500,2]])
y = np.array([14,16,24,26]) gbdt = GradientBoostingRegressor(n_estimators=10)
gbdt.fit(X,y)
gbdt.predict(X)
#array([16.09207064, 17.39471376, 22.60528624, 23.90792936])

第一颗决策树,根据平均值,计算了残差[-6,-4,4,6]

plt.rcParams["font.sans-serif"] = ["Heiti TC"]
plt.figure(figsize=(9,6))
_ = tree.plot_tree(gbdt[0,0],filled=True,feature_names=["消费","上网"])

在这里插入图片描述

#计算friedman_mse
((y-y.mean())**2).mean()
#26.0
((y[:2]-y[:2].mean())**2).mean()
#1.0

value(-6,-4,6,4)是14,16,26,24和20的差,即残差
残差越小——>越好——>越准确

第二颗决策树,根据梯度提升,减少残差(残差越小,结果越好,越准确)

plt.rcParams["font.sans-serif"] = ["Heiti TC"]
plt.figure(figsize=(9,6))
_ = tree.plot_tree(gbdt[1,0],filled=True,feature_names=["消费","上网"])

在这里插入图片描述

gbdt1 = np.array([-6,-4,6,4])
#梯度提升
gbdt2 = gbdt1 - gbdt1*0.1   #learning_rate = 0.1
#array([-5.4, -3.6,  5.4,  3.6])

第三颗决策树

plt.rcParams["font.sans-serif"] = ["Heiti TC"]
plt.figure(figsize=(9,6))
_ = tree.plot_tree(gbdt[0,0],filled=True,feature_names=["消费","上网"])

在这里插入图片描述

gbdt1 = np.array([-5.4,-3.6,5.4,3.6])
#梯度提升
gbdt1  - gbdt1 *0.1 #learning_rate = 0.1
#array([-4.86, -3.24,  4.86,  3.24])

最后一棵树

plt.rcParams["font.sans-serif"] = ["Heiti TC"]
plt.figure(figsize=(9,6))
_ = tree.plot_tree(gbdt[-1,0],filled=True,feature_names=["消费","上网"])

在这里插入图片描述

#learning_rate = 0.1
gbdt = np.array([-2.325,-1.55,1.55,2.325])
#梯度提升 学习率0.1
residual = gbdt - gbdt*0.1
residual
#array([-2.0925, -1.395 ,  1.395 ,  2.0925])y - residual
#array([16.0925, 17.395 , 22.605 , 23.9075])
gbdt.predict(X)
#array([16.09207064, 17.39471376, 22.60528624, 23.90792936])

根据最后一棵树的残差,计算了算法最终的预测值
直接使用算法predict返回的值和手算一模一样


http://chatgpt.dhexx.cn/article/EUTBNR71.shtml

相关文章

GBDT总结

一:声明 本文基本转自刘建平先生的该篇文章,原文写的很好,读者可以去看看。本文中,作者将根据自己实际项目和所学结合该文章,阐述自己的观点和看法。 二:GBDT概述 GBDT也是集成学习Boosting家族的成员&a…

GBDT模型

GBDT(Gradient Boosting Decision Tree,梯度提升树)属于一种有监督的集成学习算法,与前面几章介绍的监督算法类似,同样可用于分类问题的识别和预测问题的解决。该集成算法体现了三方面的优势,分别是提升Boo…

GBDT模型详解

GBDT算法 GBDT(Gradient Boosting Decision Tree),全名叫梯度提升决策树,是一种迭代的决策树算法,又叫MART(Multiple Additive Regression Tree),它通过构造一组弱的学习器&#xf…

机器学习之集成学习:GBDT

目录 一、什么是GBDT 二、GBDT的理解 2.1、GBDT通俗解释 2.2、GBDT详解 三、GBDT的应用 3.1、二分类问题 3.2、多分类问题 3.3、回归问题 四、GBDT如何选择特征 五、GBDT优缺点 一、什么是GBDT GBDT,Gardient Boosting Decision Tree,梯度…

传统机器学习笔记7——GBDT模型详解

目录 前言一.GBDT算法1.1.Boosting1.2.GDBT1.2.1.GBDT与负梯度近似残差1.2.2.GDBT训练过程 二.梯度提升与梯度下降三.GDBT模型优缺点四.GDBT vs 随机森林 前言 上篇博文我们介绍了关于回归树模型的基本知识点,有不懂的小伙伴可以回到前面再看下,传统机器…

机器学习集成学习——GBDT(Gradient Boosting Decision Tree 梯度提升决策树)算法

系列文章目录 机器学习神经网络——Adaboost分离器算法 机器学习之SVM分类器介绍——核函数、SVM分类器的使用 机器学习的一些常见算法介绍【线性回归,岭回归,套索回归,弹性网络】 文章目录 系列文章目录 前言 一、GBDT(Gradient Boos…

移动端UI框架大比拼

1、vonic vonic是一个基于 vue.js 和 ionic 样式的 UI 框架,用于快速构建移动端单页应用 onic 依赖以下几个库,在创建 vonic 项目之前,请确保引入它们。 vue.js vue-router.js axios.js (vue.js 官方推荐的 ajax 方案) 中文文档 在线预览 …

手机/移动端的UI框架-Vant和NutUI

下面推荐2款手机/移动端的UI框架。 其实还有很多的框架,各个大厂都有UI框架。目前,找来找去,只有腾讯的移动端是setup语法写的TDesign,其他大厂,虽然都是VUE3写的,但是都还未改成setup的语法,而…

一:移动端UI框架mint-ui

官网按需引入的.babelrc写法是老的写法,配置的plugins一直报错是因为"component"后面不要[]直接跟{} 一:Mint-UI中按钮组件的使用 简介:Mint UI是基于 Vue.js 的移动端组件库.mint-ui官网链接 1.安装 // 安装 # Vue 1.x npm install mint-ui1 -S # Vue…

vue3使用的移动端UI框架,vue3.0 ui组件库

vue 3.0 项目中 element-ui form 表单元素中 仅 el-button 显示,其他不显示,如何解决? 谷歌人工智能写作项目:小发猫 在页面中引用了laydate插件,在显示的时候,字体图标一直显示不出来 typescript&#x…

值得推荐的Vue 移动端UI框架

1. Vant(支持Vue3) 是有赞前端团队基于有赞统一的规范实现的 Vue 组件库,提供了一整套 UI 基础组件和业务组件,这是我目前用过最好用的框架。 中文文档 | github地址 | 在线预览 2. Mint UI Mint UI 由饿了么前端团队推出的 M…

移动端UI框架总结

1. React Native 网站地址:React Native 中文网 使用React来编写原生应用的框架 GitHub:https://github.com/facebook/react-native 网站描述:Facebookt推出基于 React 的创建跨平台移动应用开发框架 React Native使你能够在Javascript和React的基础上获得完全一致的开发体验…

与运算()、或运算(|)、异或运算(^)的本质 及 用途,文末附加 位运算面试题

目录 一:与运算符(&)and 1、运算规则: 2、例如:3&5 3、用途: 1)判断 奇偶性 2)清零。 3)取一个数中指定位 二:或运算(|&#xff…

Python与或运算

今天碰到一道有意思的题目,看了之后发现自己对Python与或的理解还是欠缺,如下。 题目:求12…n 来源:Leetcode 如果不加限制,我们有很多方法计算该值,例如高斯公式,递归等。 我们思考下递归的解…

sql查询数据表某列的重复值并计数

查询sql为: SELECTdevice_id,count( device_id ) AS number FROMcms_sticker_member GROUP BYdevice_id HAVINGcount( device_id ) > 1 ORDER BYnumber DESC; 结果:

查询多个字段同时重复2次以上的记录的sql的次数

表数据如上图, 1.筛选 type、pid 重复的数据的次数大于等于2的 次数和对应的数据值 SELECT COUNT(*),TYPE,pid FROM AREA GROUP BY TYPE,pid HAVING COUNT(*)>2; 2.筛选 type、pid 重复的数据的次数大于等于2,并且对应的 pid和type值相反的重复的数…

sql查询、删除重复相同数据的语句或只保留一条数据

1、查询(字段1, 字段2, 字段3)全部重复相同的数据 SELECT * FROM 表 WHERE (字段1, 字段2, 字段3) IN (SELECT 字段1, 字段2, 字段3 FROM 表 GROUP BY 字段1, 字段2, 字段3 HAVING COUNT(*) > 1) ORDER BY 排序字段2、过滤(字段1, 字段…

分享SQL重复记录查询的几种方法

SQL重复记录查询的几种方法,需要的朋友可以参考一下 1、查找表中多余的重复记录,重复记录是根据单个字段(peopleId)来判断 代码如下: select * from people where peopleId in (select peopleId from people group by peo…

SQL查询重复记录

1、查找表中多余的重复记录,重复记录是根据单个字段(peopleId)来判断 select * from people where peopleId in (select peopleId from people group by peopleId having count(peopleId) > 1) 2、删除表中多余的重复记录,重…

为什么int无法转换为Double????

规律:拆、装箱和升、降级两者可以在同一条语句中进行,但是一定要先拆箱或装箱再升级或者降级。。。 一条语句中,int无法转换为Double,因为这里涉及到先升级再装箱子,拆装箱一定要在升降级前面。。。。。 一条语句中&…