SMOTE算法原理 易用手搓小白版 数据集扩充 python

article/2025/11/11 10:24:20

前言

为啥要写这个呢,在做课题的时候想着扩充一下数据集,尝试过这个过采样降采样,交叉采样,我还研究了一周的对抗生成网络,对抗生成网络暂时还解决不了我要生成的信号模式崩塌的问题,然后就看着尝试一下别的,就又来实验了一下SMOTE,我看原理也不是很难,想着调库的话不如自己手搓一个稍微,可以简单理解一点的,最后呢也是成功了,然后呢对训练集进行了扩充,效果额,训练集准确率肯定是嗷嗷提升,训练的效果稳定了一点,但是测试集出来的效果,感觉变化不大,可能是我实验样本比较少的原因,说明普通的SMOTE还是比较吃原始数据分布,我写的这个是只用numpy 和 random 两个库,内容都是手搓的,和官方例程最大的不同,就是官方例程控制的是生成样本和原样本的比例,本程序控制的是生成样本的数量。也就是可以直接指定生成样本的数量进行输出。


一、SMOTE理论

SMOTE算法是一种2002年发表的根据样本之间的关系,生成新样本的,扩充数据集的算法,论文源地址贴在下面,然后用一个图表示一下一个样本的生成过程

SMOTE: Synthetic Minority Over-sampling Technique:
论文地址: https://www.jair.org/index.php/jair/article/download/10302/24590

在这里插入图片描述
虽然别人的图画的很好,但是想到自己作为一个研究生😭,还是少复制粘贴,代码都手搓了图也忍痛不复制自己画一下,好了,进入正题
描述一下这个图,可以看到图中分布着两种样本点,因为五边形表示的这一类的样本点为少数类样本,所以个图里选择五边形这一类样本进行扩充,随机认定一个五边形样本点为中心,搜索离它距离最近的K个同类样本点(也就是五边形样本点),随机选择一个被搜索到的样本点,用最开始认定的作为搜索中心的样本点和后来被随机选中的样本点生成一个新的样本。
那通过两个样本点是如何生成一个新的样本点呢这里用到的就是一个重要的线性代数的知识

对于 x 1 , x 2 \quad x_1,x_2\quad x1,x2如果 λ ∈ [ 0 , 1 ] \lambda\in[0,1]\quad λ[0,1] λ x 1 + ( 1 − λ ) x 2 \lambda x_1+(1-\lambda)x_2\quad λx1+(1λ)x2一定在 x 1 和 x 2 x_1和x_2 x1x2的连线上

其中 λ x 1 + ( 1 − λ ) x 2 \lambda x_1+(1-\lambda)x_2 λx1+(1λ)x2也可以转换为 x 2 + λ ( x 1 − x 2 ) x_2+\lambda(x_1-x_2) x2+λ(x1x2)或者 x 1 + λ ( x 2 − x 1 ) x_1+\lambda(x_2-x_1) x1+λ(x2x1)下图中 x 3 x_3 x3 x 1 x_1 x1 x 2 x_2 x2连接线上的一点,用初中的移项等知识就一定可以求到一个 λ \lambda λ,好了初中知识就不赘述了
请添加图片描述
因此可以通过随机生成一个0~1之间的数结合两个样本点就能合成一个新的数据

二.python代码

实际应用中定义一个class 类来实现功能在实例中定义了三个子函数
class SMOTE(object):
初始化函数
def __init__(self,sample,k=2,gen_num=3):
获取相邻点的函数
def get_neighbor_point(self):
获取合成的样本的函数
def get_syn_data(self):
后面依次介绍,首先调用一下需要用到的基础库

import numpy as np     
import random   #用于生成随机数
import matplotlib.pyplot as plt  #画图

2.1初始化部分

初始化部分需要输入三个参数
1.被扩充的样本
2.Smote算法需要设置的K值
3.生成样本的数量

    def __init__(self,sample,k=2,gen_num=3):#需要被扩充的样本self.sample = sample      #获取输入数据的形状self.sample_num,self.feature_len = self.sample.shape#近邻点  self.k = min(k,self.sample_num-1)#需要生成的样本的数量                self.gen_num = gen_num    # 定义一个数组存储生成的样本self.syn_data = np.zeros((self.gen_num,self.feature_len))  # 定义一个数组存储每一个点和其临近点的坐标self.k_neighbor = np.zeros((self.sample_num,self.k),dtype=int)  

