Marker 源码解析(二)(4)

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
简介: Marker 源码解析(二)

Marker 源码解析(二)(3)https://developer.aliyun.com/article/1483806

.\marker\marker\segmentation.py

# 导入所需的库
from concurrent.futures import ThreadPoolExecutor
from typing import List
from transformers import LayoutLMv3ForTokenClassification
# 导入自定义的模块
from marker.bbox import unnormalize_box
from transformers.models.layoutlmv3.image_processing_layoutlmv3 import normalize_box
import io
from PIL import Image
from transformers import LayoutLMv3Processor
import numpy as np
from marker.settings import settings
from marker.schema import Page, BlockType
import torch
from math import isclose
# 设置图像最大像素值,避免部分图像被截断
Image.MAX_IMAGE_PIXELS = None
# 从预训练模型加载 LayoutLMv3Processor
processor = LayoutLMv3Processor.from_pretrained(settings.LAYOUT_MODEL_NAME, apply_ocr=False)
# 定义需要分块的键和不需要分块的键
CHUNK_KEYS = ["input_ids", "attention_mask", "bbox", "offset_mapping"]
NO_CHUNK_KEYS = ["pixel_values"]
# 加载 LayoutLMv3ForTokenClassification 模型
def load_layout_model():
    # 从预训练模型加载 LayoutLMv3ForTokenClassification 模型
    model = LayoutLMv3ForTokenClassification.from_pretrained(
        settings.LAYOUT_MODEL_NAME,
        torch_dtype=settings.MODEL_DTYPE,
    ).to(settings.TORCH_DEVICE_MODEL)
    # 设置模型的标签映射
    model.config.id2label = {
        0: "Caption",
        1: "Footnote",
        2: "Formula",
        3: "List-item",
        4: "Page-footer",
        5: "Page-header",
        6: "Picture",
        7: "Section-header",
        8: "Table",
        9: "Text",
        10: "Title"
    }
    model.config.label2id = {v: k for k, v in model.config.id2label.items()}
    return model
# 检测文档块类型
def detect_document_block_types(doc, blocks: List[Page], layoutlm_model, batch_size=settings.LAYOUT_BATCH_SIZE):
    # 获取特征编码、元数据和样本长度
    encodings, metadata, sample_lengths = get_features(doc, blocks)
    # 预测块类型
    predictions = predict_block_types(encodings, layoutlm_model, batch_size)
    # 将预测结果与框匹配
    block_types = match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model)
    # 断言块类型数量与块数量相等
    assert len(block_types) == len(blocks)
    return block_types
# 获取临时框
def get_provisional_boxes(pred, box, is_subword, start_idx=0):
    # 从预测结果中获取临时框
    prov_predictions = [pred_ for idx, pred_ in enumerate(pred) if not is_subword[idx]][start_idx:]
    # 从列表中筛选出不是子词的元素,并从指定索引开始切片,得到新的列表
    prov_boxes = [box_ for idx, box_ in enumerate(box) if not is_subword[idx]][start_idx:]
    # 返回处理后的预测结果和框
    return prov_predictions, prov_boxes
