【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

本文涉及的产品
自定义KV模板,自定义KV模板 500次/账号
票据凭证识别,票据凭证识别 200次/月
个人证照识别,个人证照识别 200次/月
简介: 【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、OCR文字识别简介

利用计算机自动识别字符的技术,是模式识别应用的一个重要领域。人们在生产和生活中,要处理大量的文字、报表和文本。为了减轻人们的劳动,提高处理效率,从上世纪50年代起就开始探讨文字识别方法,并研制出光学字符识别器。

OCR(Optical Character Recognition)图像文字识别是人工智能的重要分支,赋予计算机人眼的功能,使其可以看图识字,图像文字识别系统流程一般分为图像采集、文字检测、文字识别以及结果输出四部分。

二、OCR文字识别项目实战

1:数据集简介

MSRA-TD500该数据集共包含500 张自然场景图像,其分辨率在1296 ´ 864至920 ´ 1280 之间,涵盖了室内商场、标识牌、室外街道、广告牌等大多数场,文本包含中文和英文,有着不同的字体、大小和倾斜方向,部分数据集图像如下图所示。

数据集项目结构如下 分为训练集和测试集

2:项目结构

整体项目结构如下 上面是一些算法和模型比如CRAFT CRNN的定义,下面是测试代码

CRAFT算法实现文本行的检测如图下图所示。首先将完整的文字区域输入CRAFT文字检测网络,得到字符级的文字得分结果热图(Text Score)和字符级文本连接得分热图(Link Score),最后根据连通域得到每个文本行的位置

3:效果展示

开始运行代码

输出运行结果 可以放入不同图片进行测试

三、代码

部分代码如下 需要全部代码和数据集请点赞关注收藏后评论区留言私信~~~

 

"""This script demonstrates how to train the model
on the SynthText90 using multiple GPUs."""
# pylint: disable=invalid-name
import datetime
import argparse
import math
import random
import string
import functools
import itertools
import os
import tarfile
import urllib.request
import numpy as np
import cv2
import imgaug
import tqdm
import tensorflow as tf
import keras_ocr
# pylint: disable=redefined-outer-name
def get_filepaths(data_path, split):
    """Get the list of filepaths for a given split (train, val, or test)."""
    with open(os.path.join(data_path, f'mnt/ramdisk/max/90kDICT32px/annotation_{split}.txt'),
              'r') as text_file:
        filepaths = [
            os.path.join(data_path, 'mnt/ramdisk/max/90kDICT32px',
                         line.split(' ')[0][2:]) for line in text_file.readlines()
        ]
    return filepaths
# pylint: disable=redefined-outer-name
def download_extract_and_process_dataset(data_path):
    """Download and extract the synthtext90 dataset."""
    archive_filepath = os.path.join(data_path, 'mjsynth.tar.gz')
    extraction_directory = os.path.join(data_path, 'mnt')
    if not os.path.isfile(archive_filepath) and not os.path.isdir(extraction_directory):
        print('Downloading the dataset.')
        urllib.request.urlretrieve("https://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz",
                                   archive_filepath)
    if not os.path.isdir(extraction_directory):
        print('Extracting files.')
        with tarfile.open(os.path.join(data_path, 'mjsynth.tar.gz')) as tfile:
            tfile.extractall(data_path)
