MATLAB实现SVM多分类(one-vs-rest),利用自带函数fitcsvm

article/2025/9/8 1:21:39

MATLAB实现SVM多分类(one-vs-rest),利用自带函数fitcsvm

  • SVM多分类
    • 一对一(one-vs-one)
    • 一对多(one-vs-rest)
    • fitcsvm简单介绍
    • 代码
    • 实验结果图
      • 第一次写博客,还请大家多多包涵,欢迎指教!

SVM多分类

SVM也叫支持向量机,其是一个二类分类器,但是对于多分类,SVM也可以实现。主要方法就是训练多个二类分类器。常见的有以下两种方式:

一对一(one-vs-one)

给定m个类,对m个类中的每两个类都训练一个分类器,总共的二类分类器个数为 m(m-1)/2 .比如有三个类,1,2,3,那么需要有三个分类器,分别是针对:1和2类,1和3类,2和3类。对于一个需要分类的数据x,它需要经过所有分类器的预测,最后使用投票的方式来决定x最终的类属性。

一对多(one-vs-rest)

给定m个类,需要训练m个二类分类器。其中的分类器 i 是将 i 类数据设置为类1(正类),其它所有m-1个i类以外的类共同设置为类2(负类),这样,针对每一个类都需要训练一个二类分类器,最后,我们一共有 m 个分类器。对于一个需要分类的数据 x,通常选择置信度最大的类别标记为分类结果。

fitcsvm简单介绍

在新版本中svmtrain和svmclassify函数提示已经被移除,所以我们应该跟上潮流学习使用fitcsvm。

// An highlighted block
SVMModel =  fitcsvm(X,Y,'ClassNames',{'negClass','posClass'},'Standardize',true,...'KernelFunction','rbf','BoxConstraint',1);

