DL之CNN:基于CNN-RNN(GRU,2)算法(keras+tensorflow)实现不定长文本识别

本文涉及的产品
车辆物流识别,车辆物流识别 200次/月
个人证照识别,个人证照识别 200次/月
教育场景识别,教育场景识别 200次/月
简介: DL之CNN:基于CNN-RNN(GRU,2)算法(keras+tensorflow)实现不定长文本识别

输出结果


后期更新……





实现代码


后期更新……


image_ocr代码:DL之CNN:利用CNN(keras, CTC loss, {image_ocr})算法实现OCR光学字符识别

https://blog.csdn.net/qq_41185868/article/details/90239954


#DL之CNN:基于CNN-RNN(GRU,2)算法(keras+tensorflow)实现不定长文本识别

#Keras 的 CTC loss函数:位于 https://github.com/fchollet/keras/blob/master/keras/backend/tensorflow_backend.py文件中,内容如下:

import tensorflow as tf

from tensorflow.python.ops import ctc_ops as ctc

def ctc_batch_cost(y_true, y_pred, input_length, label_length):

   """Runs CTC loss algorithm on each batch element.

   # Arguments

       y_true: tensor `(samples, max_string_length)`

           containing the truth labels.

       y_pred: tensor `(samples, time_steps, num_categories)`

           containing the prediction, or output of the softmax.

       input_length: tensor `(samples, 1)` containing the sequence length for

           each batch item in `y_pred`.

       label_length: tensor `(samples, 1)` containing the sequence length for

           each batch item in `y_true`.

   # Returns

       Tensor with shape (samples,1) containing the

           CTC loss of each element.

   """

   label_length = tf.to_int32(tf.squeeze(label_length))

   input_length = tf.to_int32(tf.squeeze(input_length))

   sparse_labels = tf.to_int32(ctc_label_dense_to_sparse(y_true, label_length))

   y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)

   return tf.expand_dims(ctc.ctc_loss(inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)

# 不定长文本识别

import os

import itertools

import re

import datetime

import cairocffi as cairo

import editdistance

import numpy as np

from scipy import ndimage

import pylab

from keras import backend as K

from keras.layers.convolutional import Conv2D, MaxPooling2D

from keras.layers import Input, Dense, Activation, Reshape, Lambda

from keras.layers.merge import add, concatenate

from keras.layers.recurrent import GRU

from keras.models import Model

from keras.optimizers import SGD

from keras.utils.data_utils import get_file

from keras.preprocessing import image

from keras.callbacks import EarlyStopping,Callback

from keras.backend.tensorflow_backend import set_session

import tensorflow as tf

import matplotlib.pyplot as plt

config = tf.ConfigProto()

config.gpu_options.allow_growth=True

set_session(tf.Session(config=config))

OUTPUT_DIR = 'image_ocr'

np.random.seed(55)

# # 从 Keras 官方文件中 import 相关的函数

# !wget https://raw.githubusercontent.com/fchollet/keras/master/examples/image_ocr.py

from image_ocr import *

#定义必要的参数:

run_name = datetime.datetime.now().strftime('%Y:%m:%d:%H:%M:%S')

start_epoch = 0

stop_epoch  = 200

img_w = 128

img_h = 64

words_per_epoch = 16000

val_split = 0.2

val_words = int(words_per_epoch * (val_split))

# Network parameters

conv_filters = 16

kernel_size = (3, 3)

pool_size = 2

time_dense_size = 32

rnn_size = 512

input_shape = (img_w, img_h, 1)

# 使用这些函数以及对应参数构建生成器,生成不固定长度的验证码

fdir = os.path.dirname(get_file('wordlists.tgz', origin='http://www.mythic-ai.com/datasets/wordlists.tgz', untar=True))

img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),

                                bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),

                                minibatch_size=32, img_w=img_w, img_h=img_h,

                                downsample_factor=(pool_size ** 2), val_split=words_per_epoch - val_words )

#构建CNN网络

act = 'relu'

input_data = Input(name='the_input', shape=input_shape, dtype='float32')

inner = Conv2D(conv_filters, kernel_size, padding='same',  activation=act, kernel_initializer='he_normal',

                  name='conv1')(input_data)

inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)

inner = Conv2D(conv_filters, kernel_size, padding='same',  activation=act, kernel_initializer='he_normal',

                  name='conv2')(inner)

inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)

conv_to_rnn_dims = (img_w // (pool_size ** 2), (img_h // (pool_size ** 2)) * conv_filters)

inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)

#减少输入尺寸到RNN:cuts down input size going into RNN:  

inner = Dense(time_dense_size, activation=act, name='dense1')(inner)

#GRU模型:两层双向的算法

# Two layers of bidirecitonal GRUs

# GRU seems to work as well, if not better than LSTM:

gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru1')(inner)

gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(inner)

gru1_merged = add([gru_1, gru_1b])

gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)

gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(gru1_merged)

#将RNN输出转换为字符激活:transforms RNN output to character activations

inner = Dense(img_gen.get_output_size(), kernel_initializer='he_normal',

                 name='dense2')(concatenate([gru_2, gru_2b]))

y_pred = Activation('softmax', name='softmax')(inner)

Model(inputs=input_data, outputs=y_pred).summary()

labels = Input(name='the_labels', shape=[img_gen.absolute_max_string_len], dtype='float32')

input_length = Input(name='input_length', shape=[1], dtype='int64')

label_length = Input(name='label_length', shape=[1], dtype='int64')