先不用思考接下来我对每一句话进行解释

首先是获取数据样本的长和宽

#需要被扩充的样本
self.sample = sample      
#获取输入数据的形状
self.sample_num,self.feature_len = self.sample.shape  

举个例子如果输入的的样本的形状是10✖2的
也就意味着输入了10个样本
每一个样本有2个特征也就是一个样本由2个数构成
对应到代码中样本数量数据被存储到了self.sample_num=10
样本长度数据被存储到了self.feature_len=2
为什么要获取这两个数据呢先从这一句开始解释

self.k = min(k,self.sample_num-1)

如果输入的需要被扩充的数据有10个样本,也就是说每一个样本最多有10-1也就是9个相邻的点(样本),也就是相对输入数据中的每一个样本点,他能搜索到的邻近样本数量是有上限的,因此避免输入K值过大,超过能搜索的最大值,就需要结合输入样本的数量(self.sample_num)进行约束

接下来看最后三句,根据输入的需要生成的样本的数量(self.gen_num),和我们已经知道的每一个样本的长度(self.feature_len),就能生成一个self.syn_data形状是(self.gen_num×self.feature_len)的全0数组存储生成的数据

#需要生成的样本的数量                
self.gen_num = gen_num    
# 定义一个数组存储生成的样本
self.syn_data = np.zeros((self.gen_num,self.feature_len))  
# 定义一个数组存储每一个点的坐标和其临近点的坐标
self.k_neighbor = np.zeros((self.sample_num,self.k),dtype=int)  

最后一句,如果我们K值设置的是3,也就是寻找最邻近的三个点,若一共有10个数据那就是生成的是一个10×3的全零数组存储的是每一个点的与它最近的三个点的数据所在位置的索引值
例如一个数据为x = [1,4,3,2]
其对应索引值为[0,1,2,3] (x[0] = 1,x[1] = 4, x[2] = 3,x[3] = 2)
k值为2
则计算之后的数组(self.k_neighbor)为
[[3,2],
[2,3],
[1,3],
[0,2]]
标黄意味着 除了x[0] 的三个数中 x[3],x[2]离x[0]最近,x[3]更近一些
(越靠前的越近,同样近的索引值小的靠前)
同理
[[3,2],
[2,3],
[1,3],
[0,2]]
第二行意味着除了x[1] 的三个数中 x[2],x[3]离x[1]最近

2.2计算距离部分

再介绍一下函数有基础可以跳过

2.2.1 enumerate()

enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中

seasons = ['Spring', 'Summer', 'Fall', 'Winter']
print(list(enumerate(seasons)))
#[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]

链接: 菜鸟教程enumerate

2.2.2 numpy.argsort()

numpy.argsort() 函数返回的是数组值从小到大的索引值。

import numpy as np 
x = np.array([3,  1,  2])  
print ('我们的数组是:')
print (x)
print ('\n')
print ('对 x 调用 argsort() 函数:')
y = np.argsort(x)  
print (y)
print ('\n')
print ('以排序后的顺序重构原数组:')
print (x[y])
print ('\n')
print ('使用循环重构原数组:')
for i in y:  print (x[i], end=" ")
'''
我们的数组是:
[3 1 2]
对 x 调用 argsort() 函数:
[1 2 0]
以排序后的顺序重构原数组:
[1 2 3]
使用循环重构原数组
1 2 3
'''

链接: 菜鸟教程argsort

2.2.3 numpy.square()

算数组中每一个数的平方

print('sqrt计算各个元素的平方根:')
num = np.array([1,2,3])
print(num)
print(np.square(num))
'''
sqrt计算各个元素的平方根:
[1,2,3]
[1,4,9]
'''

2.2.4 列表生成式(推导式)

Python 推导式是一种独特的数据处理方式,可以从一个数据序列构建另一个新的数据序列的结构体。

