验证码识别(二)多输出模型
前面我们使用了Python-tesseract进行验证码的识别。不过我们发现一个问题,稍微复杂一点的验证码它就识别不出来了。
本章我们就来实现一个自己的验证码识别程序。我们使用的是CNN多输出的方式来进行验证码的识别。CNN多输出在很多场景都有用到如:ACGAN,Yolo等。
数据生成 captcha
首先是数据集,我们使用了captcha来生成验证码图片。captcha 是用 python 写的生成验证码的库,它支持图片验证码和语音验证码,我们使用的是它生成图片验证码的功能。由于比较简单,这里就直接给代码。
from captcha.image import ImageCaptcha
import random
import string
import os
# 修改为对应的地址 (先创建好)
train_path=r'E:\DataSets\Digitalverificationcode\train_image'
val_path=r'E:\DataSets\Digitalverificationcode\val_image'
# 生成图片
def gen_image(path,num):
# 字符 A~Z + 0~9
characters=string.digits+string.ascii_uppercase
# print(characters)
width,height,n_len,n_class=160,40,4,len(characters)
# 生成160*40的图片
generator=ImageCaptcha(width=width,height=height)
for i in range(num):
random_str=''.join([random.choice(characters) for j in range(4)])
# 随机生成图片
img=generator.generate_image(random_str)
# 防止重名报错停止运行
try:
img.save(os.path.join(path,'%s.jpg'%random_str))
except:
continue
print('done')
# 数据生成
gen_image(train_path,20000)
gen_image(val_path,1000)
运行上述代码,我们就可以在对应的文件夹中看见如下的图片,并且图片内的验证码就是图片的命名,这会方便我们后面进行读取。当然写成CSV的格式也是可以的。
网络搭建 CNN
接着,我们搭建网络结果,这次搭建的网络和我们之前搭建的网络有所不同,之前搭建的单输入单输出的网络结构,在这里我们要搭建的是单输入多输出的网络结构,因为每张图片其实对应的是4个标签。对应的网络结构图如下:
模型搭建这块,如果你已经学会如何搭建VGG,ResNet等网络。我想对你来说是比较简单的。代码如下:
import tensorflow.keras as keras
import os
import random
import cv2
import numpy as np
def create_model(input):
x=keras.layers.Conv2D(64,kernel_size=(3,3),padding='same',
activation='relu')(input)
for i in range(3):
x=keras.layers.Conv2D(64*(i+1),kernel_size=(3,3),padding='same',
activation='relu')(x)
x=keras.layers.Conv2D(64*(i+1),kernel_size=(3,3),strides=2,
padding='same',activation='relu')(x)
x=keras.layers.BatchNormalization()(x)
x=keras.layers.MaxPool2D()(x)
x=keras.layers.Flatten()(x)
x=keras.layers.Dropout(0.5)(x)
x=[keras.layers.Dense(36,activation='softmax')(x) for i in range(4)]
model=keras.models.Model(inputs=input,outputs=x)
return model
数据加载 generator
接着是加载数据,我们需要读取图片的路径,然后把标签换成对应的字符索引。并把数据封装成生成器的格式,避免内存溢出。最后记得将标签独热编码。
# 读取数据集地址和标签信息
def load_data(path,data_split=0.9):
print('loading_data...')
str='0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
image_paths=[os.path.join(path,p) for p in os.listdir(path)]
# 将 字符转换为 索引 8AFS ==>[8,10,15,28]
label_=[label.split('.')[0] for label in os.listdir(path)]
label=[]
for i in label_:
lab=[]
for j in i:
lab.append(str.index(j))
label.append(lab)
# 打乱数据集
seed_=random.randint(1,100)
random.seed(seed_)
random.shuffle(image_paths)
random.seed(seed_)
random.shuffle(label)
# 分割数据集
num=int(len(image_paths)*data_split)
x_train=image_paths[:num]
y_train=label[:num]
x_test=image_paths[num:]
y_test=label[num:]
print('done')
return x_train,y_train,x_test,y_test
# 数据生成器
def gen_data(image_paths,labels,batch_size):
while True:
data=[]
label=[]
for index,image_path in enumerate(image_paths):
# print(image_path) 图片处理
image=cv2.imread(image_path)
image=cv2.resize(image,(160,40))
data.append(image)
label.append(labels[index])
# 当数组中存在一个批次的数据之后返回
if len(data)==batch_size:
data=np.array(data)
data=data.reshape(-1,40,160,3)/255.
label=np.array(label)
# y 对应4个输出,一个batch的数据 以及36分类
y=np.zeros([4,batch_size,36])
for i in range(batch_size):
for idx,row_i in enumerate(label[i]):
1 =
y=[yy for yy in y]
yield data,y
data=[]
label=[]
使用如下的方式查看生成器生成的数据是否图片和标签一一对应,如果没有问题,就可以编写训练代码了。
训练和测试 train&test
训练代码如下,由于前面我们以及把需要用到的东西写好了,这里我们只需要简单的调用一下就可以进行训练了。
if __name__ == '__main__':
batch_size=32
train_path=r'E:\DataSets\Digitalverificationcode\train_image'
x_train,y_train,x_test,y_test=load_data(train_path)
input=keras.layers.Input((40,160,3))
model=create_model(input)
model.compile(loss='categorical_crossentropy',optimizer=keras.optimizers.Adam(lr=3e-4),
metrics=['acc'])
model.fit_generator(gen_data(x_train,y_train,batch_size),
steps_per_epoch=len(x_train)//batch_size,
validation_data=gen_data(x_test,y_test,batch_size),
validation_steps=len(x_test)//batch_size,
epochs=5)
model.save('yzm_.h5')
训练结束之后,我们就可以使用测试集来进行测试了。注意在图片处理的时候要和训练时一致。
import tensorflow.keras as keras
import os
import cv2
import numpy as np
val_path=r'E:\DataSets\Digitalverificationcode\val_image'
str='0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
model=keras.models.load_model('yzm_.h5')
image_paths=[os.path.join(val_path,path) for path in os.listdir(val_path)]
num=0
for index ,image_path in enumerate(image_paths):
image=cv2.imread(image_path)
img=image.copy()
image=cv2.resize(image,(160,40))
image=image.reshape(-1,40,160,3)/255.
p=model.predict(image)
y=np.argmax(np.array(p),axis=2)[:,0]
char=''.join([str[x] for x in y])
cv2.putText(img,char,(5,15),cv2.FONT_HERSHEY_COMPLEX,.5,
(255,0,255),1)
cv2.imshow('img',img)
cv2.waitKey(0)
测试效果