【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

本文涉及的产品
票据凭证识别,票据凭证识别 200次/月
个人证照识别,个人证照识别 200次/月
企业资质识别,企业资质识别 200次/月
简介: 【Keras+计算机视觉+Tensorflow】OCR文字识别实战(附源码和数据集 超详细必看)

需要源码和数据集请点赞关注收藏后评论区留言私信~~~

一、OCR文字识别简介

利用计算机自动识别字符的技术,是模式识别应用的一个重要领域。人们在生产和生活中,要处理大量的文字、报表和文本。为了减轻人们的劳动,提高处理效率,从上世纪50年代起就开始探讨文字识别方法,并研制出光学字符识别器。

OCR(Optical Character Recognition)图像文字识别是人工智能的重要分支,赋予计算机人眼的功能,使其可以看图识字,图像文字识别系统流程一般分为图像采集、文字检测、文字识别以及结果输出四部分。

二、OCR文字识别项目实战

1:数据集简介

MSRA-TD500该数据集共包含500 张自然场景图像,其分辨率在1296 ´ 864至920 ´ 1280 之间,涵盖了室内商场、标识牌、室外街道、广告牌等大多数场,文本包含中文和英文,有着不同的字体、大小和倾斜方向,部分数据集图像如下图所示。

数据集项目结构如下 分为训练集和测试集

2:项目结构

整体项目结构如下 上面是一些算法和模型比如CRAFT CRNN的定义,下面是测试代码

CRAFT算法实现文本行的检测如图下图所示。首先将完整的文字区域输入CRAFT文字检测网络,得到字符级的文字得分结果热图(Text Score)和字符级文本连接得分热图(Link Score),最后根据连通域得到每个文本行的位置

3:效果展示

开始运行代码

输出运行结果 可以放入不同图片进行测试

三、代码

部分代码如下 需要全部代码和数据集请点赞关注收藏后评论区留言私信~~~

 

"""This script demonstrates how to train the model
on the SynthText90 using multiple GPUs."""
# pylint: disable=invalid-name
import datetime
import argparse
import math
import random
import string
import functools
import itertools
import os
import tarfile
import urllib.request
import numpy as np
import cv2
import imgaug
import tqdm
import tensorflow as tf
import keras_ocr
# pylint: disable=redefined-outer-name
def get_filepaths(data_path, split):
    """Get the list of filepaths for a given split (train, val, or test)."""
    with open(os.path.join(data_path, f'mnt/ramdisk/max/90kDICT32px/annotation_{split}.txt'),
              'r') as text_file:
        filepaths = [
            os.path.join(data_path, 'mnt/ramdisk/max/90kDICT32px',
                         line.split(' ')[0][2:]) for line in text_file.readlines()
        ]
    return filepaths
# pylint: disable=redefined-outer-name
def download_extract_and_process_dataset(data_path):
    """Download and extract the synthtext90 dataset."""
    archive_filepath = os.path.join(data_path, 'mjsynth.tar.gz')
    extraction_directory = os.path.join(data_path, 'mnt')
    if not os.path.isfile(archive_filepath) and not os.path.isdir(extraction_directory):
        print('Downloading the dataset.')
        urllib.request.urlretrieve("https://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz",
                                   archive_filepath)
    if not os.path.isdir(extraction_directory):
        print('Extracting files.')
        with tarfile.open(os.path.join(data_path, 'mjsynth.tar.gz')) as tfile:
            tfile.extractall(data_path)
