机器学习之手写决策树以及sklearn中的决策树及其可视化

article/2025/9/24 21:33:50

文章目录

  • 决策树理论部分
    • 基本算法
    • 划分选择
      • 信息熵
    • 信息增益
    • 信息增益率
    • 基尼系数
    • 基尼指数
  • 决策树代码实现
  • 参考

决策树理论部分

在这里插入图片描述
决策树的思路很简单,就是从数据集中挑选一个特征,然后进行分类。

基本算法

在这里插入图片描述
从伪代码中可以看出,分三种情况考虑:
(1)如果输入样本同属于一类,那么将节点划分为此类的叶节点。
(2)如果属性划分次数达到上限,即属性划分完了,或者是样本中在此类属性取值都一样,可以认为全部划分仍然存在不同类的样本,那么这个节点就标记为类别数占较多的叶节点。
(3)需要继续划分的情况,选择一个属性对数据集进行划分。在这里插入图片描述

划分选择

划分选择还是比较重要的,因为不同的划分选择会建出不同的决策树。划分选择的指标就是希望叶节点的数据尽可能都是属于同一类,即节点的“纯度”越来越高。

信息熵

在这里插入图片描述
其中|y|是指样本标签的种类的个数,pk代表第k类样本所占的比例

信息增益

在这里插入图片描述
|Dv|代表a特征中同样是v值的样本的数量。
当前样本此特征的信息增益 = 当前样本的信息熵 - 加权求和的同特征值的样本的信息熵。

举个例子

西瓜数据集2.0如下
在这里插入图片描述
首先计算样本的信息熵
在这里插入图片描述
然后计算各个特征的信息增益
在这里插入图片描述
在这里插入图片描述
可见纹理的信息增益最大,也说明用纹理来划分当前数据,得到的纯度提升是最高的。

在这里插入图片描述
在这里插入图片描述

信息增益率

因为信息增益对可取值较多的属性有所偏好,为了减少这个影响,可以采用信息增益率。
在这里插入图片描述
但是仍然存在问题:
在这里插入图片描述

因此特征选择策略:

在这里插入图片描述

基尼系数

在这里插入图片描述

基尼指数

在这里插入图片描述

决策树代码实现

千言万语都在注释里了。

