使用深度森林(Deep Forest)进行分类-Python

article/2025/9/25 22:29:56

一、什么是深度森林?

传统DNN的不足:

1、需要大量的数据集来训练;

2DNN的模型太复杂;

3DNN有着太多的超参数

gcForest的优势:

1、更容易训练;

2、性能更佳;

3、效率高且可扩展、支持小规模训练数据。

        深度森林是一个新的基于树的集成学习方法,它通过对树构成的森林进行集成并串联起来达到让分类器做表征学习的目的,从而提高分类的效果。

二、深度森林的结构

其结构主要包括级联森林和多粒度扫描。

1.级联森林

        级联森林的构成:级联森林的每一个Level包含若干个集成学习的分类器(这里是决策树森林),这是一种集成中的集成的结构

        为了体现多样性,这里用了两种代表了若干不同的集成学习器。图中的级联森林每一层包括两个完全随机森林(黑色)和两个随机森林(蓝色)。这两种森林的主要区别在于候选特征空间,完全随机森林是在完整的特征空间中随机选取特征来分裂,而普通随机森林是在一个随机特征子空间内通过基尼系数来选取分裂节点。

       我们都知道决策树其实是在特征空间中不断划分子空间,并且给每个子空间打上标签(分类问题就是一个类别,回归问题就是一个目标值),所以给予一条测试样本,每棵树会根据样本所在的子空间中训练样本的类别占比生成一个类别的概率分布,然后对森林内所有树的各类比例取平均,输出整个森林对各类的比例。也就是每个森林都会生成一个长度为 C 的概率向量,假如 gcforest 的每一层由 N 个森林构成,那么每一层的输出就是 N C 维向量连接在一起,即 C*N 维向量。gcForest采用了DNN中的layer-by-layer结构,从前一层输入的数据和输出结果数据做concat作为下一层的输入。这个向量然后与输入到级联的下一层的原始特征向量相拼接(图 中粗红线部分),作为下一层的输入.这样我们就做了一次特征变化,并保留了原始特征继续后续处理,每一层都这样,最后一层将所有随机森林输出的三维向量加和求平均算出最大的一维作为最终输出。

例如,在图中的三分类问题中,每层由 4 个随机森林构成,而每个森林都将生成一个 3 维向量,因此,每层产生一个 4*3=12 维的特征向量,此特征向量将作为下一层的输入增强原始特征。

 

为了降低过拟合风险,每个森林生成的类向量是通过 k 折交叉验证产生的即,每个样本都会被 当作训练数据训练 k-1 ,产生 k-1 个类3维向量,然后对其取平均值即为这个森林最终特征向量,再将这4个森林的3维特征向量连在一起,作为下一层的增强特征向量.

那么这样一层一层接下去什么时候停止呢。在扩展一个新的层后,整个级联的性能将在验证集上进行评估,如果没有显着的性能提升,训练过程将终止 因此,级联中层的数量是自动确定的。

下图是每个森林的决策过程。

2.多粒度扫描

在日常生活中,由于数据的特征之间可能存在某种关系,例如,在图像识别中,位置相近的像素点之间有很强的空间关系,序列数据有顺序上的关系。gcForest 使用多粒度扫描对级联森林进行增强,,它利用多种大小的滑动 窗口进行采样,以获得更多的特征子样本,从而达到多粒度扫描的效果。

比如图中的例子:

对于序列数据,假设我们的输入特征是400维,扫描窗口大小是100维,这样就得到301100维的特征向量,每个100维的特征向量对应一个3分类的类向量(3维类向量),即得到:301个*3维类向量!最终每棵森林会得到903维的特征变量!特征就更多更丰富对于图像数据的处理和序列数据一样,图像数据的扫描方式当然是从左到右、从上到下,而序列数据只是从上到下。可以用各种尺寸不等的扫描窗口去扫描,这样就会得到更多的、更丰富的特征关系!

3.整体结构

那么深度森林一个整体结构就如图所示:

在图中,假设有 3 个类,并且分别使用 100 200 300 维的窗口在原始 400 维的特征上进行滑动。(3个级联森林 每个级联森林都N层)得到特征向量后再使用级联森林进行训练,得到最后的预测模型和结果。

三、实践-简单分类实例

Paper:https://arxiv.org/abs/1702.08835v2
Github:https://github.com/kingfengji/gcForest
Website:http://lamda.nju.edu.cn/code_gcForest.ashx

南京大学机器学习与数据挖掘研究所提供了基于Python 2.7官方实现版本,在本文中,我们使用基于Python3实现的gcForest实现分类任务。

Github:https://github.com/pylablanche/gcForest

gcForest类与sklearn包装的分类器使用方法类似,使用 a .fit() 进行训练,使用a .predict() 进行预测。其中需要我们进行设置的属性为shape_1X和window。shape_1X由数据集决定(所有样本必须具有相同的形状),而window取决于我们自己的选择。

分类器构建时需要的参数如下所示:

