机器学习模型评测:holdout cross-validation k-fold cross-validation

article/2025/10/2 9:16:40

cross-validation:从 holdout validation 到 k-fold validation

2016年01月15日 11:06:00 Inside_Zhang 阅读数:4445

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lanchunhui/article/details/50522424

构建机器学习模型的一个重要环节是评价模型在新的数据集上的性能。模型过于简单时,容易发生欠拟合(high bias);模型过于复杂时,又容易发生过拟合(high variance)。为了达到一个合理的 bias-variance 的平衡,此时需要对模型进行认真地评估。本文将介绍两个十分有用的cross-validation技术,holdout cross-validation 以及 k-fold cross-validation,这些方法将帮助我们获得模型关于泛化误差(generalization error)的可信的估计,所谓的泛化误差也即模型在新数据集上的表现。

holdout validation

from sklearn.cross_validation import train_test_split

使用 holdout 方法,我们将初始数据集(initial dataset)分为训练集(training dataset)和测试集(test dataset)两部分。训练集用于模型的训练,测试集进行性能的评价。然而,在实际机器学习的应用中,我们常常需要反复调试和比较不同的参数设置以提高模型在新数据集上的预测性能。这一调参优化的过程就被称为模型的选择(model selection):select the optimal values of tuning parameters (also called hyperparameters)。然而,如果我们重复使用测试集的话,测试集等于称为训练集的一部分,此时模型容易发生过拟合。

一个使用 holdout 方法进行模型选择的较好的方式是将数据集做如下的划分:

  • 训练集(training set);

    The training set is used to fit the different models

  • 评价集(validation set);

    The performance on the validation set is then used for the model selection

  • 测试集(test set);

    The advantage of having a test set that the model hasn’t seen before during the training and model selection steps is that we can obtain a less biased estimate of its ability to generalize to new data. 下图阐述了 holdout cross-validation 的工作流程,其中我们重复地使用 validation set 来评估参数调整时(已经历训练的过程)模型的性能。一旦我们对参数值满意,我们就将在测试集(新的数据集)上评估模型的泛化误差。

 


 

 

holdout 方法的弊端在于性能的评估对training set 和 validation set分割的比例较为敏感

k-fold validation

在 k-fold cross_validation,我们无放回(without replacement)地将训练集分为 k folds(k个部分吧),其中的 k-1 folds 用于模型的训练,1 fold 用于测试。将这一过程重复 k 次,我们便可获得 k 个模型及其性能评价。

我们然后计算基于不同的、独立的folds的模型(s)的平均性能,显见该性能将与holdout method相比,对training set的划分较不敏感。一旦我们找到了令人满意的超参的值,我们将在整个训练集上进行模型的训练。

因为k-fold cross-validation 是无放回的重采样技术,这种方法的优势在于每一个采样数据仅只成为训练或测试集一部分一次,这将产生关于模型性能的评价,比 hold-out 方法较低的variance。下图展示了 k=10k=10 时的 k-fold 方法的工作流程。

 


 

 

k-fold cross-validation 的 kk 值一般取10,对大多数应用而言是一个合理的选择。然而,如果我们处理的是相对较小的训练集的话,增加 kk 的值将会非常实用。因为如果增加 kk,更多的训练数据(N×(1−1k)N×(1−1k))得以使用在每次迭代中,这将导致较低的 bias在评估模型的泛化性能方面。然而,更大的 kk 将会增加cross-validation的运行时间,生成具有更高higher variance 的评价因为此时训练数据彼此非常接近。另一方面,如果我们使用的是大数据集,我们可以选择更小的 kk,比如 k=5k=5,较小的 kk 值将会降低在不同folds上的refitting 以及模型评估时的计算负担

一个 k-fold cross-validation 的特例是 LOO(leave-one-out) cross-validation method。在LOO中,我们取 k=nk=n,如前所述,LOO尤其适用于具有小规模的数据集上。

对于 k-fold cross-validation 的一中改进是 stratified k-fold cv method,该方法可以产生更好的 bias以及variance 评价,尤其当 unequal class proportions。

实践

