CART树回归

article/2025/9/24 16:23:38

说明:本博客是学习《python机器学习算法》赵志勇著的学习笔记,其图片截取也来源本书。

基于树的回归算法是一类基于局部的回归算法,通过将数据集切分成多份,在每一份数据中单独建模。与局部加权线性回归不同的是,基于树回归的算法是一种基于参数学习的算法,利用训练数据训练完模型后,参数一定确定,无需再改变。

分类回归树(Classification And Regression Tree,CART)算法是使用比较多的一种树模型,CART算法既可以解决分类问题也可以解决回归问题。前面的博客随机森林中有介绍CART算法处理分类的问题,在这次的博客中将涉及到CART的回归问题。CART树回归属于一种局部的回归算法,通过将全局的数据集划分成多份容易建模的数据集,这样在每一个局部的数据集上进行局部的回归建模。

CART算法中的树采用一种二分递归分割技术,即将当前的样本集分为左子树和右子树两个样本集,使得生成的每个非子叶节点都有两个分支。因此,CART算法生成的决策树是非典型的二叉树。

利用CART算法处理回归问题的主要步骤:1.CART会归树的生成;2.CART回归树的剪枝。

1、CART回归树的生成

1.1、CART回归树的划分

在CART分类算法中,利用Gini指数作为树的指标,通过样本中的特征,对样本进行划分,直到所有的叶节点的所有样本都为同样类别为止。但在CART回归树中,样本的标是一系列的连续值的集合,不能再使用Gini指数作为划分树的指标。同时,我们也注意到Gini指数是衡量数据的混乱程度的,对于连续的数据,当数据分布比较分散时,各个数据与平均数的差的平方和较大,方差就越大;反之,当数据分布比较集中时,各个数据与平均数的差的平方和较小,方差就越小,数据的波动就越小。因此,对于连续的数据,可以使用样本与平均值的平方和作为划分回归树的指标。
这里写图片描述

有了划分的标准,那如何划分数据呢?在CART中我们根据每一维特征中的每一个值,尝试将样本划分到树节点的左右子树中,如取样本特征中第j维特征中值x作为划分的值,如果一个样本在第j维初的值大于或者等于x,则将其划分到右子树中,具体划分过程如下图所示。

这里写图片描述
一般小于特征值在左子树,大于等于在右子树。

1.1、CART回归树的构建

CART分类树的构建过程如下所示:

1、对于当前序训练数据集,遍历所有属性及其所有可能的切分点,寻找最佳切分属性及其最佳切分点,使得切分之后的基尼指数最小,利用该最佳属性及其最佳切分点将数据集划分为两个子数据集,分别对应的结果就是左右两子树。

2、对第一步中生成的两个数据集递归调用第一步,直至满足条件为止。
3、生成CART决策树。

2、CART回归树剪枝

对CART进行剪枝的目的是防止CART出现过拟合。在剪枝中主要分为:前剪枝和后剪枝。

2.1、前剪枝

前剪枝是指在生成CART回归树可以通过设置最小的过程中对树的深度进行控制,防止生成过多的叶子节点。比如每次子树的最小样本数量和最小误差率。来控制是否进行更多的划分。

2.2、后剪枝

后剪枝是指将训练样本分为两个部分,一部分用来训练CART树模型,这部分数据称为训练数据,另一部分用来对生成的CART树模型进行剪枝,这一部分称为验证数据。

由上述过程可知,在后剪枝的过程中,通过验证生成好的CART树模型是否在验证数据集上发生了过拟合,如果出现了过拟合的现象,则合并一些叶子节点来达到对CART树模型的剪枝。

