CRNN项目实践

CRNN项目实战

之前写过一篇文章利用CRNN进行文字识别,当时重点讲的CRNN网络结构和CNN部分的代码实现,因为缺少文字数据集没有进行真正的训练,这次正好有一批不定长的字符验证码,正好CRNN主要就是用于端到端地对不定长的文本序列进行识别,当然是字符和文字都是可以用的,所以这里进行了一次实战。

主要是参考github项目:https://github.com/meijieru/crnn.pytorch

关于lmdb

lmdb安装

首先关于lmdb这个数据库,python有两个包,一个是lmdb,另一个是python-lmdb。

使用pycharm的包安装功能可以看到关于lmdb的描述

Universal Python binding for the LMDB 'Lightning' Database Version 1.3.0

关于python-lmdb的描述

simple lmdb bindings written using ctypes Version 1.0.0

所以理论上我们安装前者肯定是可以用的,但是经过亲身实践,

在pip环境中使用pip install lmdb确实可以正常使用;

但是在conda环境中,使用conda install lmdb安装完成之后却无法导入包。

所以又使用:conda install python-lmdb安装,安装完之后却可以使用,非常奇怪。

后发现原因大概率是版本问题,使用pip可以安装lmdb=1.3.0的最新版本,而conda只能安装lmdb=0.9.x的版本,所以目前在conda中只能使用python-lmdb暂替使用。

制作适用CRNNlmdb数据集

github项目中关于如何训练自己的数据集写的不是很清楚,如果我们直接运行train.py会遇到各种问题,首先第一个问题就是数据集的问题,lmdbDataset中的初始化

1
2
3
4
5
6
7
8
self.env = lmdb.open(
root,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False
)

这里会报错,因为这里读取的路径下需要有lmdb格式的数据,所以在这之前我们需要生成lmdb格式的数据集。

相关代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- coding: utf-8 -*-
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import glob
import numpy as np


def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.frombuffer(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
if img is None:
return False
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True


def writeCache(env, cache): # 在python3环境下运行
with env.begin(write=True) as txn:
for k, v in cache.items():
if type(v) is str:
txn.put(k.encode(), v.encode())
continue
txn.put(k.encode(), v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert (len(imagePathList) == len(labelList)), 'len(x) != len(y)'
nSamples = len(imagePathList)
print('...................')
# env = lmdb.open(outputPath, map_size=104857600) # 最大100MB
env = lmdb.open(outputPath, map_size=10485760)

cache = {}
cnt = 1
for i in range(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue

# .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
imageKey = f'image-{cnt}'
labelKey = f'label-{cnt}'
cache[imageKey] = imageBin
cache[labelKey] = label

if lexiconList:
lexiconKey = f'lexicon-{cnt}'
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt - 1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)


def read_text(path):
with open(path) as f:
text = f.read()
text = text.strip()

return text


if __name__ == '__main__':

# lmdb 输出目录
# outputPath = 'data/train/lmdb_data' # 训练数据
outputPath = 'data/val/lmdb_data' # 验证数据

# 训练图片路径,标签是txt格式,名字跟图片名字要一致,如123.jpg对应标签需要是123.txt
# input_path = 'data/train/origin_data/*.jpg'
input_path = 'data/val/origin_data/*.jpg'

imagePathList = glob.glob(input_path)
print('------------', len(imagePathList), '------------')
imgLabelLists = []
for p in imagePathList:
try:
# imgLabelLists.append((p, read_text(p.replace('.jpg', '.txt'))))
imgLabelLists.append((p, p.split('_')[2].replace('.jpg', '')))
except Exception as _e:
print(_e)
continue

# sort by labelList
imgLabelList = sorted(imgLabelLists, key=lambda x: len(x[1]))
imgPaths = [p[0] for p in imgLabelList]
txtLists = [p[1] for p in imgLabelList]

createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)

代码执行完成会在相应目录下生成data.mdb、lock.mdb两个文件。代码很简单,一看就懂,因为我的原始数据集,标签包含在图片名称中,不是另外存储在txt文件中,所以对相应代码进行了改动,另外把python2相关的东西改成了可以用python3运行。

另外有一点:

env = lmdb.open(outputPath, map_size=104857600) # 最大100MB

map_size需要根据自己的数据集设置大小(单位是B),运行完生成的data.mdb的大小就是设置的大小,如果设置的比较大造成空间的浪费,设置的比较小可能会不够用(默认应该是10MB)。

很多资料都写这里设置1T,如果你的电脑硬盘空间不够,就会报错(报的错误是乱码)。

另外,如果你还遇到了其他乱码报错,大概率是路径错误。

参考文章:https://www.cnblogs.com/yanghailin/p/14519525.html

CTCLoss

在train.py中有这么一行代码,

from warpctc_pytorch import CTCLoss

初次使用的话一般是显示没有这个包的,而pytorch(version>=1.1)其实是有CTCLoss模块的

from torch.nn import CTCLoss

所以如果你的pytorch版本满足,就无需额外安装warp_ctc_pytorch了,替换一下导入代码即可。如果你的版本比较低,还是需要手动安装这个包的,如果是Windows环境下,比较麻烦的就是需要安装cmake来编译文件。不再赘述。

需要用到的warp_ctc_pytorch: https://github.com/SeanNaren/warp-ctc

参考文章:https://blog.csdn.net/weixin_40437821/article/details/105473032

然后简单介绍下,pytorchCTCLoss的用法。

初始化

1
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')

类初始化参数说明:

blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定

reduction:处理output losses的方式,string类型,可选’none’ 、 ‘mean’ 及 ‘sum’,’none’表示对output losses不做任何处理,’mean’ 则对output losses取平均值处理,’sum’则是对output losses求和处理,默认为’mean’ 。

计算损失

1
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)

