基于CRNN的文本识别

article/2025/8/23 23:49:35

文章目录

    • 0. 前言
    • 1. 数据集准备
    • 2.构建网络
    • 3.数据读取
    • 4.训练模型

0. 前言

至于CRNN网络的细节这里就不再多言了,网上有很多关于crnn的介绍,这里直接讲一下代码的实现流程

1. 数据集准备

CRNN是识别文本的网络,所以我们首先需要构建数据集,使用26个小写字母以及0到9十个数字,一共有36个字符,从这36个字符中随机选择4到9个字符(这里要说明一下,网上很多关于crnn的训练集中每张图片中的字符个数是一样的,这就具有很大的局限性。所以本文使用4到9随机选择字符个数构建图片。)

生成数据集代码如下:

import cv2
import numpy as np
import random
import imgaug.augmenters as iaadef get_img():zfu=['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','0','1','2','3','4','5','6','7','8','9']# zfu=[str(i) for i in range(10)]# zfu=[str(i) for i in range(10)]k=random.randint(4,9)select=random.choices(zfu,k=k)lab=[zfu.index(i) for i in select]select="".join(select)font=cv2.FONT_HERSHEY_COMPLEXsrc=np.ones(shape=(50,250,3)).astype('uint8')*255src=cv2.putText(src,select,(20,27),font,1,(0,0,0),2)seq = iaa.Sequential([# iaa.Flipud(0.5),  # flip up and down (vertical)# iaa.Fliplr(0.5),  # flip left and right (horizontal)iaa.Multiply((0.5, 1.5)),  # change brightness, doesn't affect BBs(bounding boxes)iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值iaa.Crop(percent=(0, 0.06)),iaa.Grayscale(alpha=(0, 1)),iaa.Affine(#translate_px={"x": (0, 15), "y": (0, 15)},  # 平移scale=(0.95, 1.05),  # 尺度变换mode=iaa.ia.ALL,cval=(100, 255)),iaa.Resize({"height": 32, "width": 200})])# img是numpy格式,无归一化src=np.expand_dims(src,axis=0)src = seq(images=src)[0]# cv2.imshow('a21',src)# cv2.waitKey(0)return src,labf_train=open('train.txt','w')
f_val=open('val.txt','w')for i in range(10000):img,lab=get_img()lab=[str(i) for i in lab]lab=" ".join(lab)path='train_data/'+str(i)+'.jpg'cv2.imwrite(path,img)f_train.write(path+' '+lab+'\n')print(i)
for i in range(1000):img,lab=get_img()lab=[str(i) for i in lab]lab=" ".join(lab)path='val_data/'+str(i)+'.jpg'cv2.imwrite(path,img)f_val.write(path+' '+lab+'\n')print(i)

运行上述代码之前首先需要手动新建两个空文件夹用于存放训练图像和验证图像,文件夹名字分别是:train_data和val_data。运行完上述代码以后会在train_data文件夹中保存10000张训练图像,在val_data文件夹中保存1000张验证图像。此外还会生成两个txt文件,分别为train.txt和val.txt。
txt文本中存放的是图片的路径及包含字符的类别,如下所示:

在这里插入图片描述
部分训练图像如下所示:

在这里插入图片描述

2.构建网络

构建crnn网络的代码如下所示:

