Svm实现多分类

article/2025/10/6 11:54:04

机器学习---Svm实现多分类详解

  • Svm实现多类分类原理
    • 代码实现
    • 训练的图片

Svm实现多类分类原理

1.支持向量机分类算法最初只用于解决二分类问题,缺乏处理多分类问题的能力。后来随着需求的变化,需要svm处理多分类分为。目前构造多分类支持向
量机分类器的方法主要有两类: 一类是“同时考虑所有分类”方法,另一类是组合二分类器解诀多分类问题。
第一类方法主要思想是在优化公式的同时考虑所有的类别数据,J.Weston 和C.Watkins 提出的“K-Class 多分类算法”就属于这一类方法。该算法在经典的SVM理论的基础上,重新构造多类分类型,同时考虑多个类别,然后将问题也转化为-个解决二次规划(Quadratic Programming,简称QP)问题,从而实现多分类。该算法由于涉及到的变量繁多,选取的目标函数复杂,实现起来比较困难,计算复杂度高。
第二类方法的基本思想是通过组合多个二分类器实现对多分类器的构造,常见的构造方法有“一对一”(one-against-one)和“一对其余”(one-against-the rest两种。 其中“一对一”方法需要对n类训练数据两两组合,构建n(n- 1)/2个支持向量机,每个支持向量机训练两种不同类别的数据,最后分类的时候采取“投票”的方式决定分类结果。“一对其余”方法对n分类问题构建n个支持向量机,每个支持向量机负责区分本类数据和非本类数据。该分类器为每个类构造一个支持向量机, 第k个支持向量机在第k类和其余n-1个类之间构造一个超平面,最后结果由输出离分界面距离wx+ b最大的那个支持向量机决定。

本文将上述“一对其余”的SVM多分类方法对莺尾花数据进行分类识别,并设法减少训练样本个数,提高训练速度。

代码实现

#-*- coding:utf-8 -*-
'''
@project: exuding-bert-all
@author: exuding
@time: 2019-04-23 09:59:52
'''
#svm 高斯核函数实现多分类
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasetssess = tf.Session()#加载数据集,并为每类分离目标值
iris = datasets.load_iris()
#提取数据的方法
x_vals = np.array([[x[0],x[3]] for x in iris.data])y_vals1 = np.array([1 if y==0 else -1 for y in iris.target])
y_vals2 = np.array([1 if y==1 else -1 for y in iris.target])
y_vals3 = np.array([1 if y==2 else -1 for y in iris.target])
#合并数据的方法
y_vals = np.array([y_vals1,y_vals2,y_vals3])
#数据集四个特征,只是用两个特征就可以
class1_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==0]
class1_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==0]
class2_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==1]
class2_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==1]
class3_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==2]
class3_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==2]
#从单类目标分类到三类目标分类,利用矩阵传播和reshape技术一次性计算所有的三类SVM,一次性计算,y_target的占位符维度是[3,None]
batch_size = 50
x_data = tf.placeholder(shape = [None,2],dtype=tf.float32)
y_target = tf.placeholder(shape=[3,None],dtype=tf.float32)
#TODO
prediction_grid = tf.placeholder(shape=[None,2],dtype=tf.float32)
b = tf.Variable(tf.random_normal(shape=[3,batch_size]))
#计算高斯核函数 TODO
gamma = tf.constant(-10.0)
dist = tf.reduce_sum(tf.square(x_data),1)
dist = tf.reshape(dist,[-1,1])
sq_dists = tf.add(tf.subtract(dist,tf.multiply(2.,tf.matmul(x_data,tf.transpose(x_data)))),tf.transpose(dist))
my_kernel = tf.exp(tf.multiply(gamma,tf.abs(sq_dists)))
#扩展矩阵维度
def reshape_matmul(mat):v1 = tf.expand_dims(mat,1)v2 = tf.reshape(v1,[3,batch_size,1])return (tf.matmul(v2,v1))
#计算对偶损失函数
model_output = tf.matmul(b,my_kernel)
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b),b)
y_target_cross = reshape_matmul(y_target)
second_term = tf.reduce_sum(tf.multiply(my_kernel,tf.multiply(b_vec_cross,y_target_cross)),[1,2])
loss = tf.reduce_sum(tf.negative(tf.subtract(first_term,second_term)))
#创建预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data),1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid),1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA,tf.multiply(2.,tf.matmul(x_data,tf.transpose(prediction_grid)))),tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma,tf.abs(pred_sq_dist)))
#创建预测函数,这里实现的是一对多的方法,所以预测值是分类器有最大返回值的类别
prediction_output = tf.matmul(tf.multiply(y_target,b),pred_kernel)
prediction = tf.arg_max(prediction_output-tf.expand_dims(tf.reduce_mean(prediction_output,1),1),0)
accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction,tf.arg_max(y_target,0)),tf.float32))
#准备好核函数,损失函数,预测函数以后,声明优化器函数和初始化变量
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)
init = tf.initialize_all_variables()
sess.run(init)
#该算法收敛的相当快,所以迭代训练次数不超过100次
loss_vec = []
batch_accuracy = []
for i in range(100):rand_index = np.random.choice(len(x_vals),size=batch_size)rand_x = x_vals[rand_index]rand_y = y_vals[:,rand_index]sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y})temp_loss = sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})loss_vec.append(temp_loss)acc_temp = sess.run(accuracy,feed_dict={x_data:rand_x,y_target:rand_y,prediction_grid:rand_x})batch_accuracy.append(acc_temp)if(i+1)%25 ==0:print('Step #' + str(i+1))print('Loss #' + str(temp_loss))x_min,x_max = x_vals[:,0].min()-1,x_vals[:,0].max()+1
y_min,y_max = x_vals[:,1].min()-1,x_vals[:,1].max()+1
xx,yy = np.meshgrid(np.arange(x_min,x_max,0.02),np.arange(y_min,y_max,0.02))
grid_points = np.c_[xx.ravel(),yy.ravel()]
grid_predictions = sess.run(prediction,feed_dict={x_data:rand_x,y_target:rand_y,prediction_grid:grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)
#绘制训练结果,批量准确度和损失函数
#等高线图
plt.contourf(xx,yy,grid_predictions,cmap=plt.cm.Paired,alpha=0.8)
plt.plot(class1_x,class1_y,'ro',label = 'I.setosa')
plt.plot(class2_x,class2_y,'kx',label = 'I.versicolor')
plt.plot(class3_x,class3_y,'gv',label = 'T.virginica')
plt.title('Gaussian Svm Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5,3.0])
plt.xlim([3.5,8.5])
plt.show()plt.plot(batch_accuracy,'k-',label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Sepal Width')
plt.legend(loc = 'lower right')
plt.show()plt.plot(loss_vec,'k--')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

训练的图片

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述


http://chatgpt.dhexx.cn/article/8bEqSFUO.shtml

相关文章

SVM学习(二):线性分类器

1.线性分类器概念 线性分类器(一定意义上,也可以叫做感知机) 是最简单也很有效的分类器形式.在一个线性分类器中,可以看到SVM形成的思路,并接触很多SVM的核心概念。用一个二维空间里仅有两类样本的分类问题来举个小例子。如图所示: C1和C2是要区分的两个类别&#x…

机器学习笔记之(5)——SVM分类器

本博客为SVM分类器的学习笔记~由于仅仅是自学的笔记,大部分内容来自参考书籍以及个人理解,还请广大读者多多赐教 主要参考资料如下: 《机器学习实战》《Python机器学习》《机器学习Python实践》《Python机器学习算法》《Python大战机器学习》…

机器学习之SVM分类器介绍——核函数、SVM分类器的使用

系类文章目录 机器学习算法——KD树算法介绍以及案例介绍 机器学习的一些常见算法介绍【线性回归,岭回归,套索回归,弹性网络】 文章目录 一、SVM支持向量机介绍 1.1、SVM介绍 1.2、几种核函数简介 a、sigmoid核函数 b、非线性SVM与核函…

SVM分类器(matlab)

源自:https://blog.csdn.net/lwwangfang/article/details/52351715 支持向量机(Support Vector Machine,SVM),可以完成对数据的分类,包括线性可分情况和线性不可分情况。1、线性可分 首先,对于SVM来说&…

线性分类器(SVM,softmax)

目录 导包和处理数据 数据预处理--减平均值和把偏置并入权重 SVM naive版 向量版 Softmax navie版 向量版 线性分类器--采用SGD算法 SVM版线性分类 Softmax版线性分类 使用验证集调试学习率和正则化系数 画出结果 测试准确率 可视化权重 值得注意的地方 赋值 ran…

SVM多分类的两种方式

以下内容参考:https://www.cnblogs.com/CheeseZH/p/5265959.html http://blog.csdn.net/rainylove1/article/details/32101113 王正海《基于决策树多分类支持向量机岩性波谱分类》 SVM本身是一个二值分类器,SVM算法最初是为二值分类问题设计的&#xff0…

使用SVM分类器进行图像多分类

ResNet backbone SVM分类器 对于样本较为均衡小型数据集,SVM作为分类器的效果与MLP的效果相近。 从经验上看,对于样本不均衡的大型数据集,MLP的效果强于SVM。 本博客在自己的小型数据集上进行实验,本来使用MLP已经达到很好的效果…

SVM分类器原理详解

第一层、了解SVM 支持向量机,因其英文名为support vector machine,故一般简称SVM,通俗来讲,它是一种二类分类模型,其基本模型定义为特征空间上的间隔最大的线性分类器,其学习策略便是间隔最大化&#xff0c…

【CV-Learning】线性分类器(SVM基础)

数据集介绍(本文所用) CIFAR10数据集 包含5w张训练样本、1w张测试样本,分为飞机、汽车、鸟、猫、鹿、狗、蛙、马、船、卡车十个类别,图像均为彩色图像,其大小为32*32。 图像类型(像素表示) 二…

支持向量机通俗导论(理解SVM的三层境界)

支持向量机通俗导论(理解SVM的三层境界) 作者:July 。致谢:pluskid、白石、JerryLead。说明:本文最初写于2012年6月,而后不断反反复复修改&优化,修改次数达上百次,最后修改于2016年1月。 前言 动笔写这个支持向量机(support vector machine)是费了不少劲和…

[机器学习] 支持向量机通俗导论节选(一)

本文转载自:http://blog.csdn.net/v_july_v/article/details/7624837 支持向量机通俗导论(理解SVM的三层境界) 作者: July、pluskid ; 致谢:白石、J erryLead 出处:结构之法算法之道 blog …

机器学习之旅---SVM分类器

本次内容主要讲解什么是支持向量,SVM分类是如何推导的,最小序列SMO算法部分推导。 最后给出线性和非线性2分类问题的smo算法matlab实现代码。 一、什么是支持向量机(Support Vector Machine) 本节内容部分翻译Opencv教程: http://docs.open…

人工智能学习笔记 实验五 python 实现 SVM 分类器的设计与应用

学习来源 【机器学习】基于SVM人脸识别算法的一些对比探究(先降维好还是先标准化好等对比分析)_○( ^皿^)っHiahiahia…的博客-CSDN博客 实验原理 有关svm原理 请移步该篇通俗易懂的博客 机器学习算法(一&#xff0…

Matlab-SVM分类器

支持向量机(Support Vector Machine,SVM),可以完成对数据的分类,包括线性可分情况和线性不可分情况。 1、线性可分 首先,对于SVM来说,它用于二分类问题,也就是通过寻找一个分类线(二维是直线&…

UGUI——RectTransform详解

什么是RectTransform 创建一个UGUI控件时,查看其Inspector面板,原先熟悉的Transform已经被替换成RectTransform,面板也与原先的Transform的面板相去甚远。 先看看Unity官方对RectTransform的描述: Position, size, anchor and pi…

【Unity3D】UGUI之Button

1 Button属性面板 在 Hierarchy 窗口右键,选择 UI 列表里的 Button 控件,即可创建 Button 控件,选中创建的 Button 控件,按键盘【T】键,可以调整 Button 控件的大小和位置。创建 Button 控件时,系统会自动给…

UGUI基础

UGUI基础 ##1、UGUI概述 1.1、Unity界面发展史 【老版本界面onGUI】>【GUI插件NGUI】>【新版本界面UGUI】 1.2、UGUI特点 新的UI系统是从Unity4.6开始被集成到Unity编译器中的。Unity官方给这个新的UI系统赋予的标签是:灵活,快速和可视化。 对…

【Unity基础】ugui的案例篇(个人学习)

文章目录 前言案例1、点击游戏物体改变一次颜色,被UI遮挡的情况下点击无效1.动态图演示2.实现方式I.实现方案1 通过射线检测实现 3.源码演示Lua部分代码CSharp部分代码 案例2、圆形图片的制作1.图演示2.实现方式I.实现方案1 使用Mask组件实现II.实现方案2 通过重写G…

Unity UGUI源码解析

前言 这篇文章想写的目的也是因为我面试遇到了面试官问我关于UGUI原理性的问题,虽然我看过,但是并没有整理出完整的知识框架,导致描述的时候可能自己都是不清不楚的。自己说不清楚的东西,别人就等于你不会。每当学完一个东西的时…

UGUI基础学习

目录 TEXT IMAGE ROWIMAGE TEXT fontsize:字体 color:字体颜色 ;inespacing:字行间隔 代码展示: using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.UI;public class TEXTtest : MonoBehaviour {p…