【PaddleOCR-det-finetune】一:基于PPOCRv3的det检测模型finetune训练

article/2025/9/24 19:22:20

文章目录

  • 基本流程
  • 详细步骤
    • 打标签,构建自己的数据集
    • 下载PPOCRv3训练模型
    • 修改超参数,训练自己数据集
    • 启动训练
    • 导出模型
  • 测试

相关参考手册在PaddleOCR项目工程中的位置:
det模型训练和微调:PaddleOCR\doc\doc_ch\PPOCRv3_det_train.md
模型微调PaddleOCR\doc\doc_ch\finetune.md

在手册PPOCRv3_det_train.md中,提到

finetune训练适用于三种场景

  • 基于CML蒸馏方法的finetune训练,适用于教师模型在使用场景上精度高于PPOCRv3检测模型,且希望得到一个轻量检测模型。
  • 基于PPOCRv3轻量检测模型的finetune训练,无需训练教师模型,希望在PPOCRv3检测模型基础上提升使用场景上的精度。
  • 基于DML蒸馏方法的finetune训练,适用于采用DML方法进一步提升精度的场景。

由于第二种工程量最小,本篇中博客中,我记录的是第二种:
基于PPOCRv3轻量检测模型的finetune训练,无需训练教师模型,希望在PPOCRv3检测模型基础上提升使用场景上的精度。
的det模型finetune过程

也就是使用自己的数据集,在PPOCRv3预训练模型上做微调,提升垂类场景效果

基本流程

  • 首先使用PPOCRLabel工具,打标签,构造基于自己垂类场景的数据集
  • 根据自己数据集的性质和场景需求,修改训练的配置文件configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml参数
  • 然后基于下载下来的学生模型student.pdparams进行训练

详细步骤

打标签,构建自己的数据集

使用PPOCRLabel,指路: 【PaddleOCR-PPOCRLabel】标注工具使用,这篇博客详细说过了

下载PPOCRv3训练模型

在PaddleOCR\doc\doc_ch\finetune.md中的教学:
提取Student参数的方法如下……

但其实下载下来模型已经有提取好了的,所以就不用自己提取了

这里提取学生模型参数,在我看来就是获取准备拿来微调的det模型
参数模型就是student.pdparams这个文件,下载下来就有

#在项目根目录
mkdir student
cd student
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf ch_PP-OCRv3_det_distill_train.tar

在这里插入图片描述

修改超参数,训练自己数据集

对于其中configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml的参数,需要安装训练的实际数据集中训练集和验证集的位置,在yml文件中修改对应txt路径,具体参数说明,见表和下方注释

参数名称类型默认值含义
det_db_threshfloat0.3DB输出的概率图中,得分大于该阈值的像素点才会被认为是文字像素点
det_db_box_threshfloat0.6检测结果边框内,所有像素点的平均得分大于该阈值时,该结果会被认为是文字区域
det_db_unclip_ratiofloat1.5Vatti clipping算法的扩张系数,使用该方法对文字区域进行扩张
max_batch_sizeint10预测的batch size
use_dilationboolFalse是否对分割结果进行膨胀以获取更优检测效果
det_db_score_modestr“fast”DB的检测结果得分计算方法,支持fastslowfast是根据polygon的外接矩形边框内的所有像素计算平均得分,slow是根据原始polygon内的所有像素计算平均得分,计算速度相对较慢一些,但是更加准确一些。

更多参数可以在PaddleOCR\doc\doc_ch\inference_args.md《PaddleOCR模型推理参数解释》里面找到
我修改后文件命名为ch_PP-OCRv3_det_student_3.7.yml