#Keras目前不支持带有额外参数的loss funcs,所以CTC loss是在lambda层中实现的

# Keras doesn't currently support loss funcs with extra parameters, so CTC loss is implemented in a lambda layer

loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])

#clipnorm似乎加快了收敛速度:clipnorm seems to speeds up convergence

sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)

model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)

#计算损失发生在其他地方,所以使用一个哑函数来表示损失

# the loss calc occurs elsewhere, so use a dummy lambda func for the loss

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)

if start_epoch > 0:

   weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))

   model.load_weights(weight_file)

#捕获softmax的输出,以便在可视化过程中解码输出

# captures output of softmax so we can decode the output during visualization

test_func = K.function([input_data], [y_pred])

# 反馈函数,即运行固定次数后,执行反馈函数可保存模型,并且可视化当前训练的效果

viz_cb = VizCallback(run_name, test_func, img_gen.next_val())

# 执行训练:

model.fit_generator(generator=img_gen.next_train(), steps_per_epoch=(words_per_epoch - val_words),

                       epochs=stop_epoch, validation_data=img_gen.next_val(), validation_steps=val_words,

                       callbacks=[EarlyStopping(patience=10), viz_cb, img_gen], initial_epoch=start_epoch)





相关文章
|
6月前
|
数据采集 算法 数据可视化
基于Python的k-means聚类分析算法的实现与应用,可以用在电商评论、招聘信息等各个领域的文本聚类及指标聚类,效果很好
本文介绍了基于Python实现的k-means聚类分析算法,并通过微博考研话题的数据清洗、聚类数量评估、聚类分析实现与结果可视化等步骤,展示了该算法在文本聚类领域的应用效果。
202 1
|
4月前
|
机器学习/深度学习 SQL 数据采集
基于tensorflow、CNN网络识别花卉的种类(图像识别)
基于tensorflow、CNN网络识别花卉的种类(图像识别)
108 1
|
5月前
|
机器学习/深度学习 存储 人工智能
文本情感识别分析系统Python+SVM分类算法+机器学习人工智能+计算机毕业设计
使用Python作为开发语言,基于文本数据集(一个积极的xls文本格式和一个消极的xls文本格式文件),使用Word2vec对文本进行处理。通过支持向量机SVM算法训练情绪分类模型。实现对文本消极情感和文本积极情感的识别。并基于Django框架开发网页平台实现对用户的可视化操作和数据存储。
82 0
文本情感识别分析系统Python+SVM分类算法+机器学习人工智能+计算机毕业设计
|
7月前
|
机器学习/深度学习 数据采集 监控
基于CNN卷积神经网络的步态识别matlab仿真,数据库采用CASIA库
**核心程序**: 完整版代码附中文注释,确保清晰理解。 **理论概述**: 利用CNN从视频中学习步态时空特征。 **系统框架**: 1. 数据预处理 2. CNN特征提取 3. 构建CNN模型 4. 训练与优化 5. 识别测试 **CNN原理**: 卷积、池化、激活功能强大特征学习。 **CASIA数据库**: 高质量数据集促进模型鲁棒性。 **结论**: CNN驱动的步态识别展现高精度,潜力巨大,适用于监控和安全领域。
|
6月前
|
数据采集 自然语言处理 数据可视化
基于Python的社交媒体评论数据挖掘,使用LDA主题分析、文本聚类算法、情感分析实现
本文介绍了基于Python的社交媒体评论数据挖掘方法,使用LDA主题分析、文本聚类算法和情感分析技术,对数据进行深入分析和可视化,以揭示文本数据中的潜在主题、模式和情感倾向。
599 0
|
7月前
|
机器学习/深度学习 数据采集 算法
Python基于OpenCV和卷积神经网络CNN进行车牌号码识别项目实战
Python基于OpenCV和卷积神经网络CNN进行车牌号码识别项目实战
|
6月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
67 0
|
6月前
|
机器学习/深度学习 数据采集 TensorFlow
从零到精通:TensorFlow与卷积神经网络(CNN)助你成为图像识别高手的终极指南——深入浅出教你搭建首个猫狗分类器,附带实战代码与训练技巧揭秘
【8月更文挑战第31天】本文通过杂文形式介绍了如何利用 TensorFlow 和卷积神经网络(CNN)构建图像识别系统,详细演示了从数据准备、模型构建到训练与评估的全过程。通过具体示例代码,展示了使用 Keras API 训练猫狗分类器的步骤,旨在帮助读者掌握图像识别的核心技术。此外,还探讨了图像识别在物体检测、语义分割等领域的广泛应用前景。
75 0
|
7月前
|
机器学习/深度学习 数据采集 算法
Python基于KMeans算法进行文本聚类项目实战
Python基于KMeans算法进行文本聚类项目实战
|
7月前
|
机器学习/深度学习 人工智能 自然语言处理
算法金 | 秒懂 AI - 深度学习五大模型:RNN、CNN、Transformer、BERT、GPT 简介
**RNN**,1986年提出,用于序列数据,如语言模型和语音识别,但原始模型有梯度消失问题。**LSTM**和**GRU**通过门控解决了此问题。 **CNN**,1989年引入,擅长图像处理,卷积层和池化层提取特征,经典应用包括图像分类和物体检测,如LeNet-5。 **Transformer**,2017年由Google推出,自注意力机制实现并行计算,优化了NLP效率,如机器翻译。 **BERT**,2018年Google的双向预训练模型,通过掩码语言模型改进上下文理解,适用于问答和文本分类。
204 9

热门文章

最新文章