import numpy as np
# import cPickle as pickleclass node:'''树的节点的类'''def __init__(self, fea=-1, value=None, results=None, right=None, left=None):self.fea = fea  # 用于切分数据集的属性的列索引值self.value = value  # 设置划分的值self.results = results  # 存储叶节点的值self.right = right  # 右子树self.left = left  # 左子树def load_data(data_file):'''导入训练数据input:  data_file(string):保存训练数据的文件output: data(list):训练数据'''data = []f = open(data_file)for line in f.readlines():sample = []lines = line.strip().split("\t")for x in lines:sample.append(float(x))  # 转换成float格式data.append(sample)f.close()return datadef split_tree(data, fea, value):'''根据特征fea中的值value将数据集data划分成左右子树input:  data(list):训练样本fea(float):需要划分的特征indexvalue(float):指定的划分的值output: (set_1, set_2)(tuple):左右子树的聚合'''set_1 = []  # 右子树的集合set_2 = []  # 左子树的集合for x in data:if x[fea] >= value:set_1.append(x)else:set_2.append(x)return (set_1, set_2)def leaf(dataSet):'''计算叶节点的值input:  dataSet(list):训练样本output: np.mean(data[:, -1])(float):均值'''data = np.mat(dataSet)return np.mean(data[:, -1])def err_cnt(dataSet):'''回归树的划分指标input:  dataSet(list):训练数据output: m*s^2(float):总方差'''data = np.mat(dataSet)return np.var(data[:, -1]) * np.shape(data)[0]def build_tree(data, min_sample, min_err):'''构建树input:  data(list):训练样本min_sample(int):叶子节点中最少的样本数min_err(float):最小的erroroutput: node:树的根结点'''# 构建决策树,函数返回该决策树的根节点if len(data) <= min_sample:return node(results=leaf(data))# 1、初始化best_err = err_cnt(data)bestCriteria = None  # 存储最佳切分属性以及最佳切分点bestSets = None  # 存储切分后的两个数据集# 2、开始构建CART回归树feature_num = len(data[0]) - 1for fea in range(0, feature_num):feature_values = {}for sample in data:feature_values[sample[fea]] = 1for value in feature_values.keys():# 2.1、尝试划分(set_1, set_2) = split_tree(data, fea, value)if len(set_1) < 2 or len(set_2) < 2:continue# 2.2、计算划分后的error值now_err = err_cnt(set_1) + err_cnt(set_2)# 2.3、更新最优划分if now_err < best_err and len(set_1) > 0 and len(set_2) > 0:best_err = now_errbestCriteria = (fea, value)bestSets = (set_1, set_2)# 3、判断划分是否结束if best_err > min_err:right = build_tree(bestSets[0], min_sample, min_err)left = build_tree(bestSets[1], min_sample, min_err)return node(fea=bestCriteria[0], value=bestCriteria[1], \right=right, left=left)else:return node(results=leaf(data))  # 返回当前的类别标签作为最终的类别标签def predict(sample, tree):'''对每一个样本sample进行预测input:  sample(list):样本tree:训练好的CART回归树模型output: results(float):预测值'''# 1、只是树根if tree.results != None:return tree.resultselse:# 2、有左右子树val_sample = sample[tree.fea]  # fea处的值branch = None# 2.1、选择右子树if val_sample >= tree.value:branch = tree.right# 2.2、选择左子树else:branch = tree.leftreturn predict(sample, branch)def cal_error(data, tree):''' 评估CART回归树模型input:  data(list):tree:训练好的CART回归树模型output: err/m(float):均方误差'''m = len(data)  # 样本的个数   n = len(data[0]) - 1  # 样本中特征的个数err = 0.0for i in range(m):tmp = []for j in range(n):tmp.append(data[i][j])pre = predict(tmp, tree)  # 对样本计算其预测值# 计算残差err += (data[i][-1] - pre) * (data[i][-1] - pre)return err / m# def save_model(regression_tree, result_file):
#     '''将训练好的CART回归树模型保存到本地
#     input:  regression_tree:回归树模型
#             result_file(string):文件名
#     '''
#     with open(result_file, 'w') as f:
#         pickle.dump(regression_tree, f)if __name__ == "__main__":# 1、导入训练数据print("----------- 1、load data -------------")data = load_data("C:\\Python-Machine-Learning-Algorithm-master\\Chapter_9 CART\\sine.txt")# 2、构建CART树print("----------- 2、build CART ------------")regression_tree = build_tree(data, 30, 0.3)# 3、评估CART树print("----------- 3、cal err -------------") err = cal_error(data, regression_tree)print("\t--------- err : ", err)
#     # 4、保存最终的CART模型
#     print "----------- 4、save result -----------"  
#     save_model(regression_tree, "regression_tree")
----------- 1、load data -------------
----------- 2、build CART ------------
----------- 3、cal err ---------------------- err :  0.017472194888

