【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

本文涉及的产品
票证核验,票证核验 50次/账号
自定义KV模板,自定义KV模板 500次/账号
OCR统一识别,每月200次
简介: 【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、OCR文字识别简介

利用计算机自动识别字符的技术,是模式识别应用的一个重要领域。人们在生产和生活中,要处理大量的文字、报表和文本。为了减轻人们的劳动,提高处理效率,从上世纪50年代起就开始探讨文字识别方法,并研制出光学字符识别器。

OCR(Optical Character Recognition)图像文字识别是人工智能的重要分支,赋予计算机人眼的功能,使其可以看图识字,图像文字识别系统流程一般分为图像采集、文字检测、文字识别以及结果输出四部分。

二、OCR文字识别项目实战

1:数据集简介

MSRA-TD500该数据集共包含500 张自然场景图像,其分辨率在1296 ´ 864至920 ´ 1280 之间,涵盖了室内商场、标识牌、室外街道、广告牌等大多数场,文本包含中文和英文,有着不同的字体、大小和倾斜方向,部分数据集图像如下图所示。

数据集项目结构如下 分为训练集和测试集

2:项目结构

整体项目结构如下 上面是一些算法和模型比如CRAFT CRNN的定义,下面是测试代码

CRAFT算法实现文本行的检测如图下图所示。首先将完整的文字区域输入CRAFT文字检测网络,得到字符级的文字得分结果热图(Text Score)和字符级文本连接得分热图(Link Score),最后根据连通域得到每个文本行的位置

3:效果展示

开始运行代码

输出运行结果 可以放入不同图片进行测试

三、代码

部分代码如下 需要全部代码和数据集请点赞关注收藏后评论区留言私信~~~

 

"""This script demonstrates how to train the model
on the SynthText90 using multiple GPUs."""
# pylint: disable=invalid-name
import datetime
import argparse
import math
import random
import string
import functools
import itertools
import os
import tarfile
import urllib.request
import numpy as np
import cv2
import imgaug
import tqdm
import tensorflow as tf
import keras_ocr
# pylint: disable=redefined-outer-name
def get_filepaths(data_path, split):
    """Get the list of filepaths for a given split (train, val, or test)."""
    with open(os.path.join(data_path, f'mnt/ramdisk/max/90kDICT32px/annotation_{split}.txt'),
              'r') as text_file:
        filepaths = [
            os.path.join(data_path, 'mnt/ramdisk/max/90kDICT32px',
                         line.split(' ')[0][2:]) for line in text_file.readlines()
        ]
    return filepaths
# pylint: disable=redefined-outer-name
def download_extract_and_process_dataset(data_path):
    """Download and extract the synthtext90 dataset."""
    archive_filepath = os.path.join(data_path, 'mjsynth.tar.gz')
    extraction_directory = os.path.join(data_path, 'mnt')
    if not os.path.isfile(archive_filepath) and not os.path.isdir(extraction_directory):
        print('Downloading the dataset.')
        urllib.request.urlretrieve("https://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz",
                                   archive_filepath)
    if not os.path.isdir(extraction_directory):
        print('Extracting files.')
        with tarfile.open(os.path.join(data_path, 'mjsynth.tar.gz')) as tfile:
            tfile.extractall(data_path)