def get_image_generator(filepaths, augmenter, width, height):
    """Get an image generator for a list of SynthText90 filepaths."""
    filepaths = filepaths.copy()
    for filepath in itertools.cycle(filepaths):
        text = filepath.split(os.sep)[-1].split('_')[1].lower()
        image = cv2.imread(filepath)
        if image is None:
            print(f'An error occurred reading: {filepath}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = keras_ocr.tools.fit(image,
                                    width=width,
                                    height=height,
                                    cval=np.random.randint(low=0, high=255, size=3).astype('uint8'))
        if augmenter is not None:
            image = augmenter.augment_image(image)
        if filepath == filepaths[-1]:
            random.shuffle(filepaths)
        yield image, text
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--model_id',
                        default='recognizer',
                        help='The name to use for saving model checkpoints.')
    parser.add_argument(
        '--data_path',
        default='.',
        help='The path to the directory containing the dataset and where we will put our logs.')
    parser.add_argument(
        '--logs_path',
        default='./logs',
        help=(
            'The path to where logs and checkpoints should be stored. '
            'If a checkpoint matching "model_id" is found, training will resume from that point.'))
    parser.add_argument('--batch_size', default=16, help='The training batch size to use.')
    parser.add_argument('--no-file-verification', dest='verify_files', action='store_false')
    parser.set_defaults(verify_files=True)
    args = parser.parse_args()
    weights_path = os.path.join(args.logs_path, args.model_id + '.h5')
    csv_path = os.path.join(args.logs_path, args.model_id + '.csv')
    download_extract_and_process_dataset(args.data_path)
    with tf.distribute.MirroredStrategy().scope():
        recognizer = keras_ocr.recognition.Recognizer(alphabet=string.digits +
                                                      string.ascii_lowercase,
                                                      height=31,
                                                      width=200,
                                                      stn=False,
                                                      optimizer=tf.keras.optimizers.RMSprop(),
                                                      weights=None)
    if os.path.isfile(weights_path):
        print('Loading saved weights and creating new version.')
        dt_string = datetime.datetime.now().isoformat()
        weights_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.h5')
        csv_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.csv')
        recognizer.model.load_weights(weights_path)
    augmenter = imgaug.augmenters.Sequential([
        imgaug.augmenters.Multiply((0.9, 1.1)),
        imgaug.augmenters.GammaContrast(gamma=(0.5, 3.0)),
        imgaug.augmenters.Invert(0.25, per_channel=0.5)
    ])
    os.makedirs(args.logs_path, exist_ok=True)
    training_filepaths, validation_filepaths = [
        get_filepaths(data_path=args.data_path, split=split) for split in ['train', 'val']
    ]
    if args.verify_files:
        assert all(
            os.path.isfile(filepath) for
            filepath in tqdm.tqdm(training_filepaths + validation_filepaths,
                                  desc='Checking filepaths.')), 'Some files appear to be missing.'
    (training_image_generator, training_steps), (validation_image_generator, validation_steps) = [
        (get_image_generator(
            filepaths=filepaths,
            augmenter=augmenter,
            width=recognizer.model.input_shape[2],
            height=recognizer.model.input_shape[1],
        ), math.ceil(len(filepaths) / args.batch_size))
        for filepaths, augmenter in [(training_filepaths, augmenter), (validation_filepaths, None)]
    ]
    training_generator, validation_generator = [
        tf.data.Dataset.from_generator(
            functools.partial(recognizer.get_batch_generator,
                              image_generator=image_generator,
                              batch_size=args.batch_size),
            output_types=((tf.float32, tf.int64, tf.float64, tf.int64), tf.float64),
            output_shapes=((tf.TensorShape([None, 31, 200, 1]), tf.TensorShape([None, recognizer.training_model.input_shape[1][1]]), 
                            tf.TensorShape([None,
                                            1]), tf.TensorShape([None,
                                                                 1])), tf.TensorShape([None, 1])))
        for image_generator in [training_image_generator, validation_image_generator]
    ]
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                         min_delta=0,
                                         patience=10,
                                         restore_best_weights=False),
        tf.keras.callbacks.ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True),
        tf.keras.callbacks.CSVLogger(csv_path)
    ]
    recognizer.training_model.fit(
        x=training_generator,
        steps_per_epoch=training_steps,
        validation_steps=validation_steps,
        validation_data=validation_generator,
        callbacks=callbacks,
        epochs=1000,
    )
