什么是CRNN
CRNN的整体框架图:
CRNN=CNN+RNN+CTC
1)CNN主要是为RNN提取特征;
2)RNN主要是将CNN输出的特征序列转换为输出;
3)CTC为翻译层,得到最终的预测结果,由于CTC适合不知道输入和输出是否对齐的情况使用的算法,所以CTC适合语音识别和手写字符识别的任务,具体内容也可参考[2];
具体内容,参考[1]中的内容已经介绍的很详细,然后这边重点对比一下pytorch中网络结构代码:
其中:CRNN的网络结构详细:
注意:1.Transcription为转录层,将每一帧的预测变为最终的标签序列
pytorch中的代码实现:
class CRNN(nn.Module):def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):super(CRNN, self).__init__()#assert imgH % 16 == 0, 'imgH has to be a multiple of 16'ks = [3, 3, 3, 3, 3, 3, 2]ps = [1, 1, 1, 1, 1, 1, 0]ss = [1, 1, 1, 1, 1, 1, 1]nm = [64, 128, 256, 256, 512, 512, 512]cnn = nn.Sequential()def convRelu(i, batchNormalization=False):nIn = nc if i == 0 else nm[i - 1]nOut = nm[i]cnn.add_module('conv{0}'.format(i),nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))if batchNormalization:cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))if leakyRelu:cnn.add_module('relu{0}'.format(i),nn.LeakyReLU(0.2, inplace=True))else:cnn.add_module('relu{0}'.format(i), nn.ReLU(True))convRelu(0)cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64convRelu(1)cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32convRelu(2, True)convRelu(3)cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16convRelu(4, True)convRelu(5)cnn.add_module('pooling{0}'.format(3),nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16convRelu(6, True) # 512x1x16self.cnn = cnnself.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))# print('')def forward(self, input):# conv featuresconv = self.cnn(input)b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2)conv = conv.permute(2, 0, 1) # [w, b, c]# rnn featuresoutput = self.rnn(conv)# output = self.rnn(conv)return output
注意:1.在卷积层后面都接入了激活relu层,在网络结构表中,并没有体现,详细可看下面的代码;
参考文献
参考[1]:CRNN算法详解
参考[2]:CTC算法详解
参考[3]:《Connectionist Temporal Classification》
参考[4]:中文文字检测与识别的评测方法
参考[5]:Sequence Modeling With CTC
其他:CTPN
好文链接:如何优雅的使用pytorch内置torch.nn.CTCLoss的方法