# crnn.py
import argparse, os
import torch
import torch.nn as nnclass BidirectionalLSTM(nn.Module):def __init__(self, nInput_size, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.lstm = nn.LSTM(nInput_size, nHidden, bidirectional=True)self.linear = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, (hidden, cell) = self.lstm(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.linear(t_rec)  # [T * b, nOut]output = output.view(T, b, -1)  # 输出变换为[seq,batch,类别总数]return outputclass CNN(nn.Module):def __init__(self, imageHeight, nChannel):super(CNN, self).__init__()assert imageHeight % 32 == 0, 'image Height has to be a multiple of 32'self.depth_conv0 = nn.Conv2d(in_channels=nChannel, out_channels=nChannel, kernel_size=3, stride=1, padding=1,groups=nChannel)self.point_conv0 = nn.Conv2d(in_channels=nChannel, out_channels=64, kernel_size=1, stride=1, padding=0,groups=1)self.relu0 = nn.ReLU(inplace=True)self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)self.depth_conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, groups=64)self.point_conv1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, groups=1)self.relu1 = nn.ReLU(inplace=True)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.depth_conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, groups=128)self.point_conv2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1)self.batchNorm2 = nn.BatchNorm2d(256)self.relu2 = nn.ReLU(inplace=True)self.depth_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)self.point_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1)self.relu3 = nn.ReLU(inplace=True)self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))self.depth_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)self.point_conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)self.batchNorm4 = nn.BatchNorm2d(512)self.relu4 = nn.ReLU(inplace=True)self.depth_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512)self.point_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)self.relu5 = nn.ReLU(inplace=True)self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))# self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0)self.depth_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0, groups=512)self.point_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)self.batchNorm6 = nn.BatchNorm2d(512)self.relu6 = nn.ReLU(inplace=True)def forward(self, input):depth0 = self.depth_conv0(input)point0 = self.point_conv0(depth0)relu0 = self.relu0(point0)pool0 = self.pool0(relu0)# print(pool0.size())depth1 = self.depth_conv1(pool0)point1 = self.point_conv1(depth1)relu1 = self.relu1(point1)pool1 = self.pool1(relu1)# print(pool1.size())depth2 = self.depth_conv2(pool1)point2 = self.point_conv2(depth2)batchNormal2 = self.batchNorm2(point2)relu2 = self.relu2(batchNormal2)# print(relu2.size())depth3 = self.depth_conv3(relu2)point3 = self.point_conv3(depth3)relu3 = self.relu3(point3)pool3 = self.pool3(relu3)# print(pool3.size())depth4 = self.depth_conv4(pool3)point4 = self.point_conv4(depth4)batchNormal4 = self.batchNorm4(point4)relu4 = self.relu4(batchNormal4)# print(relu4.size())depth5 = self.depth_conv5(relu4)point5 = self.point_conv5(depth5)relu5 = self.relu5(point5)pool5 = self.pool5(relu5)# print(pool5.size())depth6 = self.depth_conv6(pool5)point6 = self.point_conv6(depth6)batchNormal6 = self.batchNorm6(point6)relu6 = self.relu6(batchNormal6)# print(relu6.size())return relu6class CRNN(nn.Module):def __init__(self, imgHeight, nChannel, nClass, nHidden):super(CRNN, self).__init__()self.cnn = nn.Sequential(CNN(imgHeight, nChannel))self.lstm = nn.Sequential(BidirectionalLSTM(512, nHidden, nHidden),BidirectionalLSTM(nHidden, nHidden, nClass),)def forward(self, input):conv = self.cnn(input)# pytorch框架输出结构为BCHWbatch, channel, height, width = conv.size()assert height == 1, "the output height must be 1."# 将height==1的维度去掉-->BCWconv = conv.squeeze(dim=2)# 调整各个维度的位置(B,C,W)->(W,B,C),对应lstm的输入(seq,batch,input_size)conv = conv.permute(2, 0, 1)output = self.lstm(conv)return outputif __name__ == "__main__":x = torch.rand(1, 1, 32, 100)model = CRNN(imgHeight=32, nChannel=1, nClass=11, nHidden=256)y = model(x)print(y.shape)

3.数据读取

读取训练数据的代码如下所示:

import os
import torch
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import imgaug.augmenters as iaaclass CRNNDataSet(Dataset):def __init__(self, lines,train=True,img_width=100):super(CRNNDataSet, self).__init__()self.lines=linesself.train=trainself.img_width=img_widthself.T=img_width//4+1def __getitem__(self, index):image_path = self.lines[index].strip().split()[0]label = self.lines[index].strip().split()[1:]image = cv2.imread(image_path,0)# 图像预处理if self.train:image=self.get_random_data(image)else:image = cv2.resize(image,(self.img_width,32))# cv2.imshow('a21',image)# cv2.waitKey(0)# 标签格式转换为IntTensorlabel_max=np.ones(shape=(self.T),dtype=np.int32)*-1label = np.array([int(i) for i in label])label_max[0:len(label)]=label#归一化image=(image/255.).astype('float32')image=np.expand_dims(image,axis=0)image=torch.from_numpy(image)label_max=torch.from_numpy(label_max)return image, label_maxdef __len__(self):return len(self.lines)def get_random_data(self,img):"""随机增强图像"""seq = iaa.Sequential([iaa.Multiply((0.8, 1.3)),  # change brightness, doesn't affect BBs(bounding boxes)iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值iaa.Crop(percent=(0, 0.05)),iaa.Affine(scale=(0.95, 1.05),  # 尺度变换rotate=(-4, 4),cval=(100,250),mode=iaa.ia.ALL),iaa.Resize({"height": 32, "width": self.img_width})])img=seq.augment(image=img)return imgif __name__ == '__main__':batch_size = 16lines=open('train.txt','r').readlines()trainData = CRNNDataSet(lines=lines)trainLoader=DataLoader(dataset=trainData,batch_size=batch_size)for data, label in trainLoader:print(data.shape,label)

4.训练模型

训练代码如下所示:

from model import CRNN
from mydataset import CRNNDataSet
from torch.utils.data import DataLoader
import torch
from torch import optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as pltdef decode(preds):char_set = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n','o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z','0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + [" "]preds=list(preds)pred_text = ''for i,j in enumerate(preds):if j==n_class-1:continueif i==0:pred_text+=char_set[j]continueif preds[i-1]!=j:pred_text += char_set[j]return pred_text
def getAcc(preds,labs):acc=0char_set = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n','o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z','0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + [" "]labs=labs.cpu().detach().numpy()preds = preds.cpu().detach().numpy()preds=np.argmax(preds,axis=-1)preds=np.transpose(preds,(1,0))out=[]for pred in preds:out_txt=decode(pred)out.append(out_txt)ll=[]for lab in labs:a=lab[lab!=-1]b=[char_set[i] for i in a]b="".join(b)ll.append(b)for a1,a2 in zip(out,ll):if a1==a2:acc+=1return acc/batch_sizebatch_size=128
n_class = 37train_lines=open('train.txt','r').readlines()
val_lines=open('val.txt','r').readlines()
trainData = CRNNDataSet(lines=train_lines,train=True,img_width=200)
trainLoader = DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True, num_workers=1)
valData = CRNNDataSet(lines=val_lines,train=False,img_width=200)
valLoader = DataLoader(dataset=valData, batch_size=batch_size, shuffle=False, num_workers=1)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = CRNN(imgHeight=32, nChannel=1, nClass=n_class, nHidden=256)
net=net.to(device)loss_func = torch.nn.CTCLoss(blank=n_class - 1)  # 注意,这里的CTCLoss中的 blank是指空白字符的位置,在这里是第65个,也即最后一个
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))
#学习率衰减
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)#画图列表
trainLoss=[]
valLoss=[]
trainAcc=[]
valAcc=[]
if __name__ == '__main__':#设置迭代次数200次Epoch=100epoch_step = len(train_lines) / batch_sizefor epoch in range(1, Epoch + 1):net.train()train_total_loss = 0val_total_loss=0train_total_acc = 0val_total_acc = 0with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:for step, (features, label) in enumerate(trainLoader, 1):labels = torch.IntTensor([])for j in range(label.size(0)):labels = torch.cat((labels, label[j]), 0)labels=labels[labels!=-1]features = features.to(device)labels = labels.to(device)loss_func=loss_func.to(device)batch_size = features.size()[0]out = net(features)log_probs = out.log_softmax(2).requires_grad_()targets = labelsinput_lengths = torch.IntTensor([out.size(0)] * int(out.size(1)))target_lengths = torch.where(label!=-1,1,0).sum(dim=-1)loss = loss_func(log_probs, targets, input_lengths, target_lengths)acc=getAcc(out,label)optimizer.zero_grad()loss.backward()optimizer.step()train_total_loss += losstrain_total_acc += accpbar.set_postfix(**{'loss': train_total_loss.item() / (step),'acc': train_total_acc/ (step), })pbar.update(1)trainLoss.append(train_total_loss.item()/step)trainAcc.append(train_total_acc/step)#保存模型torch.save(net.state_dict(), 'model.pth')#验证net.eval()for step, (features, label) in enumerate(valLoader, 1):with torch.no_grad():labels = torch.IntTensor([])for j in range(label.size(0)):labels = torch.cat((labels, label[j]), 0)labels = labels[labels != -1]features = features.to(device)labels = labels.to(device)loss_func = loss_func.to(device)batch_size = features.size()[0]out = net(features)log_probs = out.log_softmax(2).requires_grad_()targets = labelsinput_lengths = torch.IntTensor([out.size(0)] * int(out.size(1)))target_lengths = torch.where(label != -1, 1, 0).sum(dim=-1)loss = loss_func(log_probs, targets, input_lengths, target_lengths)acc = getAcc(out, label)val_total_loss+=lossval_total_acc+=accvalLoss.append(val_total_loss.item()/step)valAcc.append(val_total_acc/step)lr_scheduler.step()# print(trainLoss)# print(valLoss)"""绘制loss acc曲线图"""plt.figure()plt.plot(trainLoss, 'r')plt.plot(valLoss, 'b')plt.title('Training and validation loss')plt.xlabel("Epochs")plt.ylabel("Loss")plt.legend(["Loss", "Validation Loss"])plt.savefig('loss.png')plt.figure()plt.plot(trainAcc, 'r')plt.plot(valAcc, 'b')plt.title('Training and validation acc')plt.xlabel("Epochs")plt.ylabel("Acc")plt.legend(["Acc", "Validation Acc"])plt.savefig('acc.png')# plt.show()

