机器学习基石:PLA算法代码实现

article/2025/11/6 23:59:41

一、前言

本篇是面向机器学习基石第一次作业而言。

15-20题都是需要编程实现才能正确做出选择。

前面14个选择题,我觉得题目出得并不好。这里就不再多说。主要面向最后的LPA和pocket算法的实现。

代码对应的gitee地址

二、PLA算法实现

数据集中每个样本都是的 X X X都是四维向量 [ x 1 , x 2 , x 3 , x 4 ] [x_1,x_2,x_3,x_4] [x1,x2,x3,x4],对应y = {1, -1}

1. 按照数据集给定的顺序更新算法

def pla(datas):size = len(datas)if size <= 1:return;err_i = -1  # 标记当前用于更新的data行dms = len(datas[0])if dms == 0:return;w = [0 for x in range(0, dms)]run_times = 0last_pause = sizenow = 0while True:run_times = run_times + 1 # 整个数据循环的圈数while now != last_pause:    # 转一圈之后,两个碰在一起p = 0now %= size # 当前在size中的位置for x in range(0, dms-1):p += w[x] * datas[now][x]p += w[-1]if p <= 0 and datas[now][-1] > 0 or p >0 and datas[now][-1] < 0:err_i = nowlast_pause = err_iif last_pause == 0:last_pause == sizenow += 1breaknow += 1# 更新w(w_0放在末尾)if err_i != -1:for x in range(0, dms - 1):w[x] += datas[err_i][-1]*datas[err_i][x]w[-1] += datas[err_i][-1]err_i = -1else:break;return [w, run_times]

其中需要是last_pause是当前最后错误的位置,如果从当前错误的位置转了一圈又回到这里而没有遇到其他错误的更新点的时候,说明已经更新完毕。

2. 随机访问数据的顺序更新算法

# 永远保证当前时刻是[0,n)
# 每次交换当前i的随机的数
def randomIndex(n):index = [i for i in range(0,n)]def swap(l,x,y):l[x] = l[x]+l[y]l[y] = l[x] - l[y]l[x] = l[x] - l[y]for i in range(0,n):swap(index,i,int(random.random()*n))return indexdef plaImproved(datas,n = 1):size = len(datas)if size<=1:return;err_i = -1dms = len(datas[0])if dms == 0:return;para = [0 for x in range(0,dms)]run_times = 0index = randomIndex(size)last_pause = sizei = 0while True:#if run_times>=50:#breakrun_times+=1#for i in range(0, size):while i != last_pause:p = 0i %= sizefor x in range(0, dms - 1):p += para[x] * datas[index[i]][x]p += para[-1]if p <= 0 and datas[index[i]][-1] > 0 or p > 0 and datas[index[i]][-1] < 0:#ignore datas[i][-1] == 0err_i = index[i]break; #遇到错误推出循环i+=1if err_i != -1:for x in range(0, dms - 1): #用这个错误来更新参数para[x] = para[x]+ n* datas[err_i][-1] * datas[err_i][x]  # update the parameterspara[-1] += n * datas[err_i][-1]last_pause = iif last_pause == 0:last_pause = sizei+=1err_i = -1;else:break;return [para,run_times]

注意

  1. 所谓improved版本的PLA,主要是实现利用打乱的后的index来作为更新顺序。
  2. 更改权重w:会影响到最终的效果。但是从最终结果来看,平均更新次数差不多。
    W t + 1 = W t + w ∗ y n ( t ) X n ( t ) W_{t+1} = W_t+w*y_n(t)X_n(t) Wt+1=Wt+wyn(t)Xn(t)

三、Pocket算法实现

def pocket(datas, max_time=50, greedy=1):size=len(datas)if size <= 1:returnerr_i = -1dms = len(datas[0])if dms == 0:returnw = [0 for x in range(0,dms)]new_w = [0 for x in range(0,dms)]new_error = 0last_error = sizerun_times = 0while True:index = randomIndex(size)if run_times>max_time:breakrun_times += 1for i in range(0, size):p = 0for x in range(0, dms-1):p += new_w[x]*datas[index[i]][x]p += new_w[-1]if p <= 0 and datas[index[i]][-1] > 0 or p > 0 and datas[index[i]][-1] < 0:#ignore datas[i][-1] == 0err_i = index[i]breakif err_i != -1:for x in range(0, dms - 1): #用这个错误来更新参数new_w[x] += datas[err_i][-1] * datas[err_i][x]  # update the parametersnew_w[-1] += datas[err_i][-1]if greedy == 1:           for i in range(0, size):p = 0for x in range(0, dms-1):p += new_w[x]*datas[index[i]][x]p += new_w[-1]if p <= 0 and datas[index[i]][-1] > 0 or p > 0 and datas[index[i]][-1] < 0:#ignore datas[i][-1] == 0new_error += 1if (new_error < last_error):w = copy.deepcopy(new_w)    # 如果不是deepcopy,就等于只是引用last_error = new_errornew_error = 0err_i = -1else: breakif greedy == 0:return [new_w, run_times]else:return [w, run_times]