# 获取页面编码信息,输入参数为页面和页面块对象
def get_page_encoding(page, page_blocks: Page):
    # 如果页面块中的所有行数为0,则返回空列表
    if len(page_blocks.get_all_lines()) == 0:
        return [], []
    # 获取页面块的边界框、宽度和高度
    page_box = page_blocks.bbox
    pwidth = page_blocks.width
    pheight = page_blocks.height
    # 获取页面块的像素图,并转换为 PNG 格式
    pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
    png = pix.pil_tobytes(format="PNG")
    png_image = Image.open(io.BytesIO(png))
    # 如果图像太大,则缩小以适应模型
    rgb_image = png_image.convert('RGB')
    rgb_width, rgb_height = rgb_image.size
    # 确保图像大小与 PDF 页面的比例正确
    assert isclose(rgb_width / pwidth, rgb_height / pheight, abs_tol=2e-2)
    # 获取页面块中的所有行
    lines = page_blocks.get_all_lines()
    boxes = []
    text = []
    for line in lines:
        box = line.bbox
        # 处理边界框溢出的情况
        if box[0] < page_box[0]:
            box[0] = page_box[0]
        if box[1] < page_box[1]:
            box[1] = page_box[1]
        if box[2] > page_box[2]:
            box[2] = page_box[2]
        if box[3] > page_box[3]:
            box[3] = page_box[3]
        # 处理边界框宽度或高度为0或负值的情况
        if box[2] <= box[0]:
            print("Zero width box found, cannot convert properly")
            raise ValueError
        if box[3] <= box[1]:
            print("Zero height box found, cannot convert properly")
            raise ValueError
        boxes.append(box)
        text.append(line.prelim_text)
    # 将边界框归一化为模型(缩放为1000x1000)
    boxes = [normalize_box(box, pwidth, pheight) for box in boxes]
    for box in boxes:
        # 验证所有边界框都是有效的
        assert(len(box) == 4)
        assert(max(box)) <= 1000
        assert(min(box)) >= 0
    # 使用 processor 处理 RGB 图像,传入文本、框、返回偏移映射等参数
    encoding = processor(
        rgb_image,
        text=text,
        boxes=boxes,
        return_offsets_mapping=True,
        truncation=True,
        return_tensors="pt",
        stride=settings.LAYOUT_CHUNK_OVERLAP,
        padding="max_length",
        max_length=settings.LAYOUT_MODEL_MAX,
        return_overflowing_tokens=True
    )
    # 从 encoding 中弹出 offset_mapping 和 overflow_to_sample_mapping
    offset_mapping = encoding.pop('offset_mapping')
    overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
    # 将 encoding 中的 bbox、input_ids、attention_mask、pixel_values 转换为列表
    bbox = list(encoding["bbox"])
    input_ids = list(encoding["input_ids"])
    attention_mask = list(encoding["attention_mask"])
    pixel_values = list(encoding["pixel_values"])
    # 断言各列表长度相等
    assert len(bbox) == len(input_ids) == len(attention_mask) == len(pixel_values) == len(offset_mapping)
    # 将各列表中的元素组成字典,放入 list_encoding 列表中
    list_encoding = []
    for i in range(len(bbox)):
        list_encoding.append({
            "bbox": bbox[i],
            "input_ids": input_ids[i],
            "attention_mask": attention_mask[i],
            "pixel_values": pixel_values[i],
            "offset_mapping": offset_mapping[i]
        })
    # 其他数据包括原始框、pwidth 和 pheight
    other_data = {
        "original_bbox": boxes,
        "pwidth": pwidth,
        "pheight": pheight,
    }
    # 返回 list_encoding 和 other_data
    return list_encoding, other_data
# 获取文档的特征信息
def get_features(doc, blocks):
    # 初始化编码、元数据和样本长度列表
    encodings = []
    metadata = []
    sample_lengths = []
    # 遍历每个块
    for i in range(len(blocks)):
        # 调用函数获取页面编码和其他数据
        encoding, other_data = get_page_encoding(doc[i], blocks[i])
        # 将页面编码添加到编码列表中
        encodings.extend(encoding)
        # 将其他数据添加到元数据列表中
        metadata.append(other_data)
        # 记录当前页面编码的长度
        sample_lengths.append(len(encoding))
    # 返回编码、元数据和样本长度
    return encodings, metadata, sample_lengths
# 预测块类型
def predict_block_types(encodings, layoutlm_model, batch_size):
    # 初始化所有预测结果列表
    all_predictions = []
    # 按批次处理编码
    for i in range(0, len(encodings), batch_size):
        # 计算当前批次的起始和结束索引
        batch_start = i
        batch_end = min(i + batch_size, len(encodings))
        # 获取当前批次的编码
        batch = encodings[batch_start:batch_end]
        # 构建模型输入
        model_in = {}
        for k in ["bbox", "input_ids", "attention_mask", "pixel_values"]:
            model_in[k] = torch.stack([b[k] for b in batch]).to(layoutlm_model.device)
        model_in["pixel_values"] = model_in["pixel_values"].to(layoutlm_model.dtype)
        # 进入推理模式
        with torch.inference_mode():
            # 使用模型进行推理
            outputs = layoutlm_model(**model_in)
            logits = outputs.logits
        # 获取预测结果
        predictions = logits.argmax(-1).squeeze().tolist()
        if len(predictions) == settings.LAYOUT_MODEL_MAX:
            predictions = [predictions]
        # 将预测结果添加到所有预测结果列表中
        all_predictions.extend(predictions)
    # 返回所有预测结果
    return all_predictions