import math
import numpy
import numpy as np
import collections
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier  # 导入决策树DTC包class DecisionNode(object):def __init__(self, f_idx, threshold, value=None, L=None, R=None):self.f_idx = f_idx  # 属性的下标,表示通过下标为f_idx的属性来划分样本self.threshold = threshold  # 下标 `f_idx` 对应属性的阈值self.value = value  # 如果该节点是叶子节点,对应的是被划分到这个节点的数据的类别self.L = L  # 左子树self.R = R  # 右子树# 寻找最优的阈值
def find_best_threshold(dataset: np.ndarray, f_idx: int, split_choice: str):  # dataset:numpy.ndarray (n,m+1) x<-[x,y]  f_idx:feature indexbest_gain = -math.inf  # 信息增益越小纯度越低best_gini = math.inf # 基尼值越大纯度越低best_threshold = Nonecandidate = [0, 1]  # 因为只有01,就用这两个来划分。候选值1代表是这个特征,0代表不是这个特征# 遍历候选值,找出纯度最大的划分值(这里是0或者1)for threshold in candidate:L, R = split_dataset(dataset, f_idx, threshold)   # 根据阈值分割数据集,小于阈值gain = Noneif split_choice == "gain":# 计算信息增益gain = calculate_gain(dataset, L, R)  # 根据数据集和分割之后的数if gain > best_gain:  # 如果增益大于最大增益,则更换最大增益和最大阈值best_gain = gainbest_threshold = thresholdif split_choice == "gain_ratio":# 计算信息增益率gain = calculate_gain_ratio(dataset, L, R)if gain > best_gain:  # 如果增益大于最大增益,则更换最大增益和最大阈值best_gain = gainbest_threshold = threshold# 计算基尼指数if split_choice == "gini":gini = calculate_gini_index(dataset, L, R)if gini < best_gini:  # gini指数越小越好best_gini = ginibest_threshold = threshold# 返回此特征最优的划分值(0或1)以及对应的信息增益/增益率/基尼指数return best_threshold, best_gain# 计算信息熵
def calculate_entropy(dataset: np.ndarray):  # 熵scale = dataset.shape[0]  # 多少条数据d = {}for data in dataset:# 一条数据的最后一位是标签key = data[-1]# 统计数据类别个数if key in d:d[key] += 1else:d[key] = 1entropy = 0.0for key in d.keys():# pkp = d[key] / scale# -pk * log2(pk)entropy -= p * math.log(p, 2)return entropy# 计算信息增益
def calculate_gain(dataset, l, r):# l:左子树的数据# r:右子树的数据# 计算信息熵e1 = calculate_entropy(dataset)# 因为每个特征只有两种取值,是或不是(l,r已然是按特征分开的两类)e2 = len(l) / len(dataset) * calculate_entropy(l) + len(r) / len(dataset) * calculate_entropy(r)gain = e1 - e2return gain# 计算信息增益率
def calculate_gain_ratio(dataset, l, r):s = 0gain = calculate_gain(dataset, l, r)p1 = len(l) / len(dataset)p2 = len(r) / len(dataset)# 会出现 1/0 的情况 全被划分到一边 s=0# 只有0,1两种取值if p1 == 0:s = p2 * math.log(p2, 2)elif p2 == 0:s = p1 * math.log(p1, 2)else:s = - p1 * math.log(p1, 2) - p2 * math.log(p2, 2)# 如果s为0,说明全都划分到一类,信息增益率可以看成无限大if s == 0:gain_ratio = math.infelse:gain_ratio = gain / sreturn gain_ratio# 计算基尼系数(随机抽取两个样本,其类别不一致的概率)
def calculate_gini(dataset: np.ndarray):scale = dataset.shape[0]  # 多少条数据d = {}for data in dataset:key = data[-1]if key in d:d[key] += 1else:d[key] = 1gini = 1.0for key in d.keys():p = d[key] / scalegini -= p * preturn gini# 计算基尼指数,基尼指数越小,纯度越高
def calculate_gini_index(dataset, l, r):gini_index = len(l) / len(dataset) * calculate_gini(l) + len(r) / len(dataset) * calculate_gini(r)return gini_indexdef split_dataset(X: np.ndarray, f_idx: int, threshold: float):# 左边是f_idx特征小于阈值的数据# 右边是大于阈值的数据L = X[:, f_idx] < thresholdR = ~Lreturn X[L], X[R]def majority_count(dataset):class_list = [data[-1] for data in dataset]# 返回数量最多的类别return collections.Counter(class_list).most_common(1)[0][0]def build_tree(dataset: np.ndarray, f_idx_list: list, split_choice: str):   # return DecisionNode 递归# f_idx_list 待选取特征的列表class_list = [data[-1] for data in dataset]  # 类别# 全属于同一类别(二分类)if class_list.count(class_list[0]) == len(class_list):return DecisionNode(None, None, value=class_list[0])# 若属性都用完, 标记为数量最多的那一类elif len(f_idx_list) == 0:value = collections.Counter(class_list).most_common(1)[0][0]return DecisionNode(None, None, value=value)else:# 找到划分 增益最大的属性best_gain = -math.infbest_gini = math.infbest_threshold = Nonebest_f_idx = None# 遍历所有特征,找出纯度最大的那个特征for i in f_idx_list:threshold, gain = find_best_threshold(dataset, i, split_choice)# 基尼指数越小纯度越大if split_choice == "gini":if gain < best_gini:best_gini = gainbest_threshold = thresholdbest_f_idx = i# 信息增益/信息增益率越大,纯度越大if split_choice == "gain" or split_choice == "gain_ratio" :if gain > best_gain:  # 如果增益大于最大增益,则更换最大增益和最大best_gain = gainbest_threshold = thresholdbest_f_idx = i# 拷贝原特征son_f_idx_list = f_idx_list.copy()# 移除进行分类的特征(挑选出的最优特征)son_f_idx_list.remove(best_f_idx)# 以最优阈值分割数据L, R = split_dataset(dataset, best_f_idx, best_threshold)# 左边的数据为0那么说明已经全都为一类了,那么叶节点就产生了if len(L) == 0:L_tree = DecisionNode(f_idx=None, threshold=None, value=majority_count(dataset))  # 叶子节点# 否则就继续往下划分else:L_tree = build_tree(L, son_f_idx_list, split_choice)  # return DecisionNode# 右边也同理if len(R) == 0:R_tree = DecisionNode(f_idx=None, threshold=None, value=majority_count(dataset))  # 叶子节点else:R_tree = build_tree(R, son_f_idx_list, split_choice)  # return DecisionNode# 递归调用建树return DecisionNode(f_idx=best_f_idx, threshold=best_threshold, value=None, L=L_tree, R=R_tree)def predict_one(model: DecisionNode, data):if model.value is not None:return model.valueelse:feature_one = data[model.f_idx]branch = Noneif feature_one >= model.threshold:branch = model.R  # 走右边else:branch = model.L   # 走左边return predict_one(branch, data)def predict_accuracy(y_predict, y_test):y_predict = y_predict.tolist()y_test = y_test.tolist()count = 0for i in range(len(y_predict)):if int(y_predict[i]) == y_test[i]:count = count + 1accuracy = count / len(y_predict)return accuracyclass SimpleDecisionTree(object):def __init__(self, split_choice):# split_choice 分割策略:信息增益、信息增益率或者基尼指数self.split_choice = split_choicedef fit(self, X: np.ndarray, y: np.ndarray):dataset_in = np.c_[X, y] # 纵向拼接f_idx_list = [i for i in range(X.shape[1])]# 特征列self.my_tree = build_tree(dataset_in, f_idx_list, self.split_choice) # 建树def predict(self, X: np.ndarray): predict_list = []for data in X:predict_list.append(predict_one(self.my_tree, data))return np.array(predict_list)if __name__ == "__main__":predict_accuracy_all = []import pandas as pdfor i in range(10):data = pd.read_csv("data.csv")y = data["label"].values x = data.drop(columns="label").valuesX_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2)predict_accuracy_list = []  # 储存4种结果split_choice_list = ["gain", "gain_ratio", "gini"]for split_choice in split_choice_list:m = SimpleDecisionTree(split_choice)m.fit(X_train, y_train)y_predict = m.predict(X_test)y_predict_accuracy = predict_accuracy(y_predict, y_test.reshape(-1))predict_accuracy_list.append(y_predict_accuracy)clf = DecisionTreeClassifier()  # 所以参数均置为默认状态clf.fit(X_train, y_train)  # 使用训练集训练模型predicted = clf.predict(X_test)predict_accuracy_list.append(clf.score(X_test, y_test))predict_accuracy_all.append(predict_accuracy_list)p = numpy.array(predict_accuracy_all)p = np.round(p, decimals=3)accs = []for i in p:accs.append(i)accs = pd.DataFrame(accs)accs.columns = ["gain", "gain_ratio", "gini", "sklearn"]print(accs)