def get_image_generator(filepaths, augmenter, width, height):
    """Get an image generator for a list of SynthText90 filepaths."""
    filepaths = filepaths.copy()
    for filepath in itertools.cycle(filepaths):
        text = filepath.split(os.sep)[-1].split('_')[1].lower()
        image = cv2.imread(filepath)
        if image is None:
            print(f'An error occurred reading: {filepath}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = keras_ocr.tools.fit(image,
                                    width=width,
                                    height=height,
                                    cval=np.random.randint(low=0, high=255, size=3).astype('uint8'))
        if augmenter is not None:
            image = augmenter.augment_image(image)
        if filepath == filepaths[-1]:
            random.shuffle(filepaths)
        yield image, text
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--model_id',
                        default='recognizer',
                        help='The name to use for saving model checkpoints.')
    parser.add_argument(
        '--data_path',
        default='.',
        help='The path to the directory containing the dataset and where we will put our logs.')
    parser.add_argument(
        '--logs_path',
        default='./logs',
        help=(
            'The path to where logs and checkpoints should be stored. '
            'If a checkpoint matching "model_id" is found, training will resume from that point.'))
    parser.add_argument('--batch_size', default=16, help='The training batch size to use.')
    parser.add_argument('--no-file-verification', dest='verify_files', action='store_false')
    parser.set_defaults(verify_files=True)
    args = parser.parse_args()
    weights_path = os.path.join(args.logs_path, args.model_id + '.h5')
    csv_path = os.path.join(args.logs_path, args.model_id + '.csv')
    download_extract_and_process_dataset(args.data_path)
    with tf.distribute.MirroredStrategy().scope():
        recognizer = keras_ocr.recognition.Recognizer(alphabet=string.digits +
                                                      string.ascii_lowercase,
                                                      height=31,
                                                      width=200,
                                                      stn=False,
                                                      optimizer=tf.keras.optimizers.RMSprop(),
                                                      weights=None)
    if os.path.isfile(weights_path):
        print('Loading saved weights and creating new version.')
        dt_string = datetime.datetime.now().isoformat()
        weights_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.h5')
        csv_path = os.path.join(args.logs_path, args.model_id + '_' + dt_string + '.csv')
        recognizer.model.load_weights(weights_path)
    augmenter = imgaug.augmenters.Sequential([
        imgaug.augmenters.Multiply((0.9, 1.1)),
        imgaug.augmenters.GammaContrast(gamma=(0.5, 3.0)),
        imgaug.augmenters.Invert(0.25, per_channel=0.5)
    ])
    os.makedirs(args.logs_path, exist_ok=True)
    training_filepaths, validation_filepaths = [
        get_filepaths(data_path=args.data_path, split=split) for split in ['train', 'val']
    ]
    if args.verify_files:
        assert all(
            os.path.isfile(filepath) for
            filepath in tqdm.tqdm(training_filepaths + validation_filepaths,
                                  desc='Checking filepaths.')), 'Some files appear to be missing.'
    (training_image_generator, training_steps), (validation_image_generator, validation_steps) = [
        (get_image_generator(
            filepaths=filepaths,
            augmenter=augmenter,
            width=recognizer.model.input_shape[2],
            height=recognizer.model.input_shape[1],
        ), math.ceil(len(filepaths) / args.batch_size))
        for filepaths, augmenter in [(training_filepaths, augmenter), (validation_filepaths, None)]
    ]
    training_generator, validation_generator = [
        tf.data.Dataset.from_generator(
            functools.partial(recognizer.get_batch_generator,
                              image_generator=image_generator,
                              batch_size=args.batch_size),
            output_types=((tf.float32, tf.int64, tf.float64, tf.int64), tf.float64),
            output_shapes=((tf.TensorShape([None, 31, 200, 1]), tf.TensorShape([None, recognizer.training_model.input_shape[1][1]]), 
                            tf.TensorShape([None,
                                            1]), tf.TensorShape([None,
                                                                 1])), tf.TensorShape([None, 1])))
        for image_generator in [training_image_generator, validation_image_generator]
    ]
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                         min_delta=0,
                                         patience=10,
                                         restore_best_weights=False),
        tf.keras.callbacks.ModelCheckpoint(weights_path, monitor='val_loss', save_best_only=True),
        tf.keras.callbacks.CSVLogger(csv_path)
    ]
    recognizer.training_model.fit(
        x=training_generator,
        steps_per_epoch=training_steps,
        validation_steps=validation_steps,
        validation_data=validation_generator,
        callbacks=callbacks,
        epochs=1000,
    )