acc和loss图如下所示:
在这里插入图片描述
在这里插入图片描述
经过验证发现准确率可达95%以上,效果不错。
整体项目下载地址:项目下载


http://chatgpt.dhexx.cn/article/nudLNScN.shtml

相关文章

CRNN论文翻译——中文版

文章作者:Tyan 博客:noahsnail.com | CSDN | 简书 翻译论文汇总:https://github.com/SnailTyan/deep-learning-papers-translation An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Applicatio…

CRNN论文笔记

0. 前言 在这篇论文《An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition》所讲述的内容便是大名鼎鼎的CRNN网络,中实现了端到端的文本识别。 论文地址 Github地址 该网络具有如下的特点: 1)该模…

CRNN详解

一.概述 常用文字识别算法主要有两个框架: CNNRNNCTC(CRNNCTC)CNNSeq2SeqAttention 本文介绍第一种方法。 CRNN是一种卷积循环神经网络结构,用于解决基于图像的序列识别问题,特别是场景文字识别问题。 文章认为文字识别是对序列的预测方法…

CRNN模型

介绍: 是目前较为流行的图文识别模型,可识别较长的文本序列, 它利用BLSTM和CTC部件学习字符图像中的上下文关系, 从而有效提升文本识别准确率,使得模型更加鲁棒。 CRNN是一种卷积循环神经网络结构,用于解决…

