深度森林(gcforest)原理讲解以及代码实现

article/2025/9/25 22:29:57

GcForest原理

gcforest采用Cascade结构,也就是多层级结构,每层由四个随机森林组成,两个随机森林和两个极端森林,每个极端森林包含1000(超参数)个完全随机树,每个森林都会对的数据进行训练,每个森林都输出结果,我们把这个结果叫做森林生成的类向量,为了避免过拟合,我们喂给每个森林训练的数据都是通过k折交叉验证的,每一层最后生成四个类向量,下一层以上一层的四个类向量以及原有的数据为新的train data进行训练,如此叠加,最后一层将类向量进行平均,作为预测结果

个人认为这种结构非常类似于神经网络,神经网络的每个单位是神经元,而深度森林的单位元却是随机森林,单个随机森林在性能上强于单个神经元的,这就是使得深度森林很多时候尽管层级和基础森林树不多,也能取得好的结果的主要原因

在这里插入图片描述

GcForest代码实现原理图

我们需要做出一个layer的结构,每个layer由二种四个森林组成

在这里插入图片描述

每个layer都输出两个结果:每个森林的预测结果、四个森林的预测的平均结果

为了防止过拟合我们的layer都由k折交叉验证产生

在这里插入图片描述

同时为了保留数据全部的特征我们将得到的小layer叠在一起定义一个Biger Layer

之后我们就可以构建深度森林了

在这里插入图片描述

GcForest代码

layer.py

extraTree(极端树)使用的所有的样本,只是特征是随机选取的,因为分裂是随机的,所以在某种程度上比随机森林得到的结果更加好

from sklearn.ensemble import ExtraTreesRegressor#引入极端森林回归
from sklearn.ensemble import RandomForestRegressor#引入随机森林回归
import numpy as npclass Layer:#定义层类def __init__(self, n_estimators, num_forests, max_depth=30, min_samples_leaf=1):self.num_forests = num_forests  # 定义森林数self.n_estimators = n_estimators  # 每个森林的树个数self.max_depth = max_depth#每一颗树的最大深度self.min_samples_leaf = min_samples_leaf#树会生长到所有叶子都分到一个类,或者某节点所代表的样本数已小于min_samples_leafself.model = []#最后产生的类向量def train(self, train_data, train_label, weight, val_data):#训练函数val_prob = np.zeros([self.num_forests, val_data.shape[0]])#定义出该层的类向量,有self.num_forersts行,val_data.shape[0]列,这里我们认为val_data应该就是我们的weightfor forest_index in range(self.num_forests):#对具体的layer内的森林进行构建if forest_index % 2 == 0:#如果是第偶数个设为随机森林clf = RandomForestRegressor(n_estimators=self.n_estimators#子树的个数,n_jobs=-1, #cpu并行树,-1表示和cpu的核数相同max_depth=self.max_depth,#最大深度min_samples_leaf=self.min_samples_leaf)clf.fit(train_data, train_label, weight)#weight是取样比重Sample weightsval_prob[forest_index, :] = clf.predict(val_data)#记录类向量else:#如果是第奇数个就设为极端森林clf = ExtraTreesRegressor(n_estimators=self.n_estimators,#森林所含树的个数n_jobs=-1, #并行数max_depth=self.max_depth,#最大深度min_samples_leaf=self.min_samples_leaf)#最小叶子限制clf.fit(train_data, train_label, weight)val_prob[forest_index, :] = clf.predict(val_data)#记录类向量self.model.append(clf)#组建layer层val_avg = np.sum(val_prob, axis=0)#按列进行求和val_avg /= self.num_forests#求平均val_concatenate = val_prob.transpose((1, 0))#对记录的类向量矩阵进行转置return [val_avg, val_concatenate]#返回平均结果和转置后的类向量矩阵def predict(self, test_data):#定义预测函数,也是最后一层的功能predict_prob = np.zeros([self.num_forests, test_data.shape[0]])for forest_index, clf in enumerate(self.model):predict_prob[forest_index, :] = clf.predict(test_data)predict_avg = np.sum(predict_prob, axis=0)predict_avg /= self.num_forestspredict_concatenate = predict_prob.transpose((1, 0))return [predict_avg, predict_concatenate]class KfoldWarpper:#定义每个树进行训练的所用的数据def __init__(self, num_forests, n_estimators, n_fold, kf, layer_index, max_depth=31, min_samples_leaf=1):#包括森林树,森林使用树的个数,k折的个数,k-折交叉验证,第几层,最大深度,最小叶子节点限制self.num_forests = num_forestsself.n_estimators = n_estimatorsself.n_fold = n_foldself.kf = kfself.layer_index = layer_indexself.max_depth = max_depthself.min_samples_leaf = min_samples_leafself.model = []def train(self, train_data, train_label, weight):num_samples, num_features = train_data.shapeval_prob = np.empty([num_samples])val_prob_concatenate = np.empty([num_samples, self.num_forests])#创建新的空矩阵,num_samples行,num_forest列,用于放置预测结果for train_index, test_index in self.kf:#进行k折交叉验证,在train_data里创建交叉验证的补充X_train = train_data[train_index, :]#选出训练集X_val = train_data[test_index, :]#验证集y_train = train_label[train_index]#训练标签weight_train = weight[train_index]#训练集对应的权重layer = Layer(self.n_estimators, self.num_forests, self.max_depth, self.min_samples_leaf)#加入层val_prob[test_index], val_prob_concatenate[test_index, :] = \layer.train(X_train, y_train, weight_train, X_val)#记录输出的结果self.model.append(layer)#在模型中填充层级,这也是导致程序吃资源的部分,每次进行return [val_prob, val_prob_concatenate]def predict(self, test_data):#定义预测函数,用做下一层的训练数据test_prob = np.zeros([test_data.shape[0]])test_prob_concatenate = np.zeros([test_data.shape[0], self.num_forests])for layer in self.model:temp_prob, temp_prob_concatenate = \layer.predict(test_data)test_prob += temp_probtest_prob_concatenate += temp_prob_concatenatetest_prob /= self.n_foldtest_prob_concatenate /= self.n_foldreturn [test_prob, test_prob_concatenate]

