回归(regression)

article/2025/9/26 3:08:19

1. 回归(regression)

1.1 起源与定义

回归最早是被高尔顿提出的。他通过研究发现:如果父母都比较高一些,那么生出的子女身高会低于父母的平均身高;反之,如果父母双亲都比较矮一些,那么生出的子女身高要高于父母平均身高。他认为,自然界有一种约束力,使得身高的分布不会向高矮两个极端发展,而是趋于回到中心,所以称为回归。
目前,从用法角度将其定义为一种数值(scalar)预测的技术,区别于分类(类别预测技术)。

1.2 不同的用法

1.2.1 解释(Explanation)

回归可用于做实证研究,研究自变量和因变量之间的内在联系和规律,常见于社会科学研究中。

  • 互联网的普及降低了教育不平等程度吗?
  • 大学生就业选择的影响因素有哪些?
  • 医疗电子商务场景下客户满意度的影响因素有哪些?

1.2.2 预测(Prediction)

回归也可用来做预测,根据已知的信息去准确预测未知的事情。

  • 股市预测:根据过去10年股票的变动、新闻咨询、公司并购咨询等,预测股市明天的平均值。
  • 商品推荐:根据用户过去的购买记录和候选的商品信息,预测用户购买某个商品的可能性。
  • 自动驾驶:根据汽车的各个sensor的数据,例如路况和车距等,预测正确的方向盘角度。

1.3 模型的构建

无论目的是解释还是预测,都需要掌握与任务相关的规律(认识世界),即建立合理的模型。
不同的一点是,解释模型只需要基于训练集构建,一般具备解析解(计量经济模型)。 预测模型必须在测试集上做检验和调整,一般不具备解析解,需要通过机器学习的方法去调整参数。因此,同样的模型框架和数据集,最优的解释模型和预测模型很可能是不相同的。
本文主要关注预测模型的构建,不涉及解释模型相关的内容。

2. 基于机器学习的模型构建

我们以Pokemon精灵攻击力预测这个任务为例,梳理机器学习三个步骤的详细内容。

  • 输入:进化前的CP值、物种(Bulbasaur)、血量(HP)、重量(Weight)、高度(Height)
  • 输出:进化后的CP值

2.1 模型假设 - 线性模型

为了方便,我们选择最简单的线性模型来作为完成回归任务的模型框架。我们可以使用单特征或者多特征的线性回归模型,后者会更加复杂,模型集合会更大。

为选择合理的模型框架,提前对数据集进行探索,观察变量间的关系是很有必要的,这将决定最终将哪些变量放入模型,以及是否需要对变量进行再次处理(二次项、取倒数等)。

可以看出,横轴和纵轴主要呈直线关系,也有一些二次关系(可考虑加二次项)。
模型框架(预先设定) + 参数(待估计) = 模型(目标)
目前模型的参数包括各个特征的权重 w i w_i wi 以及偏移量 b b b

2.2 模型评价 - 损失函数

本文阐述的回归任务属于有监督学习场景,因此需要收集足够的输入输出对以指导模型的构建。

有了这些真实的数据,那我们怎么衡量模型的好坏呢?从数学的角度来讲,我们使用损失函数(Loss function) 来衡量模型的好坏。Loss function基于模型预测值和实际值的差异来设置。

在本文中,我们选择常用的均方误差作为损失函数。

2.3 模型调优 - 梯度下降

当模型非凸时,是没有解析解的,只能通过启发式的方式迭代优化,常用的方法是梯度下降。

首先,我们随机选择一个 w 0 w^0 w0,然后计算微分判定移动的方向,再更新对应参数,循环往复,直到找到最低点(两次更新之间差异小于阈值或者达到预先设定好的迭代次数)。
对于有多个待更新参数的模型,步骤是基本一致的,只不过做的是偏微分。

在梯度下降的过程中,会遇到一些问题,导致无法达到最优点。

这些问题如何解决以后会涉及到。

3. 模型构建中的问题和解决

3.1 评价模型的泛用性(Generalization)

好模型不仅要在训练集中表现优异,在未知的数据集(测试集,真实应用场景)中也应该一样。
因此,我们必须要计算模型在测试机上的性能,理想情况下不能有较大的下滑。

