参考链接:
一文读懂CRNN+CTC文字识别 - 知乎
CTC loss - 知乎
1、背景
文字识别主流的两种算法
1.1 基于CRNN+CTC
1.2 基于CNN+Seq2Seq+Attention
2、CRNN+CTC原理解析
CRNN+CTC结构图
以下是根据paddleocr中以mobilenetv3为backbone的网络结构图
model 输入:256*3*32*100
BaseModel((backbone): MobileNetV3((conv1): ConvBNLayer((conv): Conv2D(3, 8, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)(bn): BatchNorm())(blocks): Sequential((0): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(8, 8, kernel_size=[3, 3], padding=1, groups=8, data_format=NCHW)(bn): BatchNorm())(linear_conv): ConvBNLayer((conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(1): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(8, 32, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(32, 32, kernel_size=[3, 3], stride=[2, 1], padding=1, groups=32, data_format=NCHW)(bn): BatchNorm())(linear_conv): ConvBNLayer((conv): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(2): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(40, 40, kernel_size=[3, 3], padding=1, groups=40, data_format=NCHW)(bn): BatchNorm())(linear_conv): ConvBNLayer((conv): Conv2D(40, 16, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(3): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(40, 40, kernel_size=[5, 5], stride=[2, 1], padding=2, groups=40, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(40, 10, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(10, 40, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(40, 24, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(4): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(64, 16, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(16, 64, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(5): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(64, 16, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(16, 64, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(6): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(24, 120, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(120, 120, kernel_size=[3, 3], padding=1, groups=120, data_format=NCHW)(bn): BatchNorm())(linear_conv): ConvBNLayer((conv): Conv2D(120, 40, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(7): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(40, 104, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(104, 104, kernel_size=[3, 3], padding=1, groups=104, data_format=NCHW)(bn): BatchNorm())(linear_conv): ConvBNLayer((conv): Conv2D(104, 40, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(8): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)(bn): BatchNorm())(linear_conv): ConvBNLayer((conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(9): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)(bn): BatchNorm())(linear_conv): ConvBNLayer((conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(10): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(40, 240, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(240, 240, kernel_size=[3, 3], padding=1, groups=240, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(240, 60, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(60, 240, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(240, 56, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(11): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(336, 336, kernel_size=[3, 3], padding=1, groups=336, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(336, 84, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(84, 336, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(336, 56, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(12): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(336, 336, kernel_size=[5, 5], stride=[2, 1], padding=2, groups=336, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(336, 84, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(84, 336, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(336, 80, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(13): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(480, 120, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(120, 480, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm()))(14): ResidualUnit((expand_conv): ConvBNLayer((conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())(bottleneck_conv): ConvBNLayer((conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)(bn): BatchNorm())(mid_se): SEModule((avg_pool): AdaptiveAvgPool2D(output_size=1)(conv1): Conv2D(480, 120, kernel_size=[1, 1], data_format=NCHW)(conv2): Conv2D(120, 480, kernel_size=[1, 1], data_format=NCHW))(linear_conv): ConvBNLayer((conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)(bn): BatchNorm())))(conv2): ConvBNLayer((conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW) # 输出256 * 480 *1 *25,就是将图片编码为480*25的序列。(bn): BatchNorm())(pool): MaxPool2D(kernel_size=2, stride=2, padding=0))(neck): SequenceEncoder( # 输入neck的是256*480*1*25(encoder_reshape): Im2Seq() # 经过Im2Seq后,将图片输出为256*25*480的向量序列,25个时间序列,每个序列的维度为480(encoder): EncoderWithRNN((lstm): LSTM(480, 96, num_layers=2(0): BiRNN((cell_fw): LSTMCell(480, 96)(cell_bw): LSTMCell(480, 96))(1): BiRNN((cell_fw): LSTMCell(192, 96)(cell_bw): LSTMCell(192, 96) #经过BILSTM后输出256*25*192))))(head): CTCHead( #输入256*25*192,输出256*25*96,25是最大字符长度,96是字典的个数加上了sos,bos两个起始符(fc): Linear(in_features=192, out_features=96, dtype=float32))
)
CTCLOSS:最后一步ctcloss就是动态路径规划,设定一些规则,去掉重复的和空白的字符。
forward计算得到loss=每一条可以正确到达lable的路径的概率之和
每一条可以正确到达lable的路径的概率=这条路径的每个字符的概率乘积。
backward计算更新梯度,就是为了更新LSTM中的参数。反向算法:计算在t时刻,以第s个label 开始的后缀到正确结束点的概率和。