CTCLoss()对象调用形参说明:

log_probsshape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即输出序列长度,N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度,log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中;

targets:shape为(N, S) 或(sum(target_lengths))的张量,其中第一种类型,N表示训练的batch size长度,S则为标签长度,第二种类型,则为所有标签长度之和,但是需要注意的是targets不能包含有空白标签;

input_lengths:shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同;

target_lengths:shape为(N)的张量或元组,其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的;

这里最重要的就是初始化blank参数的设置和计算损失时,log_probs参数需要先进行log_softmax,这也是我们在这个项目中需要调整的点,如果我们直接从warp_ctc_pytorch更换为pytorch内置的CTCLoss,然后其他的不改动的话,是训练不出来结果的。

改动点:

1
2
3
criterion = CTCLoss(blank=0, reduction='mean')  # 初始化

cost = criterion(preds.log_softmax(2), text, preds_size, length) / batch_size # 损失计算,共有两处,训练和验证

比较巧合的是,在这个项目中,0的位置就是为空白字符预留的,而且blank的默认值也为0,所以不改动也是可以的。

训练

配置训练数据路径(trainroot)、验证数据路径(valroot)、预训练权重路径(pretrained),将lr(学习率)设置为0.001,nepoch=200。

训练的过程中会报很多错误,因为这个GitHub项目可能部分代码写的比较粗糙,另一方面也是因为python2python3,Linux和windows的环境问题。

我遇到了以下错误:

1、trainRoot,valRoot需要改下大小写

2、TypeError: Won't implicitly convert Unicode to bytes; use .encode()

按照错误提示加上encode
txn.get(‘num-samples’.encode())
label_byte = txn.get(label_key.encode())
imgbuf = txn.get(img_key.encode())

3、ValueError: sampler option is mutually exclusive with shuffle

1
2
3
4
5
6
7
8
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=opt.batchSize,
shuffle=False, # sampler不为None,shuffle就需要为False
sampler=sampler,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio)
)

因为本来shuffle参数为True,需要更改为False。

参考文章:https://blog.csdn.net/hjxu2016/article/details/111300972

4、TypeError: cannot pickle 'Environment' object

这个是因为在windows环境下,多进程训练有问题,将workers参数设置为0即可。

参考文章:https://blog.csdn.net/weixin_43272781/article/details/112757371

5、AttributeError: module 'torch' has no attribute 'longTensor'

在dataset.py脚本中,应该是torch.LongTensor。

6、TypeError: randint() takes 3 positional arguments but 4 were given

1
random.randint(0, len(self), self.batch_size)

我猜这里是笔误了,应该是len(self)-self.batch_size。

7、train.py中

image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)

推测最后一个参数应该是opt.imgW。

8、RuntimeError: The expanded size of the tensor (64) must match the existing size (63) at non-singleton dimension 0. Target sizes: [64]. Tensor sizes: [63]

嘿嘿,这个问题是因为我改了torch.range代码,因为pycharm提示说这个方法被废弃了,要求使用torch.arange

1
2
batch_index = random_start + torch.arange(0, self.batch_size-1)
index[i * self.batch_size:(i+1)*self.batch_size] = batch_index

torch.range(start=1, end=6)的结果是会包含end的,而torch.arange(start=1, end=6)的结果并不包含end。所以这里就不需要减1了。

1
batch_index = random_start + torch.arange(0, self.batch_size)

同样的,后面取末尾元素的时候也要去掉减1

1
tail_index = random_start + torch.arange(0, tail)

参考文章:https://blog.csdn.net/lunhuicnm/article/details/106712026

9、RuntimeError: set_sizes_contiguous is not allowed on a Tensor created from .data or .detach().

1
2
v.data.resize_(data.size()).copy_(data)
v.resize_(data.size()).copy_(data)

将上面的替换为下面的

参考文章:https://blog.csdn.net/weixin_45292103/article/details/102736742

10、还有一个关于utils.py中编码器(encode()方法的问题),

本来有一行代码是:_str = unicode(_str, 'utf-8')

这是python2中的语法,python3并不需要,但是因为前面制作lmdb数据集的时候,我们的标签进行了encode()处理,也就是这一步:txn.put(k.encode(), v.encode())