gcforest.py

from sklearn.model_selection import KFold
from layer import *
import numpy as npdef compute_loss(target, predict):#对数误差函数temp = np.log(abs(target + 1)) - np.log(abs(predict + 1))res = np.dot(temp, temp) / len(temp)#向量点成后平均return resclass gcForest:#定义gcforest模型def __init__(self, num_estimator, num_forests, max_layer=2, max_depth=31, n_fold=5):self.num_estimator = num_estimatorself.num_forests = num_forestsself.n_fold = n_foldself.max_depth = max_depthself.max_layer = max_layerself.model = []def train(self,train_data, train_label, weight):num_samples, num_features = train_data.shape# basis processtrain_data_new = train_data.copy()# return valueval_p = []best_train_loss = 0.0layer_index = 0best_layer_index = 0bad = 0kf = KFold(2,True,self.n_fold).split(train_data_new.shape[0])
#这里加入k折交叉验证while layer_index < self.max_layer:print("layer " + str(layer_index))layer = KfoldWarpper(self.num_forests, self.num_estimator, self.n_fold, kf, layer_index, self.max_depth, 1)#其实这一个layer是个夹心layer,是2层layer的平均结果val_prob, val_stack= layer.train(train_data_new, train_label, weight)#使用该层进行训练	train_data_new = np.concatenate([train_data, val_stack], axis=1)#将该层的训练结果也加入到train_data中temp_val_loss = compute_loss(train_label, val_prob)print("val   loss:" + str(temp_val_loss))if best_train_loss < temp_val_loss:#用于控制加入的层数,如果加入的层数较多,且误差没有下降也停止运行bad += 1else:bad = 0best_train_loss = temp_val_lossbest_layer_index = layer_indexif bad >= 3:breaklayer_index = layer_index + 1self.model.append(layer)for index in range(len(self.model), best_layer_index + 1, -1):#删除多余的layerself.model.pop()def predict(self, test_data):test_data_new = test_data.copy()test_prob = []for layer in self.model:predict, test_stack = layer.predict(test_data_new)test_data_new = np.concatenate([test_data, test_stack], axis=1)return predict

test.py