CRNN

CRNN详解:https://blog.csdn.net/bestrivern/article/details/91050960 https://www.cnblogs.com/skyfsm/p/10335717.html 1 概述 传统的OCR识别过程分为两步:单字切割和分类任务。现在更流行的是基于深度学习的端到端的文字识别,即我们不需…

论文阅读 - CRNN

文章目录 1 概述2 模型介绍2.1 输入2.2 Feature extraction2.3 Sequence modeling2.4 Transcription2.4.1 训练部分2.4.2 预测部分 3 模型效果参考资料 1 概述 CRNN(Convolutional Recurrent Neural Network)是2015年华科的白翔老师团队提出的,直至今日&#xff0c…

文本识别网络CRNN

文本识别网络CRNN 简介网络结构CNN层LSTM层CTC Loss 代码实现 简介 CRNN,全称Convolutional Recurrent Neural Network,卷积循环神经网络。 它是一种基于图像的序列识别网络,可以对不定长的文字序列进行端到端的识别。 它集成了卷积神经网络…

CRNN——文本识别算法

常用文字识别算法主要有两个框架: CNNRNNCTC(CRNNCTC)CNNSeq2SeqAttention 文章认为文字识别是对序列的预测方法,所以采用了对序列预测的RNN网络。通过CNN将图片的特征提取出来后采用RNN对序列进行预测,最后通过一个CTC的翻译层得到最终结果…