"""This script is what was used to generate the
backgrounds.zip and fonts.zip files.
"""
# pylint: disable=invalid-name,redefined-outer-name
import json
import urllib.request
import urllib.parse
import concurrent
import shutil
import zipfile
import glob
import os
import numpy as np
import tqdm
import cv2
import keras_ocr
if __name__ == '__main__':
    fonts_commit = 'a0726002eab4639ee96056a38cd35f6188011a81'
    fonts_sha256 = 'e447d23d24a5bbe8488200a058cd5b75b2acde525421c2e74dbfb90ceafce7bf'
    fonts_source_zip_filepath = keras_ocr.tools.download_and_verify(
        url=f'https://github.com/google/fonts/archive/{fonts_commit}.zip',
        cache_dir='.',
        sha256=fonts_sha256)
    shutil.rmtree('fonts-raw', ignore_errors=True)
    with zipfile.ZipFile(fonts_source_zip_filepath) as zfile:
        zfile.extractall(path='fonts-raw')
    retained_fonts = []
    sha256s = []
    basenames = []
    # The blacklist includes fonts that, at least for the English alphabet, were found
    # to be illegible (e.g., thin fonts) or render in unexpected ways (e.g., mathematics
    # fonts).
    blacklist = [
        'AlmendraDisplay-Regular.ttf', 'RedactedScript-Bold.ttf', 'RedactedScript-Regular.ttf',
        'Sevillana-Regular.ttf', 'Mplus1p-Thin.ttf', 'Stalemate-Regular.ttf', 'jsMath-cmsy10.ttf',
        'Codystar-Regular.ttf', 'AdventPro-Thin.ttf', 'RoundedMplus1c-Thin.ttf',
        'EncodeSans-Thin.ttf', 'AlegreyaSans-ThinItalic.ttf', 'AlegreyaSans-Thin.ttf',
        'FiraSans-Thin.ttf', 'FiraSans-ThinItalic.ttf', 'WorkSans-Thin.ttf',
        'Tomorrow-ThinItalic.ttf', 'Tomorrow-Thin.ttf', 'Italianno-Regular.ttf',
        'IBMPlexSansCondensed-Thin.ttf', 'IBMPlexSansCondensed-ThinItalic.ttf',
        'Lato-ExtraLightItalic.ttf', 'LibreBarcode128Text-Regular.ttf',
        'LibreBarcode39-Regular.ttf', 'LibreBarcode39ExtendedText-Regular.ttf',
        'EncodeSansExpanded-ExtraLight.ttf', 'Exo-Thin.ttf', 'Exo-ThinItalic.ttf',
        'DrSugiyama-Regular.ttf', 'Taviraj-ThinItalic.ttf', 'SixCaps.ttf', 'IBMPlexSans-Thin.ttf',
        'IBMPlexSans-ThinItalic.ttf', 'AdobeBlank-Regular.ttf',
        'FiraSansExtraCondensed-ThinItalic.ttf', 'HeptaSlab[wght].ttf', 'Karla-Italic[wght].ttf',
        'Karla[wght].ttf', 'RalewayDots-Regular.ttf', 'FiraSansCondensed-ThinItalic.ttf',
        'jsMath-cmex10.ttf', 'LibreBarcode39Text-Regular.ttf', 'LibreBarcode39Extended-Regular.ttf',
        'EricaOne-Regular.ttf', 'ArimaMadurai-Thin.ttf', 'IBMPlexSerif-ExtraLight.ttf',
        'IBMPlexSerif-ExtraLightItalic.ttf', 'IBMPlexSerif-ThinItalic.ttf', 'IBMPlexSerif-Thin.ttf',
        'Exo2-Thin.ttf', 'Exo2-ThinItalic.ttf', 'BungeeOutline-Regular.ttf', 'Redacted-Regular.ttf',
        'JosefinSlab-ThinItalic.ttf', 'GothicA1-Thin.ttf', 'Kanit-ThinItalic.ttf', 'Kanit-Thin.ttf',
        'AlegreyaSansSC-ThinItalic.ttf', 'AlegreyaSansSC-Thin.ttf', 'Chathura-Thin.ttf',
        'Blinker-Thin.ttf', 'Italiana-Regular.ttf', 'Miama-Regular.ttf', 'Grenze-ThinItalic.ttf',
        'LeagueScript-Regular.ttf', 'BigShouldersDisplay-Thin.ttf', 'YanoneKaffeesatz[wght].ttf',
        'BungeeHairline-Regular.ttf', 'JosefinSans-Thin.ttf', 'JosefinSans-ThinItalic.ttf',
        'Monofett.ttf', 'Raleway-ThinItalic.ttf', 'Raleway-Thin.ttf', 'JosefinSansStd-Light.ttf',
        'LibreBarcode128-Regular.ttf'
    ]
    for filepath in tqdm.tqdm(sorted(glob.glob('fonts-raw/**/**/**/*.ttf')),
                              desc='Filtering fonts.'):
        sha256 = keras_ocr.tools.sha256sum(filepath)
        basename = os.path.basename(filepath)
        # We check the sha256 and filenames because some of the fonts
        # in the repository are duplicated (see TRIVIA.md).
        if sha256 in sha256s or basename in basenames or basename in blacklist:
            continue
        sha256s.append(sha256)
        basenames.append(basename)
        retained_fonts.append(filepath)
    retained_font_families = set([filepath.split(os.sep)[-2] for filepath in retained_fonts])
    added = []
    with zipfile.ZipFile(file='fonts.zip', mode='w') as zfile:
        for font_family in tqdm.tqdm(retained_font_families, desc='Saving ZIP file.'):
            # We want to keep all the metadata files plus
            # the retained font files. And we don't want
            # to add the same file twice.
            files = [
                input_filepath for input_filepath in glob.glob(f'fonts-raw/**/**/{font_family}/*')
                if input_filepath not in added and
                (input_filepath in retained_fonts or os.path.splitext(input_filepath)[1] != '.ttf')
            ]
            added.extend(files)
            for input_filepath in files:
                zfile.write(filename=input_filepath,
                            arcname=os.path.join(*input_filepath.split(os.sep)[-2:]))
    print('Finished saving fonts file.')
    # pylint: disable=line-too-long
    url = (
        'https://commons.wikimedia.org/w/api.php?action=query&generator=categorymembers&gcmtype=file&format=json'
        '&gcmtitle=Category:Featured_pictures_on_Wikimedia_Commons&prop=imageinfo&gcmlimit=50&iiprop=url&iiurlwidth=1024'
    )
    gcmcontinue = None
    max_responses = 300
    responses = []
    for responseCount in tqdm.tqdm(range(max_responses)):
        current_url = url
        if gcmcontinue is not None:
            current_url += f'&continue=gcmcontinue||&gcmcontinue={gcmcontinue}'
        with urllib.request.urlopen(url=current_url) as response:
            current = json.loads(response.read())
            responses.append(current)
            gcmcontinue = None if 'continue' not in current else current['continue']['gcmcontinue']
        if gcmcontinue is None:
            break
    print('Finished getting list of images.')
    # We want to avoid animated images as well as icon files.
    image_urls = []
    for response in responses:
        image_urls.extend(
            [page['imageinfo'][0]['thumburl'] for page in response['query']['pages'].values()])
    image_urls = [url for url in image_urls if url.lower().endswith('.jpg')]
    shutil.rmtree('backgrounds', ignore_errors=True)
    os.makedirs('backgrounds')
    assert len(image_urls) == len(set(image_urls)), 'Duplicates found!'
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        futures = [
            executor.submit(keras_ocr.tools.download_and_verify,
                            url=url,
                            cache_dir='./backgrounds',
                            verbose=False) for url in image_urls
        ]
        for _ in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            pass
    for filepath in glob.glob('backgrounds/*.JPG'):
        os.rename(filepath, filepath.lower())
    print('Filtering images by aspect ratio and maximum contiguous contour.')
    image_paths = np.array(sorted(glob.glob('backgrounds/*.jpg')))
    def compute_metrics(filepath):
        image = keras_ocr.tools.read(filepath)
        aspect_ratio = image.shape[0] / image.shape[1]
        contour, _ = keras_ocr.tools.get_maximum_uniform_contour(image, fontsize=40)
        area = cv2.contourArea(contour) if contour is not None else 0
        return aspect_ratio, area
    metrics = np.array([compute_metrics(filepath) for filepath in tqdm.tqdm(image_paths)])
    filtered_paths = image_paths[(metrics[:, 0] < 3 / 2) & (metrics[:, 0] > 2 / 3) &
                                 (metrics[:, 1] > 1e6)]
    detector = keras_ocr.detection.Detector()
    paths_with_text = [
        filepath for filepath in tqdm.tqdm(filtered_paths) if len(
            detector.detect(
                images=[keras_ocr.tools.read_and_fit(filepath, width=640, height=640)])[0]) > 0
    ]
    filtered_paths = np.array([path for path in filtered_paths if path not in paths_with_text])
    filtered_basenames = list(map(os.path.basename, filtered_paths))
    basename_to_url = {
        os.path.basename(urllib.parse.urlparse(url).path).lower(): url
        for url in image_urls
    }
    filtered_urls = [basename_to_url[basename.lower()] for basename in filtered_basenames]
    assert len(filtered_urls) == len(filtered_paths)
    removed_paths = [filepath for filepath in image_paths if filepath not in filtered_paths]
    for filepath in removed_paths:
        os.remove(filepath)
    with open('backgrounds/urls.txt', 'w') as f:
        f.write('\n'.join(filtered_urls))
    with zipfile.ZipFile(file='backgrounds.zip', mode='w') as zfile:
        for filepath in tqdm.tqdm(filtered_paths.tolist() + ['backgrounds/urls.txt'],
                                  desc='Saving ZIP file.'):
            zfile.write(filename=filepath, arcname=os.path.basename(filepath.lower()))