输出结果:
在这里插入图片描述
我们还可以可视化一下sklearn帮我们建立的决策树:

from sklearn import tree
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['font.sans-serif'] = ['FangSong']  # 指定中文字体
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False  # 正常显示负号fn=data.columns[:-1]
cn=['坏瓜', '好瓜']
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)
tree.plot_tree(clf,feature_names = fn, class_names=cn,filled = True);
# value表示对应类别的样例分别有多少个。

在这里插入图片描述
还是sklearn比较好。

参考

机器学习——周志华
手写分类决策树(鸢尾花数据集)


http://chatgpt.dhexx.cn/article/rD8prJME.shtml

相关文章

android使用友盟推送注册失败获取不到token accs bindapp error!

使用友盟推送注册失败获取不到token public void onFailure(String s, String s1)的值分别是“-9”和“accs bindapp error!”或者s的值为-11.都是同一个问题 就是主工程&#xff08;除友盟PushSDK 其他的module均看成为主工程&#xff09;so目录与PushSDK下的so目录不一致…

同时集成阿里云旺与友盟推送,初始化失败s:-11,s1:accs bindapp error!的解决办法

在应用中需要同时集成聊天和推送功能&#xff0c;聊天选用阿里的sdk&#xff08;百川云旺&#xff09;&#xff0c;推送选用友盟的pushSDK。 这时候悲剧就出现了&#xff0c;注册友盟的时候报错。 I/com.umeng.message.PushAgent: register-->onFailure-->s:-11,s1:accs …

