【机器学习】Python实现决策树的预剪枝与后剪枝

article/2025/9/16 1:12:43

决策树是一种用于分类和回归任务的非参数监督学习算法。它是一种分层树形结构,由根节点、分支、内部节点和叶节点组成。

95ac9de7d787b98fcee1afa427c05911.png

从上图中可以看出,决策树从根节点开始,根节点没有任何传入分支。然后,根节点的传出分支为内部节点(也称为决策节点)提供信息。两种节点都基于可用功能执行评估以形成同类子集,这些子集由叶节点或终端节点表示。叶节点表示数据集内所有可能的结果。

决策树的类型

Hunt 算法于 20 世纪 60 年代提出,起初用于模拟心理学中的人类学习,为许多常用的决策树算法奠定了基础,例如:

  • ID3:该算法的开发归功于 Ross Quinlan,全称为"迭代二叉树 3 代" ("Iterative Dichotomiser 3")。该算法利用信息熵与信息增益作为评估候选拆分的指标。

  • C4.5:该算法是 ID3 的后期扩展,同样由 Quinlan 开发。它可以使用信息增益或增益率来评估决策树中的切分点。

  • CART:术语 "CART" 的全称是"分类和回归",提出者是 Leo Breiman。该算法通常利用"基尼不纯度"来确定要拆分的理想属性。基尼不纯度衡量随机选择的属性被错误分类的频率。使用该评估方法时,基尼不纯度越小越理想。

决策树的构建

详细的构建过程可以参考:决策树的构建原理

案例数据集准备

1f8499feb5a86fb35e043da9d29e25ce.png
泰坦尼克号数据集
eaa94de4e9e7ff65e5da6badead15ac4.png
数据处理后的数据集
616c9cf7dee6b46b3076832b52a7d898.png
幸存者统计

决策树构建及可视化