'''
[表达式 for 变量 in 列表] 
[out_exp_res for out_exp in input_list]或者 [表达式 for 变量 in 列表 if 条件]
[out_exp_res for out_exp in input_list if condition]
'''
multiples = [i for i in range(30) if i % 3 == 0]
print(multiples)
[0, 3, 6, 9, 12, 15, 18, 21, 24, 27]

2.2.5 距离样本代码

好了铺垫完这回再看代码,应该不至于劝退了

    def get_neighbor_point(self):for index,single_signal in enumerate(self.sample):# 获取欧式距离Euclidean_distance = np.array([np.sum(np.square(single_signal-i)) for i in self.sample])# 获取欧式距离从小到大的索引排序序列Euclidean_distance_index = Euclidean_distance.argsort()# 截取k个距离最近的样本的索引值self.k_neighbor[index] = Euclidean_distance_index[1:self.k+1]

Euclidean_distance返回的是一个距离数组,计算距离使用欧式距离,也就是对应点的平方求和
Euclidean_distance_index返回的是从小到大的样本距离排序的索引,每个Euclidean_distance_index第一个索引值一定是本次循环的对比信号本身,因为距离是0,所以从列表的第二个数据开始截取K个索引存到最开始定义的self.k_neighbor变量的对应位置中

self.k_neighbor[index] = Euclidean_distance_index[1:self.k+1]

好了终于把计算距离这一部分说完了

2.3 生成数据

铺垫环节

2.3.1 random.randint (a,b)

random.randint(参数1, 参数2)
参数1,参数2必须是整数
函数返回参数1和参数2之间的任意整数

import random
result = random.randint(1,10)
print("result: ",result)
#输出:
#result: 6

2.3.2 random.uniform (a,b)

random.uniform(参数1,参数2) 返回参数1和参数2之间的任意值

import random
result = random.uniform(1,3)
print("result: ",result)
#输出:
#result: 2.639781736005787

2.3.3 生成部分代码

生成代码部分循环self.gen_num次每次的内部步骤都是,选择一个中心样本,然后选择一个他的临近样本,生成合成样本

def get_syn_data(self):self.get_neighbor_point()#生成self.gen_num个样本循环N次for i in range(self.gen_num):#随机选择的中心样本点的索引key = random.randint(0,self.sample_num-1)#随机选择的中心样本点的邻近样本点中的随机一个K_neighbor_point = self.k_neighbor[key][random.randint(0,self.k-1)]#gap = x1-x2 = self.sample[K_neighbor_point](随机选择的当前样本中前k近的样本点中的随机一个)- self.sample[key](随机选择的用于生成数据的中心样本点) gap = self.sample[K_neighbor_point] - self.sample[key]#公式 生成 = 被选中作为中心的样本 - 0到1中的一个数 × (被选中作为中心的样本 - 被选中作为中心的样本的临近样本点中的随机一个)self.syn_data[i] = self.sample[key] + random.uniform(0,1)*gapreturn self.syn_data

三.完整代码如下

import numpy as np
import random
import matplotlib.pyplot as pltclass SMOTE(object):def __init__(self,sample,k=2,gen_num=3):self.sample = sample      self.sample_num,self.feature_len = self.sample.shapeself.k = min(k,self.sample_num-1)                self.gen_num = gen_num    self.syn_data = np.zeros((self.gen_num,self.feature_len))  self.k_neighbor = np.zeros((self.sample_num,self.k),dtype=int)  def get_neighbor_point(self):for index,single_signal in enumerate(self.sample):Euclidean_distance = np.array([np.sum(np.square(single_signal-i)) for i in self.sample])Euclidean_distance_index = Euclidean_distance.argsort()self.k_neighbor[index] = Euclidean_distance_index[1:self.k+1]def get_syn_data(self):self.get_neighbor_point()for i in range(self.gen_num):key = random.randint(0,self.sample_num-1)K_neighbor_point = self.k_neighbor[key][random.randint(0,self.k-1)]gap = self.sample[K_neighbor_point] - self.sample[key]self.syn_data[i] = self.sample[key] + random.uniform(0,1)*gapreturn self.syn_dataif __name__ == '__main__':#随机生成原始数据data=np.random.uniform(0,1,size=[20,2])#生成对象k=5 gen_num=20Syntheic_sample = SMOTE(data,5,20)#生成数据new_data = Syntheic_sample.get_syn_data()#绘制原始数据for i in data:plt.scatter(i[0],i[1],c='b')#绘制生成数据for i in new_data:plt.scatter(i[0],i[1],c='y')plt.show()