注意事项

  1. python中的拷贝:使用deepcopy才能真正实现我们想要的拷贝
  2. 更新次数的提升,会降低错误分类的比例。100次更新time比50次要好1%左右
  3. 就50次更新而言,使用pocket比直接使用更新后的w效果好,大约1%左右

四、算法可视化

随机生成二维平面[0-20]的点:

def random2DDatas(num):result = []g1 = [random.random()*20,random.random()*20]g2 = [random.random()*20,random.random()*20]# 由数据范围内的两个点来确定分割线,保证划分线一定会经过生成的点的范围w = [(g1[1] - g2[1])/(g1[0] -g2[0]),-1,g1[1] - (g1[1] - g2[1])/(g1[0] -g2[0])*g1[0]]result.append(w) # 完美分割线for i in range(num):x = [random.random()*20,random.random()*20]y = w[0]*x[0]+w[1]*x[1]+w[2]if y<0:x.append(-1)elif y>0:x.append(1)else:continue#print(x,y)result.append(x)return result

调用PLA算法,并做可视化:

def visualizePLA(all,w = []):x = np.linspace(0,20,50)  # 在1到10之间产生50组数据(数据之间呈等差数列)y = - all[0][2]/all[0][1]  - all[0][0]/all[0][1]*x  # 最开始的线plt.figure()plt.plot(x,y,color="black")if len(w)!=0:z = - w[2] / w[1] - w[0] / w[1] * xplt.plot(x,z,color="orange",linestyle="--")posx = []posy = []negx = []negy = []for i in  range(1,len(all)):if all[i][-1] == -1:negx.append(all[i][0])negy.append(all[i][1])else:posx.append(all[i][0])posy.append(all[i][1])plt.scatter(negx,negy,marker='x',c='r')plt.scatter(posx,posy,marker='o',c='g')plt.show()

最终效果:

在这里插入图片描述


http://chatgpt.dhexx.cn/article/bH6WF0UD.shtml

相关文章

机器学习基石——作业2解答

机器学习基石——作业2解答 这里的 μ 指的是某个h(x)≈f(x)&#xff0c;对应的 Eout(h) 。其中目标函数 f(x) 是确定性的&#xff0c;没有噪声干扰。如果加上噪声&#xff0c;目标函数变为课中讲的概率分布 P(y│x) &#xff0c;表示为 P(y│x){λ1−λyf(x)otherwize \begi…

台湾大学林轩田机器学习基石课程学习笔记3 -- Types of Learning

红色石头的个人网站&#xff1a;redstonewill.com 上节课我们主要介绍了解决线性分类问题的一个简单的方法&#xff1a;PLA。PLA能够在平面中选择一条直线将样本数据完全正确分类。而对于线性不可分的情况&#xff0c;可以使用Pocket Algorithm来处理。本节课将主要介绍一下机器…

林軒田《机器学习基石》课程总结

最近发布了一系列台湾大学资讯工程系林軒田&#xff08;Hsuan-Tien Lin&#xff09;教授开设的《机器学习基石》的课程总结&#xff0c;分为4个部分&#xff0c;点击标题可查看&#xff1a; 机器什么时候能够学习&#xff1f;&#xff08;When Can Machines Learn&#xff1f;…

台大林轩田《机器学习基石》:作业三python实现

台大林轩田《机器学习基石》&#xff1a;作业一python实现 台大林轩田《机器学习基石》&#xff1a;作业二python实现 台大林轩田《机器学习基石》&#xff1a;作业三python实现 台大林轩田《机器学习基石》&#xff1a;作业四python实现 完整代码&#xff1a; https://github…

机器学习基石系列三

课程关联与可学习 核心问题 上界限制 增长上限 上界证明&#xff08;不太懂&#xff09; - step three

林轩田 《机器学习基石》学习笔记

参考资料&#xff1a; 除了redstone的笔记较好之外&#xff0c;还有豆瓣的https://www.douban.com/doulist/3381853/的笔记也比较好 -------------------------------------- 1. 什么时候机器可以学习&#xff1f; 2. 为什么要要使用机器学习&#xff1f; 3. 机器怎么可以学习到…

【机器学习】机器学习基石-林轩田-1-机器学习介绍

机器学习基石-1-机器学习介绍 本节内容What is Machine Learning&#xff1f;What is skill?Why use machine learning?When use machine learning?Key Essence of Machine LearningFun TimeApplications of Machine LearningComponents of Machine Learning相关术语Leanin…

机器学习基石 作业0

机器学习基石 作业0 1 Probability and Statistics2 Linear Algebra3 Caculus网上没找到作业0的答案,这是自己做的版本,有一些可能会有错误,欢迎讨论。 1 Probability and Statistics 用数学归纳法。N=1时满足,假定N=n满足,当N=n+1同样满足。得证。 10个挑4个正面 C 10 4…

