【机器学习】决策树——CART分类回归树(理论+图解+公式)

article/2025/9/24 16:46:47

🌠 『精品学习专栏导航帖』

  • 🐳最适合入门的100个深度学习实战项目🐳
  • 🐙【PyTorch深度学习项目实战100例目录】项目详解 + 数据集 + 完整源码🐙
  • 🐶【机器学习入门项目10例目录】项目详解 + 数据集 + 完整源码🐶
  • 🦜【机器学习项目实战10例目录】项目详解 + 数据集 + 完整源码🦜
  • 🐌Java经典编程100例🐌
  • 🦋Python经典编程100例🦋
  • 🦄蓝桥杯历届真题题目+解析+代码+答案🦄
  • 🐯【2023王道数据结构目录】课后算法设计题C、C++代码实现完整版大全🐯

文章目录

  • 一、概述
  • 二、CART决策树
    • 1.分类树
      • 1.1 基尼系数
      • 1.1 特征离散
      • 1.2 特征连续
    • 2.回归树
  • 三、剪枝算法


2021人工智能领域新星创作者,带你从入门到精通,该博客每天更新,逐渐完善机器学习各个知识体系的文章,帮助大家更高效学习。


一、概述

针对于ID3和C4.5只能处理分类的问题,后来有人提出了CART,该模型是由Breima等人在1984年提出的,它是被应用广泛的决策树学习方法,它可以用于分类与回归问题,同样CART也是由特征选择、树的生成以及剪枝组成。

所以针对于该算法可以分为几种情况:

数据:离散型特征、连续型特征

标签:离散值、连续值

针对于不同的场景处理方式也大不相同,一般情况下选择特征划分节点时,如果标签为离散的,我们可以使用基尼系数作为划分标准,在ID3和C4.5中是使用信息增益方式进行评估,在CART中是使用基尼系数,如果标签是连续性的,显然不能够使用基尼系数,因为此时无法计算每个节点不同类别的概率,应使用均方误差来进行评估,原来是使用每个节点的熵值期望与原来的熵做差,如果标签连续使用均方误差,每个节点的均方误差与分割前节点的均方误差做对比。

二、CART决策树

1.分类树

其实CART分类树和ID3和C4.5的树生成算法差不多,只不过是在特征选择是采用了基尼系数

1.1 基尼系数

基尼系数公式的定义如下:
G i n i ( p ) = ∑ i = k K p k ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini(p)=\sum_{i=k}^Kp_k(1-p_k)=1-\sum_{k=1}^Kp_k^2 Gini(p)=i=kKpk(1pk)=1k=1Kpk2

G i n i ( D ) = 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ D ∣ ) 2 Gini(D)=1-\sum_{k=1}^K(\frac{|C_k|}{|D|})^2 Gini(D)=1k=1K(DCk)2

  • K:样本的类别个数
  • p k p_k pk :每个类别的概率
  • C k C_k Ck :每个类别的样本数
  • D:样本总数

所以我们需要计算根据一个特征分割后的基尼系数与分割前的基尼系数做差:

假设A特征有两个值,所以可以分成两个节点,那么分割后的基尼系数为:
G i n i ( D , A ) = p 1 G i n i ( D 1 ) + p 2 G i n i ( D 2 ) = ∣ D 1 ∣ ∣ D ∣ G i n i ( D 1 ) + ∣ D 2 ∣ ∣ D ∣ G i n i ( D 2 ) Gini(D,A)=p_1Gini(D_1)+p_2Gini(D_2)\\=\frac{|D_1|}{|D|}Gini(D_1)+\frac{|D_2|}{|D|}Gini(D_2) Gini(D,A)=p1Gini(D1)+p2Gini(D2)=DD1Gini(D1)+DD2Gini(D2)
我们也需要获得增益:
G i n i ( D , A ) − G i n i ( D ) Gini(D,A)-Gini(D) Gini(D,A)Gini(D)
其实这个和熵非常相似,只不过是换个衡量指标罢了。

1.1 特征离散

如果特征是离散的,那么它就是按照特征的可选值进行划分节点,该特征有几个离散值,那么就划分成几个节点,和ID3、C4.5决策树一样。

1.2 特征连续

如果特征值是连续的,划分节点时就不能够按照特征的可选数量进行分割节点,因为连续特征有很多可选值,所以肯定不能和离散特征一样的分割方式,它是采用二叉树的方式,每次按照连续特征分成两个分支,分割方式为将待分割特征的所有值从小到大排序,然后选中其中一个值作为划分点,将样本划分为两个部分。