简单说一下参数:
X是训练样本,nxm的矩阵,n是样本数,m是特征维数;
Y是样本标签,nx1的矩阵,n是样本数;
‘ClassNames’,{‘negClass’,‘posClass’} 为键值对参数,指定正负类别,负类名在前,正类名在后,与样本标签Y中的元素对应;
‘Standardize’,true 为键值对参数,指示软件是否应在训练分类器之前使预测期标准化!
‘KernelFunction’,‘rbf’ 为键值对参数,有3种 ‘linear’(默认), ‘gaussian’ (or ‘rbf’), ‘polynomial’
‘BoxConstraint’,1 为键值对参数,直观上可以理解为一个惩罚因子(或者说正则参数),这个参数和svmtrain里的-c是一个道理。其实际上涉及到软间隔SVM的间隔(Margin)大小。
基本思想如下:当原始数据未能呈现出较好的可分性时,算法允许其在训练集上呈现出一些误分类,matlab默认的BoxConstraint为1。框约束的数值越大,意味着惩罚力度越小,最后得到的分类超平面的间隔越小,支持向量数越多,模型越复杂。这也就是很多机器学习理论书中一开始推导的硬间隔支持向量机(Hard-Margin SVM)。因为该参数默认为1,所以使用默认参数训练时,我们采用的是软间隔SVM。
更详细的大家可以参考官方说明文档 [https://ww2.mathworks.cn/help/stats/fitcsvm.html].

代码

说一下思路:
1.我自己造的数据不用太关心,训练数据是60x2,60是样本数,2是特征数;测试数据是20x2的。
2.目标是分5类,一对多的方式,就要分别训练5个SVM模型;每个模型都是一个二分类,所以需要正、负样本的划分。我是这么做的正样本全部来自该类别,负样本从其它4个类别中随机选择,但数目与正样本相同。有了每一类的正、负样本,这就得到了训练样本X;再设定标签,我设的是+1,-1,这就得到了样本标签Y;其它参数均默认不设,这样就可以为每一类样本训练SVM模型了。
3.测试样本并不需要对每一类划分正、负样本,只要知道测试数据和样本标签即可。
4.每个测试样本在5个SVM模型中均得到一个得分score,利用最大得分判定该样本最终属于哪一类。
5.这个混淆矩阵函数confusionmat是真的好用,只需要知道真实标签和预测标签就能算出查准率(precision)、查全率(recall)和综合评价指标(F-measure)。
如图:

哈哈哈
类别1的查准率 = a / ( a + d + g ) =a/(a+d+g) =a/(a+d+g)
类别1的查全率 = a / ( a + b + c ) =a/(a+b+c) =a/(a+b+c)
类别2的查准率 = e / ( b + e + h ) =e/(b+e+h) =e/(b+e+h)
类别2的查全率 = e / ( d + e + f ) =e/(d+e+f) =e/(d+e+f)
···

// An highlighted block
clc;
clear;
close all;
tic
fprintf('-----已开始请等待-----\n\n');
%% 造数据不用关心,直接跳过
% 造数据 20*2
data = [0.4,0.3;-0.5,0.1;-0.2,-0.3;0.5,-0.3;2.1,1.9;1.8,2.2;1.7,2.5;2.3,1.6;-2.2,1.6;-1.9,2.1;-1.7,2.6;-2.3,2.5;-3.1,-1.9;-2.8,-2.1;-1.9,-2.5;-2.3,-3.2;3.9,-3.5;2.8,-2.2;1.7,-3.1;2.5,-3.4];
data1 = data + 2.5*rand(20,2);
data2 = data + 2.5*rand(20,2);
data3 = data + 2.5*rand(20,2); data1(17:20,:);
% 训练数据
train_data = [data1(1:4,:);data2(1:4,:);data3(1:4,:);data1(5:8,:);data2(5:8,:);data3(5:8,:);data1(9:12,:);data2(9:12,:);data3(9:12,:);data1(13:16,:);data2(13:16,:);data3(13:16,:);data1(17:20,:);data2(17:20,:);data3(17:20,:)];% 画图显示
figure;
% gscatter函数可以按分类或者分组画离散点
% group为分组向量,对应每一个坐标的类别
group_train = [1;1;1;1;1;1;1;1;1;1;1;1;2;2;2;2;2;2;2;2;2;2;2;2;3;3;3;3;3;3;3;3;3;3;3;3;4;4;4;4;4;4;4;4;4;4;4;4;5;5;5;5;5;5;5;5;5;5;5;5];
gscatter(train_data(:,1),train_data(:,2),group_train);title('训练数据样本分布');
xlabel('样本特征1');
ylabel('样本特征2');
legend('Location','Northwest');
grid on;%%
% 测试数据
test_data = data + 3.0*rand(20,2);
test_features = test_data;
% 测试数据的真实标签
test_labels = [1;1;1;1;2;2;2;2;3;3;3;3;4;4;4;4;5;5;5;5];%%
% 训练数据分为5% 类别i的 正样本 选择类别i的全部,负样本 从其余类别中随机选择(个数与正样本相同)
% 类别1
class1_p = train_data(1:12,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(1:12,:) = [];
class1_n = train_data_c(index1,:);train_features1 = [class1_p;class1_n];
% 正类表示为1,负类表示为-1
train_labels1 = [ones(12,1);-1*ones(12,1)];% 类别2
class2_p = train_data(13:24,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(13:24,:) = [];
class2_n = train_data_c(index1,:);train_features2 = [class2_p;class2_n];
% 正类表示为1,负类表示为-1
train_labels2 = [ones(12,1);-1*ones(12,1)];% 类别3
class3_p = train_data(25:36,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(25:36,:) = [];
class3_n = train_data_c(index1,:);train_features3 = [class3_p;class3_n];
% 正类表示为1,负类表示为-1
train_labels3 = [ones(12,1);-1*ones(12,1)];% 类别4
class4_p = train_data(37:48,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(37:48,:) = [];
class4_n = train_data_c(index1,:);train_features4 = [class4_p;class4_n];
% 正类表示为1,负类表示为-1
train_labels4 = [ones(12,1);-1*ones(12,1)];% 类别5
class5_p = train_data(49:60,:);
% randperm(n,k)是从1到n的序号中随机返回k个
index1 = randperm(48,12);
% 从其余样本中随机选择k个
train_data_c = train_data;
train_data_c(49:60,:) = [];
class5_n = train_data_c(index1,:);train_features5 = [class5_p;class5_n];
% 正类表示为1,负类表示为-1
train_labels5 = [ones(12,1);-1*ones(12,1)];%%
% 分别训练5个类别的SVM模型
model1 = fitcsvm(train_features1,train_labels1,'ClassNames',{'-1','1'});
model2 = fitcsvm(train_features2,train_labels2,'ClassNames',{'-1','1'});
model3 = fitcsvm(train_features3,train_labels3,'ClassNames',{'-1','1'});
model4 = fitcsvm(train_features4,train_labels4,'ClassNames',{'-1','1'});
model5 = fitcsvm(train_features5,train_labels5,'ClassNames',{'-1','1'});
fprintf('-----模型训练完毕-----\n\n');
%%
% label是n*1的矩阵,每一行是对应测试样本的预测标签;
% score是n*2的矩阵,第一列为预测为“负”的得分,第二列为预测为“正”的得分。
% 用训练好的5SVM模型分别对测试样本进行预测分类,得到5个预测标签
[label1,score1] = predict(model1,test_features);
[label2,score2] = predict(model2,test_features);
[label3,score3] = predict(model3,test_features);
[label4,score4] = predict(model4,test_features);
[label5,score5] = predict(model5,test_features);
% 求出测试样本在5个模型中预测为“正”得分的最大值,作为该测试样本的最终预测标签
score = [score1(:,2),score2(:,2),score3(:,2),score4(:,2),score5(:,2)];
% 最终预测标签为k*1矩阵,k为预测样本的个数
final_labels = zeros(20,1);
for i = 1:size(final_labels,1)% 返回每一行的最大值和其位置[m,p] = max(score(i,:));% 位置即为标签final_labels(i,:) = p;
end
fprintf('-----样本预测完毕-----\n\n');
% 分类评价指标group = test_labels; % 真实标签
grouphat = final_labels; % 预测标签
[C,order] = confusionmat(group,grouphat,'Order',[1;2;3;4;5]); % 'Order'指定类别的顺序
c1_p = C(1,1) / sum(C(:,1));
c1_r = C(1,1) / sum(C(1,:));
c1_F = 2*c1_p*c1_r / (c1_p + c1_r);
fprintf('c1类的查准率为%f,查全率为%f,F测度为%f\n\n',c1_p,c1_r,c1_F);c2_p = C(2,2) / sum(C(:,2));
c2_r = C(2,2) / sum(C(2,:));
c2_F = 2*c2_p*c2_r / (c2_p + c2_r);
fprintf('c2类的查准率为%f,查全率为%f,F测度为%f\n\n',c2_p,c2_r,c2_F);c3_p = C(3,3) / sum(C(:,3));
c3_r = C(3,3) / sum(C(3,:));
c3_F = 2*c3_p*c3_r / (c3_p + c3_r);
fprintf('c3类的查准率为%f,查全率为%f,F测度为%f\n\n',c3_p,c3_r,c3_F);c4_p = C(4,4) / sum(C(:,4));
c4_r = C(4,4) / sum(C(4,:));
c4_F = 2*c4_p*c4_r / (c4_p + c4_r);
fprintf('c4类的查准率为%f,查全率为%f,F测度为%f\n\n',c4_p,c4_r,c4_F);c5_p = C(5,5) / sum(C(:,5));
c5_r = C(5,5) / sum(C(5,:));
c5_F = 2*c5_p*c5_r / (c5_p + c5_r);
fprintf('c5类的查准率为%f,查全率为%f,F测度为%f\n\n',c5_p,c5_r,c5_F);  figure;
subplot(121);
% gscatter函数可以按分类或者分组画离散点
% group为分组向量,对应每一个坐标的类别
group_test = test_labels;
gscatter(test_data(:,1),test_data(:,2),group_test);title('测试数据样本真实分布');
xlabel('样本特征1');
ylabel('样本特征2');
legend('Location','Northwest');
grid on;subplot(122);
% gscatter函数可以按分类或者分组画离散点
% group为分组向量,对应每一个坐标的类别
group_test = final_labels;
gscatter(test_data(:,1),test_data(:,2),group_test);title('测试数据样本预测分布');
xlabel('样本特征1');
ylabel('样本特征2');
legend('Location','Northwest');
grid on;

实验结果图

在这里插入图片描述
在这里插入图片描述

-----已开始请等待----------模型训练完毕----------样本预测完毕-----c1类的查准率为0.375000,查全率为0.750000,F测度为0.500000c2类的查准率为0.800000,查全率为1.000000,F测度为0.888889c3类的查准率为1.000000,查全率为0.750000,F测度为0.857143c4类的查准率为1.000000,查全率为0.250000,F测度为0.400000c5类的查准率为1.000000,查全率为0.750000,F测度为0.857143

第一次写博客,还请大家多多包涵,欢迎指教!

参考资料:
[https://www.cnblogs.com/litthorse/p/9303711.html].
[https://blog.csdn.net/qq_39328617/article/details/95207473].
[https://baijiahao.baidu.com/s?id=1619821729031070174&wfr=spider&for=pc].


http://chatgpt.dhexx.cn/article/G0lGikb7.shtml

相关文章

SVM多分类原理学习

https://scikit-learn.org/stable/modules/svm.html https://sklearn.apachecn.org/docs/master/5.html 中文翻译 SVC,NuSVC,LinearSVC在一个数据集上可以实现二分类,也能多类分类 SCV和NuSVC是相似的方法,但是接受的参数设置可…

SVM多类分类

从 SVM的那几张图可以看出来,SVM是一种典型的两类分类器,即它只回答属于正类还是负类的问题。而现实中要解决的问题,往往是多类的问题(少部分例外,例如垃圾邮件过滤,就只需要确定“是”还是“不是”垃圾邮件…

傻瓜攻略(十九)——MATLAB实现SVM多分类

对于组合二元支持向量机模型的多类学习,使用纠错输出码(ECOC,error-correcting output codes )。有关详细信息,请参阅fitcecoc。 ECOC 可以用来将 Multiclass Learning 问题转化为 Binary Classification 问题。 以下…

【机器学习】SVM多分类问题及基于sklearn的Python代码实现

SVM多分类问题及Python代码实现 1. 什么是SVM?2. SVM的分类3. SVM决策函数类型4. SVM多分类的Python代码实现参考资料1. 什么是SVM? 对于这个点已经介绍的非常多了,不管是西瓜书还是各种博客,就是需要找到一个超平面,用这个超平面把数据划分成两个类别,最开始的SVM就是在…

SVM多分类问题

SVM本身是一个二值分类器,SVM算法最初是为二值分类问题设计的,当处理多类问题时,就需要构造合适的多类分类器。 1、直接法 :直接在目标函数上进行修改,将多个分类面的参数求解合并到一个最优化问题中,通…

《机器学习算法》SVM进行多分类及代码实现

最近做了一个工作就是对属性进行分类,然后用了不同的分类器,其中就用到了SVM,再次做一个总结。 1、什么是SVM? 对于这个点已经介绍的非常多了,不管是西瓜书还是各种博客,就是我们需要找到一个超平面&…

android uevent机制,安卓linux uevent内核上报机制实例

uevent可以实现内核通知上层的一种机制,最常见的电池状态的变化就是kernel uevent通知的,每次百分比或者其他的变化通过power_supply_changed通知上层update; 每个device下面都有kobj,找到device就可以通过kobject_uevent_env 通知android了; 以拔出T卡为例,内核通知上层。…

Linux设备模型剖析系列之二(uevent、sysfs)

CSDN链接: Linux设备模型剖析系列一(基本概念、kobject、kset、kobj_type) Linux设备模型剖析系列之二(uevent、sysfs) Linux设备模型剖析系列之三(device和device driver) Linux设备模型剖析系…

Linux下的uevent

查找linux的uevent节点(find /sys -name uevent),大概有1000多个,那这些节点是怎么实现的呢。 drivers/base/core.c 有如下代码,每创建一个device,都会创建一个event节点 static ssize_t uevent_show(struct device *dev, struc…

Android UEvent事件分析

1.背景概述 众所周知,在安卓系统中有状态栏,在插入外设的时候,会在顶部状态栏显示小图标。 比如,camera设备,耳机设备,U盘,以及电池等等。这些都需要在状态栏动态显示。 从上面这张图片可以看出这些设备都有自己的服务一直在跑,并且都是继承了UEventObserver.java这个…

嵌入式Linux——uevent机制:uevent原理分析

简介: 本文主要介绍uevent机制是什么,并通过代码分析使用uevent机制生成设备节点的过程。而本文将分为两部分,第一部分我们介绍一些预备知识和uevent的原理,而第二部分——通过代码介绍使用uevent机制创建设备节点。 Linux内核&am…

uevent机制:uevent原理分析

简介: 本文主要介绍uevent机制是什么,并通过代码分析使用uevent机制生成设备节点的过程。而本文将分为两部分,第一部分我们介绍一些预备知识和uevent的原理,而第二部分——通过代码介绍使用uevent机制创建设备节点。 声明&#…

900 多道 LeetCode 题解,这个 GitHub 项目值得 Star!

公众号关注 “GitHubPorn” 设为 “星标”,每天带你逛 GitHub! 大家好,我是小 G。 周末风和日丽,适合刷 LeetCode 今天给你们推荐个 GitHub 项目,里面收集了 900 多道 LeetCode 题解,并包含中英文两个版本&…

Leetcode中你的代码执行之后显示超出时间限制

Leetcode中但凡是你的代码执行之后显示超出时间限制 比如: 那么必定是你写的代码不够完善甚至是还存在错误!

升职加薪,必不可少!Python刷题打怪,你要的LeetCode答案都在这里了!

对于还不了解LeetCode的同学,那比较厉害了,估计离大厂还有一步距离! LeetCode,让程序员进阶的在线平台,找工作备战名企技术面试!(文末阅读原文到达学习平台) 本公众号之前陪伴了几期LeetCode的打卡之旅&…

LeetCode 96~100

前言 本文隶属于专栏《LeetCode 刷题汇总》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构请见LeetCode 刷题汇总 正文 幕布 幕布链接 96. 不同的二叉搜索树 题解 官方…

【下载】快速通过Python笔试?学大家一样先把LeetCode答案私藏了

如今学习python的同学越来越多了,也正是同学们秋招时期,去年分享了LeetCode答案后,已经有上百位同学找到小编开始实践这个平台。 LeetCode,让程序员进阶的在线平台,找工作备战名企技术面试!(文末阅读原文到…

面试失败总结,这 577 道 LeetCode 题 Java 版答案你值得拥有

去字节、美团、BAT 等大厂面试,刷 LeetCode 上的数据结构算法题是必修课。许多读者说,刷题的时候经常会遇到困难,想要找一本答案题解做参考。 下面分享几个用 Java 语言实现的开源 LeetCode 题解,也要感谢这些优秀的开源作者们&a…

LeetCode答案大全题(java版)

思路:查找时, 建立索引(Hash查找) 或进行排序(二分查找)。本题缓存可在找的过程中建立索引,故一个循环可以求出解(总是使用未 使用元素查找使用元素,可以保证每一对都被检…

LeetCode数据库题目汇总一(附答案)

1、基础SQL 数据表: dept: deptno(primary key), dname, loc emp: empno(primary key), ename, job, mgr(references emp(empno)), sal, deptno(references dept(deptno)) 1 列出emp表中各部门的部门号,最高工资,最低工资 select max(sal) as 最高工资,min(sal) as 最…