机器学习基石 作业三

机器学习基石 作业三 代入计算 线性回归得到的映射函数 H H H的性质问题。显然映射多次与映射一次效果一样。其它的可以根据 H H H的性质,秩为d+1,显然不可逆。特征值的部分不是非常清楚,大概是根据 I − H I-H I−H的迹等于 N − ( d + 1 ) N-(d+1) N−(d+1)得到的。3. PLA…

机器学习基石笔记

文章目录 一. 机器学习什么时候用二. 机器学习的基本流程三. 什么是机器学习四. 机器学习的可行性NFL定理从统计学中找到可行的方法统计学与机器学习产生联系 一. 机器学习什么时候用 事物本身存在某种潜在规律某些问题难以使用普通编程解决有大量的数据样本可供使用 二. 机器…

机器学习基石 作业二

机器学习基石 作业二 1 计算一下本来预测对与预测错时加上噪音导致的错误率然后相加即可。 2 选择一个 λ \lambda λ的值让 μ \mu μ的系数为0。 3 根据VC bound 公式带入计算即可,N=46000的时候error最接近0.05。下面的代码可以计算不同的N与目标error之间的差距。 def …

机器学习基石2-Learning to Answer Yes-No

注&#xff1a; 文章中所有的图片均来自台湾大学林轩田《机器学习基石》课程。 笔记原作者&#xff1a;红色石头 微信公众号&#xff1a;AI有道 上节课&#xff0c;简述了机器学习的定义及其重要性&#xff0c;并用流程图的形式介绍了机器学习的整个过程&#xff1a;根据模型\(…

机器学习基石-林轩田-第一周笔记

Lecture 01 - The Learning Problem When Can Machine Learn ?Why Can Machine Learn ?How Can Machine Learn ?How Can Machine Learn Better ? What is Machine Learning 什么是“学习”&#xff1f;学习就是人类通过观察、积累经验&#xff0c;掌握某项技能或能力。就…

机器学习基石16:三个重要原则(Three Learning Principles)

本节介绍了机器学习中三个重要原则&#xff0c;包括奥卡姆剃刀原理&#xff0c;样本偏差&#xff0c;数据窥探&#xff1b;并对16课程所学知识进行了总结。 系列文章 机器学习基石01&#xff1a;机器学习简介 机器学习基石02&#xff1a;感知器算法&#xff08;Perceptron Alg…

机器学习基石1(ML基本概念和VC dimension)

文章目录 一、什么是机器学习?二、什么时候可以使用机器学习?三、感知机perceptron四、机器学习的输入形式五、机器真的可以学习吗&#xff1f;六、vc dimension 一、什么是机器学习? 其实第一个问题和第二个问题是穿插到一块儿回答的&#xff0c;首先机器学习要解决的是常规…

Wireshark抓包数据

首先官网下载Wireshark&#xff0c;下载好后&#xff0c;用浏览器打开桂林生活网&#xff0c;无需注册&#xff0c;输入账号密码。 打开Wireshark&#xff0c;用命令提示符查看本机ip 在Wireshark的过滤搜索中输入ip10.34.152.44&#xff0c;找到http类型的数据查看&#xff0…

Wireshark抓包数据分析

文章目录 准备数据链路层实作一 熟悉 Ethernet 帧结构实作二 了解子网内/外通信时的 MAC 地址实作三 掌握 ARP 解析过程 网络层实作一 熟悉 IP 包结构实作二 IP 包的分段与重组实作三 考察 TTL 事件 传输层实作一 熟悉 TCP 和 UDP 段结构实作二 分析 TCP 建立和释放连接 应用层…

网络数据包分析与抓取

多年的网络数据包分析与抓取经验&#xff0c;闲话少说&#xff0c;上干货。先列举数据包的种类&#xff1a;1、Http数据包&#xff1b;2、UDP数据包&#xff1b;3、TCP数据包&#xff1b;4、ARP数据包&#xff1b;其实数据包的概念是很泛的&#xff0c;在软件可逆领域&#xff…

如何进行数据的抓包

抓包 抓包就是对网络传输中发送与接收的数据包进行截获、重发、编辑、转存等操作。 前提&#xff1a;抓取的数据包是从网卡设备中进行抓取的&#xff1b; win wiresharkLinux tcpdump命令 从上图我们就可以了解到tcpdump就是我们使用的一个工具&#xff1b; 我们在使用它时有…

WireShark基本抓包数据分析

WireShark抓包数据分析&#xff1a; 1、TCP报文格式 源端口、目的端口&#xff1a;16位长。标识出远端和本地的端口号。 顺序号&#xff1a;32位长。表明了发送的数据报的顺序。 确认号&#xff1a;32位长。希望收到的下一个数据报的序列号。 TCP协议数据报头DE 头长&#xff…