import numpy as np
from gcForest import *
from time import timedef load_data():train_data = np.load()  train_label = np.load()train_weight = np.load()test_data = np.load()test_label = np.load()test_file = np.load() return [train_data, train_label, train_weight, test_data, test_label, test_file]if __name__ == '__main__':train_data, train_label, train_weight, test_data, test_label, test_file = load_data()clf = gcForest(num_estimator = 100, num_forests = 4, max_layer=2, max_depth=100, n_fold=5)start = time()clf.train(train_data, train_label, train_weight)end = time()print("fitting time: " + str(end - start) + " sec")start = time()prediction = clf.predict(test_data)end = time()print("prediction time: " + str(end - start) + " sec")result = {}for index, item in enumerate(test_file):if item not in result:result[item] = prediction[index]else:result[item] = (result[item] + prediction[index]) / 2print(result)


代码源自:科大讯飞开放平台
原理主要参考:周志华老师论文


http://chatgpt.dhexx.cn/article/C8CrGkPy.shtml

相关文章

有关 -fPIC 选项的编译问题

嵌入式 Linux 上基于 makefile 的编译&#xff0c;执行编译出现如下错误&#xff1a; error: test/fake_blemgr_test.o: requires unsupported dynamic reloc R_ARM_REL32; recompile with -fPIC 这个问题涉及到生成位置无关代码的机制&#xff0c;在我这里的情况是&#xff0…

-fpic 与-fPIC的区别

-fpic 与-fPIC的区别 前言 在编译动态库的时候&#xff0c;我们应该需要使用-fpic 或-fPIC参数。如下所示&#xff1a; 然后&#xff0c;使用gcc或g 命令生成动态库 pic 与PIC的异同 相同点&#xff1a;都是为了在动态库中生成位置无关的代码。通过全局偏移表&#xff08;GOT&…

在字符串中删除指定的特定字符

使用C语言&#xff0c;编写一个尽可能高效的函数&#xff0c;删除字符串中特定字符。 思路&#xff1a;要求尽可能高效&#xff0c;定义一个256的int数组&#xff0c;将需要删除的字符ASCII作为数组下标&#xff0c;要删除的置1&#xff0c;注意点&#xff1a;处理后的字符串要…

实现一个删除字符串中的指定字符的简单函数

做出此函数只需将函数分步运行 第一步&#xff1a;确定函数接口和定义变量 因为是要删除字符串中的指定函数&#xff0c;所以要定义出字符串以及字符&#xff0c;即str[]和c。函数接口为void delchar&#xff08;char*str&#xff0c;char c); #include<stdio.h> int …

删除字符串某一指定子字符串

文章目录 功能&#xff1a;删除字符串某一指定子字符串业务场景实现方法1、Java实现 功能&#xff1a;删除字符串某一指定子字符串 业务场景 在数据库中有一个String类型的字符串&#xff0c;该字符串通过逗号进行分割&#xff0c;现在前端传过来字符串中的一个子字符串&…

删除字符串中指定的字符C语言详解

问题描述&#xff1a; 从键盘输入一个字符串和一个字符&#xff0c;删除字符串中所指定的字符&#xff0c;将结果保存到一个新的字符串中并输出 编译环境&#xff1a;vc 6.0; 代码 #include <stdio.h>int main() {char str1[80], str2[80], ch;int i,j0;printf("…

C语言删除字符串中的指定字符

一、函数方法&#xff08;推荐使用这个方法&#xff09; 只需要进行一次对目标字符串的遍历即可完成删除目标字符的功能&#xff0c;具体的代码如下所示&#xff1a; void delchar( char *str, char c ) {int i,j;for(ij0;str[i]!\0;i){if(str[i]!c)//判断是否有和待删除字符一…

C++ 删除指定字符串中的某些字符

C 删除指定字符串中的某些字符 题目 输入URL前缀和后缀&#xff0c;删除字符串中的“,/”&#xff0c;把URL拼接在一起且后边有“/”。 用例&#xff1a; 输入&#xff1a;/abhdsjvf/,/afsggfd 输出&#xff1a;/abhdsjvf/afsggfd/ 程序实现&#xff1a; #include<iostrea…

Windows server 2012 R2安装教程

镜像下载地址&#xff1a; ed2k://|file|cn_windows_server_2012_r2_vl_with_update_x64_dvd_6052729.iso|5545527296|BD499EBCABF406AB82293DD8A5803493|/ 1&#xff0c;语言&#xff0c;键盘输入法&#xff0c;默认&#xff0c;点击下一步 2&#xff0c;点击 现在安装 3&…

SQL Serevr 2012 安装教程

