深度森林原理及实现

article/2025/9/25 16:20:35

目录

背景

级联森林

多粒度扫描

代码

总结


背景

深度森林(Deep Forest)是周志华教授和冯霁博士在2017年2月28日发表的论文《Deep Forest: Towards An Alternative to Deep Neural Networks》中提出来的一种新的可以与深度神经网络相媲美的基于树的模型,其结构如图所示。


 

级联森林

 

上图表示gcForest的级联结构。

每一层都由多个随机森林组成。通过随机森林学习输入特征向量的特征信息,经过处理后输入到下一层。为了增强模型的泛化能力,每一层选取多种不同类型的随机森林,上图给了两种随机森林结构,分别为completely-random tree forests(蓝色)和random forests(黑色),每种两个。其中,每个completely-random tree forests包含1000棵树,每个节点通过随机选取一个特征作为判别条件,并根据这个判别条件生成子节点,直到每个叶子节点只包含同一类的实例而停止;每个random forests同样包含1000棵树,节点特征的选择通过随机选择√d个特征(d为输入特征的数量),然后选择基尼系数最大特征作为该节点划分的条件。

级联森林的迭代终止条件:迭代到效果不能提升就停止!!!

级联森林中每个森林是如何决策的呢?

每个森林中都包括好多棵决策树,每个决策树都会决策出一个类向量结果(以3类为例,下面也是),然后综合所有的决策树结果,再取均值,生成每个森林的最终决策结果——一个3维类向量!每个森林的决策过程如下图所示。

 

这样,每个森林都会决策出一个3维类向量,回到图1中,级联森林中的4个森林就都可以决策出一个3维类向量,然后对4个*3维类向量取均值,最后取最大值对应的类别,作为最后的预测结果!

多粒度扫描

多粒度扫描是为了增强级联森林,为了对特征做更多的处理的一种技术手段,具体扫描过程如下图所示。

 

上图表示对输入特征使用多粒度扫描的方式产生级联森林的输入特征向量。

对于400维的序列数据,采用100维的滑动窗对输入特征进行处理,得到301(400 - 100 + 1)个100维的特征向量。对于20×20的图像数据,采用10×10的滑动窗对输入特征进行处理,得到121((20-10+1)*(20-10+1))个10×10的二维特征图。然后将得到的特征向量(或特征图)分别输入到一个completely-random tree forest和一个random forest中(不唯一,也可使用多个森林),以三分类为例,会得到301(或121)个3维类分布向量,将这些向量进行拼接,得到1806(或726)维的特征向量。

​​​​​​​

 

第一步:使用多粒度扫描对输入特征进行预处理。以使用三个尺寸的滑动窗为例,分别为100-dim,200-dim和300-dim。输入数据为400-dim的序列特征,使用100-dim滑动窗会得到301个100-dim向量,然后输入到一个completely-random tree forest和一个random forest中,两个森林会分别得到的301个3-dim向量(3分类),将两个森林得到的特征向量进行拼接,会得到1806-dim的特征向量。同理,使用200-dim和300-dim滑动窗会分别得到1206-dim和606-dim特征向量。

第二步:将得到的特征向量输入到级联森林中进行训练。首先使用100-dim滑动窗得到的1806-dim特征向量输入到第一层级联森林中进行训练,得到12-dim的类分布向量(3分类,4棵树)。然后将得到的类分布向量与100-dim滑动窗得到的特征向量进行拼接,得到1818-dim特征向量,作为第二层的级联森林的输入数据;第二层级联森林训练得到的12-dim类分布向量再与200-dim滑动窗得到的特征向量进行拼接。作为第三层级联森林的输入数据;第三层级联森林训练得到的12-dim类分布向量再与300-dim滑动窗得到的特征向量进行拼接,做为下一层的输入。一直重复上述过程,直到验证收。

代码:

去GitHub上下载源码放在自己python文件夹下的site-packages里面

from gcforest.gcforest import GCForest

运行没有报错,即安装成功

参数配置

#训练的配置,采用默认的模型-即原库代码实现方式
def get_toy_config():config = {}ca_config = {}ca_config["random_state"] = 0  # 0 or 1ca_config["max_layers"] = 100 # 最大的层数,layer对应论文中的levelca_config["early_stopping_rounds"] = 3 #如果出现某层的三层以内的准确率都没有提升,层中止ca_config["n_classes"] = 2 #判别的类别数量ca_config["estimators"] = []ca_config["estimators"].append({"n_folds": 2, "type": "RandomForestClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})ca_config["estimators"].append({"n_folds": 2, "type": "ExtraTreesClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})ca_config["estimators"].append({"n_folds": 2, "type": "LogisticRegression"})config["cascade"] = ca_config #共使用了3个基学习器return config