Global:debug: falseuse_gpu: trueepoch_num: 135  # 总的epoch数目log_smooth_window: 20print_batch_step: 10save_model_dir: ./output/ch_PP-OCR_V3_det_11_9/save_epoch_step: 100  # 每100个global_step 保存一次模型eval_batch_step:   # 每200个global_step 验证一次模型- 0- 200 # 400cal_metric_during_train: false     # 设置是否在训练过程中评估指标,此时评估的是模型在当前batch下的指标pretrained_model: nullcheckpoints: nullsave_inference_dir: ./output/det_db_inference/ # nulluse_visualdl: True  # falseinfer_img:  DATA2/predict01.jpg    # doc/imgs_en/img_10.jpgsave_res_path: ./output/det19/predicts_ppocrv3_distillation.txt   # ./checkpoints/det_db/predicts_db.txt
#  save_res_path: ./output/det2/predicts_ppocrv3_distillation.txt   # ./checkpoints/det_db/predicts_db.txtdistributed: trueArchitecture:model_type: det # 网络类型algorithm: DB # 模型名称Transform: # 设置变换方式Backbone:name: MobileNetV3scale: 0.5model_name: large # 网络大小disable_se: TrueNeck:name: RSEFPNout_channels: 96shortcut: TrueHead:name: DBHeadk: 50 # DBHead二值化系数Loss:name: DBLossbalance_loss: true # DBLossloss中是否对正负样本数量进行均衡(使用OHEM)main_loss_type: DiceLoss # DBLossloss中shrink_map所采用的的lossalpha: 5 # DBLossloss中shrink_map_loss的系数beta: 10 # DBLossloss中threshold_map_loss的系数ohem_ratio: 3
Optimizer: # 主要修改部分name: Adambeta1: 0.9beta2: 0.999lr: # 设置学习率下降方式name: Cosine # 使用cosine下降策略learning_rate: 0.00005  # 0.001warmup_epoch: 2regularizer: # 正则化name: L2factor: 5.0e-05 # 正则化系数
PostProcess:name: DBPostProcessthresh: 0.42   # 输出的概率图中,得分大于该阈值的像素点才会被认为是文字像素点box_thresh: 0.52  # 检测结果边框内,所有像素点的平均得分大于该阈值时,该结果会被认为是文字区域max_candidates: 1000unclip_ratio: 2.6  # 算法的扩张系数,使用该方法对文字区域进行扩张
Metric:name: DetMetricmain_indicator: hmean
Train:dataset:name: SimpleDataSetdata_dir: ./train_data/det/train/ # ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/det/train0.txt # ./train_data/icdar2015/text_localization/train_icdar2015_label.txt- ./train_data/det/train1.txt- ./train_data/det/train2.txt- ./train_data/det/train3.txtratio_list: [1.0, 1.0, 1.0, 1.0]
#    ratio_list: [1.0]transforms:- DecodeImage:img_mode: BGRchannel_first: false- DetLabelEncode: null- IaaAugment:augmenter_args:- type: Fliplr # 翻转args:p: 0.5- type: Affine # 仿射args:rotate:- -10- 10- type: Resize # 调整大小args:size:- 0.5- 3- EastRandomCropData:size:- 960- 960max_tries: 50keep_ratio: true- MakeBorderMap:shrink_ratio: 0.4thresh_min: 0.3thresh_max: 0.7- MakeShrinkMap:shrink_ratio: 0.4min_text_size: 8- NormalizeImage:   # 图像归一化scale: 1./255.  # 线性变换参数mean:- 0.485- 0.456- 0.406std:- 0.229- 0.224- 0.225order: hwc- ToCHWImage: null- KeepKeys:keep_keys:- image- threshold_map- threshold_mask- shrink_map- shrink_maskloader:shuffle: truedrop_last: falsebatch_size_per_card: 2num_workers: 0 # 4
Eval:dataset:name: SimpleDataSetdata_dir: ./train_data/det/val/ # ./train_data/icdar2015/text_localization/label_file_list:- ./train_data/det/val0.txt # ./train_data/icdar2015/text_localization/test_icdar2015_label.txt\- ./train_data/det/val1.txt- ./train_data/det/val2.txt- ./train_data/det/val3.txt
#    ratio_list: [1.0, 1.0, 1.0, 1.0]   #
#    ratio_list: [1.0]transforms:- DecodeImage:img_mode: BGRchannel_first: false- DetLabelEncode: null- DetResizeForTest: null
#        image_shape:
#        - 736
#        - 736
#        resize_long: 960
#        limit_side_len: 736
#        limit_type: min
#        keep_ratio: true- NormalizeImage:scale: 1./255.mean:- 0.485- 0.456- 0.406std:- 0.229- 0.224- 0.225order: hwc- ToCHWImage: null- KeepKeys:keep_keys:- image- shape- polys- ignore_tagsloader:shuffle: falsedrop_last: falsebatch_size_per_card: 1num_workers: 0 # 2

