PaddleOCR学习(二)PaddleOCR检测模型训练

article/2025/10/29 18:18:34

这一部分主要介绍,如何使用自己的数据库去训练PaddleOCR的文本检测模型。

官方教程https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md

一、准备训练数据

首先你需要有自己的数据,如果没有自己的数据,推荐使用ICDAR2015的数据库,上网搜即可找到,内含1000个训练样本和500个测试样本,包括图片与标准数据(txt格式)。

如何标注自己的数据大家可以自行去网上搜索一下,PaddleOCR自带标注工具PPOCRLabel:https://github.com/PaddlePaddle/PaddleOCR/tree/develop/PPOCRLabel

不过因为我不是用PPOCRLabel进行的标注,而是采用了另一种更麻烦的方法进行标注,所以这里就不班门弄斧了,如果使用PPOCRLabel的过程中出了问题,也可以考虑采用我的方法:

(1)首先由于我的数据中涉及到了倾斜文本(弯曲文本我还没有了解过有没有什么特别好的检测模型,目前主流的检测模型可能也只到倾斜文本),所以我使用的是roLabelImg工具进行的标注;

(2)使用rolabelImg工具标注图片获得倾斜文本框,输出xml文件;

(3)将xml文件转换为txt文件,具体转换算法我放在本文最后xmltotxt.py:

需要注意的是,txt中的内容格式应该是: x 1 , y 1 , x 2 , y 2 , x 3 , y 3 , x 4 , y 4 , t e x t x_1,y_1,x_2,y_2,x_3,y_3,x_4,y_4,text x1,y1,x2,y2,x3,y3,x4,y4,text。对于roLablelImg标注的数据,角点坐标都保留两位小数,但是PaddleOCR中是按整数进行的计算,所以后面需要一点细微的修改。

此时获得,一张图片对应一个标注txt文件中的内容应该像以下内容:
在这里插入图片描述
(4)现在获得的应该是一个包含所有图片的文件夹与一个包含相同数量与图片同名txt文件的文件夹,接下来需要将该文件夹先分成训练用样本和测试用样本,为了后续方便,先新建以下结构的文件夹:

在这里插入图片描述
DatasetRes是我自己的数据集的名字,将标注好的数据按一定比例分别放进train_imgs和test_imgs中(具体的比例不好说,我也是新手,我觉得可以参考ICDAR的比例,训练:测试=2:1)。

然后,打开train_data/gen_label.py,修改其中的模式、图片路径、标注路径、输出结果路径:

gen_label的效果是,将所有标注txt,总合成一个总的txt文件,记得分别对测试数据和训练数据运行gen_label,获得两个label.txt文件。

切记,输出完之后,尽量不要修改文件夹或者txt文件的名称。