本文只选择了三个模型,没有使用xgboost【xgboost输入格式更加严格,需要调试】

config=get_toy_config()
gc = GCForest(config)
#X_train_enc是每个模型最后一层输出的结果,每一个类别的可能性
X_train_enc = gc.fit_transform(data, label)
y_pred = gc.predict(X_test)
acc = accuracy_score(y_test, y_pred)

Test Accuracy of GcForest = 82.84 %

可以使用gcForest得到的X_enc数据进行其他模型的训练比如xgboost/RF

clf = RandomForestClassifier(n_estimators=1000, max_depth=None, n_jobs=-1)
clf.fit(X_train_enc, label)
y_pred = clf.predict(X_test_enc)
acc1 = accuracy_score(y_test, y_pred)
print("Test Accuracy of Other classifier using gcforest's X_encode = {:.2f} %".format(acc1 * 100))

Test Accuracy of Other classifier using gcforest's X_encode = 79.48 %

完整代码放在GitHub上:Recommendation_algorithm/gcForest.py at master · Andyszl/Recommendation_algorithm · GitHub

总结

相比深度神经网络,gcForest有如下若干有点:

1. 容易训练,计算开销小
2.天然适用于并行的部署,效率高
3. 超参数少,模型对超参数调节不敏感,并且一套超参数可使用到不同数据集
4.可以适应于不同大小的数据集,模型复杂度可自适应伸缩
5. 每个级联的生成使用了交叉验证,避免过拟合
6. 在理论分析方面也比深度神经网络更加容易。


 

参考:

官方开源地址:GitHub - kingfengji/gcForest: This is the official implementation for the paper 'Deep forest: Towards an alternative to deep neural networks'

【深度学习】Deep Forest:gcForest算法理解_z小白的博客-CSDN博客_深度森林算法


http://chatgpt.dhexx.cn/article/3L2Cwlpw.shtml

相关文章

论文阅读:深度森林

论文地址:https://arxiv.org/pdf/1702.08835.pdf 相关代码:https://github.com/kingfengji/gcForest 深度森林是南大周志华老师前两年提出的一种基于随机森林的深度学习模型。 当前的深度学习模型大多基于深度学习神经网络(DNN)…

Deep Forest(gcforest)通俗易懂理解

DeepForest(gcforest)深度森林介绍 1.背景介绍 当前的深度学习模型主要建立在神经网络上,即可以通过反向传播训练的多层参数化可微分非线性模块,周志华老师希望探索深度学习模型的新模式,探索不可微模块构建深度模型的可能性。从而提出了一…

【深度学习】Deep Forest:gcForest算法理解

一、相关理论 本篇博文主要介绍南京大学周志华教授在2017年提出的一种深度森林结构——gcForest(多粒度级联森林)。近年来,深度神经网络在图像和声音处理领域取得了很大的进展。关于深度神经网络,我们可以把它简单的理解为多层非…

从深度学习到深度森林方法(Python)