def get_image_generator(filepaths, augmenter, width, height):
    """Get an image generator for a list of SynthText90 filepaths."""
    filepaths = filepaths.copy()
    for filepath in itertools.cycle(filepaths):
        text = filepath.split(os.sep)[-1].split('_')[1].lower()
        image = cv2.imread(filepath)
        if image is None:
            print(f'An error occurred reading: {filepath}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = keras_ocr.tools.fit(image,
                                    width=width,
                                    height=height,
                                    cval=np.random.randint(low=0, high=255, size=3).astype('uint8'))
        if augmenter is not None:
            image = augmenter.augment_image(image)
        if filepath == filepaths[-1]:
            random.shuffle(filepaths)
        yield image, text
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--model_id',
                        default='recognizer',
                        help='The name to use for saving model checkpoints.')
    parser.add_argument(
        '--data_path',
        default='.',
        help='The path to the directory containing the dataset and where we will put our logs.')
    parser.add_argument(
        '--logs_path',
        default='./logs',
        help=(
            'The path to where logs and checkpoints should be stored. '
            'If a checkpoint matching "model_id" is found, training will resume from that point.'))
    parser.add_argument('--batch_size', default=16, help='The training batch size to use.')
    parser.add_argument('--no-file-verification', dest='verify_files', action='store_false')
    parser.set_defaults(verify_files=True)
    args = parser.parse_args()
    weights_path = os.path.join(args.logs_path, args.model_id + '.h5')
    csv_path = os.path.join(args.logs_path, args.model_id + '.csv')
    download_extract_and_process_dataset(args.data_path)
    with tf.distribute.MirroredStrategy().scope():
        recognizer = keras_ocr.recognition.Recognizer(alphabet=string.digits +
                                                      string.ascii_lowercase,
                                                      height=31,
                                                      width=200,
                                                      stn=False,
                                                      optimizer=tf.keras.optimizers.RMSprop(),
                                                      weights=None)
    if os.path.isfile(weights_path):
        print('Loading saved weights and creating new version.')
        dt_string = datetime.datetime.now().isoformat()
        weights_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.h5')
        csv_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.csv')
        recognizer.model.load_weights(weights_path)
    augmenter = imgaug.augmenters.Sequential([
        imgaug.augmenters.Multiply((0.9, 1.1)),
        imgaug.augmenters.GammaContrast(gamma=(0.5, 3.0)),
        imgaug.augmenters.Invert(0.25, per_channel=0.5)
    ])
    os.makedirs(args.logs_path, exist_ok=True)
    training_filepaths, validation_filepaths = [
        get_filepaths(data_path=args.data_path, split=split) for split in ['train', 'val']
    ]
    if args.verify_files:
        assert all(
            os.path.isfile(filepath) for
            filepath in tqdm.tqdm(training_filepaths + validation_filepaths,
                                  desc='Checking filepaths.')), 'Some files appear to be missing.'
    (training_image_generator, training_steps), (validation_image_generator, validation_steps) = [
        (get_image_generator(
            filepaths=filepaths,
            augmenter=augmenter,
            width=recognizer.model.input_shape[2],
            height=recognizer.model.input_shape[1],
        ), math.ceil(len(filepaths) / args.batch_size))
        for filepaths, augmenter in [(training_filepaths, augmenter), (validation_filepaths, None)]
    ]
    training_generator, validation_generator = [
        tf.data.Dataset.from_generator(
            functools.partial(recognizer.get_batch_generator,
                              image_generator=image_generator,
                              batch_size=args.batch_size),
            output_types=((tf.float32, tf.int64, tf.float64, tf.int64), tf.float64),
            output_shapes=((tf.TensorShape([None, 31, 200, 1]), tf.TensorShape([None, recognizer.training_model.input_shape[1][1]]), 
                            tf.TensorShape([None,
                                            1]), tf.TensorShape([None,
                                                                 1])), tf.TensorShape([None, 1])))
        for image_generator in [training_image_generator, validation_image_generator]
    ]
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                         min_delta=0,
                                         patience=10,
                                         restore_best_weights=False),
        tf.keras.callbacks.ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True),
        tf.keras.callbacks.CSVLogger(csv_path)
    ]
    recognizer.training_model.fit(
        x=training_generator,
        steps_per_epoch=training_steps,
        validation_steps=validation_steps,
        validation_data=validation_generator,
        callbacks=callbacks,
        epochs=1000,
    )