关于友盟s=-11;s1=accs bindapp error!的解决处理

项目使用了友盟推送之后&#xff0c;在部分手机上出现accs bindapp error&#xff0c;错误码-11的问题&#xff0c;为什么会出现这个问题呢&#xff0c;网上查找了很久&#xff0c;友盟给出的解释是so文件不正确。 具体链接&#xff1a;http://bbs.umeng.com/thread-23018-1-1…

友盟register failed: -11 accs bindapp error!

下载官方Demo后,替换自己的id包名后出现 register failed: -11 accs bindapp error! 经过一番搜索之后,都是说这二种原因 1、检查appkey和secret key是否配置正确&#xff0c;如果正确无误&#xff0c;请看步骤2。2、so文件配置有误&#xff0c;需重新配置&#xff1a; Pus…

阿里无线11.11 | 手机淘宝移动端接入网关基础架构演进之路

移动网络优化是超级App永恒的话题&#xff0c;对于无线电商来说更为重要&#xff0c;网络请求体验跟用户的购买行为息息相关&#xff0c;手机淘宝从过去的HTTP API网关&#xff0c;到2014年升级支持SPDY&#xff0c;2015年双十一自研高性能、全双工、安全的ACCS&#xff08;阿里…

VS2015 realease模式下调试

一、将项目属性设置为Release&#xff0c;生成--->配置管理器&#xff1a; 二、按AltF7&#xff0c;弹出属性页进行设置&#xff1a;

AndroidStudio如何打包生成realease版本的arr包,并上传到Nexus搭建的maven仓库,供项目远程依赖(二)

AndroidStudio如何打包生成realease版本的arr包&#xff0c;并上传到Nexus搭建的maven仓库&#xff0c;供项目远程依赖&#xff08;二&#xff09; AndroidStudio如何打包生成realease版本的arr包&#xff0c;并上传到Nexus搭建的maven仓库&#xff0c;供项目远程依赖&#xff…

AndroidStudio如何打包生成realease版本的arr包,并上传到Nexus搭建的maven仓库,供项目远程依赖(一)

AndroidStudio如何打包生成realease版本的arr包&#xff0c;并上传到Nexus搭建的maven仓库&#xff0c;供项目远程依赖&#xff08;一&#xff09; 背景: 公司之前在eclipse上做开发&#xff0c;写了很多library库供项目依赖使用&#xff0c;现在转AS上了&#xff0c;并用Nexu…

QT debug 功能正常 realease和windeplayqt工具打包部分功能无法使用或者不正常

目录 说明开发环境错误说明结论 说明 在项目的开发中&#xff0c;一般程序员都是使用debug版本进行程序的编写和调试&#xff0c;习惯好一些的程序员可能会天天用realease跑一遍自己写的程序是否正常&#xff0c;但是很多程序员可能都不会这么做&#xff0c;直到程序功能完成时…

Python OpenCV10:OpenCV 视频基本操作

1. 读视频 1.1 获取视频对象 要在 OpenCV 中获取视频&#xff0c;需要创建一个 VideoCapture 对象并指定要读取的视频文件。 cv.VideoCapture(filepath) 参数&#xff1a; filepath 视频文件路径 返回值&#xff1a; cap 读取视频的对象 1.2 获取视频属性 cap.get(propId) 获…