作者 |泳鱼 来源 |算法进阶 一、深度森林的介绍 目前深度神经网络(DNN)做得好的几乎都是涉及图像视频(CV)、自然语言处理(NLP)等的任务,都是典型的数值建模任务(在表格数据tabular …

深度森林(gcforest)原理讲解以及代码实现

GcForest原理 gcforest采用Cascade结构,也就是多层级结构,每层由四个随机森林组成,两个随机森林和两个极端森林,每个极端森林包含1000(超参数)个完全随机树,每个森林都会对的数据进行训练,每个森林都输出结…

有关 -fPIC 选项的编译问题

嵌入式 Linux 上基于 makefile 的编译,执行编译出现如下错误: error: test/fake_blemgr_test.o: requires unsupported dynamic reloc R_ARM_REL32; recompile with -fPIC 这个问题涉及到生成位置无关代码的机制,在我这里的情况是&#xff0…

-fpic 与-fPIC的区别

-fpic 与-fPIC的区别 前言 在编译动态库的时候,我们应该需要使用-fpic 或-fPIC参数。如下所示: 然后,使用gcc或g 命令生成动态库 pic 与PIC的异同 相同点:都是为了在动态库中生成位置无关的代码。通过全局偏移表(GOT&…

在字符串中删除指定的特定字符

使用C语言,编写一个尽可能高效的函数,删除字符串中特定字符。 思路:要求尽可能高效,定义一个256的int数组,将需要删除的字符ASCII作为数组下标,要删除的置1,注意点:处理后的字符串要…

实现一个删除字符串中的指定字符的简单函数

做出此函数只需将函数分步运行 第一步&#xff1a;确定函数接口和定义变量 因为是要删除字符串中的指定函数&#xff0c;所以要定义出字符串以及字符&#xff0c;即str[]和c。函数接口为void delchar&#xff08;char*str&#xff0c;char c); #include<stdio.h> int …

删除字符串某一指定子字符串

文章目录 功能&#xff1a;删除字符串某一指定子字符串业务场景实现方法1、Java实现 功能&#xff1a;删除字符串某一指定子字符串 业务场景 在数据库中有一个String类型的字符串&#xff0c;该字符串通过逗号进行分割&#xff0c;现在前端传过来字符串中的一个子字符串&…

删除字符串中指定的字符C语言详解

问题描述&#xff1a; 从键盘输入一个字符串和一个字符&#xff0c;删除字符串中所指定的字符&#xff0c;将结果保存到一个新的字符串中并输出 编译环境&#xff1a;vc 6.0; 代码 #include <stdio.h>int main() {char str1[80], str2[80], ch;int i,j0;printf("…

C语言删除字符串中的指定字符

一、函数方法&#xff08;推荐使用这个方法&#xff09; 只需要进行一次对目标字符串的遍历即可完成删除目标字符的功能&#xff0c;具体的代码如下所示&#xff1a; void delchar( char *str, char c ) {int i,j;for(ij0;str[i]!\0;i){if(str[i]!c)//判断是否有和待删除字符一…

C++ 删除指定字符串中的某些字符

C 删除指定字符串中的某些字符 题目 输入URL前缀和后缀&#xff0c;删除字符串中的“,/”&#xff0c;把URL拼接在一起且后边有“/”。 用例&#xff1a; 输入&#xff1a;/abhdsjvf/,/afsggfd 输出&#xff1a;/abhdsjvf/afsggfd/ 程序实现&#xff1a; #include<iostrea…

Windows server 2012 R2安装教程

镜像下载地址&#xff1a; ed2k://|file|cn_windows_server_2012_r2_vl_with_update_x64_dvd_6052729.iso|5545527296|BD499EBCABF406AB82293DD8A5803493|/ 1&#xff0c;语言&#xff0c;键盘输入法&#xff0c;默认&#xff0c;点击下一步 2&#xff0c;点击 现在安装 3&…

SQL Serevr 2012 安装教程

需要的工具 SQL Server 2012R2 镜像 ISO WINDOWS SERVER 2012R2 操作系统 安装过程 1、打开安装文件&#xff0c;打开 setup 应用程序图标 2、在 SQL server 安装中心窗口中&#xff0c;点击安装– 点击全新 SQL Server 独立安装或向现有安装添加功能 3、点击确定 4、输入产…

Microsoft SQL Server 2008 R2 官方简体中文正式版下载(附激活序列号密钥)

微软官方发布的Microsoft SQL Server 2008 R2 简体中文完整版。基于SQL Server 2008提供可靠高效的智能数据平台构建而成&#xff0c;SQL Server 2008 R2 提供了大量新改进&#xff0c;可帮助您的组织满怀信心地调整规模、提高 IT 效率并实现管理完善的自助 BI。此版本中包含应…

SqlServer2012下载+安装+启动(资源+密钥)

一、下载 此处提供一个下载链接。具体地址如下&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1zaMOa-nA19mZxStrBK9jjg 提取码&#xff1a;k03s 下载压缩包是SqlServer2012版镜像&#xff0c; 二、安装 直接双击下载得到的镜像文件&#xff0c;后双击文件夹的如下set…

超详细windows server2012 r2 虚拟机安装步骤

1、Windows Server 2012 R2简介 Windows Server 2012 R2 是基于Windows8.1 以及Windows RT 8.1 界面的新一代 Windows Server 操作系统&#xff0c;提供企业级数据中心和混合云解决方案&#xff0c;易于部署、具有成本效益、以应用程序为重点、以用户为中心。 在 Microsoft 云…

server 2012 各个版本的安装及激活教程

服务器引导盘的部分我就不多解释了&#xff0c;直接讲重点&#xff0c;需要注意的地方 Windows Server 2012 R2 安装密钥&#xff08;只适用安装&#xff0c;不支持激活&#xff09; 标准版 NB4WH-BBBYV-3MPPC-9RCMV-46XCB 数据中心版 BH9T4-4N7CW-67J3M-64J36-WW98Y 安装的…

VMware虚拟机安装Windows Server 2012 R2

想必同学们已经开学了&#xff0c;也都进入了军训阶段吧&#xff0c;而很多计算机网络专业的同学们要开始接触到Windows Server了&#xff0c;这也是计算机网络技术专业的专业基础课程&#xff0c;想当年我们实训课学习使用的好像是2008版的&#xff0c;也不晓得现在各个学校会…