shape_1X: int or tuple list or np.array (default=None)训练量样本的大小,格式为[n_lines, n_cols]. n_mgsRFtree: int (default=30)多粒度扫描时构建随即森林使用的决策树数量.window: int (default=None)多粒度扫描时的数据扫描窗口大小.stride: int (default=1)数据切片时的步长大小.cascade_test_size: float or int (default=0.2)级联训练时的测试集大小.n_cascadeRF: int (default=2)每个级联层的随机森林的大小.n_cascadeRFtree: int (default=101)每个级联层的随即森林中包含的决策树的数量.min_samples_mgs: float or int (default=0.1)多粒度扫描期间,要执行拆分行为时节点中最小样本数.min_samples_cascade: float or int (default=0.1)训练级联层时,要执行拆分行为时节点中最小样本数.cascade_layer: int (default=np.inf)级联层层数的最大值tolerance: float (default=0.0)判断级联层是否增长的准确度公差。如果准确性的提高不如tolerance,那么层数将停止增长。n_jobs: int (default=1)随机森林并行运行的工作数量。如果为-1,则设置为cpu核心数.

在这里,我们使用sklearn带有的Iris数据集进行分类测试,Iris数据集是常用的分类实验数据集。Iris也称鸢尾花卉数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(SetosaVersicolourVirginica)三个种类中的哪一类。

iris.data 原始数据集,150×4,4代表四个属性。

[[5.1 3.5 1.4 0.2] [4.9 3.  1.4 0.2] [4.7 3.2 1.3 0.2] [4.6 3.1 1.5 0.2] [5.  3.6 1.4 0.2]…...等等,不一一列出。

iris.target:目标分类结果数据集  150x1  。012分别代表3个类 ,每个类有50个样本。

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2  2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]

具体代码如下所示:

# 对鸢尾花数据集进行测试
iris = load_iris()
X, y = iris.data, iris.target 

#===iris.data这是150个数据集 每个数据有四个属性

#iris.target150个数据的分类结果 0 1 2 分别表示三个类
print('==========================Data Shape======================')
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.3)

#从样本中随机的按比例选取train_data和test_data这里去0.3 即训练数据和测试数据都是随机选45

model = GCForest.gcForest(shape_1X=4, window=2, tolerance=0.0)
model.fit(X_train, y_train) #fit(X,y) 在输入数据X和相关目标y上训练gcForest;
joblib.dump(model,'irisModel.sav') #持久化存储 保存模型  加载模型
model=joblib.load('irisModel.sav')
y_predict = model.predict_proba(X_test)
#预测未知样本X的类概率;
y_predict = y_predict.tolist()
y_predict1 = model.predict(X_test) #预测未知样本X的类别;
print('==========================y_predict======================')
print('预测的分类结果',y_predict1)
print("---每个样本对应每个类别的概率---")
for one_res in y_predict:
   
print(one_res)
accuarcy = accuracy_score(
y_true=y_test, y_pred=y_predict1)
print('gcForest accuarcy : {}'.format(accuarcy))


http://chatgpt.dhexx.cn/article/PCyYiM2P.shtml

相关文章

深度森林浅析

深度森林 深度学习最大的贡献是表征学习(representation learning),通过端到端的训练,发现更好的features,而后面用于分类(或其他任务)的输出function,往往也只是普通的softmax&…

深度森林原理及实现

目录 背景 级联森林 多粒度扫描 代码 总结 背景 深度森林(Deep Forest)是周志华教授和冯霁博士在2017年2月28日发表的论文《Deep Forest: Towards An Alternative to Deep Neural Networks》中提出来的一种新的可以与深度神经网络相媲美的基于树的模型,其结构…

论文阅读:深度森林

论文地址:https://arxiv.org/pdf/1702.08835.pdf 相关代码:https://github.com/kingfengji/gcForest 深度森林是南大周志华老师前两年提出的一种基于随机森林的深度学习模型。 当前的深度学习模型大多基于深度学习神经网络(DNN)…

Deep Forest(gcforest)通俗易懂理解

DeepForest(gcforest)深度森林介绍 1.背景介绍 当前的深度学习模型主要建立在神经网络上,即可以通过反向传播训练的多层参数化可微分非线性模块,周志华老师希望探索深度学习模型的新模式,探索不可微模块构建深度模型的可能性。从而提出了一…

【深度学习】Deep Forest:gcForest算法理解

一、相关理论 本篇博文主要介绍南京大学周志华教授在2017年提出的一种深度森林结构——gcForest(多粒度级联森林)。近年来,深度神经网络在图像和声音处理领域取得了很大的进展。关于深度神经网络,我们可以把它简单的理解为多层非…

从深度学习到深度森林方法(Python)