蓝色是原始样本橘色是生成样本
在这里插入图片描述
虽然点看着分散,你要是细心观察你会发现所有的橘色的点都在两个蓝色的点的联线上,为清晰这点其实有一个更直观的方法,直接把生成的点选择成好几百,如果k=5 gen_num = 100还是不够明显
在这里插入图片描述
再来个gen_num = 500 的这回生成的样本点在连线上已经很明显了
在这里插入图片描述
这回将k设置成3同样生成500个样本比起k=5的时候交叉线明显减少了
在这里插入图片描述
将k设置成1再来一次可以看到生成的样本已经没有交叉线了
在这里插入图片描述
最后再试一下生成整数原始数据,扩充之后将原始数据和生成数据打印出来

在这里插入图片描述

总结

这个代码目前只能生成一维的数据,高维的需要处理成一维的才能使用,然后之后会尝试写SMOTE的各种延伸版本
也非常感谢这位老哥的参考
链接: 原版论文复现.


http://chatgpt.dhexx.cn/article/YoaGS2oJ.shtml

相关文章

机器学习_SMOTE:简单原理图示_算法实现及R和Python调包简单实现

一、SMOTE原理 SMOTE的全称是Synthetic Minority Over-Sampling Technique 即“人工少数类过采样法”,非直接对少数类进行重采样,而是设计算法来人工合成一些新的少数样本。 SMOTE步骤__1.选一个正样本 红色圈覆盖 SMOTE步骤__2.找到该正样本的K个近…

Hash碰撞(冲突)

2019独角兽企业重金招聘Python工程师标准>>> 什么是哈希(哈希算法) 哈希算法是将任意长度的二进制值映射为较短的固定长度的二进制值,这个小的二进制值称为哈希值。 哈希值是一段数据唯一且极其紧凑的数值表示形式。如果散列一段明…

Hash 碰撞是什么?如何解决(开放寻址法和拉链法)?hash链表和红黑树知识扩展?

一、什么是Hash碰撞 hash碰撞指的是,两个不同的值(比如张三、李四的学号)经过hash计算后,得到的hash值相同,后来的李四要放到原来的张三的位置,但是数组的位置已经被张三占了,导致冲突 二、Ha…

hash碰撞解决方法

Hash碰撞冲突 我们知道,对象Hash的前提是实现equals()和hashCode()两个方法,那么HashCode()的作用就是保证对象返回唯一hash值,但当两个对象计算值一样时,这就发生了碰撞冲突。如下将介绍如何处理冲突,当然其前提是一…

Java 集合深入理解 (十一) :HashMap之实现原理及hash碰撞

文章目录 前言哈希表原理实现示例HashMap实现原理全篇注释分析实现注意事项默认属性分析属性分析构造方法分析重要的put方法总结 前言 哈希表(hashMap)又叫散列表 是一种非常重要的数据结构基于map接口实现应用场景及其丰富,本地临时缓存&a…

java基础篇 - HashMap 理解Hash碰撞

HashMap是大家都在用,面试的时候也经常会被考的考点,在这篇文章中说下HashMap的hash碰撞和减轻碰撞的优化。 1、什么是hash碰撞 在解释Hash碰撞之前先说一下hashmap的存储结构、添加和检索是怎么实现的 1.1HashMap的存储结构 HashMap的存储结构是En…

大白话解释hash碰撞是什么以及如何解决

一、Hash如何存数据 hash表的本质其实就是数组,hash表中通常存放的是键值对Entry。 这里的id是个key,哈希表就是根据key值来通过哈希函数计算得到一个值,这个值就是下标值,用来确定这个Entry要存放在哈希表中哪个位置。 二、Ha…

hash碰撞的概率推导(生日攻击生日问题)

1.关于hash碰撞 哈希碰撞是指,两个不同的输入得到了相同的输出; hash碰撞不可避免,hash算法是把一个无限输入的集合映射到一个有限的集合里,必然会发生碰撞; 2.碰撞概率的问题描述的其他形式 n个球,&…