OCR论文笔记系列(一): CRNN文字识别

👨‍💻作者简介:大数据专业硕士在读,CSDN人工智能领域博客专家,阿里云专家博主,专注大数据与人工智能知识分享,公众号:GoAI的学习小屋,免费分享书籍、简历、导图等资料,更有交流群分享AI和大数据,加群方式公众号回复“加群”或➡️点击链接。 🎉专栏推荐:➡️点…

CRNN——卷积循环神经网络结构

CRNN——卷积循环神经网络结构 简介构成CNNMap-to-Sequence 图解RNNctcloss序列合并机制推理过程编解码过程 代码实现 简介 CRNN 全称为 Convolutional Recurrent Neural Network,是一种卷积循环神经网络结构,主要用于端到端地对不定长的文本序列进行识…

java bean的生命周期

文章转载来自博客园:https://www.cnblogs.com/kenshinobiy/p/4652008.html Spring 中bean 的生命周期短暂吗? 在spring中,从BeanFactory或ApplicationContext取得的实例为Singleton,也就是预设为每一个Bean的别名只能维持一个实例&#xf…

Spring创建Bean的生命周期

1.Bean 的创建生命周期 UserService.class —> 无参构造方法(推断构造方法) —> 普通对象 —> 依赖注入(为带有Autowired的属性赋值) —> 初始化前(执行带有PostConstruct的方法) —> 初始…

Bean的生命周期(不要背了记思想)

文章内容引用自 咕泡科技 咕泡出品,必属精品 文章目录 1. 应付面试2 可以跟着看源码的图3 学习Bean 的生命周期之前你应该知道什么4 Bean 的完整生命周期 1. 应付面试 你若是真的为面试而来,请把下面这段背下来,应付面试足矣 spring的bean的…

简述 Spring Bean的生命周期

“请你描述下 Spring Bean 的生命周期?”,这是面试官考察 Spring 的常用问题,可见是 Spring 中很重要的知识点。 其实要记忆该过程,还是需要我们先去理解,本文将从以下两方面去帮助理解 Bean 的生命周期: 生…

【Spring源码】讲讲Bean的生命周期

1、前言 面试官:“看过Spring源码吧,简单说说Spring中Bean的生命周期” 大神仙:“基本生命周期会经历实例化 -> 属性赋值 -> 初始化 -> 销毁”。 面试官:“......” 2、Bean的生命周期 如果是普通Bean的生命周期&am…

Spring中bean的生命周期(易懂版)

bean的生命周期 写在前面的话bean的生命周期代码演示 bean的更完整的生命周期添加后置处理器的代码演示 写在前面的话 关于bean的生命周期有很多的文章,但是大多数都是长篇的理论,说来说去也不是很好理解,再次我就整理了一篇比较好理解的bea…

面试官:讲一下Spring Bean的生命周期?

1. 引言 “请你描述下 Spring Bean 的生命周期?”,这是面试官考察 Spring 的常用问题,可见是 Spring 中很重要的知识点。 其实要记忆该过程,还是需要我们先去理解,本文将从以下两方面去帮助理解 Bean 的生命周期&…

Spring Bean的生命周期(非常详细)

生命周期图 文章目录 前言一、生命周期流程图:二、各种接口方法分类三、演示 前言 Spring作为当前Java最流行、最强大的轻量级框架,受到了程序员的热烈欢迎。准确的了解Spring Bean的生命周期是非常必要的。我们通常使用ApplicationContext作为Spring容…

Spring 中Bean的生命周期

目录 Bean的生命周期 五个阶段 下面是一个bean对象创建到销毁经历过的方法。 图示​ 问答 普通Java类是在哪一步变成beanDefinition的 推荐视频: 阿里专家耗时一年,终于把Spring源码AOP、IOC、Ben生命周期、事物、设计模式以及循环依赖讲全了_哔哩…

Spring之Bean的生命周期详解

通过前面多个接口的介绍了解了Bean对象生命周期相关的方法,本文就将这些接口的方法串起来,来了解Bean的完整的生命周期。而介绍Bean的生命周期也是面试过程中经常会碰到的一个问题,如果不注意就跳坑里啦~~ Spring之Bean对象的初始化和销毁方法…