from sklearn.cross_validation import StratifiedKFold
kfolds = StratifiedKFold(y=y_train, n_folds=10, random_state=1)
scores = []
for i, (train, test) in enumerate(kfolds):# train.shape 将是 test.shape的接近9倍clf.fit(X_train[train], y_train[train])score = clf.score(X_train[test], y_train[test])scores.append(score)print('Fold: %s, class dist.: %s, Acc: %.4f' %(i+1, np.bincount(y_train[train]), score))
print('CV accuracy: %.4f +/- %.4f' % (np.mean(scores), np.std(scores)))

 


http://chatgpt.dhexx.cn/article/itWbSup6.shtml

相关文章

三种模型验证方法:holdout, K-fold, leave one out cross validation(LOOCV)

Cross Validation: A Beginner’s Guide An introduction to LOO, K-Fold, and Holdout model validation By: Caleb Neale, Demetri Workman, Abhinay Dommalapati 源自:https://towardsdatascience.com/cross-validation-a-beginners-guide-5b8ca04962cd 文章目录…

模型检验方法:holdout、k-fold、bootstrap

参考:https://www.cnblogs.com/chay/articles/10745417.html https://www.cnblogs.com/xiaosongshine/p/10557891.html 1.Holdout检验 Holdout 检验是最简单也是最直接的验证方法, 它将原始的样本集合随机划分成训练集和验证集两部分。 比方说&#x…

多种方式Map集合遍历

1.如何遍历Map中的key-value对,代码实现(至少2种) Map集合的遍历(方式1)键找值: package com.B.Container_13.Map;import java.util.HashMap; import java.util.Map; import java.util.Set;//Map集合的遍历(方式1)键找值 public class Map04_01 {publi…

Map集合中的四种遍历方式

1.Map接口的概述 &#xff08;1&#xff09;它是双列集合&#xff1b; &#xff08;2&#xff09;格式&#xff1a;Interface Map<k,v> K:键的类型 V&#xff1a;值得类型 &#xff08;3&#xff09;它的每个元素都包含一个键对象Key和值对象Value&#xff0c;并且他们…

Java中的Map集合以及Map集合遍历实例

文章目录 一、Map集合二、Map集合遍历实例 一、Map集合 Map<K,V>k是键&#xff0c;v是值 1、 将键映射到值的对象&#xff0c;一个映射不能包含重复的键&#xff0c;每个键最多只能映射的一个值 2、 实现类  a) HashMap  b) TreeMap 3、 Map集合和Collection集合的区别…

Map集合的四种遍历

Map集合的四种遍历 这里记录一下map集合的4种遍历&#xff1a; 第一种 得到所有的key–map.keySet() ,根据key拿到value–map.get(key) public static void main(String[] args) {Map<String, String> map new HashMap();map.put("1", "刘备");…

Map集合遍历的三种方式

Map集合遍历的三种方式 遍历Map集合的三种方式 键找值键值对Lambda表达式 方式一 : 键找值 先获取Map集合的全部键的Set集合遍历键的Set集合,然后通过键提取对应值 原理图 键找值涉及到的API 方法名称说明Set keySet()获取所有键的集合V get(Object key)根据键获取值 Map…

java中Map集合的四种遍历方式

java中Map集合的四种遍历方式 Map接口和Collection接口的集合不同,Map集合是双列的,Collection是单列的.Map集合将键映射到值的对象. 双列的集合遍历起来也是比较麻烦些的,特别是嵌套的map集合,这里说下MAP集合的四种遍历方式&#xff0c;并且以嵌套的hashMap集合为例, 遍历一…

如何遍历map集合

Map集合是基于java核心类——java.util中的&#xff1b; Map集合用于储存元素对&#xff0c;Map储存的是一对键值&#xff08;key和value&#xff09;&#xff0c;是通过key映射到它的value values() : 是获取集合中的所有的值----没有键&#xff0c;没有对应关系。 KeySet(…

Map集合常用的三种遍历方式

Map集合使用的是Key - Value的形式存储元素&#xff0c;也就是键值对的形式。Map集合内部的实现分别是HashMap和TreeMap&#xff0c;也就是哈希表和二叉树这两种数据结构。List集合和Set集合都是继承自Collection类&#xff0c;而Map集合就是自己的父类。前者可以直接通过Itera…