其中的label_file_list参数对应的txt,记得修改成服务器保存数据的实际路径
如果有多个txt,可以用逗号并列
在这里插入图片描述

启动训练

# 单卡训练
python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student_3.7.yml -o Global.pretrained_model=student/ch_PP-OCRv3_det_distill_train/student.pdparams # 如果要使用多GPU分布式训练,请使用如下命令:
python3  -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml \-o Global.pretrained_model=./student \Global.save_model_dir=./output/

注意写对yml文件里面的数据集和label文件路径,以及ratio_list: [1.0]不然可能会报错:


AssertionError: The length of ratio_list should be the same as the file_list.

导出模型

我训练了3h,训练模型格式还要进行export为推理模型格式,才可用例程代码推理

python tools/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student_3.7.yml -o Global.pretrained_model=output/ch_PP-OCR_V3_det_3.7/best_accuracy.pdparams

在这里插入图片描述

测试

与微调前的v3模型相比,进行det推理测试,看看自训练模型效果是否有改善


http://chatgpt.dhexx.cn/article/inEkEDRD.shtml

相关文章

模型微调(Finetune)

参考:https://zhuanlan.zhihu.com/p/35890660 ppt下载地址:https://github.com/jiangzhubo/What-is-Fine-tuning 一.什么是模型微调 给定预训练模型(Pre_trained model),基于模型进行微调(Fine Tune)。相…

fine-tuning

微调(fine-tuning) 在平时的训练中,我们通常很难拿到大量的数据,并且由于大量的数据,如果一旦有调整,重新训练网络是十分复杂的,而且参数不好调整,数量也不够,所以我们可…

大模型的三大法宝:Finetune, Prompt Engineering, Reward

编者按:基于基础通用模型构建领域或企业特有模型是目前趋势。本文简明介绍了最大化挖掘语言模型潜力的三大法宝——Finetune, Prompt Engineering和RLHF——的基本概念,并指出了大模型微调面临的工具层面的挑战。 以下是译文,Enjoy! 作者 | B…

RCNN网络源码解读(Ⅲ) --- finetune训练过程

目录 0.回顾 1.finetune二分类代码解释(finetune.py) 1.1 load_data(定义获取数据的方法) 1.2 CustomFineTuneDataset类 1.3 custom_batch_sampler类( custom_batch_sampler.py) 1.4 训练train_mod…

FinSH