"""This script is what was used to generate the
backgrounds.zip and fonts.zip files.
"""
# pylint: disable=invalid-name,redefined-outer-name
import json
import urllib.request
import urllib.parse
import concurrent
import shutil
import zipfile
import glob
import os
import numpy as np
import tqdm
import cv2
import keras_ocr
if __name__ == '__main__':
    fonts_commit = 'a0726002eab4639ee96056a38cd35f6188011a81'
    fonts_sha256 = 'e447d23d24a5bbe8488200a058cd5b75b2acde525421c2e74dbfb90ceafce7bf'
    fonts_source_zip_filepath = keras_ocr.tools.download_and_verify(
        url=f'https://github.com/google/fonts/archive/{fonts_commit}.zip',
        cache_dir='.',
        sha256=fonts_sha256)
    shutil.rmtree('fonts-raw', ignore_errors=True)
    with zipfile.ZipFile(fonts_source_zip_filepath) as zfile:
        zfile.extractall(path='fonts-raw')
    retained_fonts = []
    sha256s = []
    basenames = []
    # The blacklist includes fonts that, at least for the English alphabet, were found
    # to be illegible (e.g., thin fonts) or render in unexpected ways (e.g., mathematics
    # fonts).
    blacklist = [
        'AlmendraDisplay-Regular.ttf', 'RedactedScript-Bold.ttf', 'RedactedScript-Regular.ttf',
        'Sevillana-Regular.ttf', 'Mplus1p-Thin.ttf', 'Stalemate-Regular.ttf', 'jsMath-cmsy10.ttf',
        'Codystar-Regular.ttf', 'AdventPro-Thin.ttf', 'RoundedMplus1c-Thin.ttf',
        'EncodeSans-Thin.ttf', 'AlegreyaSans-ThinItalic.ttf', 'AlegreyaSans-Thin.ttf',
        'FiraSans-Thin.ttf', 'FiraSans-ThinItalic.ttf', 'WorkSans-Thin.ttf',
        'Tomorrow-ThinItalic.ttf', 'Tomorrow-Thin.ttf', 'Italianno-Regular.ttf',
        'IBMPlexSansCondensed-Thin.ttf', 'IBMPlexSansCondensed-ThinItalic.ttf',
        'Lato-ExtraLightItalic.ttf', 'LibreBarcode128Text-Regular.ttf',
        'LibreBarcode39-Regular.ttf', 'LibreBarcode39ExtendedText-Regular.ttf',
        'EncodeSansExpanded-ExtraLight.ttf', 'Exo-Thin.ttf', 'Exo-ThinItalic.ttf',
        'DrSugiyama-Regular.ttf', 'Taviraj-ThinItalic.ttf', 'SixCaps.ttf', 'IBMPlexSans-Thin.ttf',
        'IBMPlexSans-ThinItalic.ttf', 'AdobeBlank-Regular.ttf',
        'FiraSansExtraCondensed-ThinItalic.ttf', 'HeptaSlab[wght].ttf', 'Karla-Italic[wght].ttf',
        'Karla[wght].ttf', 'RalewayDots-Regular.ttf', 'FiraSansCondensed-ThinItalic.ttf',
        'jsMath-cmex10.ttf', 'LibreBarcode39Text-Regular.ttf', 'LibreBarcode39Extended-Regular.ttf',
        'EricaOne-Regular.ttf', 'ArimaMadurai-Thin.ttf', 'IBMPlexSerif-ExtraLight.ttf',
        'IBMPlexSerif-ExtraLightItalic.ttf', 'IBMPlexSerif-ThinItalic.ttf', 'IBMPlexSerif-Thin.ttf',
        'Exo2-Thin.ttf', 'Exo2-ThinItalic.ttf', 'BungeeOutline-Regular.ttf', 'Redacted-Regular.ttf',
        'JosefinSlab-ThinItalic.ttf', 'GothicA1-Thin.ttf', 'Kanit-ThinItalic.ttf', 'Kanit-Thin.ttf',
        'AlegreyaSansSC-ThinItalic.ttf', 'AlegreyaSansSC-Thin.ttf', 'Chathura-Thin.ttf',
        'Blinker-Thin.ttf', 'Italiana-Regular.ttf', 'Miama-Regular.ttf', 'Grenze-ThinItalic.ttf',
        'LeagueScript-Regular.ttf', 'BigShouldersDisplay-Thin.ttf', 'YanoneKaffeesatz[wght].ttf',
        'BungeeHairline-Regular.ttf', 'JosefinSans-Thin.ttf', 'JosefinSans-ThinItalic.ttf',
        'Monofett.ttf', 'Raleway-ThinItalic.ttf', 'Raleway-Thin.ttf', 'JosefinSansStd-Light.ttf',
        'LibreBarcode128-Regular.ttf'
    ]
    for filepath in tqdm.tqdm(sorted(glob.glob('fonts-raw/**/**/**/*.ttf')),
                              desc='Filtering fonts.'):
        sha256 = keras_ocr.tools.sha256sum(filepath)
        basename = os.path.basename(filepath)
        # We check the sha256 and filenames because some of the fonts
        # in the repository are duplicated (see TRIVIA.md).
        if sha256 in sha256s or basename in basenames or basename in blacklist:
            continue
        sha256s.append(sha256)
        basenames.append(basename)
        retained_fonts.append(filepath)
    retained_font_families = set([filepath.split(os.sep)[-2] for filepath in retained_fonts])
    added = []
    with zipfile.ZipFile(file='fonts.zip', mode='w') as zfile:
        for font_family in tqdm.tqdm(retained_font_families, desc='Saving ZIP file.'):
            # We want to keep all the metadata files plus
            # the retained font files. And we don't want
            # to add the same file twice.
            files = [
                input_filepath for input_filepath in glob.glob(f'fonts-raw/**/**/{font_family}/*')
                if input_filepath not in added and
                (input_filepath in retained_fonts or os.path.splitext(input_filepath)[1] != '.ttf')
            ]
            added.extend(files)
            for input_filepath in files:
                zfile.write(filename=input_filepath,
                            arcname=os.path.join(*input_filepath.split(os.sep)[-2:]))
    print('Finished saving fonts file.')
    # pylint: disable=line-too-long
    url = (
        'https://commons.wikimedia.org/w/api.php?action=query&generator=categorymembers&gcmtype=file&format=json'
        '&gcmtitle=Category:Featured_pictures_on_Wikimedia_Commons&prop=imageinfo&gcmlimit=50&iiprop=url&iiurlwidth=1024'
    )
    gcmcontinue = None
    max_responses = 300
    responses = []
    for responseCount in tqdm.tqdm(range(max_responses)):
        current_url = url
        if gcmcontinue is not None:
            current_url += f'&continue=gcmcontinue||&gcmcontinue={gcmcontinue}'
        with urllib.request.urlopen(url=current_url) as response:
            current = json.loads(response.read())
            responses.append(current)
            gcmcontinue = None if 'continue' not in current else current['continue']['gcmcontinue']
        if gcmcontinue is None:
            break
    print('Finished getting list of images.')
    # We want to avoid animated images as well as icon files.
    image_urls = []
    for response in responses:
        image_urls.extend(
            [page['imageinfo'][0]['thumburl'] for page in response['query']['pages'].values()])
    image_urls = [url for url in image_urls if url.lower().endswith('.jpg')]
    shutil.rmtree('backgrounds', ignore_errors=True)
    os.makedirs('backgrounds')
    assert len(image_urls) == len(set(image_urls)), 'Duplicates found!'
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        futures = [
            executor.submit(keras_ocr.tools.download_and_verify,
                            url=url,
                            cache_dir='./backgrounds',
                            verbose=False) for url in image_urls
        ]
        for _ in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            pass
    for filepath in glob.glob('backgrounds/*.JPG'):
        os.rename(filepath, filepath.lower())
    print('Filtering images by aspect ratio and maximum contiguous contour.')
    image_paths = np.array(sorted(glob.glob('backgrounds/*.jpg')))
    def compute_metrics(filepath):
        image = keras_ocr.tools.read(filepath)
        aspect_ratio = image.shape[0] / image.shape[1]
        contour, _ = keras_ocr.tools.get_maximum_uniform_contour(image, fontsize=40)
        area = cv2.contourArea(contour) if contour is not None else 0
        return aspect_ratio, area
    metrics = np.array([compute_metrics(filepath) for filepath in tqdm.tqdm(image_paths)])
    filtered_paths = image_paths[(metrics[:, 0] < 3 / 2) & (metrics[:, 0] > 2 / 3) &
                                 (metrics[:, 1] > 1e6)]
    detector = keras_ocr.detection.Detector()
    paths_with_text = [
        filepath for filepath in tqdm.tqdm(filtered_paths) if len(
            detector.detect(
                images=[keras_ocr.tools.read_and_fit(filepath, width=640, height=640)])[0]) > 0
    ]
    filtered_paths = np.array([path for path in filtered_paths if path not in paths_with_text])
    filtered_basenames = list(map(os.path.basename, filtered_paths))
    basename_to_url = {
        os.path.basename(urllib.parse.urlparse(url).path).lower(): url
        for url in image_urls
    }
    filtered_urls = [basename_to_url[basename.lower()] for basename in filtered_basenames]
    assert len(filtered_urls) == len(filtered_paths)
    removed_paths = [filepath for filepath in image_paths if filepath not in filtered_paths]
    for filepath in removed_paths:
        os.remove(filepath)
    with open('backgrounds/urls.txt', 'w') as f:
        f.write('\n'.join(filtered_urls))
    with zipfile.ZipFile(file='backgrounds.zip', mode='w') as zfile:
        for filepath in tqdm.tqdm(filtered_paths.tolist() + ['backgrounds/urls.txt'],
                                  desc='Saving ZIP file.'):
            zfile.write(filename=filepath, arcname=os.path.basename(filepath.lower()))