# 将预测结果与框匹配
def match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model) -> List[List[BlockType]]:
    # 断言编码、预测结果和样本长度的长度相等
    assert len(encodings) == len(predictions) == sum(sample_lengths)
    assert len(metadata) == len(sample_lengths)
    # 初始化页面起始索引和页面块类型列表
    page_start = 0
    page_block_types = []
    # 返回页面块类型列表
    return page_block_types

.\marker\marker\settings.py

import os
from typing import Optional, List, Dict
from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings
import fitz as pymupdf
import torch
# 定义一个设置类,继承自BaseSettings
class Settings(BaseSettings):
    # General
    TORCH_DEVICE: Optional[str] = None
    # 计算属性,返回TORCH_DEVICE_MODEL
    @computed_field
    @property
    def TORCH_DEVICE_MODEL(self) -> str:
        # 如果TORCH_DEVICE不为None,则返回TORCH_DEVICE
        if self.TORCH_DEVICE is not None:
            return self.TORCH_DEVICE
        # 如果CUDA可用,则返回"cuda"
        if torch.cuda.is_available():
            return "cuda"
        # 如果MPS可用,则返回"mps"
        if torch.backends.mps.is_available():
            return "mps"
        # 否则返回"cpu"
        return "cpu"
    INFERENCE_RAM: int = 40 # 每个GPU的VRAM量(以GB为单位)。
    VRAM_PER_TASK: float = 2.5 # 每个任务分配的VRAM量(以GB为单位)。 峰值标记VRAM使用量约为3GB,但工作程序的平均值较低。
    DEFAULT_LANG: str = "English" # 我们假设文件所在的默认语言,应该是TESSERACT_LANGUAGES中的一个键
    SUPPORTED_FILETYPES: Dict = {
        "application/pdf": "pdf",
        "application/epub+zip": "epub",
        "application/x-mobipocket-ebook": "mobi",
        "application/vnd.ms-xpsdocument": "xps",
        "application/x-fictionbook+xml": "fb2"
    }
    # PyMuPDF
    TEXT_FLAGS: int = pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES
    # OCR
    INVALID_CHARS: List[str] = [chr(0xfffd), "�"]
    OCR_DPI: int = 400
    TESSDATA_PREFIX: str = ""
    TESSERACT_LANGUAGES: Dict = {
        "English": "eng",
        "Spanish": "spa",
        "Portuguese": "por",
        "French": "fra",
        "German": "deu",
        "Russian": "rus",
        "Chinese": "chi_sim",
        "Japanese": "jpn",
        "Korean": "kor",
        "Hindi": "hin",
    }
    TESSERACT_TIMEOUT: int = 20 # 何时放弃OCR
    # 定义拼写检查语言对应的字典
    SPELLCHECK_LANGUAGES: Dict = {
        "English": "en",
        "Spanish": "es",
        "Portuguese": "pt",
        "French": "fr",
        "German": "de",
        "Russian": "ru",
        "Chinese": None,
        "Japanese": None,
        "Korean": None,
        "Hindi": None,
    }
    # 是否在每一页运行 OCR,即使可以提取文本
    OCR_ALL_PAGES: bool = False
    # 用于 OCR 的并行 CPU 工作线程数
    OCR_PARALLEL_WORKERS: int = 2
    # 使用的 OCR 引擎,可以是 "tesseract" 或 "ocrmypdf",ocrmypdf 质量更高但速度较慢
    OCR_ENGINE: str = "ocrmypdf"
    # Texify 模型相关参数
    TEXIFY_MODEL_MAX: int = 384 # Texify 的最大推理长度
    TEXIFY_TOKEN_BUFFER: int = 256 # Texify 的 token 缓冲区大小
    TEXIFY_DPI: int = 96 # 渲染图像的 DPI
    TEXIFY_BATCH_SIZE: int = 2 if TORCH_DEVICE_MODEL == "cpu" else 6 # Texify 的批处理大小,CPU 上较低因为使用 float32
    TEXIFY_MODEL_NAME: str = "vikp/texify"
    # Layout 模型相关参数
    BAD_SPAN_TYPES: List[str] = ["Caption", "Footnote", "Page-footer", "Page-header", "Picture"]
    LAYOUT_MODEL_MAX: int = 512
    LAYOUT_CHUNK_OVERLAP: int = 64
    LAYOUT_DPI: int = 96
    LAYOUT_MODEL_NAME: str = "vikp/layout_segmenter"
    LAYOUT_BATCH_SIZE: int = 8 # 最大 512 个 token 意味着较高的批处理大小
    # Ordering 模型相关参数
    ORDERER_BATCH_SIZE: int = 32 # 可以较高,因为最大 token 数为 128
    ORDERER_MODEL_NAME: str = "vikp/column_detector"
    # 最终编辑模型相关参数
    EDITOR_BATCH_SIZE: int = 4
    EDITOR_MAX_LENGTH: int = 1024
    EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
    ENABLE_EDITOR_MODEL: bool = False # 编辑模型可能会产生误报
    EDITOR_CUTOFF_THRESH: float = 0.9 # 忽略概率低于此阈值的预测
    # Ray 相关参数
    RAY_CACHE_PATH: Optional[str] = None # 保存 Ray 缓存的路径
    RAY_CORES_PER_WORKER: int = 1 # 每个 worker 分配的 CPU 核心数
    # 调试相关参数
    DEBUG: bool = False # 启用调试日志
    # 调试数据文件夹路径,默认为 None
    DEBUG_DATA_FOLDER: Optional[str] = None
    # 调试级别,范围从 0 到 2,2 表示记录所有信息
    DEBUG_LEVEL: int = 0
    
    # 计算属性,返回是否使用 CUDA
    @computed_field
    @property
    def CUDA(self) -> bool:
        return "cuda" in self.TORCH_DEVICE
    
    # 计算属性,返回模型数据类型
    @computed_field
    @property
    def MODEL_DTYPE(self) -> torch.dtype:
        if self.TORCH_DEVICE_MODEL == "cuda":
            return torch.bfloat16
        else:
            return torch.float32
    
    # 计算属性,返回用于转换的数据类型
    @computed_field
    @property
    def TEXIFY_DTYPE(self) -> torch.dtype:
        return torch.float32 if self.TORCH_DEVICE_MODEL == "cpu" else torch.float16
    
    # 类配置
    class Config:
        # 从环境文件中查找 local.env 文件
        env_file = find_dotenv("local.env")
        # 额外配置,忽略错误
        extra = "ignore"