http://chatgpt.dhexx.cn/article/Nw8Xjjri.shtml

相关文章

剪枝、cart树

一、剪枝 1. 为什么要剪枝 在决策树生成的时候&#xff0c;更多考虑的是训练数据&#xff0c;而不是未知数据&#xff0c;这会导致过拟合&#xff0c;使树过于复杂&#xff0c;对于未知的样本不准确。 2. 剪枝的依据——通过极小化决策树的损失函数 损失函数的定义为&#x…

【机器学习】决策树——CART分类回归树(理论+图解+公式)

&#x1f320; 『精品学习专栏导航帖』 &#x1f433;最适合入门的100个深度学习实战项目&#x1f433;&#x1f419;【PyTorch深度学习项目实战100例目录】项目详解 数据集 完整源码&#x1f419;&#x1f436;【机器学习入门项目10例目录】项目详解 数据集 完整源码&…

CART树(分类回归树)

主要内容 &#xff08;1&#xff09;CART树简介 &#xff08;2&#xff09;CART树节点分裂规则 &#xff08;3&#xff09;剪枝 --------------------------------------------------------------------------------------------------------------------- 一、简介 CART…

CART树

算法概述 CART(Classification And Regression Tree)算法是一种决策树分类方法。 它采用一种二分递归分割的技术&#xff0c;分割方法采用基于最小距离的基尼指数估计函数&#xff0c;将当前的样本集分为两个子样本集&#xff0c;使得生成的的每个非叶子节点都有两个分支。因此…

Pytorch之view,reshape,resize函数

对于深度学习中的一下数据&#xff0c;我们通常是要变成tensor格式&#xff0c;并且需要对其调整形状&#xff0c;很多时候我们往往只关注view之后的结果&#xff08;比如输出的尺寸&#xff09;&#xff0c;而不关心过程。但有时候还是要关注一下这个到底是怎么变换过来的&…

OpenCV-Python图像处理:插值方法及使用resize函数进行图像缩放

☞ ░ 前往老猿Python博客 https://blog.csdn.net/LaoYuanPython ░ 图像缩放用于对图像进行缩小或扩大&#xff0c;当图像缩小时需要对输入图像重采样去掉部分像素&#xff0c;当图像扩大时需要在输入图像中根据算法生成部分像素&#xff0c;二者都会利用插值算法来实现。 一…

vector的resize函数和reserve函数

博客原文&#xff1a;C基础篇 -- vector的resize函数和reserve函数_VampirEM_Chosen_One的博客-CSDN博客&#xff0c;写的特别好&#xff0c;谢谢原博主。 正文&#xff1a; 对于C的vector容器模板类&#xff0c;存在size和capacity这样两个概念&#xff0c;可以分别通过vect…

OpenCV 图片尺寸缩放——resize函数

文章目录 OpenCV中的缩放&#xff1a;resize函数代码案例 OpenCV中的缩放&#xff1a; 如果要放大或缩小图片的尺寸&#xff0c;可以使用OpenCV提供的两种方法&#xff1a; resize函数&#xff0c;是最直接的方式&#xff1b;pyrUp&#xff0c;pyrDown函数&#xff0c;即图像…

OpenCV的resize函数优化

背景 在使用OpenCV做图像处理的时候&#xff0c;最常见的问题是c版本性能不足&#xff0c;以resize函数为例来说明&#xff0c;将size为[864,1323,3]的函数缩小一半&#xff1a; Mat img0;gettimeofday(&t4, NULL);cv::resize(source, img0, cv::Size(cols_out,rows_out)…

C++ | resize函数的用法