比如说,有一列特征A,值为 [ 1 , 5 , 2 , 6 , 8 , 3 ] [1,5,2,6,8,3] [1,5,2,6,8,3]​ ,按照顺序进行排序, [ 1 , 2 , 3 , 5 , 6 , 8 ] [1,2,3,5,6,8] [1,2,3,5,6,8]​ ,所以可选的值很多,我们假设选中3作为划分点,将原始样本划分为: [ 1 , 2 , 3 ] [1,2,3] [1,2,3]​ 和 [ 5 , 6 , 8 ] [5,6,8] [5,6,8]​ 。

image-20210826145838753

按照连续型特征分割后然后在用基尼系数进行评估。

2.回归树

其实回归树就是标签为连续型的,所以此时不能够使用基尼系数、熵这种的概率评估作为评估指标,因为不是分类不能够利用古典概型求出概率,所以我们考虑使用均方误差作为特征划分的好坏,将划分后的每个节点所有样本的均方误差之和之前没划分的节点的均方误差做差来代替基尼系数。

之前分类问题是计算所有特征的信息增益,此时我们会计算每个特征按照每个划分点的均方误差:
m i n j , s [ m i n c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + m i n c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] min_{j,s}[min_{c_1}\sum_{xi\in R_1(j,s)}(y_i-c1)^2+min_{c_2}\sum_{xi\in R_2(j,s)}(y_i-c2)^2] minj,s[minc1xiR1(j,s)(yic1)2+minc2xiR2(j,s)(yic2)2]
上面的j是不同的特征,s是对应每个特征可供选择的划分点,因为一个连续特征的值很多,所以划分点很多,要选择最优的。

中括号内的意思就是找出针对特征j的最优划分点s,采用均方误差,最外层是特征,计算不同特征。

回归的比分类相对麻烦一些,分类只需要计算每个特征的信息增益,回归是计算每个特征的均方误差增益,但是它多了一个步骤就是求每个特征增益的时候还要找出最优划分值s。

这样生成的树成为最小二乘回归树。

算法流程:

  1. 选择最优切分特征j和切分点s

m i n j , s [ m i n c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + m i n c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] min_{j,s}[min_{c_1}\sum_{xi\in R_1(j,s)}(y_i-c1)^2+min_{c_2}\sum_{xi\in R_2(j,s)}(y_i-c2)^2] minj,s[minc1xiR1(j,s)(yic1)2+minc2xiR2(j,s)(yic2)2]

  1. 用选定的对(j,s)划分区域并决定相应的输出值:

R 1 ( j , s ) = { x ∣ x ( j ) ≤ s } R 2 ( j , s ) = { x ∣ x ( j ) > s } R_1(j,s)=\{x|x^{(j)}\leq s\}\quad R_2(j,s)=\{x|x^{(j)}> s\} R1(j,s)={xx(j)s}R2(j,s)={xx(j)>s}

c m = 1 N m ∑ x i ∈ R m ( j , s ) y i x ∈ R m , m = 1 , 2 c_m=\frac{1}{N_m}\sum_{x_i\in R_m(j,s)}y_i \quad x\in R_m,m=1,2 cm=Nm1xiRm(j,s)yixRm,m=1,2

第一个式子是将数据按照切分点分成两个节点,第二个是求每个节点的均方误差之和。

  1. 继续对两个子区域调用步骤1,2直至满足停止条件
  2. 将输入空间划分为M个区域, R 1 , R 2 , . . . R M R_1,R_2,...R_M R1,R2,...RM ,生成决策树:

f ( x ) = ∑ i = 1 M c m I ( x ∈ R m ) f(x)=\sum_{i=1}^Mc_mI(x\in R_m) f(x)=i=1McmI(xRm)

该式子的意思是求分到相同节点的均值作为预测值,后面的指示函数作为划分到那么区域。

三、剪枝算法

同样针对于CART决策树也存在防止过拟合的方法剪枝,CART剪枝算法由两步组成,首先从生成算法产生的决策树 T 0 T_0 T0 底端开始不断剪枝,直到 T 0 T_0 T0 的根节点,形成一个子树序列 { T 0 , T 1 , . . . , T n } \{T_0,T_1,...,T_n\} {T0,T1,...,Tn} ,然后通过交叉验证法在独立的验证数据集熵对于子树序列进行测试,从中选择最优子树。

我们定义树模型的损失函数为:
C a ( T ) = C ( T ) + a ∣ T ∣ C_a(T)=C(T)+a|T| Ca(T)=C(T)+aT
其中 C ( T ) C(T) C(T) 为模型的预测误差(基尼系数、熵信息增益等), a ∣ T ∣ a|T| aT 代表模型的复杂度,其中 ∣ T ∣ |T| T 代表模型叶节点的个数,所以 C a ( T ) C_a(T) Ca(T) 可以作为树的整体损失,参数 a用于权衡训练数据的拟合程度与模型的复杂度。

取两个极端情况,如果a=0,那么此时的树是最茂盛的,如果a趋于无穷大,那么此时的树就为一个根节点,所以随着a的增大,我们的树会不断变小。