parser.add_argument('--mode',type=str,default="det",   # 模式help='Generate rec_label or det_label, can be set rec or det')parser.add_argument('--root_path',type=str,default="DatasetRes/test_imgs/",   # 图片help='The root directory of images.Only takes effect when mode=det ')parser.add_argument('--input_path',type=str,default="DatasetRes/test_txts/",   # 标注help='Input_label or input path to be converted')parser.add_argument('--output_label',type=str,default="DatasetRes/test_label.txt",  # 输出结果help='Output file name')

另外,gen_label.py中还有两个可能会坑人的地方,都在gen_det_label()函数中,一个是paddleocr对坐标的读取是int类型,如果使用roLabelImg标注,一般获得的是浮点类型的;另一点是gen_det_label()函数在读取文件名时,会自动把文件名的前三位忽视掉(不知道为什么,可能和不同方法获得的标注结果有关,总之会引起错误)。我把修改过的代码放在下面了。

def gen_det_label(root_path, input_dir, out_label):with open(out_label, 'w') as out_file:for label_file in os.listdir(input_dir):img_path = root_path + label_file[:-4] + ".jpg"      # 原先是label_file[3:-4]label = []with open(os.path.join(input_dir, label_file), 'r') as f:for line in f.readlines():tmp = line.strip("\n\r").replace("\xef\xbb\xbf","").split(',')points = tmp[:8]s = []for i in range(0, len(points), 2):b = points[i:i + 2]b = [int(float(t)) for t in b]     # 原来是b=[int(t) for t in b],无法读取小数s.append(b)result = {"transcription": tmp[8], "points": s}label.append(result)out_file.write(img_path + '\t' + json.dumps(label, ensure_ascii=False) + '\n')

如此,就把paddleocr检测模型训练需要的数据集准备好了。总的label.txt文件的内容大致像以下这样:
在这里插入图片描述

二、使用自己的数据集训练检测模型

终于把数据集准备好了,接下来就可以准备开始训练模型了,训练模型用到的是tools/train.py文件,不过没什么需要在这里面修改的。

首先,官方提供了三个backbone预训练模型,分别是MobileNetV3,ResNet8_vd,ResNet50_vd
https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar
https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar

非常好懂,就是ResNet50_vd非常非常大,没有四块以上GPU建议就不要尝试了。

新建pretrain_models/detect_pretrain_models文件夹,然后将下载的预训练模型解压到detect_pretrain_models下。
在这里插入图片描述
如果你去看教程,他会告诉你运行以下命令,然后你就会一脸懵逼发现什么都没有发生,所以我觉得还是需要再详细解释一下。

python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml \-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/ \2>&1 | tee train_det.log

实际最后运行的指令应该像这样即可,记得在cmd或者anaconda prompt中cd到paddleocr-develop目录下执行:

python tools/train.py -c configs/det/det_r18_vd_db_v1.1.yml 2>&1 | tee train_det.log

重点,在运行该指令前,打开configs/det/det_r18_vd_db_v1.1.yml进行修改。

# det_r18_vd_db_v1.1.ymlGlobal:algorithm: DB     # 使用的文本检测算法,这里用的是DB,我后来用的east,我将r18对应east的yml文件放在本文最后use_gpu: trueepoch_num: 1200log_smooth_window: 20print_batch_step: 2save_model_dir: ./output/det_r18_vd_db/     # 训练好的模型输出位置save_epoch_step: 200eval_batch_step: [3000, 2000]train_batch_size_per_card: 8test_batch_size_per_card: 1image_shape: [3, 640, 640]reader_yml: ./configs/det/det_db_icdar15_reader.yml       # 记住这个文件,接下来就要改它pretrain_weights: ./pretrain_models/detect_pretrain_models/ResNet18_vd_pretrained/  # 预训练模型的保存路径save_res_path: ./output/det_r18_vd_db/predicts_db.txt     # 预测结果文件的保存路径checkpoints:save_inference_dir:infer_img:
# det_db_icdar15_reader.ymlTrainReader:reader_function: ppocr.data.det.dataset_traversal,TrainReaderprocess_function: ppocr.data.det.east_process,EASTProcessTrainnum_workers: 4 # 量力而行,看自己电脑配置img_set_dir: ./train_data/ # 记得只要写这么长就行了,label.txt文件中,图片文件名包含了DatasetRes/train_imgs/xxx.jpglabel_file_path: ./train_data/DatasetReal/train_label.txt  # 刚才gen_label保存的文件路径background_ratio: 0.125min_crop_side_ratio: 0.1min_text_size: 10EvalReader:reader_function: ppocr.data.det.dataset_traversal,EvalTestReaderprocess_function: ppocr.data.det.east_process,EASTProcessTestimg_set_dir: ./train_data/ # 同理label_file_path: ./train_data/DatasetReal/test_label.txt  # 同理TestReader:reader_function: ppocr.data.det.dataset_traversal,EvalTestReaderprocess_function: ppocr.data.det.east_process,EASTProcessTestimg_set_dir: ./train_data/   # 同理label_file_path: ./train_data/DatasetReal/test_label.txt   # 同理do_eval: True

好了,都改好了,可以执行刚才的命令了:

python tools/train.py -c configs/det/det_r18_vd_db_v1.1.yml 2>&1 | tee train_det.log

训练时会将训练过程打印到train_det.log文件。
在这里插入图片描述

三、整理、评估训练结果

模型训练完之后,到det_r18_vd_db_v1.1.yml文件中的save_model_dir: ./output/det_r18_vd_db/位置去找训练结果,像这样:
在这里插入图片描述
具体每多少epoch输出一次可以在yml文件中设置,不多赘述。

接下来需要将模型转换为可部署文件,在paddleocr-develop目录下运行指令:

python tools/export_model.py -c configs/det/det_r18_vd_db_v1.1.yml -o Global.checkpoints="./output/det_r18_vd_db/best_accuracy" Global.save_inference_dir="./output/det_r18_vd_db/export_model"

记得根据自己的保存路径进行修改。./output/my_det_r18_vd_db/export_model中应该有两个文件:model和params。

如果训练程序中途断了,希望加载训练中断的模型继续训练,可以通过如下指令:

python tools/train.py -c configs/det/det_r18_vd_db_v1.1.yml -o Global.checkpoints="./output/det_r18_vd_db/best_accuracy"

好了,现在有了模型,如何评估模型的有效性可以自己去搜索学习一下,对于目标检测类算法,需要计算Precision、Recall、Hmean,运行以下代码即可:

python tools/eval.py -c configs/det/det_r18_vd_db_v1.1.yml -o Global.checkpoints="./output/det_r18_vd_db/best_accuracy"PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5

即可输出该模型的Precision、Recall、Hmean。

这里需要注意,上述指令是针对DB算法,如果你用的不是DB算法,而是EAST算法,指令需要有所不同,主要是在PostProcess中,EAST和DB的PostProcess的参数不同,所以进行评估时也需要输入不同的参数。如果是EAST算法,指令为:

python tools/eval.py -c configs/det/det_r18_east.yml -o Global.checkpoints="./output/det_east/best_accuracy"  # 自行注意文件夹的不同PostProcess.score_thresh=0.8 PostProcess.cover_thresh=0.1PostProcess.nms_thresh=0.2

最后是用训练好的模型去测试自己的图片看效果,在PaddleOCR学习(一)PaddleOCR安装与测试中我已经介绍过如何调用模型进行图片检测,只要将其中的det_model_dir的默认路径改到./output/det_r18_vd_db/export_model/即可。

不过其实,如果不输出成可部署文件,也可以直接进行图片测试,运行以下指令:

python tools/infer_det.py -c configs/det/det_r18_vd_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_east/best_accuracy"

或者一次性测试一整个文件夹:

python tools/infer_det.py -c configs/det/det_r18_vd_db_v1.1.yml  -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_east/best_accuracy"

还可以在测试过程中调整后处理阈值

python tools/infer_det.py -c configs/det/det_r18_vd_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_east/best_accuracy"PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5

OK,至此检测模型训练完毕,至于如何调参获取更好的训练结果,我也想知道(–_--)

附件

# xmltotxt.py# coding=utf-8import os
import xml.dom.minidom
import cv2 as cv
import mathdef xml_to_txt(indir, outdir):os.chdir(indir)xmls = os.listdir('.')for i, file in enumerate(xmls):file_save = file.split('.')[0] + '.txt'file_txt = os.path.join(outdir, file_save)f_w = open(file_txt, 'w')# actual parsingDOMTree = xml.dom.minidom.parse(file)annotation = DOMTree.documentElementfilename = annotation.getElementsByTagName("path")[0]imgname = filename.childNodes[0].dataimg_temp = imgname.split('\\')[-1]img_temp = os.path.join(image_dir, img_temp)image = cv.imread(imgname)
#        cv.imwrite(img_temp, image)objects = annotation.getElementsByTagName("object")print(file)for object in objects:bbox = object.getElementsByTagName("robndbox")[0]cx = bbox.getElementsByTagName("cx")[0]x = float(cx.childNodes[0].data)print(x)cy = bbox.getElementsByTagName("cy")[0]y = float(cy.childNodes[0].data)print(y)cw = bbox.getElementsByTagName("w")[0]w = float(cw.childNodes[0].data)print(w)ch = bbox.getElementsByTagName("h")[0]h = float(ch.childNodes[0].data)print(h)cangel = bbox.getElementsByTagName("angle")[0]angle = float(cangel.childNodes[0].data)print(angle)cname = object.getElementsByTagName("name")[0]name = cname.childNodes[0].dataprint(name)x1, y1 = rotatePoint(x, y, x - w / 2, y - h / 2, -angle)x2, y2 = rotatePoint(x, y, x + w / 2, y - h / 2, -angle)x3, y3 = rotatePoint(x, y, x + w / 2, y + h / 2, -angle)x4, y4 = rotatePoint(x, y, x - w / 2, y + h / 2, -angle)temp = str('%.2f' % x1) + ',' + str('%.2f' % y1) + ',' + str('%.2f' % x2) + ',' + str('%.2f' % y2) + ',' + \str('%.2f' % x3) + ',' + str('%.2f' % y3) + ',' + \str('%.2f' % x4) + ',' + str('%.2f' % y4) + ',' + name + '\n'f_w.write(temp)f_w.close()# 转换成四点坐标
def rotatePoint(xc, yc, xp, yp, theta):xoff = xp - xc;yoff = yp - yc;cosTheta = math.cos(theta)sinTheta = math.sin(theta)pResx = cosTheta * xoff + sinTheta * yoffpResy = - sinTheta * xoff + cosTheta * yoffreturn xc + pResx, yc + pResyif __name__ == '__main__':image_dir = "./origin_png"  # img目录indir = "./xml"  # xml目录outdir = "./txt"xml_to_txt(indir, outdir)
# det_r18_vd_east.ymlGlobal:algorithm: EAST   # EAST算法是目前比较优秀的文本检测算法use_gpu: trueepoch_num: 1000log_smooth_window: 20print_batch_step: 2save_model_dir: ./output/det_east_real/save_epoch_step: 200eval_batch_step: [3000, 2000]train_batch_size_per_card: 8test_batch_size_per_card: 1image_shape: [3, 512, 512]reader_yml: ./configs/det/det_east_icdar15_reader.ymlpretrain_weights: ./pretrain_models/detect_pretrain_models/ResNet18_vd_pretrained/save_res_path: ./output/det_east_real/predicts_east.txtcheckpoints:save_inference_dir:infer_img:Architecture:function: ppocr.modeling.architectures.det_model,DetModelBackbone:function: ppocr.modeling.backbones.det_resnet_vd,ResNetlayers: 18Head:function: ppocr.modeling.heads.det_east_head,EASTHeadmodel_name: largeLoss:function: ppocr.modeling.losses.det_east_loss,EASTLossOptimizer:function: ppocr.optimizer,AdamDecaybase_lr: 0.001beta1: 0.9beta2: 0.999PostProcess:function: ppocr.postprocess.east_postprocess,EASTPostPocessscore_thresh: 0.8       # 记住这几个参数,后面有用cover_thresh: 0.1nms_thresh: 0.2

http://chatgpt.dhexx.cn/article/gCRl9IOZ.shtml

相关文章

迁移学习的模型训练

用深度学习解决目标检测有两个重要工作: 1、设计、实现、训练和验证模型 模型如果设计模型如何编程实现如何收集足够的数据来训练并验证模型是否符合预期 从头开始设计、实现、训练和验证模型是需要有众多深度学习算法人才做支撑,并且极其耗时耗力 2、…

TF2.0模型训练

TF2.0模型训练 概述数据集介绍1、通过fit方法训练模型准备数据创建模型编译模型训练模型 2、通过fit_generator方法训练模型构建生成器创建模型编译模型训练模型 3、自定义训练准备数据创建模型定义损失函数及优化器训练模型 下一篇TF2.0模型保存 概述 这是TF2.0入门笔记【TF2…

TensorFlow 2.0 —— 模型训练

目录 1、Keras版本模型训练1.1 构造模型(顺序模型、函数式模型、子类模型)1.2 模型训练:model.fit()1.3 模型验证:model.evaluate()1.4 模型预测:model.predict()1.5 使用样本加权和类别加权1.6 回调函数1.6.1 EarlySt…

如何在jupyter上运行Java代码(适用LINUX)

如何在jupyter上运行Java代码 1.下载必须软件 下载JDK且JDK版本必须 ≥ 9 ≥9 ≥9从github上下载ijava 附 : ijava下载链接.装有jupyter,我在LINUX上是直接装的anaconda 安装过程 将下载的ijava压缩包解压出来,并在此路径用该命令 : sudo…

Java单元测试介绍

文章目录 单元测试单元测试基本介绍单元测试快速入门单元测试常用注解 单元测试 单元测试基本介绍 单元测试: 单元测试就是针对最小的功能单元编写测试代码,Java程序最小的功能单元是方法,因此,单元测试就是针对Java方法的测试,…

Jupyter 配置 Java环境,写Java代码,测试成功

本次简单诉说下怎么通过jupyter安装iJava,写Java代码。 安装Java的不说了 我使用的是Java15 然后去:https://github.com/SpencerPark/IJava/releases 下载zip,不要下载其他的 得到就是一个py文件 下面就是一个 python install.py 我这里就…

java调用python执行脚本,附代码

最近有个功能需要java调用python脚本实现一些功能,前期需要做好的准备:配置好python环境,如下: 以下展示的为两种,一种为生成图片,另一种为生成字符串。 package com.msdw.tms.common.utils.py;import ja…

Selenium Java自动化测试环境搭建

IDE用的是Eclipse。 步骤1:因为是基于Java,所以首先要下载与安装JDK(Java Development Kit) 下载: 点击这里下载JDK 安装:按照默认安装一路点next就可以了。 验证:安装完成后,在命…

java单元测试(Junit)

相关代码下载链接: http://download.csdn.net/detail/stevenhu_223/4884357 在有些时候,我们需要对我们自己编写的代码进行单元测试(好处是,减少后期维护的精力和费用),这是一些最基本的模块测试。当然&…

Java单元测试工具:JUnit4(一)——概述及简单例子

(一)JUnit概述及一个简单例子 看了慕课网的JUnit视频教程: http://www.imooc.com/learn/356,总结笔记。 这篇笔记记录JUnit的概述,以及一个快速入门的例子。 1.概述 1.1 什么是JUnit ①JUnit是用于编写可复用测试集的…

Linux下执行Python脚本

1.Linux Python环境 Linux系统一般集成Python,如果没有安装,可以手动安装,联网状态下可直接安装。Fedora下使用yum install,Ubuntu下使用apt-get install,前提都是root权限。安装完毕,可将Python加入环境变…

python pytest脚本执行工具

pytest脚本执行工具 支持获取当前路径下所有.py脚本 添加多个脚本,一起执行 import tkinter as tk from tkinter import filedialog import subprocess import os from datetime import datetimedef select_script():script_path filedialog.askopenfilename(fil…

linux上运行python(简单版)

linux上运行python(简单版) 一、前提准备1.centOS72.挂载yum源[http://t.csdn.cn/Isf0i](http://t.csdn.cn/Isf0i) 二、安装python3三、运行程序 一、前提准备 1.centOS7 2.挂载yum源http://t.csdn.cn/Isf0i 在终端进行安装python3 二、安装python3 …

linux怎么运行python脚本?

linux运行python脚本的方法: 1、命令行执行: 建立一个test.py文档,在其中书写python代码。之后,在命令行执行:python test.py 说明:其中python可以写成python的绝对路径。使用which python进行查询。 注…

java实现远程执行Linux下的shell脚本

java实现远程执行Linux下的shell脚本 背景导入Jar包第一步:远程连接第二步:开启Session第三步:新建测试脚本文件结果报错 背景 最近有个项目,需要在Linux下的服务器内写了一部分Python脚本,业务处理却是在Java内&…

Java运行Python脚本

前段时间遇到了在JavaWeb项目中嵌入运行Python脚本的功能的需求。想到的方案有两种,一种是使用Java技术(Jython或Runtime.exec)运行Python脚本,另一种是搭建一个Python工程对外提供相应http或webservice接口。两种方案我都有实现&…

Java项目分层

MVC模式 在实际的开发中有一种项目的程序组织架构方案叫做MVC模式,按照程序 的功能将他们分成三个层,如下图:Modle层(模型层)、View层(显示层)、Controller层(控制层)。…

java项目收获总结_java开发项目收获心得

1 java开发项目收获心得 it行业现在的发展如日中天,很多人都纷纷走进这个行业,而java作为跨平台的编程语言更是受欢迎。java其实相对其他语言来说的确很有优势,但是也有点缺陷,但是以后发展到什么程度,谁都不知道。那么下面小编给大家说说java开发项目收获心得,希望能对你…

java查看jar包依赖_java项目开发中如何查找到项目依赖的jar包?

不管是java普通工程,还是java web项目,甚至是android项目,依赖包的管理有2种: 1.直接依赖jar包 这种方式简单直白,项目下载后在正确的ide或者稍微做转换就可以运行起来。比如java web工程的WEB-INF/lib下 只要按这个步骤Java Build Path=>Add Libraty=>Web App Libr…

Java小白必看:开发一个编程项目的完整流程(附100套Java编程项目源码+视频)

我相信很多Java新手都会遇到这样一个问题:跟着教材敲代码,很容易;但是让他完整的实现一个应用项目,却不会;不知道从哪里开始,不知道实现一个项目的完整流程是怎样的,看似很简单的一个问题&#…