导致后面训练的时候,在对标签进行编码时,标签传过来是这样一种形式:”b’jdvfl0k’”

所以可以加上这行代码:_str = _str.replace("b'", "").replace("'", '')

11、在train.py的验证方法里,有这么一段代码:

1
2
3
_, preds = preds.max(2)
preds = preds.squeeze(2)
preds = preds.transpose(1, 0).contiguous().view(-1)

我在实际运行时发现preds.squeeze(2)会报错,然后调试发现preds此时的shape为(26, 64),所以应该无需squeeze,可以将这行代码注释。

另外,还要根据情况配置displayInterval、valInterval、saveInterval参数

displayInterval默认值是500,我这里训练不定长字符验证码,训练集总共4500张图片,batch_size=64,总共71个批次,所以设置displayInterval=10,每10个批次打印一次损失情况。其他两个参数同理,按照自己的情况调整。

终于把所有问题都解决了,可以正常训练了,但是训练的过程中打印的测试数据让我感觉不太对劲,我发现在对比预测标签和真实标签时,真实标签的形式为:”b’jdvfl0k’”

和上面同样的问题,需要处理下:

1
2
target = target.replace("b'", "").replace("'", "")
gt = gt.replace("b'", "").replace("'", "")

因为使用的预训练权重,训练的比较快,训练过程示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
[0/200][10/71] Loss: 0.07960973680019379
[0/200][20/71] Loss: 0.014807680621743202
[0/200][30/71] Loss: 0.010894499719142914
Start val
66----y----f----g----p---- => 6yfgp, gt:6yfgp
-f-t--r---z---x---d---8--- => ftrzxd8, gt:ftpzxd8
-k--a----y---m-----w------ => kaymw, gt:kaymw
-n-----3-----7-----z------ => n37z, gt:n37z
--u-------y----l----9----- => uyl9, gt:uyl9
ss---y---8---e---l----p--- => sy8elp, gt:sy8elp
--y--d--h--t--f--m----y--- => ydhtfmy, gt:ydtfmy
--zz----zz----l-----1----- => zzl1, gt:zzl1
Test loss: 0.016945159062743187, accuracy: 0.465
[0/200][40/71] Loss: 0.007085741963237524
[0/200][50/71] Loss: 0.007241038139909506
[0/200][60/71] Loss: 0.004275097511708736
[0/200][70/71] Loss: 0.0034677726216614246
Start val
-n----n----3----k----l---- => nn3kl, gt:nn3kl
-c---c----b---s---a---k--- => ccbsak, gt:ccbsak
-u---v---7--k---z--s--4--- => uv7kzs4, gt:uv7kzs4
-5----e------u------------ => 5eu, gt:5eu
-k-----e-----v-------v---- => kevv, gt:kelbv
--yy-d-----t--f--m----y--- => ydtfmy, gt:ydtfmy
-5--q---h---t--m---8--q--- => 5qhtm8q, gt:5qhtm8q
--y-----z----g-------v---- => yzgv, gt:yzgv
Test loss: 0.014375979080796242, accuracy: 0.66
[1/200][10/71] Loss: 0.004606468137353659
[1/200][20/71] Loss: 0.0033515722025185823
[1/200][30/71] Loss: 0.002877553692087531
Start val
-7--d----o---h----c---y--- => 7dohcy, gt:7d0hcy
-6----y----f----g----p---- => 6yfgp, gt:6yfgp
-1----f---j----b-----f---- => 1fjbf, gt:1fjbf
-2---o----u---h--4----z--- => 2ouh4z, gt:2ouh4z
--x----1---b---ww-----n--- => x1bwn, gt:x1bwn
-j---k--------------55---- => jk5, gt:jk15
-1---t--6---x----n----o--- => 1t6xno, gt:1t6xno
--y---x--j---k----a----t-- => yxjkat, gt:yxjkat
Test loss: 0.012275747954845428, accuracy: 0.745
[1/200][40/71] Loss: 0.0018801590194925666
[1/200][50/71] Loss: 0.002281028078868985
[1/200][60/71] Loss: 0.001854069298133254
[1/200][70/71] Loss: 0.0012131230905652046
Start val
dd----y--j-c---v--y---m--- => dyjcvym, gt:dyjcvym
-s----z---9----qq------t-- => sz9qt, gt:sz9qt
--y--qq----qq-----t---h--- => yqqth, gt:yqqth
-h---a----5--s----xx--o--- => ha5sxo, gt:ha5sxo
tt----o----7----a---0----- => to7a0, gt:to4ao
-n-----xx---f---e-----s--- => nxfes, gt:nxfes
-1---t--6---x----n----o--- => 1t6xno, gt:1t6xno
-c---e---d---e---c---m---- => cedecm, gt:cedecm
Test loss: 0.004259577952325344, accuracy: 0.815

两轮训练结束之后准确率就达到了81.5%(200个验证图片)

经过35轮的训练,准确率可以稳定在97.5%左右。