Map集合遍历方式

Map集合遍历方式一&#xff1a;键找值 先获取Map集合的全部键的Set集合 //Set keymap.keySet();遍历键的Set集合&#xff0c;然后通过键提取对应值map.getValue() Set<String> keysmaps.keySet();for(String key1:keys){int valu1emaps.get(key1);System.out.println(ke…

MAP集合的遍历方式

简单场景&#xff1a;map集合存放为数字星期 如图&#xff1a; 代码&#xff1a; Map<Integer, String> map new HashMap<>(); map.put(1, "星期一"); map.put(2, "星期二"); map.put(3, "星期三"); map.put(4, "星期四&quo…

Map集合遍历的四种方式

1.通过Map.keySet获取key的Set集合&#xff0c;之后在通过key进行遍历 2.通过Map.values获取所有value&#xff0c;之后再进行遍历 3.通过Map.entrySet获取Set集合&#xff0c;之后通过iterator进行遍历 4.直接通过foreach对Map.entrySet获取的Set集合进遍历 案例&#…

lstm结构图_LSTM模型结构的可视化

目录: 1、传统的BP网络和CNN网络 2、LSTM网络 3、LSTM的输入结构 4、pytorch中的LSTM 4.1 pytorch中定义的LSTM模型 4.2 喂给LSTM的数据格式 4.3 LSTM的output格式 5、LSTM和其他网络组合 最近在学习LSTM应用在时间序列的预测上,但是遇到一个很大的问题就是LSTM在传统BP网络上…

LSTM模型详解

&#xff08;一&#xff09;LSTM模型理解 1.长短期记忆模型&#xff08;long-short term memory&#xff09;是一种特殊的RNN模型&#xff0c;是为了解决RNN模型梯度弥散的问题而提出的&#xff1b;在传统的RNN中&#xff0c;训练算法使用的是BPTT&#xff0c;当时间比较长时&…

LSTM模型、双向LSTM模型以及模型输入输出的理解

循环神经网路&#xff08;RNN&#xff09;在工作时一个重要的优点在于&#xff0c;其能够在输入和输出序列之间的映射过程中利用上下文相关信息。然而不幸的是&#xff0c;标准的循环神经网络&#xff08;RNN&#xff09;能够存取的上下文信息范围很有限。这个问题就使得隐含层…

基于LSTM模型实现新闻分类

1、简述LSTM模型 LSTM是长短期记忆神经网络&#xff0c;根据论文检索数据大部分应用于分类、机器翻译、情感识别等场景&#xff0c;在文本中&#xff0c;主要使用tensorflow及keras&#xff0c;搭建LSTM模型实现新闻分类案例。&#xff08;只讨论和实现其模型的应用案例&#…

Pytorch实现的LSTM模型结构

LSTM模型结构 1、LSTM模型结构2、LSTM网络3、LSTM的输入结构4、Pytorch中的LSTM4.1、pytorch中定义的LSTM模型4.2、喂给LSTM的数据格式4.3、LSTM的output格式4.4 LSTM笔记 5、LSTM和其他网络组合 1、LSTM模型结构 BP网络和CNN网络没有时间维&#xff0c;和传统的机器学习算法理…

【RNN架构解析】LSTM 模型

LSTM 模型 前言1. LSTM 内部结构图2. Bi-LSTM 介绍3. LSTM 代码实现4. LSTM 优缺点前言 了解LSTM内部结构及计算公式.掌握Pytorch中LSTM工具的使用.了解LSTM的优势与缺点.LSTM(Long Short-Term Memory)也称长短时记忆结构, 它是传统RNN的变体, 与经典RNN相比能够有效捕捉长序…

Pytorch LSTM模型 参数详解

本文主要依据 Pytorch 中LSTM官方文档&#xff0c;对其中的模型参数、输入、输出进行详细解释。 目录 基本原理 模型参数 Parameters 输入Inputs: input, (h_0, c_0) 输出Outputs: output, (h_n, c_n) 变量Variables 备注 基本原理 首先我们看下面这个LSTM图&#xff0c; 对…