需要的工具 SQL Server 2012R2 镜像 ISO WINDOWS SERVER 2012R2 操作系统 安装过程 1、打开安装文件&#xff0c;打开 setup 应用程序图标 2、在 SQL server 安装中心窗口中&#xff0c;点击安装– 点击全新 SQL Server 独立安装或向现有安装添加功能 3、点击确定 4、输入产…

Microsoft SQL Server 2008 R2 官方简体中文正式版下载(附激活序列号密钥)

微软官方发布的Microsoft SQL Server 2008 R2 简体中文完整版。基于SQL Server 2008提供可靠高效的智能数据平台构建而成&#xff0c;SQL Server 2008 R2 提供了大量新改进&#xff0c;可帮助您的组织满怀信心地调整规模、提高 IT 效率并实现管理完善的自助 BI。此版本中包含应…

SqlServer2012下载+安装+启动(资源+密钥)

一、下载 此处提供一个下载链接。具体地址如下&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1zaMOa-nA19mZxStrBK9jjg 提取码&#xff1a;k03s 下载压缩包是SqlServer2012版镜像&#xff0c; 二、安装 直接双击下载得到的镜像文件&#xff0c;后双击文件夹的如下set…

超详细windows server2012 r2 虚拟机安装步骤

1、Windows Server 2012 R2简介 Windows Server 2012 R2 是基于Windows8.1 以及Windows RT 8.1 界面的新一代 Windows Server 操作系统&#xff0c;提供企业级数据中心和混合云解决方案&#xff0c;易于部署、具有成本效益、以应用程序为重点、以用户为中心。 在 Microsoft 云…

server 2012 各个版本的安装及激活教程

服务器引导盘的部分我就不多解释了&#xff0c;直接讲重点&#xff0c;需要注意的地方 Windows Server 2012 R2 安装密钥&#xff08;只适用安装&#xff0c;不支持激活&#xff09; 标准版 NB4WH-BBBYV-3MPPC-9RCMV-46XCB 数据中心版 BH9T4-4N7CW-67J3M-64J36-WW98Y 安装的…

VMware虚拟机安装Windows Server 2012 R2

想必同学们已经开学了&#xff0c;也都进入了军训阶段吧&#xff0c;而很多计算机网络专业的同学们要开始接触到Windows Server了&#xff0c;这也是计算机网络技术专业的专业基础课程&#xff0c;想当年我们实训课学习使用的好像是2008版的&#xff0c;也不晓得现在各个学校会…

win2012 R2安装与配置

安装win2012 R2的时候&#xff0c;我们最好使用update的iso镜像&#xff0c;不然安装完后会出很多问题&#xff0c;比如vmtools安装不了。 iso镜像迅雷下载&#xff1a; ed2k://|file|cn_windows_server_2012_r2_with_update_x64_dvd_6052725.iso|5545705472|121EC13B53882E50…

windows server 2008R2/2012R2安装

Windows安装 1.镜像下载 镜像网站&#xff1a;https://msdn.itellyou.cn/ Windows Server 2012R2&#xff1a;cn_windows_server_2012_r2_vl_x64_dvd_2979220.iso 迅雷下载地址&#xff1a;ed2k://|file|cn_windows_server_2012_r2_vl_x64_dvd_2979220.iso|4453249024|1F71…

windowsserver2012R2安装教程

下载windowsserver2012R2系统镜像。 刻录光盘或是写入到U盘做成启动u盘&#xff0c;操作文档可搜索iso写入U盘。 服务器或是电脑开机从U盘启动安装。 现在安装。 输入激活码 NH3KG-P864D-XYCJH-82DMH-4CX8M -- Standard 7H6M3-4N78W-RFFKJ-H9KPW-K2C2M -- Datacenter 安装什…

服务器2012r2系统安装设置,正确安装windows server 2012 r2的方法

Windows Server 2012 R2 是基于Windows8.1 以及Windows RT 8.1 界面的新一代 Windows Server 操作系统&#xff0c;提供企业级数据中心和混合云解决方案&#xff0c;易于部署、具有成本效益、以应用程序为重点、以用户为中心。windows server 2012 R2 具备的众多新特点大大的增…

安装Windows server 2012 R2

项目一 Windows Server 2012 R2服务器安装与配置 项目背景 你是一家公司的网络管理员&#xff0c;负责管理和维护公司的网络。你的公司新购置了一台计算机&#xff0c;希望你安装Windows Server 2012 R2企业版操作系统&#xff0c;设置好TCP/IP参数。 任务1&#xff1a;安装Wi…