作者 |泳鱼 来源 |算法进阶 一、深度森林的介绍 目前深度神经网络(DNN)做得好的几乎都是涉及图像视频(CV)、自然语言处理(NLP)等的任务,都是典型的数值建模任务(在表格数据tabular …

深度森林(gcforest)原理讲解以及代码实现

GcForest原理 gcforest采用Cascade结构,也就是多层级结构,每层由四个随机森林组成,两个随机森林和两个极端森林,每个极端森林包含1000(超参数)个完全随机树,每个森林都会对的数据进行训练,每个森林都输出结…

有关 -fPIC 选项的编译问题

嵌入式 Linux 上基于 makefile 的编译,执行编译出现如下错误: error: test/fake_blemgr_test.o: requires unsupported dynamic reloc R_ARM_REL32; recompile with -fPIC 这个问题涉及到生成位置无关代码的机制,在我这里的情况是&#xff0…

-fpic 与-fPIC的区别

-fpic 与-fPIC的区别 前言 在编译动态库的时候,我们应该需要使用-fpic 或-fPIC参数。如下所示: 然后,使用gcc或g 命令生成动态库 pic 与PIC的异同 相同点:都是为了在动态库中生成位置无关的代码。通过全局偏移表(GOT&…

在字符串中删除指定的特定字符

使用C语言,编写一个尽可能高效的函数,删除字符串中特定字符。 思路:要求尽可能高效,定义一个256的int数组,将需要删除的字符ASCII作为数组下标,要删除的置1,注意点:处理后的字符串要…

实现一个删除字符串中的指定字符的简单函数

做出此函数只需将函数分步运行 第一步&#xff1a;确定函数接口和定义变量 因为是要删除字符串中的指定函数&#xff0c;所以要定义出字符串以及字符&#xff0c;即str[]和c。函数接口为void delchar&#xff08;char*str&#xff0c;char c); #include<stdio.h> int …

删除字符串某一指定子字符串

文章目录 功能&#xff1a;删除字符串某一指定子字符串业务场景实现方法1、Java实现 功能&#xff1a;删除字符串某一指定子字符串 业务场景 在数据库中有一个String类型的字符串&#xff0c;该字符串通过逗号进行分割&#xff0c;现在前端传过来字符串中的一个子字符串&…

删除字符串中指定的字符C语言详解

问题描述&#xff1a; 从键盘输入一个字符串和一个字符&#xff0c;删除字符串中所指定的字符&#xff0c;将结果保存到一个新的字符串中并输出 编译环境&#xff1a;vc 6.0; 代码 #include <stdio.h>int main() {char str1[80], str2[80], ch;int i,j0;printf("…

C语言删除字符串中的指定字符

一、函数方法&#xff08;推荐使用这个方法&#xff09; 只需要进行一次对目标字符串的遍历即可完成删除目标字符的功能&#xff0c;具体的代码如下所示&#xff1a; void delchar( char *str, char c ) {int i,j;for(ij0;str[i]!\0;i){if(str[i]!c)//判断是否有和待删除字符一…

C++ 删除指定字符串中的某些字符

C 删除指定字符串中的某些字符 题目 输入URL前缀和后缀&#xff0c;删除字符串中的“,/”&#xff0c;把URL拼接在一起且后边有“/”。 用例&#xff1a; 输入&#xff1a;/abhdsjvf/,/afsggfd 输出&#xff1a;/abhdsjvf/afsggfd/ 程序实现&#xff1a; #include<iostrea…

Windows server 2012 R2安装教程

镜像下载地址&#xff1a; ed2k://|file|cn_windows_server_2012_r2_vl_with_update_x64_dvd_6052729.iso|5545527296|BD499EBCABF406AB82293DD8A5803493|/ 1&#xff0c;语言&#xff0c;键盘输入法&#xff0c;默认&#xff0c;点击下一步 2&#xff0c;点击 现在安装 3&…

SQL Serevr 2012 安装教程

需要的工具 SQL Server 2012R2 镜像 ISO WINDOWS SERVER 2012R2 操作系统 安装过程 1、打开安装文件&#xff0c;打开 setup 应用程序图标 2、在 SQL server 安装中心窗口中&#xff0c;点击安装– 点击全新 SQL Server 独立安装或向现有安装添加功能 3、点击确定 4、输入产…

Microsoft SQL Server 2008 R2 官方简体中文正式版下载(附激活序列号密钥)

微软官方发布的Microsoft SQL Server 2008 R2 简体中文完整版。基于SQL Server 2008提供可靠高效的智能数据平台构建而成&#xff0c;SQL Server 2008 R2 提供了大量新改进&#xff0c;可帮助您的组织满怀信心地调整规模、提高 IT 效率并实现管理完善的自助 BI。此版本中包含应…

SqlServer2012下载+安装+启动(资源+密钥)

一、下载 此处提供一个下载链接。具体地址如下&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1zaMOa-nA19mZxStrBK9jjg 提取码&#xff1a;k03s 下载压缩包是SqlServer2012版镜像&#xff0c; 二、安装 直接双击下载得到的镜像文件&#xff0c;后双击文件夹的如下set…

超详细windows server2012 r2 虚拟机安装步骤

1、Windows Server 2012 R2简介 Windows Server 2012 R2 是基于Windows8.1 以及Windows RT 8.1 界面的新一代 Windows Server 操作系统&#xff0c;提供企业级数据中心和混合云解决方案&#xff0c;易于部署、具有成本效益、以应用程序为重点、以用户为中心。 在 Microsoft 云…