创作不易 觉得有帮助请点赞关注收藏~~~

相关文章
|
18天前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
24 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
2月前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
54 0
|
3月前
|
机器学习/深度学习 数据采集 TensorFlow
使用TensorFlow进行模型训练:一次实战探索
【8月更文挑战第22天】本文通过实战案例详解使用TensorFlow进行模型训练的过程。首先确保已安装TensorFlow,接着预处理数据,包括加载、增强及归一化。然后利用`tf.keras`构建卷积神经网络模型,并配置训练参数。最后通过回调机制训练模型,并对模型性能进行评估。此流程为机器学习项目提供了一个实用指南。
|
3月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
74 0
|
3月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
76 0
|
3月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
40 0
|
3月前
|
机器学习/深度学习 数据采集 TensorFlow
从零到精通:TensorFlow与卷积神经网络(CNN)助你成为图像识别高手的终极指南——深入浅出教你搭建首个猫狗分类器,附带实战代码与训练技巧揭秘
【8月更文挑战第31天】本文通过杂文形式介绍了如何利用 TensorFlow 和卷积神经网络(CNN)构建图像识别系统,详细演示了从数据准备、模型构建到训练与评估的全过程。通过具体示例代码,展示了使用 Keras API 训练猫狗分类器的步骤,旨在帮助读者掌握图像识别的核心技术。此外,还探讨了图像识别在物体检测、语义分割等领域的广泛应用前景。
20 0
|
3月前
|
机器学习/深度学习 TensorFlow 数据处理
分布式训练在TensorFlow中的全面应用指南:掌握多机多卡配置与实践技巧,让大规模数据集训练变得轻而易举,大幅提升模型训练效率与性能
【8月更文挑战第31天】本文详细介绍了如何在Tensorflow中实现多机多卡的分布式训练,涵盖环境配置、模型定义、数据处理及训练执行等关键环节。通过具体示例代码,展示了使用`MultiWorkerMirroredStrategy`进行分布式训练的过程,帮助读者更好地应对大规模数据集与复杂模型带来的挑战,提升训练效率。
59 0
|
6月前
|
机器学习/深度学习 计算机视觉
AIGC核心技术——计算机视觉(CV)预训练大模型
【1月更文挑战第13天】AIGC核心技术——计算机视觉(CV)预训练大模型
587 3
AIGC核心技术——计算机视觉(CV)预训练大模型
|
2月前
|
人工智能 测试技术 API
AI计算机视觉笔记二十 九:yolov10竹签模型,自动数竹签
本文介绍了如何在AutoDL平台上搭建YOLOv10环境并进行竹签检测与计数。首先从官网下载YOLOv10源码并创建虚拟环境,安装依赖库。接着通过官方模型测试环境是否正常工作。然后下载自定义数据集并配置`mycoco128.yaml`文件,使用`yolo detect train`命令或Python代码进行训练。最后,通过命令行或API调用测试训练结果,并展示竹签计数功能。如需转载,请注明原文出处。