finSH介绍 FinSH 是 RT-Thread 的命令行组件,提供一套供用户在命令行调用的操作接口,主要用于调试或查看系统信息。它可以使用串口 / 以太网 / USB 等与 PC 机进行通信。 命令执行过程 功能: 支持鉴权,可在系统配置中选择打开/关闭。(TODO…

从统一视角看各类高效finetune方法

每天给你送来NLP技术干货! 来自:圆圆的算法笔记 随着预训练模型参数量越来越大,迁移学习的成本越来越高,parameter-efficient tuning成为一个热点研究方向。在以前我们在下游任务使用预训练大模型,一般需要finetune模型…

finetune

finetune的含义是获取预训练好的网络的部分结构和权重,与自己新增的网络部分一起训练。下面介绍几种finetune的方法。 完整代码:https://github.com/toyow/learn_tensorflow/tree/master/finetune 一,如何恢复预训练的网络 方法一&#xf…

11.2 模型finetune

一、Transform Learning 与 Model Finetune 二、pytorch中的Finetune 一、Transfer Learning 与 Model Finetune 1. 什么是Transfer Learning? 迁移学习是机器学习的一个分支,主要研究源域的知识如何应用到目标域当中。迁移学习是一个很大的概念。 怎么理解源域…

飞桨深度学习学院零基础深度学习7日入门-CV疫情特辑学习笔记(四)DAY03 车牌识别

本课分为理论和实战两个部分 理论:卷积神经网络 1.思考全连接神经网络的问题 一般来收机器学习模型实践分为三个步骤,(1)建立模型 (2)选择损失函数 (3)参数调整学习 1.1 模型结构不…

unity sdk(android)-友盟推送SDK接入

注意:一开始想接友盟Unity的SDk,但是导入后缺少各种jar,所以最后还是接了android的,demo文档齐全 官方文档:开发者中心 按照官方文档对接即可, 接入流程 1、项目中com.android.tools.build:gradle配置&…

友盟推送学习

一、首次使用U_Push 1、首先注册友盟账号,进入工作台,选择产品U_Push。 2、创建应用 3、在自己的项目中自动集成SDK 开发环境要求: Android Studio 3.0以上 Android minSdkVersion: 14 Cradle: 4.4以上 在根目录build.gradle中添加mav…

Android 学习之如何集成友盟推送

我是利用Android studio 新建一个空的Android项目。 步骤一 导入第三方库 1.切换Android项目状态为Project状态 2.在main文件下新建 jniLibs文件夹(用来导入PushSDK项目下lib文件中的so文件) 3.在libs文件夹下添加友盟PuskSDK中的 jar 文件&#xff…

用PaddlePaddle(飞浆)实现车牌识别

项目描述:本次实践是一个多分类任务,需要将照片中的每个字符分别进行识别,完成车牌的识别 实践平台:百度AI实训平台-AI Studio、PaddlePaddle1.8.0 动态图 数据集介绍(自己去网上下载车牌识别数据集) 数据…

深度学习(五) CNN卷积神经网络

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 CNN卷积神经网络 前言一、CNN是什么?二、为什么要使用CNN?三、CNN的结构1.图片的结构2.卷积层1.感受野(Receptive Field)2.卷积…

CNN网络实现手写数字(MNIST)识别 代码分析

CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import numpy as np //第三方库,用于进行科学计算 import torc…

Android删除chartty证书,C/C++知识点之android应用安全分析

本文主要向大家介绍了C/C知识点之android应用安全分析,通过具体的内容向大家展示,希望对大家学习C/C知识点有所帮助。 应用名 :OKEx(OKEx-android.apk) 包名 :com.okinc.okex MD5 :1ffbd328d13e91b661592cdf58516bd2 版…

代码编写过程 - 正确率折线图

获取绘图函数 首先,看到需要画acc和loss图。先去参考现成的,于是打开猫12分类,找到生成折线图的地方。 发现框内的两个函数绘制了折线图。既然是作为函数出现,说明已经有一定的封装,考虑能不能把整个函数搬走用。 由…

李宏毅机器学习课程HW03代码解释

作业3任务是将图片进行分类 从官网上下载数据到data文件里面。此外,将代码分为三个模块,分别是dataset,model以及main。 一、dataset模块 此模块作用是读取图片数据。 重要函数:os.path.join(path,x) 将path和x路径组合在一起 #导入库…

接入友盟厂商push通道遇到的坑

目录 调试友盟Push问题的检查清单 客户端、服务端数据协议 客户端接入方式 小米厂商通道 华为厂商通道 魅族厂商通道 VIVO厂商通道 OPPO厂商通道 支持桌面角标的厂商 吐槽一下集成友盟厂商通道时发现的问题 调试友盟Push问题的检查清单 过滤UmengPushAgent开头的日志…

Android集成友盟消息推送SDK

消息推送SDK快速集成: 申请AppKey ——> 接入Push SDK ——> 基础接口引入 ——> 消息推送测试 ——> 接入完成 1.申请AppKey 2.接入Push SDK 1)、加入依赖 //友盟push相关依赖(必须)implementationcom.umeng.umsdk:push:6.1.0impleme…