"""This script is what was used to generate the
backgrounds.zip and fonts.zip files.
"""
# pylint: disable=invalid-name,redefined-outer-name
import json
import urllib.request
import urllib.parse
import concurrent
import shutil
import zipfile
import glob
import os
import numpy as np
import tqdm
import cv2
import keras_ocr
if __name__ == '__main__':
    fonts_commit = 'a0726002eab4639ee96056a38cd35f6188011a81'
    fonts_sha256 = 'e447d23d24a5bbe8488200a058cd5b75b2acde525421c2e74dbfb90ceafce7bf'
    fonts_source_zip_filepath = keras_ocr.tools.download_and_verify(
        url=f'https://github.com/google/fonts/archive/{fonts_commit}.zip',
        cache_dir='.',
        sha256=fonts_sha256)
    shutil.rmtree('fonts-raw', ignore_errors=True)
    with zipfile.ZipFile(fonts_source_zip_filepath) as zfile:
        zfile.extractall(path='fonts-raw')
    retained_fonts = []
    sha256s = []
    basenames = []
    # The blacklist includes fonts that, at least for the English alphabet, were found
    # to be illegible (e.g., thin fonts) or render in unexpected ways (e.g., mathematics
    # fonts).
    blacklist = [
        'AlmendraDisplay-Regular.ttf', 'RedactedScript-Bold.ttf', 'RedactedScript-Regular.ttf',
        'Sevillana-Regular.ttf', 'Mplus1p-Thin.ttf', 'Stalemate-Regular.ttf', 'jsMath-cmsy10.ttf',
        'Codystar-Regular.ttf', 'AdventPro-Thin.ttf', 'RoundedMplus1c-Thin.ttf',
        'EncodeSans-Thin.ttf', 'AlegreyaSans-ThinItalic.ttf', 'AlegreyaSans-Thin.ttf',
        'FiraSans-Thin.ttf', 'FiraSans-ThinItalic.ttf', 'WorkSans-Thin.ttf',
        'Tomorrow-ThinItalic.ttf', 'Tomorrow-Thin.ttf', 'Italianno-Regular.ttf',
        'IBMPlexSansCondensed-Thin.ttf', 'IBMPlexSansCondensed-ThinItalic.ttf',
        'Lato-ExtraLightItalic.ttf', 'LibreBarcode128Text-Regular.ttf',
        'LibreBarcode39-Regular.ttf', 'LibreBarcode39ExtendedText-Regular.ttf',
        'EncodeSansExpanded-ExtraLight.ttf', 'Exo-Thin.ttf', 'Exo-ThinItalic.ttf',
        'DrSugiyama-Regular.ttf', 'Taviraj-ThinItalic.ttf', 'SixCaps.ttf', 'IBMPlexSans-Thin.ttf',
        'IBMPlexSans-ThinItalic.ttf', 'AdobeBlank-Regular.ttf',
        'FiraSansExtraCondensed-ThinItalic.ttf', 'HeptaSlab[wght].ttf', 'Karla-Italic[wght].ttf',
        'Karla[wght].ttf', 'RalewayDots-Regular.ttf', 'FiraSansCondensed-ThinItalic.ttf',
        'jsMath-cmex10.ttf', 'LibreBarcode39Text-Regular.ttf', 'LibreBarcode39Extended-Regular.ttf',
        'EricaOne-Regular.ttf', 'ArimaMadurai-Thin.ttf', 'IBMPlexSerif-ExtraLight.ttf',
        'IBMPlexSerif-ExtraLightItalic.ttf', 'IBMPlexSerif-ThinItalic.ttf', 'IBMPlexSerif-Thin.ttf',
        'Exo2-Thin.ttf', 'Exo2-ThinItalic.ttf', 'BungeeOutline-Regular.ttf', 'Redacted-Regular.ttf',
        'JosefinSlab-ThinItalic.ttf', 'GothicA1-Thin.ttf', 'Kanit-ThinItalic.ttf', 'Kanit-Thin.ttf',
        'AlegreyaSansSC-ThinItalic.ttf', 'AlegreyaSansSC-Thin.ttf', 'Chathura-Thin.ttf',
        'Blinker-Thin.ttf', 'Italiana-Regular.ttf', 'Miama-Regular.ttf', 'Grenze-ThinItalic.ttf',
        'LeagueScript-Regular.ttf', 'BigShouldersDisplay-Thin.ttf', 'YanoneKaffeesatz[wght].ttf',
        'BungeeHairline-Regular.ttf', 'JosefinSans-Thin.ttf', 'JosefinSans-ThinItalic.ttf',
        'Monofett.ttf', 'Raleway-ThinItalic.ttf', 'Raleway-Thin.ttf', 'JosefinSansStd-Light.ttf',
        'LibreBarcode128-Regular.ttf'
    ]
    for filepath in tqdm.tqdm(sorted(glob.glob('fonts-raw/**/**/**/*.ttf')),
                              desc='Filtering fonts.'):
        sha256 = keras_ocr.tools.sha256sum(filepath)
        basename = os.path.basename(filepath)
        # We check the sha256 and filenames because some of the fonts
        # in the repository are duplicated (see TRIVIA.md).
        if sha256 in sha256s or basename in basenames or basename in blacklist:
            continue
        sha256s.append(sha256)
        basenames.append(basename)
        retained_fonts.append(filepath)
    retained_font_families = set([filepath.split(os.sep)[-2] for filepath in retained_fonts])
    added = []
    with zipfile.ZipFile(file='fonts.zip', mode='w') as zfile:
        for font_family in tqdm.tqdm(retained_font_families, desc='Saving ZIP file.'):
            # We want to keep all the metadata files plus
            # the retained font files. And we don't want
            # to add the same file twice.
            files = [
                input_filepath for input_filepath in glob.glob(f'fonts-raw/**/**/{font_family}/*')
                if input_filepath not in added and
                (input_filepath in retained_fonts or os.path.splitext(input_filepath)[1] != '.ttf')
            ]
            added.extend(files)
            for input_filepath in files:
                zfile.write(filename=input_filepath,
                            arcname=os.path.join(*input_filepath.split(os.sep)[-2:]))
    print('Finished saving fonts file.')
    # pylint: disable=line-too-long
    url = (
        'https://commons.wikimedia.org/w/api.php?action=query&generator=categorymembers&gcmtype=file&format=json'
        '&gcmtitle=Category:Featured_pictures_on_Wikimedia_Commons&prop=imageinfo&gcmlimit=50&iiprop=url&iiurlwidth=1024'
    )
    gcmcontinue = None
    max_responses = 300
    responses = []
    for responseCount in tqdm.tqdm(range(max_responses)):
        current_url = url
        if gcmcontinue is not None:
            current_url += f'&continue=gcmcontinue||&gcmcontinue={gcmcontinue}'
        with urllib.request.urlopen(url=current_url) as response:
            current = json.loads(response.read())
            responses.append(current)
            gcmcontinue = None if 'continue' not in current else current['continue']['gcmcontinue']
        if gcmcontinue is None:
            break
    print('Finished getting list of images.')
    # We want to avoid animated images as well as icon files.
    image_urls = []
    for response in responses:
        image_urls.extend(
            [page['imageinfo'][0]['thumburl'] for page in response['query']['pages'].values()])
    image_urls = [url for url in image_urls if url.lower().endswith('.jpg')]
    shutil.rmtree('backgrounds', ignore_errors=True)
    os.makedirs('backgrounds')
    assert len(image_urls) == len(set(image_urls)), 'Duplicates found!'
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        futures = [
            executor.submit(keras_ocr.tools.download_and_verify,
                            url=url,
                            cache_dir='./backgrounds',
                            verbose=False) for url in image_urls
        ]
        for _ in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            pass
    for filepath in glob.glob('backgrounds/*.JPG'):
        os.rename(filepath, filepath.lower())
    print('Filtering images by aspect ratio and maximum contiguous contour.')
    image_paths = np.array(sorted(glob.glob('backgrounds/*.jpg')))
    def compute_metrics(filepath):
        image = keras_ocr.tools.read(filepath)
        aspect_ratio = image.shape[0] / image.shape[1]
        contour, _ = keras_ocr.tools.get_maximum_uniform_contour(image, fontsize=40)
        area = cv2.contourArea(contour) if contour is not None else 0
        return aspect_ratio, area
    metrics = np.array([compute_metrics(filepath) for filepath in tqdm.tqdm(image_paths)])
    filtered_paths = image_paths[(metrics[:, 0] < 3 / 2) & (metrics[:, 0] > 2 / 3) &
                                 (metrics[:, 1] > 1e6)]
    detector = keras_ocr.detection.Detector()
    paths_with_text = [
        filepath for filepath in tqdm.tqdm(filtered_paths) if len(
            detector.detect(
                images=[keras_ocr.tools.read_and_fit(filepath, width=640, height=640)])[0]) > 0
    ]
    filtered_paths = np.array([path for path in filtered_paths if path not in paths_with_text])
    filtered_basenames = list(map(os.path.basename, filtered_paths))
    basename_to_url = {
        os.path.basename(urllib.parse.urlparse(url).path).lower(): url
        for url in image_urls
    }
    filtered_urls = [basename_to_url[basename.lower()] for basename in filtered_basenames]
    assert len(filtered_urls) == len(filtered_paths)
    removed_paths = [filepath for filepath in image_paths if filepath not in filtered_paths]
    for filepath in removed_paths:
        os.remove(filepath)
    with open('backgrounds/urls.txt', 'w') as f:
        f.write('\n'.join(filtered_urls))
    with zipfile.ZipFile(file='backgrounds.zip', mode='w') as zfile:
        for filepath in tqdm.tqdm(filtered_paths.tolist() + ['backgrounds/urls.txt'],
                                  desc='Saving ZIP file.'):
            zfile.write(filename=filepath, arcname=os.path.basename(filepath.lower()))