3.2 提高模型的拟合度

若模型过于简单,则模型集合较小,可能无法包含真实的模型,即出现欠拟合问题。
我们可以选择更复杂的模型去优化性能。以使用1元2次方程举例,显著提高了预测性能。

我们还可以在模型中增加调节项(Pokemon种类)来改进模型。
在这里插入图片描述

模型在训练集和测试集的性能表现如下所示:

3.3 防止过拟合(Overfiting)的出现

如果我们继续使用更高次的模型,可能会出现过拟合问题。

我们可以通过加入正则项来防止过拟合问题的出现。

正则项权重变化对模型性能的影响如下所示:

4. 回归 - 代码演示

现在假设有10个x_data和y_data,x和y之间的关系是y_data=b+w*x_data。b,w都是参数,是需要学习出来的。现在我们来练习用梯度下降找到b和w。

import numpy as np
import matplotlib.pyplot as plt
from pylab import mpl# matplotlib没有中文字体,动态解决
plt.rcParams['font.sans-serif'] = ['Simhei']  # 显示中文
mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题# 生成实验数据
x_data = [338., 333., 328., 207., 226., 25., 179., 60., 208., 606.]
y_data = [640., 633., 619., 393., 428., 27., 193., 66., 226., 1591.]
x_d = np.asarray(x_data)
y_d = np.asarray(y_data)
x = np.arange(-200, -100, 1) # 参数的候选项,指偏移项b
y = np.arange(-5, 5, 0.1) # 参数的候选项,指权重w
Z = np.zeros((len(x), len(y)))
X, Y = np.meshgrid(x, y)# 得出每种可能组合下的loss,共需要计算100*100=10000次
for i in range(len(x)):for j in range(len(y)):b = x[i]w = y[j]Z[j][i] = 0  # meshgrid吐出结果:y为行,x为列for n in range(len(x_data)):Z[j][i] += (y_data[n] - b - w * x_data[n]) ** 2Z[j][i] /= len(x_data)

以上代码生成了实验数据,并用穷举法计算出了所有可能组合的loss,其中最小值为10216。
接下来我们尝试使用梯度下降法来快速寻找到较小的loss值。

# linear regression
b=-120
w=-4
lr = 0.000005
iteration = 10000 #先设置为10000b_history = [b]
w_history = [w]
loss_history = []
import time
start = time.time()
for i in range(iteration):m = float(len(x_d))y_hat = w * x_d  +bloss = np.dot(y_d - y_hat, y_d - y_hat) / mgrad_b = -2.0 * np.sum(y_d - y_hat) / mgrad_w = -2.0 * np.dot(y_d - y_hat, x_d) / m# update paramb -= lr * grad_bw -= lr * grad_wb_history.append(b)w_history.append(w)loss_history.append(loss)if i % 1000 == 0:print("Step %i, w: %0.4f, b: %.4f, Loss: %.4f" % (i, w, b, loss))
end = time.time()
print("大约需要时间:",end-start)
# Step 0, w: 1.6534, b: -119.9839, Loss: 3670819.0000
# Step 1000, w: 2.4733, b: -120.1721, Loss: 11492.1941
# Step 9000, w: 2.4776, b: -121.6771, Loss: 11435.5676

可以发现,梯度下降法可以快速从初始值迭代到合适的参数组合,接近最优参数。但我们发现,达到最优值的过程却非常缓慢。使用下面的代码可以对寻优过程进行可视化。

# plot the figure
plt.contourf(x, y, Z, 50, alpha=0.5, cmap=plt.get_cmap('jet'))  # 填充等高线
plt.plot([-188.4], [2.67], 'x', ms=12, mew=3, color="orange") # 最优参数
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5, 5)
plt.xlabel(r'$b$')
plt.ylabel(r'$w$')
plt.title("线性回归")
plt.show()

如下图所示,参数最终寻优的方向是正确的,但是因为迭代次数不够所以提前停止。
在这里插入图片描述

我们将迭代次数更改为10万次,结果如下所示:
在这里插入图片描述

迭代次数仍然不足,我们继续将迭代次数更改为100万次,结果接近最优,如下所示:
在这里插入图片描述