创作不易 觉得有帮助请点赞关注收藏~~~

相关文章
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
101 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
2月前
|
机器学习/深度学习 监控 算法
基于计算机视觉(opencv)的运动计数(运动辅助)系统-源码+注释+报告
基于计算机视觉(opencv)的运动计数(运动辅助)系统-源码+注释+报告
59 3
|
1月前
|
机器学习/深度学习 TensorFlow API
机器学习实战:TensorFlow在图像识别中的应用探索
【10月更文挑战第28天】随着深度学习技术的发展,图像识别取得了显著进步。TensorFlow作为Google开源的机器学习框架,凭借其强大的功能和灵活的API,在图像识别任务中广泛应用。本文通过实战案例,探讨TensorFlow在图像识别中的优势与挑战,展示如何使用TensorFlow构建和训练卷积神经网络(CNN),并评估模型的性能。尽管面临学习曲线和资源消耗等挑战,TensorFlow仍展现出广阔的应用前景。
59 5
|
22天前
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
62 0
|
2月前
|
机器学习/深度学习 TensorFlow API
使用 TensorFlow 和 Keras 构建图像分类器
【10月更文挑战第2天】使用 TensorFlow 和 Keras 构建图像分类器
|
2月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
113 0
|
3月前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
70 0
|
4月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
89 0
|
7月前
|
机器学习/深度学习 计算机视觉
AIGC核心技术——计算机视觉(CV)预训练大模型
【1月更文挑战第13天】AIGC核心技术——计算机视觉(CV)预训练大模型
638 3
AIGC核心技术——计算机视觉(CV)预训练大模型
|
3月前
|
人工智能 测试技术 API
AI计算机视觉笔记二十 九:yolov10竹签模型,自动数竹签
本文介绍了如何在AutoDL平台上搭建YOLOv10环境并进行竹签检测与计数。首先从官网下载YOLOv10源码并创建虚拟环境,安装依赖库。接着通过官方模型测试环境是否正常工作。然后下载自定义数据集并配置`mycoco128.yaml`文件,使用`yolo detect train`命令或Python代码进行训练。最后,通过命令行或API调用测试训练结果,并展示竹签计数功能。如需转载,请注明原文出处。