创作不易 觉得有帮助请点赞关注收藏~~~

相关文章
|
4月前
|
机器学习/深度学习 监控 算法
基于计算机视觉(opencv)的运动计数(运动辅助)系统-源码+注释+报告
基于计算机视觉(opencv)的运动计数(运动辅助)系统-源码+注释+报告
113 3
|
4月前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
102 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
【Tensorflow+Keras】keras实现条件生成对抗网络DCGAN--以Minis和fashion_mnist数据集为例
如何使用TensorFlow和Keras实现条件生成对抗网络(CGAN)并以MNIST和Fashion MNIST数据集为例进行演示。
79 3
|
6月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
149 0
|
6月前
|
机器学习/深度学习 TensorFlow 数据处理
分布式训练在TensorFlow中的全面应用指南:掌握多机多卡配置与实践技巧,让大规模数据集训练变得轻而易举,大幅提升模型训练效率与性能
【8月更文挑战第31天】本文详细介绍了如何在Tensorflow中实现多机多卡的分布式训练,涵盖环境配置、模型定义、数据处理及训练执行等关键环节。通过具体示例代码,展示了使用`MultiWorkerMirroredStrategy`进行分布式训练的过程,帮助读者更好地应对大规模数据集与复杂模型带来的挑战,提升训练效率。
161 0
|
7月前
|
文字识别 API 开发工具
印刷文字识别使用问题之如何提高OCR的识别率
印刷文字识别产品,通常称为OCR(Optical Character Recognition)技术,是一种将图像中的印刷或手写文字转换为机器编码文本的过程。这项技术广泛应用于多个行业和场景中,显著提升文档处理、信息提取和数据录入的效率。以下是印刷文字识别产品的一些典型使用合集。
|
7月前
|
文字识别 前端开发 API
印刷文字识别操作报错合集之通过HTTPS连接到OCR服务的API时报错,该如何处理
在使用印刷文字识别(OCR)服务时,可能会遇到各种错误。例如:1.Java异常、2.配置文件错误、3.服务未开通、4.HTTP错误码、5.权限问题(403 Forbidden)、6.调用拒绝(Refused)、7.智能纠错问题、8.图片质量或格式问题,以下是一些常见错误及其可能的原因和解决方案的合集。
|
6月前
|
机器学习/深度学习 文字识别 算法
百度飞桨(PaddlePaddle) - PaddleHub OCR 文字识别简单使用
百度飞桨(PaddlePaddle) - PaddleHub OCR 文字识别简单使用
431 0
|
7月前
|
机器学习/深度学习 人工智能 文字识别
文本,文字扫描01,OCR文本识别技术展示,一个安卓App,一个简单的设计,文字识别可以应用于人工智能,机器学习,车牌识别,身份证识别,银行卡识别,PaddleOCR+SpringBoot+Andr
文本,文字扫描01,OCR文本识别技术展示,一个安卓App,一个简单的设计,文字识别可以应用于人工智能,机器学习,车牌识别,身份证识别,银行卡识别,PaddleOCR+SpringBoot+Andr
|
7月前
|
JSON 文字识别 数据格式
文本,文字识别,Flask实现内部接口开发,OCR外部接口的开发,如何开发一个识别接口,通过post调用,参数是图片的路径,内部调用,直接传图片路径就行
文本,文字识别,Flask实现内部接口开发,OCR外部接口的开发,如何开发一个识别接口,通过post调用,参数是图片的路径,内部调用,直接传图片路径就行

热门文章

最新文章