迭代次数太多会消耗过多的计算资源,我们可以通过调整学习率来加快速度。当我们将学习率设置为之前的两倍(0.00001)时,迭代10万次即可达到接近最优的结果,如下所示;
在这里插入图片描述

但需要注意的是,学习率如果设置得太高,可能会发生振荡,无法收敛。下图是我们将学习率设置为0.00005时的情况。
在这里插入图片描述

总而言之,我们要清楚机器学习的强大能力以及不稳定性,然后学习相关原理进而熟练使用。

参考文献

  1. Datawhale 开源学习资料 李宏毅机器学习
  2. 到底什么是实证研究?

http://chatgpt.dhexx.cn/article/CfjuwtEY.shtml

相关文章

STATA regress回归结果分析

对于STATA回归结果以前一直不清不楚,每次都需要baidu一波,因此今天将结果相关分析记录下: 如上图 上面左侧的表是用来计算下面数据的,分析过程中基本不会用到 右侧从上往下 1.Number of obs 是样本容量 2.F是模型的F检验值&a…

MATLAB regress命令

1 regress命令 用于一元及多元线性回归,本质上是最小二乘法。在Matlab 命令行窗口输入help regress ,会弹出和regress的相关信息,一一整理。 调用格式: B regress(Y,X)[B,BINT] regress(Y,X)[B,BINT,R] regress(Y,X)[B,BINT,R…

MATLAB回归分析命令——regress命令

题目 假设向量y[7613.51 7850.91 8381.86 9142.81 10813.6 8631.43 8124.94 9429.79 10230.81 10163.61 9737.56 8561.06 7781.82 7110.97]; x1[7666 7704 8148 8571 8679 7704 6471 5870 5289 3815 3335 2927 2758 2591]; x2[16.22 16.85 17.93 17.28 17.23 17 19 18.22…

matlab中多元线性回归regress函数精确剖析(附实例代码)

matlab中多元线性回归regress函数精确剖析(附实例代码) 目录 前言 一、何为regress? 二、regress函数中的参数 三、实例分析 总结 前言 regress函数功能十分强大,它可以用来做多元线性回归分析,它不仅能得出线性回归函数中各个系数&#…

Ubuntu 下安装 Yar 扩展遇到的问题以及解决方案

本文为原创,转载请注明出处。 昨天在ubuntu上安装完yar之后发现yar还是不能用,感觉有些不对劲。 通过在微博上请教 Laruence 大神和公司的大大之后,问题解决了,下面就来分享这一成果。 如果还没有做好安装工作,请看…

最新yar扩展安装和使用

先说windows客户端的安装 yar扩展下载地址:https://pecl.php.net/package/yar 下载对应的版本 点击DLL可以查看支持的PHP版本,我本地是phpstudy搭建的环境,php版本是7.19.nts 所以我下载了7.1 Non Thread Safe (NTS) x64 解压后里面的php_yar.dll和php_yar.pdb文件…

yarn使用简介

yarn简介: yarn是facebook发布的一款取代npm的包管理工具。 yarn的特点: 速度超快。Yarn 缓存了每个下载过的包,所以再次使用时无需重复下载。 同时利用并行下载以最大化资源利用率,因此安装速度更快。超级安全。 在执行代码之前…

yar安装使用

1.安装 pecl install yar vim /etc/php.ini 加上extensionyar.so 查看支持的配置&#xff1a; php --re yar - Dependencies { Dependency [ json (Required) ] } - INI { Entry [ yar.packager <PERDIR> ] //打包协议 Current php } …

phpstudy安装yar扩展

最近因为项目需要yar扩展&#xff0c;本地开发环境使用phpstudy搭建&#xff0c;yar不是phpstudy的常用扩展&#xff0c;无法在扩展列表里面找到&#xff0c;所以需要自己安装。 0x01 Yar(Yet Another RPC framework for PHP) 是一个轻量级, 高效的RPC框架, 它提供了一种简单…

Yarn基本介绍(一)

1、简介 Yarn是Hadoop的分布式资源调度平台&#xff0c;负责为集群的运算提供运算资源。如果把分布式计算和单个计算机对应的话&#xff0c;HDFS就相当于计算机的文件系统&#xff0c;Yarn就是计算机的操作系统&#xff0c;MapReduce就是计算机上的应用程序。 2、组成部分 Y…

使用yaf+yar实现基于http的rpc服务

什么是RPC RPC&#xff0c;全称是Remote Procedure Call&#xff0c;远程服务调用&#xff0c;是一种通过网络从远程计算机程序上请求服务&#xff0c;而不需要了解底层网络技术的协议。简单一点来理解就是网络上的一个节点请求另一个节点提供的服务。 什么是YAF Yaf&#x…

yaf yar微服务/hprose微服务 镜像初始化 —— k8s从入门到高并发系列教程 (四)

前面的教程已经在docker镜像 软件 层面上初步安装了企业常用的插件&#xff0c;但目前还没有写任何代码。本教程带你初始化yaf框架&#xff0c;并基于yar框架和hprose跨语言微服务框架打包两个微服务代码&#xff0c;在容器间调用。 yaf是一个用c语言写的&#xff0c;用于php项…

YAR 并行RPC框架研究

前几天,部门召开了PHP技术峰会 学习会议,大家分别对这次会议的PPT 做了简单的介绍, 其中提到了 鸟哥【惠新辰】的一篇PPT《微博LAMP 演变》,如果谁有需要可以去谷歌搜,或者去 http://www.laruence.com/2013/08/15/2913.html 他的博客去看一下,我就不提供下载链接了。 …

Yarn概述

Yarn Yarn是Hadoop的分布式资源调度平台&#xff0c;负责为集群的运算提供运算资源。如果把分布式计算机和单个计算机相对应的话&#xff0c;HDFS就相当于计算机的文件系统&#xff0c;Yarn就是计算机的操作系统&#xff0c;MapReduce就是计算机上的应用程序。 Yarn的基本组成…

Yarn介绍

Yarn介绍 一&#xff0c;介绍二&#xff0c; yarn 框架三&#xff0c;ResourceManager3.1&#xff0c;ApplicationsManager3.2&#xff0c;Scheduler 四&#xff0c;NodeManager五&#xff0c;ApplicationMaster六&#xff0c;客户端提交任务到yarn中运行的流程。 YARN的基本思…

Yarn

应用场景 当部署好hadoop集群后,搭建了YARN集群,开启了hadoop的HDFS和YARN服务,访问主节点IP和8088端口的YARN监控界面,发现这个All Applications界面中的开始执行时间和结束执行时间不对,应该往后加8个小时才对,导致在页面中对任务监控的时候容易出错,所以现在要进行修…

Yar 搭建 RPC 服务

一、安装 Yar pecl install yar pecl install msgpack 二、确保 php 加载 yar 模块 php -m 三、编写服务器端 Server.php &#xff0c; 在浏览器打开 http://.../Server.php 可见API的介绍如下 <?phpclass API {public function some_method($parameter, $options &…

[Yar] yar安装与使用过程中遇到问题总结

yar安装与使用过程中遇到问题总结 Yar 简介官方文档yar安装yar运行时的默认配置yar常量使用范例遇到的问题应用流程 Yar 简介 Yar 是一个轻量级, 高效的RPC框架, 它提供了一种简单方法来让PHP项目之间可以互相远程调用对方的本地方法. 并且Yar也提供了并行调用的能力. 可以支持…

PHP封装curd,ThinkPHP5.0的模型CURD创建Create操作

模型的主要功能包括数据处理和业务逻辑&#xff0c;而这些都离不开数据的CURD操作&#xff0c;因此我们首先来谈下数据的CURD操作&#xff0c;在掌握了数据库Db类的用法后&#xff0c;模型的CURD操作就会很容易理解&#xff0c;因为本质上模型的CURD操作最终调用的还是Db类的操…

浅谈CURD系统和CRQS系统

浅谈CURD系统和CRQS系统 在网上看到关于这个内容的介绍&#xff0c;就想着自己整理一下&#xff0c;方便观看。 三层架构 先从三层架构开始讲&#xff0c;三层架构(3-tier architecture) 通常意义上的三层架构就是将整个业务应用划分为:界面层(User Interface layer)、业务逻辑…