首先对 T 0 T_0 T0​ 的任意内部节点t,以t为单节点树的损失函数为:
C a ( t ) = C ( t ) + a C_a(t)=C(t)+a Ca(t)=C(t)+a
因为此时只有一个叶子节点。

以t为根节点的子树 T t T_t Tt 的损失函数为:
C a ( T t ) = C ( T t ) + a ∣ T t ∣ C_a(T_t)=C(T_t)+a|T_t| Ca(Tt)=C(Tt)+aTt
当a=0时,有:
C a ( T t ) < C a ( t ) C_a(T_t)<C_a(t) Ca(Tt)<Ca(t)
因为此时此时过拟合,很显然可以看出,当a增大时,存在a使得:
C a ( T t ) = C a ( t ) C_a(T_t)=C_a(t) Ca(Tt)=Ca(t)
此时我们认为t节点和以该节点为根节点的子树损失值相同,损失同等情况下,我们选择复杂度小的t,所以进行剪枝,将t作为叶子节点。

此时的a为:
g ( t ) = a = C ( t ) − C ( T t ) ∣ T t ∣ − 1 g(t)=a=\frac{C(t)-C(T_t)}{|T_t|-1} g(t)=a=Tt1C(t)C(Tt)
T 0 T_0 T0 中减去 g ( t ) g(t) g(t) 最小的 T ( t ) T(t) T(t) ,将得到的子树作为 T 1 T_1 T1 ,同时将最小的 g ( t ) g(t) g(t) 设为 a 1 a_1 a1 ,如此剪枝下去,直至得到根节点,然后利用独立的验证数据集取交叉验证获得的子树序列 T 0 , T 1 , . . . T n T_0,T_1,...T_n T0,T1,...Tn 获得最优决策树 T a T_a Ta ,其中每个决策子树对应一个 a。

写在最后

        大家好,我是阿光,觉得文章还不错的话,记得“一键三连”哦!!!

img


http://chatgpt.dhexx.cn/article/n9iLbri2.shtml

相关文章

CART树(分类回归树)

主要内容 &#xff08;1&#xff09;CART树简介 &#xff08;2&#xff09;CART树节点分裂规则 &#xff08;3&#xff09;剪枝 --------------------------------------------------------------------------------------------------------------------- 一、简介 CART…

CART树

算法概述 CART(Classification And Regression Tree)算法是一种决策树分类方法。 它采用一种二分递归分割的技术&#xff0c;分割方法采用基于最小距离的基尼指数估计函数&#xff0c;将当前的样本集分为两个子样本集&#xff0c;使得生成的的每个非叶子节点都有两个分支。因此…

Pytorch之view,reshape,resize函数

对于深度学习中的一下数据&#xff0c;我们通常是要变成tensor格式&#xff0c;并且需要对其调整形状&#xff0c;很多时候我们往往只关注view之后的结果&#xff08;比如输出的尺寸&#xff09;&#xff0c;而不关心过程。但有时候还是要关注一下这个到底是怎么变换过来的&…

OpenCV-Python图像处理:插值方法及使用resize函数进行图像缩放

☞ ░ 前往老猿Python博客 https://blog.csdn.net/LaoYuanPython ░ 图像缩放用于对图像进行缩小或扩大&#xff0c;当图像缩小时需要对输入图像重采样去掉部分像素&#xff0c;当图像扩大时需要在输入图像中根据算法生成部分像素&#xff0c;二者都会利用插值算法来实现。 一…

vector的resize函数和reserve函数

博客原文&#xff1a;C基础篇 -- vector的resize函数和reserve函数_VampirEM_Chosen_One的博客-CSDN博客&#xff0c;写的特别好&#xff0c;谢谢原博主。 正文&#xff1a; 对于C的vector容器模板类&#xff0c;存在size和capacity这样两个概念&#xff0c;可以分别通过vect…

OpenCV 图片尺寸缩放——resize函数

文章目录 OpenCV中的缩放&#xff1a;resize函数代码案例 OpenCV中的缩放&#xff1a; 如果要放大或缩小图片的尺寸&#xff0c;可以使用OpenCV提供的两种方法&#xff1a; resize函数&#xff0c;是最直接的方式&#xff1b;pyrUp&#xff0c;pyrDown函数&#xff0c;即图像…

OpenCV的resize函数优化

背景 在使用OpenCV做图像处理的时候&#xff0c;最常见的问题是c版本性能不足&#xff0c;以resize函数为例来说明&#xff0c;将size为[864,1323,3]的函数缩小一半&#xff1a; Mat img0;gettimeofday(&t4, NULL);cv::resize(source, img0, cv::Size(cols_out,rows_out)…

C++ | resize函数的用法

