【树模型与集成学习】(task2)代码实现CART树(更新ing)

article/2025/9/24 15:58:33

学习心得

task2学习GYH大佬的回归CART树,并在此基础上改为分类CART树。
更新ing。。

这里做一些对决策树分裂依据更深入的思考引导:我们在task1证明离散变量信息增益非负时曾提到,信息增益本质上就是联合分布和边缘分布乘积的kl散度,而事实上kl散度属于f-divergence(https://en.wikipedia.org/wiki/F-divergence)中的一类特殊情况,由于在分裂时我们衡量的是这两个分布的差异到底有多大,因此f-divergence中的任意一种距离度量都可以用来作为分裂依据,那么在树结构上进行分裂,这些散度究竟对树的生长结果产生了怎样的影响,似乎还没有看到文章讨论过这些(可以试图充分地讨论它们之间的一些理论性质和联系)

(1)可能会发现在与sklearn对比时,有时会产生两者结果预测部分不一致的情况,这种现象主要来自于当前节点在分裂的时候不同的特征和分割点组合产生了相同的信息增益,但由于遍历特征的顺序(和sklearn内的遍历顺序)不一样,因此在预测时会产生差异,并不是算法实现上有问题。
(2)对比的时候作差后要取绝对值,(np.abs(res1-res2)<1e-8).mean()。

文章目录

  • 学习心得
  • 一、回顾决策树算法
  • 二、代码实践
  • Reference

一、回顾决策树算法

在这里插入图片描述

在这里插入图片描述

二、代码实践

from CART import DecisionTreeRegressor
from CARTclassifier import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor as dt
from sklearn.tree import DecisionTreeClassifier as dc
from sklearn.datasets import make_regression
from sklearn.datasets import make_classificationif __name__ == "__main__":# 模拟回归数据集X, y = make_regression(n_samples=200, n_features=10, n_informative=5, random_state=0)# 回归树my_cart_regression = DecisionTreeRegressor(max_depth=2)my_cart_regression.fit(X, y)res1 = my_cart_regression.predict(X)importance1 = my_cart_regression.feature_importances_sklearn_cart_r = dt(max_depth=2)sklearn_cart_r.fit(X, y)res2 = sklearn_cart_r.predict(X)importance2 = sklearn_cart_r.feature_importances_# 预测一致的比例print(((res1-res2)<1e-8).mean())# 特征重要性一致的比例print(((importance1-importance2)<1e-8).mean())# 模拟分类数据集X, y = make_classification(n_samples=200, n_features=10, n_informative=5, random_state=0)# 分类树my_cart_classification = DecisionTreeClassifier(max_depth=2)my_cart_classification.fit(X, y)res3 = my_cart_classification.predict(X)importance3 = my_cart_classification.feature_importances_sklearn_cart_c = dc(max_depth=2)sklearn_cart_c.fit(X, y)res4 = sklearn_cart_c.predict(X)importance4 = sklearn_cart_c.feature_importances_# 预测一致的比例print(((res3-res4)<1e-8).mean())# 特征重要性一致的比例print(((importance3-importance4)<1e-8).mean())
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 17 10:46:08 2021@author: 86493
"""
import numpy as np
from collections import Counterdef MSE(y):return ((y - y.mean())**2).sum() / y.shape[0]# 基尼指数
def Gini(y):c = Counter(y)return 1 - sum([(val / y.shape[0]) ** 2 for val in c.values()])class Node:def __init__(self, depth, idx):self.depth = depthself.idx = idxself.left = Noneself.right = Noneself.feature = Noneself.pivot = Noneclass Tree:def __init__(self, max_depth):self.max_depth = max_depthself.X = Noneself.y = Noneself.feature_importances_ = Nonedef _able_to_split(self, node):return (node.depth < self.max_depth) & (node.idx.sum() >= 2)def _get_inner_split_score(self, to_left, to_right):total_num = to_left.sum() + to_right.sum()left_val = to_left.sum() / total_num * Gini(self.y[to_left])right_val = to_right.sum() / total_num * Gini(self.y[to_right])return left_val + right_valdef _inner_split(self, col, idx):data = self.X[:, col]best_val = np.inftyfor pivot in data[:-1]:to_left = (idx==1) & (data<=pivot)to_right = (idx==1) & (~to_left)if to_left.sum() == 0 or to_left.sum() == idx.sum():continueHyx = self._get_inner_split_score(to_left, to_right)if best_val > Hyx:best_val, best_pivot = Hyx, pivotbest_to_left, best_to_right = to_left, to_rightreturn best_val, best_to_left, best_to_right, best_pivotdef _get_conditional_entropy(self, idx):best_val = np.inftyfor col in range(self.X.shape[1]):Hyx, _idx_left, _idx_right, pivot = self._inner_split(col, idx)if best_val > Hyx:best_val, idx_left, idx_right = Hyx, _idx_left, _idx_rightbest_feature, best_pivot = col, pivotreturn best_val, idx_left, idx_right, best_feature, best_pivotdef split(self, node):# 首先判断本节点是不是符合分裂的条件if not self._able_to_split(node):return None, None, None, None# 计算H(Y)entropy = Gini(self.y[node.idx==1])# 计算最小的H(Y|X)(conditional_entropy,idx_left,idx_right,feature,pivot) = self._get_conditional_entropy(node.idx)# 计算信息增益G(Y, X)info_gain = entropy - conditional_entropy# 计算相对信息增益relative_gain = node.idx.sum() / self.X.shape[0] * info_gain# 更新特征重要性self.feature_importances_[feature] += relative_gain# 新建左右节点并更新深度node.left = Node(node.depth+1, idx_left)node.right = Node(node.depth+1, idx_right)self.depth = max(node.depth+1, self.depth)return idx_left, idx_right, feature, pivotdef build_prepare(self):self.depth = 0self.feature_importances_ = np.zeros(self.X.shape[1])self.root = Node(depth=0, idx=np.ones(self.X.shape[0]) == 1)def build_node(self, cur_node):if cur_node is None:returnidx_left, idx_right, feature, pivot = self.split(cur_node)cur_node.feature, cur_node.pivot = feature, pivotself.build_node(cur_node.left)self.build_node(cur_node.right)def build(self):self.build_prepare()self.build_node(self.root)def _search_prediction(self, node, x):if node.left is None and node.right is None:# return self.y[node.idx].mean()return self.y[node.idx].min()if x[node.feature] <= node.pivot:node = node.leftelse:node = node.rightreturn self._search_prediction(node, x)def predict(self, x):return self._search_prediction(self.root, x)class DecisionTreeClassifier:"""max_depth控制最大深度,类功能与sklearn默认参数下的功能实现一致"""def __init__(self, max_depth):self.tree = Tree(max_depth=max_depth)def fit(self, X, y):self.tree.X = Xself.tree.y = yself.tree.build()self.feature_importances_ = (self.tree.feature_importances_ / self.tree.feature_importances_.sum())return selfdef predict(self, X):return np.array([self.tree.predict(x) for x in X])

输出结果如下,可见在误差范围内,实现的分类树和回归树均和sklearn实现的模块近似。

1.0
1.0
1.0
1.0

Reference

(0)datawhale notebook
(1)CART决策树(Decision Tree)的Python源码实现
(2)https://github.com/RRdmlearning/Decision-Tree
(3)《机器学习技法》—决策树


http://chatgpt.dhexx.cn/article/3jR6Wyxe.shtml

相关文章

CART 决策树

ID3使用信息增益&#xff0c;而C4.5使用增益比率进行拆分。 在此&#xff0c;CART是另一种决策树构建算法。 它可以处理分类和回归任务。 该算法使用名为gini索引的新度量标准来创建分类任务的决策点。 CART树的核心是决策规则将通过GINI索引值决定。 停止条件。 如果我们继续…

CART决策树算法

在进行自动识别窃漏电用户分析实战时&#xff0c;用到了CART决策树算法&#xff0c;所以整理记录该算法的内容。内容整理参考文档决策树——CART算法及其后的参考文章。 一、CART&#xff08;classification and regression tree&#xff09;分类与回归树&#xff0c;既可用于…

CART树算法解析加举例

算法步骤 CART假设决策树是二叉树&#xff0c;内部结点特征的取值为“是”和“否”&#xff0c;左分支是取值为“是”的分支&#xff0c;右分支是取值为“否”的分支。这样的决策树等价于递归地二分每个特征&#xff0c;将输入空间即特征空间划分为有限个单元&#xff0c;并在…

ID3、C4.5与CART树的联系与区别

ID3、C4.5与CART树的联系与区别&#xff1a; 参考博客&#xff1a; 链接1 链接2 特征选择准则&#xff1a; ID3的特征选择准则为信息增益&#xff0c;即集合D的经验熵H(D)与给定特征A下条件经验熵H(D|A)之差&#xff0c;即&#xff1a; H(D)表现了数据集D进行分类的不确定性…

决策树构建算法—ID3、C4.5、CART树

决策树构建算法—ID3、C4.5、CART树 决策树构建算法—ID3、C4.5、CART树 构建决策树的主要算法ID3C4.5CART三种算法总结对比 决策树构建算法—ID3、C4.5、CART树 构建决策树的主要算法 ID3C4.5CART &#xff08;Classification And Rsgression Tree&#xff09; ID3 ID3算法…

3-6 决策树、CART树、GBDT、xgboost、lightgbm一些关键点梳理

目录 1、决策树 2、CART树 2.1 CART分类树-输入样本特征&#xff1b;输出样本对应的类别(离散型) 2.2 CART回归树-输入样本特征&#xff1b;输出样本的回归值(连续型) 3、GBDT 3.1 提升树 3.2 GBDT 4、xgboost 4.1 损失函数及节点展开 4.2 精确贪心算法及相关近似算法…

CART树回归

说明&#xff1a;本博客是学习《python机器学习算法》赵志勇著的学习笔记&#xff0c;其图片截取也来源本书。 基于树的回归算法是一类基于局部的回归算法&#xff0c;通过将数据集切分成多份&#xff0c;在每一份数据中单独建模。与局部加权线性回归不同的是&#xff0c;基于…

剪枝、cart树

一、剪枝 1. 为什么要剪枝 在决策树生成的时候&#xff0c;更多考虑的是训练数据&#xff0c;而不是未知数据&#xff0c;这会导致过拟合&#xff0c;使树过于复杂&#xff0c;对于未知的样本不准确。 2. 剪枝的依据——通过极小化决策树的损失函数 损失函数的定义为&#x…

【机器学习】决策树——CART分类回归树(理论+图解+公式)

&#x1f320; 『精品学习专栏导航帖』 &#x1f433;最适合入门的100个深度学习实战项目&#x1f433;&#x1f419;【PyTorch深度学习项目实战100例目录】项目详解 数据集 完整源码&#x1f419;&#x1f436;【机器学习入门项目10例目录】项目详解 数据集 完整源码&…

CART树(分类回归树)

主要内容 &#xff08;1&#xff09;CART树简介 &#xff08;2&#xff09;CART树节点分裂规则 &#xff08;3&#xff09;剪枝 --------------------------------------------------------------------------------------------------------------------- 一、简介 CART…

CART树

算法概述 CART(Classification And Regression Tree)算法是一种决策树分类方法。 它采用一种二分递归分割的技术&#xff0c;分割方法采用基于最小距离的基尼指数估计函数&#xff0c;将当前的样本集分为两个子样本集&#xff0c;使得生成的的每个非叶子节点都有两个分支。因此…

Pytorch之view,reshape,resize函数

对于深度学习中的一下数据&#xff0c;我们通常是要变成tensor格式&#xff0c;并且需要对其调整形状&#xff0c;很多时候我们往往只关注view之后的结果&#xff08;比如输出的尺寸&#xff09;&#xff0c;而不关心过程。但有时候还是要关注一下这个到底是怎么变换过来的&…

OpenCV-Python图像处理:插值方法及使用resize函数进行图像缩放

☞ ░ 前往老猿Python博客 https://blog.csdn.net/LaoYuanPython ░ 图像缩放用于对图像进行缩小或扩大&#xff0c;当图像缩小时需要对输入图像重采样去掉部分像素&#xff0c;当图像扩大时需要在输入图像中根据算法生成部分像素&#xff0c;二者都会利用插值算法来实现。 一…

vector的resize函数和reserve函数

博客原文&#xff1a;C基础篇 -- vector的resize函数和reserve函数_VampirEM_Chosen_One的博客-CSDN博客&#xff0c;写的特别好&#xff0c;谢谢原博主。 正文&#xff1a; 对于C的vector容器模板类&#xff0c;存在size和capacity这样两个概念&#xff0c;可以分别通过vect…

OpenCV 图片尺寸缩放——resize函数

文章目录 OpenCV中的缩放&#xff1a;resize函数代码案例 OpenCV中的缩放&#xff1a; 如果要放大或缩小图片的尺寸&#xff0c;可以使用OpenCV提供的两种方法&#xff1a; resize函数&#xff0c;是最直接的方式&#xff1b;pyrUp&#xff0c;pyrDown函数&#xff0c;即图像…

OpenCV的resize函数优化

背景 在使用OpenCV做图像处理的时候&#xff0c;最常见的问题是c版本性能不足&#xff0c;以resize函数为例来说明&#xff0c;将size为[864,1323,3]的函数缩小一半&#xff1a; Mat img0;gettimeofday(&t4, NULL);cv::resize(source, img0, cv::Size(cols_out,rows_out)…

C++ | resize函数的用法

最近在leetcode用C刷题&#xff0c;刚遇到一题需要给重新弄一个容器&#xff0c;并给容器初始化。新建容器要在private类中声明resize函数用来初始化preSum容器。 resize函数是C中序列式容器的一个共性函数&#xff0c;vv.resize(int n,element)表示调整容器vv的大小为n&#x…

opencv的resize函数

一、Opencv官方文档中resize的描述&#xff1a; resize Resizes an image. C: void resize(InputArray src, OutputArray dst, Size dsize, double fx0, double fy0, int interpolationINTER_LINEAR ) Python: cv2.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]]) …

resize()函数

resize()&#xff0c;设置大小&#xff08;size&#xff09;; reserve()&#xff0c;设置容量&#xff08;capacity&#xff09;; size()是分配容器的内存大小&#xff0c;而capacity()只是设置容器容量大小&#xff0c;但并没有真正分配内存。 打个比方&#xff1a;正在建造…

OpenCV 图像缩放:cv.resize() 函数详解

目录 系列前言API函数详解参数列表缩放方式其一缩放方式其二两种方式的优先级关于插值方式 扩展 —— 相关函数 系列前言 这个系列是我第一个想要更下去的系列。每篇会全面介绍一个 OpenCV 函数&#xff0c;会给出 API 和示例。示例主要是用 Python 去写&#xff0c;但是 Open…