Hash碰撞(冲突)的解决方案

hash算法就是,用同一个哈希函数计算: 两个相同的值,计算出的hash值一定相同, 两个不同的值,计算出的hash值可能不同,也可能相同,当相同时就是hash冲突 一、链式寻址法 也叫“拉链法”&#…

MD5 hash碰撞实现解密

目录 1.前言 2.MD5 hash单个碰撞解密 3.MD5 hash批量碰撞解密 1.前言 在日常渗透中,获取到后台密码往往是加密的,常见的就是MD5加密,常见的做法我们会使用在线网站去解密,常用的有cmd5,somd5,cmd5对于一些密文是要收费的,有时我们就想白嫖。 这时我们会用so…

哈希碰撞+mysql_HashMap之Hash碰撞冲突解决方案及未来改进

HashMap位置决定与存储 通过前面的源码分析可知,HashMap 采用一种所谓的“Hash 算法”来决定每个元素的存储位置。当程序执行put(String,Obect)方法 时,系统将调用String的 hashCode() 方法得到其 hashCode 值——每个 Java 对象都有 hashCode() 方法&am…

Hash碰撞概率

计算Hash冲突的概率 虽然已经很多可以选择的Hash函数,但创建一个好的Hash函数仍然是一个活跃的研究领域。一些Hash函数是快的,一些是慢的,一些Hash值均匀地分布在值域上,一些不是。对于我们的目的,让我们假设这个Hash函数是非常好的。它的Hash值均匀地分布在值域上。 在这…

HashMap之Hash碰撞

详细理解了Hash碰撞及处理方法 为什么会出现hash碰撞 在hash算法下,假设两个输入串的值不同,但是得到的hash值相同, 即会产生hash碰撞 一个很简单的例子: 假设你自己设计了一个计算hash的算法toHashValue(String). 是取的输入值的Unicode编码值(当然实际的情况会比这复杂很…

hashmap存储方式 hash碰撞及其解决方式

1.Map 的存储特点 在 Map 这个结构中,数据是以键值对(key-value)的形式进行存储的,每一个存储进 map 的数据都是一一对应的。 创建一个 Map 结构可以使用 new HashMap() 以及 new TreeMap() 两种方式,两者之间的区别…

Java Hash 碰撞

散列函数(英语:Hash function)又称散列算法、哈希函数,是一种从任何一种数据中创建小的数字“指纹”的方法。散列函数把消息或数据压缩成摘要,使得数据量变小,将数据的格式固定下来。 该函数将数据打乱混合…

通俗解释hash碰撞是什么以及如何解决

Hash如何存数据 hash表的本质其实就是数组,hash表中通常存放的是键值对Entry。 如下图: 这里的学号是个key,哈希表就是根据key值来通过哈希函数计算得到一个值,这个值就是下标值,用来确定这个Entry要存放在哈希表中哪个位置。 H…

Hash碰撞

Hash碰撞 什么是Hash碰撞 Hash碰撞是指两个不同的输入值,经过哈希函数的处理后,得到相同的输出值,这种情况被称之为哈希碰撞。 例如:两个不同的对象(object1和object2的值)经过Hash函数计算后的&#xf…

浅谈“越权访问”

一:漏洞名称: 越权访问漏洞 描述: 越权访问,这类漏洞是指应用在检查授权(Authorization)时存在纰漏,使得攻击者在获得低权限用户帐后后,可以利用一些方式绕过权限检查,访…

逻辑越权——垂直、水平越权

水平越权:通过更换的某个ID之类的身份标识,从而使A账号获取(修改、删除等)B账号数据。 垂直越权:使用低权限身份的账号,发送高权限账号才能有的请求,获得其高权限的操作。 未授权访问&#xff1…

横向越权和纵向越权(水平越权、垂直越权)

越权:顾名思义,就是获得了本不应该有的权限。 我们都喜欢创造一些复杂的词汇,而实际上这些词就是一个代词,根本没有那么复杂。 越权漏洞往往是基于业务逻辑的漏洞,这样的漏洞很难被WAF保护。 越权的分类 按照方向…