Renderers

渲染器 (Renderers) 在将 TemplateResponse 实例返回给客户端之前&#xff0c;必须渲染它。渲染过程采用模板和上下文的中间表示&#xff0c;并将其转换为可以提供给客户端的最终字节流。—— Django 文档 REST framework 包含许多内置的渲染器 (Renderer) 类&#xff0c;允许…

python调用opencv实现视频读写

文章目录 一、从文件中读取视频并播放1.1 基本API讲解1.2 python实现 二、保存视频2.1 基本API讲解2.1 python实现范例 一、从文件中读取视频并播放 1.1 基本API讲解 在OpenCV中我们要获取一个视频&#xff0c;需要创建一个VideoCapture对象&#xff0c;指定你要读取的视频文…

记一次有趣的debug,VS编译器上Debug和Realease的差异

之前自己写过一个imageread的函数&#xff0c;用了好久一直没问题。最近两天&#xff0c;同事让我realease一个项目给他&#xff0c;其中就包含了我自己写的imageread函数。 我的函数就长这样&#xff0c;不包含公司的code&#xff0c;不算泄密哈。 在realse之前&#xff0c;我…

C++语言基础篇

✅作者简介&#xff1a;CSDN内容合伙人&#xff0c;全栈领域新星创作者&#xff0c;阿里云专家博主&#xff0c;华为云云享专家博主&#xff0c;掘金后端评审团成员 &#x1f495;前言&#xff1a; 学长出的这一系列专栏适合有⼀点 C 基础&#xff0c…

PCL12.1 Realease 附加依赖项

PCL12.1 Realease 附加依赖项 libboost_atomic-vc142-mt-g-x64-1_78.lib libboost_bzip2-vc142-mt-g-x64-1_78.lib libboost_chrono-vc142-mt-g-x64-1_78.lib libboost_container-vc142-mt-g-x64-1_78.lib libboost_context-vc142-mt-g-x64-1_78.lib libboost_contract-vc142-…

Vue强制刷新页面重新加载数据方法

业务场景 在管理后台执行完增删改查的操作之后&#xff0c;需要重新加载页面刷新数据以便页面数据的更新 实现原理 就是通过控制router-view 的显示与隐藏&#xff0c;来重渲染路由区域&#xff0c;重而达到页面刷新的效果&#xff0c;show -> flase -> show 具体代码…

Linux 重新加载 nginx 配置命令

1. 查找 nginx 位置 whereis nginx2. 进入 nginx 目录 cd /usr/local/nginx/sbin3. 检查 nginx 配置文件是否正确 ./nginx -t 4. 重新加载配置文件 ./nginx -s reload

IDEA 无法重新加载Maven项目

IDEA 无法重新加载Maven项目 如图&#xff1a; 真头疼&#xff0c;搞了半小时才搞明白&#xff0c;我的Maven版本是3.8.6&#xff0c;而idea版本是2020&#xff0c;用不了这么新版的maven。 解决方案 maven版本高于idea版本&#xff0c;去查找低于idea版本日期的maven或者直…

Unity架构之详解域重新加载和场景重新加载

一、unity进入运行模式包括以下主要阶段 备份当前场景&#xff1a;这仅在场景被修改后发生。这样当退出运行模式时&#xff0c;Unity 将场景恢复为运行模式开始前的状态。Domain Reload&#xff1a;通过重新加载脚本域来重置脚本状态。Scene Reload&#xff1a;通过重新加载场…

Pycharm如何重新加载

在日常工作中&#xff0c;我们可能会经常遇到一种情况&#xff0c;那就是当我们程序执行结束后&#xff0c;相应的项目文件并没有自动显现出来&#xff0c;这时为避免我们关闭Pycharm再重新打开的麻烦&#xff0c;我们可以使用Pycharm中的同步或者快捷键进行重新加载。 1.同步 …