# 定于预测目标变量名
Target = ["Survived"]
## 定义模型的自变量名
train_x = ["Pclass", "Sex", "SibSp", "Parch","Embarked", "Age_band","re"]
##将训练集切分为训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(data[train_x], data[Target],test_size = 0.25,random_state = 1)
## 先使用默认的参数建立一个决策树模型
dtc1 = DecisionTreeClassifier(random_state=1)
## 使用训练数据进行训练
dtc1 = dtc1.fit(X_train, y_train)
## 输出其在训练数据和验证数据集上的预测精度
dtc1_lab = dtc1.predict(X_train)
dtc1_pre = dtc1.predict(X_val)## 将获得的决策树结构可视化
dot_data = StringIO()
export_graphviz(dtc1, out_file=dot_data,feature_names=X_train.columns,filled=True, rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
Image(graph.create_png())
e8fd574596f51d999e80b3007ce5edef.png
未剪枝决策树

观察上图所示的模型结构可以发现,该模型是非常复杂的决策树模型,而且决策树的层数远远超过了10层,从而使用该决策树获得的规则会非常的复杂。通过模型的可视化进一步证明了获得的决策树模型具有严重的过拟合问题,需要对模型进行剪枝,精简模型。

模型在训练集上有74个错误样本,而在测试集上存在50个错误样本。

0963202cc33ddbf26597dd974830a623.png
训练数据集混淆矩阵
3955072246b181375a051762adc3b6dc.png
测试数据集混淆矩阵

观察图1所示的模型结构可以发现,该模型是非常复杂的决策树模型,而且决策树的层数远远超过了10层,从而使用该决策树获得的规则会非常的复杂。通过模型的可视化进一步证明了获得的决策树模型具有严重的过拟合问题,需要对模型进行剪枝,精简模型。

决策树的过拟合问题

决策树学习采用"一一击破"的策略,执行贪心搜索 (greedy search) 来识别决策树内的最佳分割点。然后以自上而下的回归方式重复此拆分过程,直到所有或者大多数记录都标记为特定的类别标签。是否将所有数据点归为同类子集在很大程度上取决于决策树的复杂性。较小的决策树更容易获得无法分裂的叶节点,即单个类别中的数据点。然而,决策树的体量越来越大,就会越来越难保持这种纯度,并且通常会导致落在给定子树内的数据过少。这种情况被称为数据碎片,通常会引起数据过拟合。

因此通常选择小型决策树,这与奥卡姆剃刀原理的"简单有效原理"相符,即"如无必要,勿增实体"。换句话说,我们应该只在必要时增加决策树的复杂性,因为最简单的往往是最好的。为了降低复杂性并防止过拟合,通常采用剪枝算法;这一过程会删除不太重要的特征的分支。然后,通过交叉验证评估模型的拟合。另一种保持决策树准确性的方法是使用随机森林算法形成一个集合;这种分类法可以得到更加准确的预测结果,特别是在决策树分支彼此不相关的情况下。

决策树的剪枝

决策树的剪枝有两种思路:

  • 预剪枝(Pre-Pruning)

  • 后剪枝(Post-Pruning)

预剪枝(Pre-Pruning)

预剪枝就是在构造决策树的过程中,先对每个结点在划分前进行估计,如果当前结点的划分不能带来决策树模型泛化性能的提升,则不对当前结点进行划分并且将当前结点标记为叶结点。

所有决策树的构建方法,都是在无法进一步降低熵的情况下才会停止创建分支的过程,为了避免过拟合,可以设定一个阈值,熵减小的数量小于这个阈值,即使还可以继续降低熵,也停止继续创建分支。但是这种方法实际中的效果并不好。

决策树模型的剪枝操作主要会用到DecisionTreeClassifier()函数中的

  • max_depth:指定了决策树的最大深度

  • max_leaf_nodes:指定了模型的叶子节点的最大数目

  • min_sample_split:指定了模型的节点允许分割的最小样本数

  • min_samples_leaf:指定了模型的一个叶节点上所需的最小样本数

这里使用参数网格搜索的方式,对该模型中的四个参数进行搜索,并通过该在验证集上的预测精度为准测,获取较合适的模型参数组合。

params = {'max_depth': np.arange(2,12,2),'max_leaf_nodes': np.arange(10,30,2),'min_samples_split': [2,3,4],'min_samples_leaf': [1,2]}clf = DecisionTreeClassifier(random_state=1)
gcv = GridSearchCV(estimator=clf,param_grid=params)
gcv.fit(X_train,y_train)model = gcv.best_estimator_
model.fit(X_train,y_train)## 可视化决策树经过剪剪枝后的树结构
dot_data = StringIO()
export_graphviz(model, out_file=dot_data,feature_names=X_train.columns,filled=True, rounded=True,special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
Image(graph.create_png())
bafe6b1828ea80baa7af93b1c80db086.png
预剪枝后决策树

从剪枝后决策树模型中可以发现:该模型和未剪枝的模型相比已经大大的简化了。模型在训练集上有95个错误样本,但在测试集上只存在47个错误样本。

220fdc8ee32e4cf97881e81d2819ced5.png
训练数据集混淆矩阵
1997a8ee184b604734670140ea88c326.png
测试数据集混淆矩阵

后剪枝(Post-Pruning)

决策树构造完成后进行剪枝。剪枝的过程是对拥有同样父节点的一组节点进行检查,判断如果将其合并,熵的增加量是否小于某一阈值。如果确实小,则这一组节点可以合并一个节点,其中包含了所有可能的结果。后剪枝是目前最普遍的做法。

后剪枝的剪枝过程是删除一些子树,然后用其叶子节点代替,这个叶子节点所标识的类别通过大多数原则(majority class criterion)确定。所谓大多数原则,是指剪枝过程中, 将一些子树删除而用叶节点代替,这个叶节点所标识的类别用这棵子树中大多数训练样本所属的类别来标识,所标识的类 称为majority class ,(majority class 在很多英文文献中也多次出现)。

后剪枝算法有很多种,这里简要总结如下:

  • Reduced-Error Pruning(REP)

  • Pesimistic-Error Pruning(PEP)

  • Cost-Complexity Pruning(CCP)

Reduced-Error Pruning (REP,错误率降低剪枝)

这个思路很直接,完全的决策树不是过度拟合么,我再搞一个测试数据集来纠正它。对于完全决策树中的每一个非叶子节点的子树,我们尝试着把它替换成一个叶子节点,该叶子节点的类别我们用子树所覆盖训练样本中存在最多的那个类来代替,这样就产生了一个简化决策树,然后比较这两个决策树在测试数据集中的表现,如果简化决策树在测试数据集中的错误比较少,那么该子树就可以替换成叶子节点。该算法以bottom-up的方式遍历所有的子树,直至没有任何子树可以替换使得测试数据集的表现得以改进时,算法就可以终止。

Pessimistic Error Pruning (PEP,悲观剪枝)

PEP剪枝算法是在C4.5决策树算法中提出的, 把一颗子树(具有多个叶子节点)用一个叶子节点来替代(我研究了很多文章貌似就是用子树的根来代替)的话,比起REP剪枝法,它不需要一个单独的测试数据集。

PEP算法首先确定这个叶子的经验错误率(empirical)为(E+0.5)/N,0.5为一个调整系数。对于一颗拥有L个叶子的子树,则子树的错误数和实例数都是就应该是叶子的错误数和实例数求和的结果,则子树的错误率为e

然后用一个叶子节点替代子树,该新叶子节点的类别为原来子树节点的最优叶子节点所决定,J为这个替代的叶子节点的错判个数,但是也要加上0.5,即KJ+0.5。最终是否应该替换的标准为

被替换子树的错误数-标准差 > 新叶子错误数

出现标准差,是因为子树的错误个数是一个随机变量,经过验证可以近似看成是二项分布,就可以根据二项分布的标准差公式算出标准差,就可以确定是否应该剪掉这个树枝了。子树中有N的实例,就是进行N次试验,每次实验的错误的概率为e,符合 B(N,e) 的二项分布,根据公式,均值为Ne,方差为Ne(1-e),标准差为方差开平方。

Cost-Complexity Pruning(CCP,代价复杂度剪枝)

在决策树中,这种剪枝技术是由代价复杂性参数ccp_alpha来参数化的。ccp_alpha值越大,剪枝的节点数就越多。简单地说,代价复杂性是一个阈值。只有当模型的整体不纯度改善了一个大于该阈值的值时,该模型才会将一个节点进一步拆分为其子节点,否则将停止。

当CCP值较低时,即使不纯度减少不多,该模型也会将一个节点分割成子节点。随着树的深度增加,这一点很明显,也就是说,当我们沿着决策树往下走时,我们会发现分割对模型整体不纯度的变化没有太大贡献。然而,更高的分割保证了类的正确分类,即准确度更高。

当CCP值较低时,会创建更多的节点。节点越高,树的深度也越高。

下面的代码(Scikit Learn)说明了如何对alpha进行调整,以获得更高精度分数的模型。

path = model_gini.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impuritiesfig, ax = plt.subplots(figsize=(16,8))
ax.plot(ccp_alphas[:-1], impurities[:-1], marker='o', drawstyle="steps-post")
a9ddba1ebb10e9e101f3f2d4dcac8653.png e334d0a87e6f88393f7301a486362b21.png

从结果可知如果alpha设置为0.04得到的测试集精度最好,我们将从新训练模型。

clf_ccp = DecisionTreeClassifier(random_state=1,ccp_alpha=0.04)
clf_ccp.fit(X_train,y_train)
31e6912b2f5893c6e2734d186bf956b2.png
后剪枝后决策树

可以看到,模型深度非常浅,也能达到很好的效果。模型在训练集上有140个错误样本,但在测试集上只存在54个错误样本。

5644728c9903140b95225fbc0d3317fc.png
训练数据集混淆矩阵
268f6ca3f8ebf7b4ef8576f885243e6e.png
测试数据集混淆矩阵

4eb8aea9636da648bc73990a5c8383be.gif

🏴‍☠️宝藏级🏴‍☠️ 原创公众号『数据STUDIO』内容超级硬核。公众号以Python为核心语言,垂直于数据科学领域,包括可戳👉 PythonMySQL数据分析数据可视化机器学习与数据挖掘爬虫 等,从入门到进阶!

长按👇关注- 数据STUDIO -设为星标,干货速递85ffafe9aa829aabf59782c9e77e3e8e.giff915275bbfefd93e75114871f91fd99f.gif


http://chatgpt.dhexx.cn/article/W1YrlIZM.shtml

相关文章

决策树的预剪枝与后剪枝

前言: 本次讲解参考的仍是周志华的《机器学习》,采用的是书中的样例,按照我个人的理解对其进行了详细解释,希望大家能看得懂。 1、数据集 其中{1,2,3,6,7,10,14,15,16,17}为测试集,{4,5,8,9,11,12,13}为训练集。 2、…

YOLOv5剪枝✂️ | 模型剪枝理论篇

文章目录 1. 前言2. 摘要精读3. 背景4. 本文提出的解决方式5. 通道层次稀疏性的优势6. 挑战7. 缩放因素和稀疏性惩罚8. 利用BN图层中的缩放因子9. 通道剪枝和微调10. 多通道方案11. 处理跨层连接和预激活结构12. 实验结果12.1 CIFAR-10数据集剪枝效果12.2 CIFAR-100数据集剪枝效…

决策树及决策树生成与剪枝

文章目录 1. 决策树学习2. 最优划分属性的选择2.1 信息增益 - ID32.1.1 什么是信息增益2.1.2 ID3 树中最优划分属性计算举例 2.2 信息增益率 - C4.52.3 基尼指数 - CART 3. 决策树剪枝3.1 决策树的损失函数3.2 如何进行决策树剪枝3.2.1 预剪枝3.2.2 后剪枝3.3.3 两种剪枝策略对…

剪枝

将复杂的决策树进行简化的过程称为剪枝,它的目的是去掉一些节点,包括叶节点和中间节点。 剪枝常用方法:预剪枝与后剪枝两种。 预剪枝:在构建决策树的过程中,提前终止决策树生长,从而避免过多的节点产生。该…

(剪枝)剪枝的理论

剪枝参考视频 本文将介绍深度学习模型压缩方法中的剪枝,内容从剪枝简介、剪枝步骤、结构化剪枝与非结构化剪枝、静态剪枝与动态剪枝、硬剪枝与软剪枝等五个部分展开。 剪枝简介 在介绍剪枝之前,首先来过参数化这个概念,过参数化主要是指在训…

剪枝总结

一、引子 剪枝,就是减小搜索树规模、尽早排除搜索树中不必要的分支的一种手段。 形象地看,就好像剪掉了搜索树的枝条,故被称为剪枝。 二、常见剪枝方法 1.优化搜索顺序 在一些问题中,搜索树的各个分支之间的顺序是不固定的 …

搜索剪枝

目录 什么是剪枝 几种常见的剪枝 1.可行性剪枝 2.排除等效冗余 3.最优性剪枝 4.顺序剪枝 5.记忆化 运用实例 1.选数 2.吃奶酪 3.小木棍 什么是剪枝 剪枝:通过某种判断,避免一些不必要的遍历过程。搜索的时间复杂度通常很大,通过剪…

【模型压缩】(二)—— 剪枝

一、概述 剪枝(Pruning)的一些概念: 当提及神经网络的"参数"时,大多数情况指的是网络的学习型参数,也就是权重矩阵weights和偏置bias;现代网络的参数量大概在百万至数十亿之间,因此…

环形队列的基本运算算法-数据结构教程

环形队列的基本概念 如图,其实它就是一个队列,就是有点难理解而已,它避免了普通队列的缺点,一样有队列头,队列尾,一样是先进先出的原则。我们采用顺时针的方式来对队列进行排序。 队列头(front) :允许进行删…

一道亚马逊算法面试题的情景分析

阅读博客的朋友可以观看视频: http://study.163.com/course/courseMain.htm?courseId1002942008 我们聚焦于一道亚马逊的算法面试题,通过分析该题,复盘它的解题情景,我们可以初步体会到算法面试的应对步骤,并从中窥…

LeetCode刷题笔记 标准模板库巧解算法题 优先队列

优先队列简介 ​ 优先队列(priority queue)可以在 O(1) 时间内获得最大值,并且可以在 O(log n) 时间内取出最大值或插入任意值。 ​ 优先队列常常用堆(heap)来实现。堆是一个完全二叉树,其每个节点的值总…

Python数据结构与算法(3.4)——队列相关应用与习题

Python数据结构与算法(3.4)——队列相关应用与习题 0. 学习目标1. 使用两个栈实现一个队列2. 使用两个队列实现一个栈3. 栈中元素连续性判断4. 重新排列队列中元素顺序5. 反转队列中前 m 个元素的顺序相关链接0. 学习目标 我们已经学习了队列的相关概念以及其实现,同时也了…

第十七章 优先队列优化Dijkstra算法

第十七章 优先队列优化Dijkstra算法 一、普通dijkstra算法的缺陷1、选出最小距离的过程:2、松弛所有点的过程: 二、如何优化1、代码模板(1)问题:(2)模板: 2、详细解读 三、优化分析1…

【自顶向下模块化编程】C语言实现多级反馈队列调度算法

自顶向下-多级反馈队列 多级反馈队列算法算法原理算法描述题目摘要 自顶向下模块化设计整体框架具体实现GeneratorSchedulerExecutor 整体代码实现 总结及心得总结心得 多级反馈队列算法 多级反馈队列调度算法是一种CPU处理机调度算法,UNIX操作系统采取的便是这种调…

[算法] 栈和队列

欢迎来到老胡的算法解题思路,本文章主要使用的语言为java,使用的题型为力扣算法题,基于这一篇文章,我将为你介绍栈和队列的基础知识和栈和队列的题型,喜欢的朋友可以关注一下,下次更新不迷路! 目…

十道经典面试算法真题详解

前言 分享一下 腾讯常考的十道算法题(真题)。在金三银四,希望对大家有帮助呀。 重排链表 最长递增子序列 环形链表 反转链表 最长回文子串 全排列 LRU 缓存 合并K个升序链表 无重复字符的最长子串 删除链表的倒数第 N 个结点 1. …

队列相关习题

1.已知循环队列存储在一维数组A0…n-1]中,且队列非空时front和rear分别指向队头元素和队尾元素。若初始时队列为空,且要求第一个进入队列的元素存储在A[0]处,则初始时front和rear的值分别是( )。 A.0,0 B. 0,n-1 C. n-…

java算法面试题_Java算法面试题汇总

原标题:Java算法面试题汇总 1. 字符串 如果IDE没有代码自动补全功能,所以你应该记住下面的这些方法。 toCharArray() // 获得字符串对应的char数组 Arrays.sort() // 数组排序 Arrays.toString(char[] a) // 数组转成字符串 charAt(int x) // 获得某个索…

详解单调队列算法

前言 嘿!彩蛋!感觉有帮助就三连呗! 如果你对这篇文章可感兴趣,可以点击「【访客必读 - 指引页】一文囊括主页内所有高质量博客」,查看完整博客分类与对应链接。 在上一篇文章中,我们介绍了「单调栈」这一最常考察的线性数据结构。而今天我们将继续沿着这个思路,介绍另…

栈和队列相关经典算法题总结(数据结构+C语言)

我们这里针对栈和队列的一些经典算法题做详细讲解: 1.括号匹配问题. 2.用队列实现栈. 3.用栈实现队列. 4.设计循环队列. 一.详细讲解如下: 1.括号匹配问题.(如下图) 给定一个只包括 (,),{,},[,] 的字符串 s &am…