# 创建一个 Settings 对象实例
settings = Settings()

.\marker\scripts\verify_benchmark_scores.py

# 导入 json 模块和 argparse 模块
import json
import argparse
# 验证分数的函数,接收一个文件路径作为参数
def verify_scores(file_path):
    # 打开文件并加载 JSON 数据
    with open(file_path, 'r') as file:
        data = json.load(file)
    # 获取 multicolcnn.pdf 文件的分数
    multicolcnn_score = data["marker"]["files"]["multicolcnn.pdf"]["score"]
    # 获取 switch_trans.pdf 文件的分数
    switch_trans_score = data["marker"]["files"]["switch_trans.pdf"]["score"]
    # 如果其中一个分数小于等于 0.4,则抛出 ValueError 异常
    if multicolcnn_score <= 0.4 or switch_trans_score <= 0.4:
        raise ValueError("One or more scores are below the required threshold of 0.4")
# 如果当前脚本被直接执行
if __name__ == "__main__":
    # 创建 ArgumentParser 对象,设置描述信息
    parser = argparse.ArgumentParser(description="Verify benchmark scores")
    # 添加一个参数,指定文件路径,类型为字符串
    parser.add_argument("file_path", type=str, help="Path to the json file")
    # 解析命令行参数
    args = parser.parse_args()
    # 调用 verify_scores 函数,传入文件路径参数
    verify_scores(args.file_path)