最近在leetcode用C刷题&#xff0c;刚遇到一题需要给重新弄一个容器&#xff0c;并给容器初始化。新建容器要在private类中声明resize函数用来初始化preSum容器。 resize函数是C中序列式容器的一个共性函数&#xff0c;vv.resize(int n,element)表示调整容器vv的大小为n&#x…

opencv的resize函数

一、Opencv官方文档中resize的描述&#xff1a; resize Resizes an image. C: void resize(InputArray src, OutputArray dst, Size dsize, double fx0, double fy0, int interpolationINTER_LINEAR ) Python: cv2.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]]) …

resize()函数

resize()&#xff0c;设置大小&#xff08;size&#xff09;; reserve()&#xff0c;设置容量&#xff08;capacity&#xff09;; size()是分配容器的内存大小&#xff0c;而capacity()只是设置容器容量大小&#xff0c;但并没有真正分配内存。 打个比方&#xff1a;正在建造…

OpenCV 图像缩放:cv.resize() 函数详解

目录 系列前言API函数详解参数列表缩放方式其一缩放方式其二两种方式的优先级关于插值方式 扩展 —— 相关函数 系列前言 这个系列是我第一个想要更下去的系列。每篇会全面介绍一个 OpenCV 函数&#xff0c;会给出 API 和示例。示例主要是用 Python 去写&#xff0c;但是 Open…

安卓中的几种线程间通信方式

一&#xff1a;Handler实现线程间的通信 andriod提供了 Handler 和 Looper 来满足线程间的通信。例如一个子线程从网络上下载了一副图片&#xff0c;当它下载完成后会发送消息给主线程&#xff0c;这个消息是通过绑定在主线程的Handler来传递的。 在Android&#xff0c;这里的…

Java中的线程通信的几种方式

Java中的线程间通信是指不同线程之间相互协作&#xff0c;以完成一些复杂的任务或实现某些功能的过程。线程间通信主要包括两个方面&#xff1a;线程之间的互斥和同步&#xff0c;以及线程之间的数据共享和通信。Java提供了多种方式来实现线程间通信&#xff0c;本文将介绍Java…

创建线程的四种方式 线程通信

文章目录 1.1 创建线程1.1.1 创建线程的四种方式1.1.2 Thread类与Runnable接口的比较1.1.3 Callable、Future与FutureTask 1.2 线程组和线程优先级1.3 Java线程的状态及主要转化方法1.4 Java线程间的通信1.4.1 等待/通知机制1.4.2 信号量1.4.3 管道 1.1 创建线程 1.1.1 创建线…

【多线程间几种通信方式】

一、使用 volatile 关键字 基于 volatile 关键字来实现线程间相互通信是使用共享内存的思想。大致意思就是多个线程同时监听一个变量&#xff0c;当这个变量发生变化的时候 &#xff0c;线程能够感知并执行相应的业务。这也是最简单的一种实现方式 代码案例 package com.han…

线程之间的通信方式

前言 我只是个搬运工&#xff0c;尊重原作者的劳动成果&#xff0c;本文来源下列文章链接&#xff1a; https://zhuanlan.zhihu.com/p/129374075 https://blog.csdn.net/jisuanji12306/article/details/86363390 线程之间为什么要通信&#xff1f; 通信的目的是为了更好的协…

Java线程间的通信方式

文章目录 线程间通信的定义一、等待—通知&#xff08;1&#xff09;等待—通知机制的相关方法&#xff1a;&#xff08;2&#xff09;注意事项&#xff1a;&#xff08;4&#xff09;notify()方法的核心原理&#xff08;5&#xff09;等待—通知机制的经典范式&#xff08;6&a…

线程间实现通信的几种方式

目录 线程通信相关概述提出问题方式一&#xff1a;使用Object类的wait() 和 notify() 方法方式二&#xff1a;Lock 接口中的 newContition() 方法返回 Condition 对象&#xff0c;Condition 类也可以实现等待/通知模式方法三&#xff1a;使用 volatile 关键字方法四&#xff1a…

线程间的通信方式

对共享数据进行更改的时候&#xff0c;先到主内存中拷贝一份到本地内存中&#xff0c;然后进行数据的更改&#xff0c;再重新将数据刷到主内存&#xff0c;这中间的过程&#xff0c;其他线程是看不到的。 1、为什么需要线程通信 线程是操作系统调度的最小单位&#xff0c;有自…

进程和线程的几种通信方式

进程之间通信的几种方式 1. 管道&#xff1a;是内核里面的一串缓存 管道传输的数据是单向的&#xff0c;若相互进行通信的话&#xff0c;需要进行创建两个管道才行的。 2. 消息队列&#xff1a; 例如&#xff0c;A进程给B进程发送消息&#xff0c;A进程把数据放在对应的消息队…