最近在leetcode用C刷题&#xff0c;刚遇到一题需要给重新弄一个容器&#xff0c;并给容器初始化。新建容器要在private类中声明resize函数用来初始化preSum容器。 resize函数是C中序列式容器的一个共性函数&#xff0c;vv.resize(int n,element)表示调整容器vv的大小为n&#x…

opencv的resize函数

一、Opencv官方文档中resize的描述&#xff1a; resize Resizes an image. C: void resize(InputArray src, OutputArray dst, Size dsize, double fx0, double fy0, int interpolationINTER_LINEAR ) Python: cv2.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]]) …

resize()函数

resize()&#xff0c;设置大小&#xff08;size&#xff09;; reserve()&#xff0c;设置容量&#xff08;capacity&#xff09;; size()是分配容器的内存大小&#xff0c;而capacity()只是设置容器容量大小&#xff0c;但并没有真正分配内存。 打个比方&#xff1a;正在建造…

OpenCV 图像缩放:cv.resize() 函数详解

目录 系列前言API函数详解参数列表缩放方式其一缩放方式其二两种方式的优先级关于插值方式 扩展 —— 相关函数 系列前言 这个系列是我第一个想要更下去的系列。每篇会全面介绍一个 OpenCV 函数&#xff0c;会给出 API 和示例。示例主要是用 Python 去写&#xff0c;但是 Open…

安卓中的几种线程间通信方式

一&#xff1a;Handler实现线程间的通信 andriod提供了 Handler 和 Looper 来满足线程间的通信。例如一个子线程从网络上下载了一副图片&#xff0c;当它下载完成后会发送消息给主线程&#xff0c;这个消息是通过绑定在主线程的Handler来传递的。 在Android&#xff0c;这里的…

Java中的线程通信的几种方式

Java中的线程间通信是指不同线程之间相互协作&#xff0c;以完成一些复杂的任务或实现某些功能的过程。线程间通信主要包括两个方面&#xff1a;线程之间的互斥和同步&#xff0c;以及线程之间的数据共享和通信。Java提供了多种方式来实现线程间通信&#xff0c;本文将介绍Java…

创建线程的四种方式 线程通信

文章目录 1.1 创建线程1.1.1 创建线程的四种方式1.1.2 Thread类与Runnable接口的比较1.1.3 Callable、Future与FutureTask 1.2 线程组和线程优先级1.3 Java线程的状态及主要转化方法1.4 Java线程间的通信1.4.1 等待/通知机制1.4.2 信号量1.4.3 管道 1.1 创建线程 1.1.1 创建线…

【多线程间几种通信方式】

一、使用 volatile 关键字 基于 volatile 关键字来实现线程间相互通信是使用共享内存的思想。大致意思就是多个线程同时监听一个变量&#xff0c;当这个变量发生变化的时候 &#xff0c;线程能够感知并执行相应的业务。这也是最简单的一种实现方式 代码案例 package com.han…

线程之间的通信方式

前言 我只是个搬运工&#xff0c;尊重原作者的劳动成果&#xff0c;本文来源下列文章链接&#xff1a; https://zhuanlan.zhihu.com/p/129374075 https://blog.csdn.net/jisuanji12306/article/details/86363390 线程之间为什么要通信&#xff1f; 通信的目的是为了更好的协…

Java线程间的通信方式

文章目录 线程间通信的定义一、等待—通知&#xff08;1&#xff09;等待—通知机制的相关方法&#xff1a;&#xff08;2&#xff09;注意事项&#xff1a;&#xff08;4&#xff09;notify()方法的核心原理&#xff08;5&#xff09;等待—通知机制的经典范式&#xff08;6&a…

线程间实现通信的几种方式

目录 线程通信相关概述提出问题方式一&#xff1a;使用Object类的wait() 和 notify() 方法方式二&#xff1a;Lock 接口中的 newContition() 方法返回 Condition 对象&#xff0c;Condition 类也可以实现等待/通知模式方法三&#xff1a;使用 volatile 关键字方法四&#xff1a…