相关文章
|
2月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
102 2
|
19天前
|
存储 设计模式 算法
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行为。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象行为模式比类行为模式具有更大的灵活性。 行为型模式分为: • 模板方法模式 • 策略模式 • 命令模式 • 职责链模式 • 状态模式 • 观察者模式 • 中介者模式 • 迭代器模式 • 访问者模式 • 备忘录模式 • 解释器模式
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
|
19天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
结构型模式描述如何将类或对象按某种布局组成更大的结构。它分为类结构型模式和对象结构型模式,前者采用继承机制来组织接口和类,后者釆用组合或聚合来组合对象。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象结构型模式比类结构型模式具有更大的灵活性。 结构型模式分为以下 7 种: • 代理模式 • 适配器模式 • 装饰者模式 • 桥接模式 • 外观模式 • 组合模式 • 享元模式
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
19天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
创建型模式的主要关注点是“怎样创建对象?”,它的主要特点是"将对象的创建与使用分离”。这样可以降低系统的耦合度,使用者不需要关注对象的创建细节。创建型模式分为5种:单例模式、工厂方法模式抽象工厂式、原型模式、建造者模式。
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
2月前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
61 12
|
1月前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
20天前
|
安全 搜索推荐 数据挖掘
陪玩系统源码开发流程解析,成品陪玩系统源码的优点
我们自主开发的多客陪玩系统源码,整合了市面上主流陪玩APP功能,支持二次开发。该系统适用于线上游戏陪玩、语音视频聊天、心理咨询等场景,提供用户注册管理、陪玩者资料库、预约匹配、实时通讯、支付结算、安全隐私保护、客户服务及数据分析等功能,打造综合性社交平台。随着互联网技术发展,陪玩系统正成为游戏爱好者的新宠,改变游戏体验并带来新的商业模式。
|
2月前
|
存储 安全 Linux
Golang的GMP调度模型与源码解析
【11月更文挑战第11天】GMP 调度模型是 Go 语言运行时系统的核心部分,用于高效管理和调度大量协程(goroutine)。它通过少量的操作系统线程(M)和逻辑处理器(P)来调度大量的轻量级协程(G),从而实现高性能的并发处理。GMP 模型通过本地队列和全局队列来减少锁竞争,提高调度效率。在 Go 源码中,`runtime.h` 文件定义了关键数据结构,`schedule()` 和 `findrunnable()` 函数实现了核心调度逻辑。通过深入研究 GMP 模型,可以更好地理解 Go 语言的并发机制。
|
2月前
|
消息中间件 缓存 安全
Future与FutureTask源码解析,接口阻塞问题及解决方案
【11月更文挑战第5天】在Java开发中,多线程编程是提高系统并发性能和资源利用率的重要手段。然而,多线程编程也带来了诸如线程安全、死锁、接口阻塞等一系列复杂问题。本文将深度剖析多线程优化技巧、Future与FutureTask的源码、接口阻塞问题及解决方案,并通过具体业务场景和Java代码示例进行实战演示。
67 3
|
3月前
|
存储
让星星⭐月亮告诉你,HashMap的put方法源码解析及其中两种会触发扩容的场景(足够详尽,有问题欢迎指正~)
`HashMap`的`put`方法通过调用`putVal`实现,主要涉及两个场景下的扩容操作:1. 初始化时,链表数组的初始容量设为16,阈值设为12;2. 当存储的元素个数超过阈值时,链表数组的容量和阈值均翻倍。`putVal`方法处理键值对的插入,包括链表和红黑树的转换,确保高效的数据存取。
70 5

推荐镜像

更多