Marker 源码解析(二)(4)

简介: Marker 源码解析(二)

Marker 源码解析(二)(3)https://developer.aliyun.com/article/1483806

.\marker\marker\segmentation.py

# 导入所需的库
from concurrent.futures import ThreadPoolExecutor
from typing import List
from transformers import LayoutLMv3ForTokenClassification
# 导入自定义的模块
from marker.bbox import unnormalize_box
from transformers.models.layoutlmv3.image_processing_layoutlmv3 import normalize_box
import io
from PIL import Image
from transformers import LayoutLMv3Processor
import numpy as np
from marker.settings import settings
from marker.schema import Page, BlockType
import torch
from math import isclose
# 设置图像最大像素值,避免部分图像被截断
Image.MAX_IMAGE_PIXELS = None
# 从预训练模型加载 LayoutLMv3Processor
processor = LayoutLMv3Processor.from_pretrained(settings.LAYOUT_MODEL_NAME, apply_ocr=False)
# 定义需要分块的键和不需要分块的键
CHUNK_KEYS = ["input_ids", "attention_mask", "bbox", "offset_mapping"]
NO_CHUNK_KEYS = ["pixel_values"]
# 加载 LayoutLMv3ForTokenClassification 模型
def load_layout_model():
    # 从预训练模型加载 LayoutLMv3ForTokenClassification 模型
    model = LayoutLMv3ForTokenClassification.from_pretrained(
        settings.LAYOUT_MODEL_NAME,
        torch_dtype=settings.MODEL_DTYPE,
    ).to(settings.TORCH_DEVICE_MODEL)
    # 设置模型的标签映射
    model.config.id2label = {
        0: "Caption",
        1: "Footnote",
        2: "Formula",
        3: "List-item",
        4: "Page-footer",
        5: "Page-header",
        6: "Picture",
        7: "Section-header",
        8: "Table",
        9: "Text",
        10: "Title"
    }
    model.config.label2id = {v: k for k, v in model.config.id2label.items()}
    return model
# 检测文档块类型
def detect_document_block_types(doc, blocks: List[Page], layoutlm_model, batch_size=settings.LAYOUT_BATCH_SIZE):
    # 获取特征编码、元数据和样本长度
    encodings, metadata, sample_lengths = get_features(doc, blocks)
    # 预测块类型
    predictions = predict_block_types(encodings, layoutlm_model, batch_size)
    # 将预测结果与框匹配
    block_types = match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model)
    # 断言块类型数量与块数量相等
    assert len(block_types) == len(blocks)
    return block_types
# 获取临时框
def get_provisional_boxes(pred, box, is_subword, start_idx=0):
    # 从预测结果中获取临时框
    prov_predictions = [pred_ for idx, pred_ in enumerate(pred) if not is_subword[idx]][start_idx:]
    # 从列表中筛选出不是子词的元素,并从指定索引开始切片,得到新的列表
    prov_boxes = [box_ for idx, box_ in enumerate(box) if not is_subword[idx]][start_idx:]
    # 返回处理后的预测结果和框
    return prov_predictions, prov_boxes
# 获取页面编码信息,输入参数为页面和页面块对象
def get_page_encoding(page, page_blocks: Page):
    # 如果页面块中的所有行数为0,则返回空列表
    if len(page_blocks.get_all_lines()) == 0:
        return [], []
    # 获取页面块的边界框、宽度和高度
    page_box = page_blocks.bbox
    pwidth = page_blocks.width
    pheight = page_blocks.height
    # 获取页面块的像素图,并转换为 PNG 格式
    pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
    png = pix.pil_tobytes(format="PNG")
    png_image = Image.open(io.BytesIO(png))
    # 如果图像太大,则缩小以适应模型
    rgb_image = png_image.convert('RGB')
    rgb_width, rgb_height = rgb_image.size
    # 确保图像大小与 PDF 页面的比例正确
    assert isclose(rgb_width / pwidth, rgb_height / pheight, abs_tol=2e-2)
    # 获取页面块中的所有行
    lines = page_blocks.get_all_lines()
    boxes = []
    text = []
    for line in lines:
        box = line.bbox
        # 处理边界框溢出的情况
        if box[0] < page_box[0]:
            box[0] = page_box[0]
        if box[1] < page_box[1]:
            box[1] = page_box[1]
        if box[2] > page_box[2]:
            box[2] = page_box[2]
        if box[3] > page_box[3]:
            box[3] = page_box[3]
        # 处理边界框宽度或高度为0或负值的情况
        if box[2] <= box[0]:
            print("Zero width box found, cannot convert properly")
            raise ValueError
        if box[3] <= box[1]:
            print("Zero height box found, cannot convert properly")
            raise ValueError
        boxes.append(box)
        text.append(line.prelim_text)
    # 将边界框归一化为模型(缩放为1000x1000)
    boxes = [normalize_box(box, pwidth, pheight) for box in boxes]
    for box in boxes:
        # 验证所有边界框都是有效的
        assert(len(box) == 4)
        assert(max(box)) <= 1000
        assert(min(box)) >= 0
    # 使用 processor 处理 RGB 图像,传入文本、框、返回偏移映射等参数
    encoding = processor(
        rgb_image,
        text=text,
        boxes=boxes,
        return_offsets_mapping=True,
        truncation=True,
        return_tensors="pt",
        stride=settings.LAYOUT_CHUNK_OVERLAP,
        padding="max_length",
        max_length=settings.LAYOUT_MODEL_MAX,
        return_overflowing_tokens=True
    )
    # 从 encoding 中弹出 offset_mapping 和 overflow_to_sample_mapping
    offset_mapping = encoding.pop('offset_mapping')
    overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
    # 将 encoding 中的 bbox、input_ids、attention_mask、pixel_values 转换为列表
    bbox = list(encoding["bbox"])
    input_ids = list(encoding["input_ids"])
    attention_mask = list(encoding["attention_mask"])
    pixel_values = list(encoding["pixel_values"])
    # 断言各列表长度相等
    assert len(bbox) == len(input_ids) == len(attention_mask) == len(pixel_values) == len(offset_mapping)
    # 将各列表中的元素组成字典,放入 list_encoding 列表中
    list_encoding = []
    for i in range(len(bbox)):
        list_encoding.append({
            "bbox": bbox[i],
            "input_ids": input_ids[i],
            "attention_mask": attention_mask[i],
            "pixel_values": pixel_values[i],
            "offset_mapping": offset_mapping[i]
        })
    # 其他数据包括原始框、pwidth 和 pheight
    other_data = {
        "original_bbox": boxes,
        "pwidth": pwidth,
        "pheight": pheight,
    }
    # 返回 list_encoding 和 other_data
    return list_encoding, other_data
# 获取文档的特征信息
def get_features(doc, blocks):
    # 初始化编码、元数据和样本长度列表
    encodings = []
    metadata = []
    sample_lengths = []
    # 遍历每个块
    for i in range(len(blocks)):
        # 调用函数获取页面编码和其他数据
        encoding, other_data = get_page_encoding(doc[i], blocks[i])
        # 将页面编码添加到编码列表中
        encodings.extend(encoding)
        # 将其他数据添加到元数据列表中
        metadata.append(other_data)
        # 记录当前页面编码的长度
        sample_lengths.append(len(encoding))
    # 返回编码、元数据和样本长度
    return encodings, metadata, sample_lengths
# 预测块类型
def predict_block_types(encodings, layoutlm_model, batch_size):
    # 初始化所有预测结果列表
    all_predictions = []
    # 按批次处理编码
    for i in range(0, len(encodings), batch_size):
        # 计算当前批次的起始和结束索引
        batch_start = i
        batch_end = min(i + batch_size, len(encodings))
        # 获取当前批次的编码
        batch = encodings[batch_start:batch_end]
        # 构建模型输入
        model_in = {}
        for k in ["bbox", "input_ids", "attention_mask", "pixel_values"]:
            model_in[k] = torch.stack([b[k] for b in batch]).to(layoutlm_model.device)
        model_in["pixel_values"] = model_in["pixel_values"].to(layoutlm_model.dtype)
        # 进入推理模式
        with torch.inference_mode():
            # 使用模型进行推理
            outputs = layoutlm_model(**model_in)
            logits = outputs.logits
        # 获取预测结果
        predictions = logits.argmax(-1).squeeze().tolist()
        if len(predictions) == settings.LAYOUT_MODEL_MAX:
            predictions = [predictions]
        # 将预测结果添加到所有预测结果列表中
        all_predictions.extend(predictions)
    # 返回所有预测结果
    return all_predictions
# 将预测结果与框匹配
def match_predictions_to_boxes(encodings, predictions, metadata, sample_lengths, layoutlm_model) -> List[List[BlockType]]:
    # 断言编码、预测结果和样本长度的长度相等
    assert len(encodings) == len(predictions) == sum(sample_lengths)
    assert len(metadata) == len(sample_lengths)
    # 初始化页面起始索引和页面块类型列表
    page_start = 0
    page_block_types = []
    # 返回页面块类型列表
    return page_block_types

.\marker\marker\settings.py

import os
from typing import Optional, List, Dict
from dotenv import find_dotenv
from pydantic import computed_field
from pydantic_settings import BaseSettings
import fitz as pymupdf
import torch
# 定义一个设置类,继承自BaseSettings
class Settings(BaseSettings):
    # General
    TORCH_DEVICE: Optional[str] = None
    # 计算属性,返回TORCH_DEVICE_MODEL
    @computed_field
    @property
    def TORCH_DEVICE_MODEL(self) -> str:
        # 如果TORCH_DEVICE不为None,则返回TORCH_DEVICE
        if self.TORCH_DEVICE is not None:
            return self.TORCH_DEVICE
        # 如果CUDA可用,则返回"cuda"
        if torch.cuda.is_available():
            return "cuda"
        # 如果MPS可用,则返回"mps"
        if torch.backends.mps.is_available():
            return "mps"
        # 否则返回"cpu"
        return "cpu"
    INFERENCE_RAM: int = 40 # 每个GPU的VRAM量(以GB为单位)。
    VRAM_PER_TASK: float = 2.5 # 每个任务分配的VRAM量(以GB为单位)。 峰值标记VRAM使用量约为3GB,但工作程序的平均值较低。
    DEFAULT_LANG: str = "English" # 我们假设文件所在的默认语言,应该是TESSERACT_LANGUAGES中的一个键
    SUPPORTED_FILETYPES: Dict = {
        "application/pdf": "pdf",
        "application/epub+zip": "epub",
        "application/x-mobipocket-ebook": "mobi",
        "application/vnd.ms-xpsdocument": "xps",
        "application/x-fictionbook+xml": "fb2"
    }
    # PyMuPDF
    TEXT_FLAGS: int = pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES
    # OCR
    INVALID_CHARS: List[str] = [chr(0xfffd), "�"]
    OCR_DPI: int = 400
    TESSDATA_PREFIX: str = ""
    TESSERACT_LANGUAGES: Dict = {
        "English": "eng",
        "Spanish": "spa",
        "Portuguese": "por",
        "French": "fra",
        "German": "deu",
        "Russian": "rus",
        "Chinese": "chi_sim",
        "Japanese": "jpn",
        "Korean": "kor",
        "Hindi": "hin",
    }
    TESSERACT_TIMEOUT: int = 20 # 何时放弃OCR
    # 定义拼写检查语言对应的字典
    SPELLCHECK_LANGUAGES: Dict = {
        "English": "en",
        "Spanish": "es",
        "Portuguese": "pt",
        "French": "fr",
        "German": "de",
        "Russian": "ru",
        "Chinese": None,
        "Japanese": None,
        "Korean": None,
        "Hindi": None,
    }
    # 是否在每一页运行 OCR,即使可以提取文本
    OCR_ALL_PAGES: bool = False
    # 用于 OCR 的并行 CPU 工作线程数
    OCR_PARALLEL_WORKERS: int = 2
    # 使用的 OCR 引擎,可以是 "tesseract" 或 "ocrmypdf",ocrmypdf 质量更高但速度较慢
    OCR_ENGINE: str = "ocrmypdf"
    # Texify 模型相关参数
    TEXIFY_MODEL_MAX: int = 384 # Texify 的最大推理长度
    TEXIFY_TOKEN_BUFFER: int = 256 # Texify 的 token 缓冲区大小
    TEXIFY_DPI: int = 96 # 渲染图像的 DPI
    TEXIFY_BATCH_SIZE: int = 2 if TORCH_DEVICE_MODEL == "cpu" else 6 # Texify 的批处理大小,CPU 上较低因为使用 float32
    TEXIFY_MODEL_NAME: str = "vikp/texify"
    # Layout 模型相关参数
    BAD_SPAN_TYPES: List[str] = ["Caption", "Footnote", "Page-footer", "Page-header", "Picture"]
    LAYOUT_MODEL_MAX: int = 512
    LAYOUT_CHUNK_OVERLAP: int = 64
    LAYOUT_DPI: int = 96
    LAYOUT_MODEL_NAME: str = "vikp/layout_segmenter"
    LAYOUT_BATCH_SIZE: int = 8 # 最大 512 个 token 意味着较高的批处理大小
    # Ordering 模型相关参数
    ORDERER_BATCH_SIZE: int = 32 # 可以较高,因为最大 token 数为 128
    ORDERER_MODEL_NAME: str = "vikp/column_detector"
    # 最终编辑模型相关参数
    EDITOR_BATCH_SIZE: int = 4
    EDITOR_MAX_LENGTH: int = 1024
    EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
    ENABLE_EDITOR_MODEL: bool = False # 编辑模型可能会产生误报
    EDITOR_CUTOFF_THRESH: float = 0.9 # 忽略概率低于此阈值的预测
    # Ray 相关参数
    RAY_CACHE_PATH: Optional[str] = None # 保存 Ray 缓存的路径
    RAY_CORES_PER_WORKER: int = 1 # 每个 worker 分配的 CPU 核心数
    # 调试相关参数
    DEBUG: bool = False # 启用调试日志
    # 调试数据文件夹路径,默认为 None
    DEBUG_DATA_FOLDER: Optional[str] = None
    # 调试级别,范围从 0 到 2,2 表示记录所有信息
    DEBUG_LEVEL: int = 0
    
    # 计算属性,返回是否使用 CUDA
    @computed_field
    @property
    def CUDA(self) -> bool:
        return "cuda" in self.TORCH_DEVICE
    
    # 计算属性,返回模型数据类型
    @computed_field
    @property
    def MODEL_DTYPE(self) -> torch.dtype:
        if self.TORCH_DEVICE_MODEL == "cuda":
            return torch.bfloat16
        else:
            return torch.float32
    
    # 计算属性,返回用于转换的数据类型
    @computed_field
    @property
    def TEXIFY_DTYPE(self) -> torch.dtype:
        return torch.float32 if self.TORCH_DEVICE_MODEL == "cpu" else torch.float16
    
    # 类配置
    class Config:
        # 从环境文件中查找 local.env 文件
        env_file = find_dotenv("local.env")
        # 额外配置,忽略错误
        extra = "ignore"
# 创建一个 Settings 对象实例
settings = Settings()

.\marker\scripts\verify_benchmark_scores.py

# 导入 json 模块和 argparse 模块
import json
import argparse
# 验证分数的函数,接收一个文件路径作为参数
def verify_scores(file_path):
    # 打开文件并加载 JSON 数据
    with open(file_path, 'r') as file:
        data = json.load(file)
    # 获取 multicolcnn.pdf 文件的分数
    multicolcnn_score = data["marker"]["files"]["multicolcnn.pdf"]["score"]
    # 获取 switch_trans.pdf 文件的分数
    switch_trans_score = data["marker"]["files"]["switch_trans.pdf"]["score"]
    # 如果其中一个分数小于等于 0.4,则抛出 ValueError 异常
    if multicolcnn_score <= 0.4 or switch_trans_score <= 0.4:
        raise ValueError("One or more scores are below the required threshold of 0.4")
# 如果当前脚本被直接执行
if __name__ == "__main__":
    # 创建 ArgumentParser 对象,设置描述信息
    parser = argparse.ArgumentParser(description="Verify benchmark scores")
    # 添加一个参数,指定文件路径,类型为字符串
    parser.add_argument("file_path", type=str, help="Path to the json file")
    # 解析命令行参数
    args = parser.parse_args()
    # 调用 verify_scores 函数,传入文件路径参数
    verify_scores(args.file_path)


相关文章
|
8月前
|
算法 测试技术 C语言
深入理解HTTP/2:nghttp2库源码解析及客户端实现示例
通过解析nghttp2库的源码和实现一个简单的HTTP/2客户端示例,本文详细介绍了HTTP/2的关键特性和nghttp2的核心实现。了解这些内容可以帮助开发者更好地理解HTTP/2协议,提高Web应用的性能和用户体验。对于实际开发中的应用,可以根据需要进一步优化和扩展代码,以满足具体需求。
807 29
|
8月前
|
前端开发 数据安全/隐私保护 CDN
二次元聚合短视频解析去水印系统源码
二次元聚合短视频解析去水印系统源码
316 4
|
8月前
|
JavaScript 算法 前端开发
JS数组操作方法全景图,全网最全构建完整知识网络!js数组操作方法全集(实现筛选转换、随机排序洗牌算法、复杂数据处理统计等情景详解,附大量源码和易错点解析)
这些方法提供了对数组的全面操作,包括搜索、遍历、转换和聚合等。通过分为原地操作方法、非原地操作方法和其他方法便于您理解和记忆,并熟悉他们各自的使用方法与使用范围。详细的案例与进阶使用,方便您理解数组操作的底层原理。链式调用的几个案例,让您玩转数组操作。 只有锻炼思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
|
8月前
|
移动开发 前端开发 JavaScript
从入门到精通:H5游戏源码开发技术全解析与未来趋势洞察
H5游戏凭借其跨平台、易传播和开发成本低的优势,近年来发展迅猛。接下来,让我们深入了解 H5 游戏源码开发的技术教程以及未来的发展趋势。
|
8月前
|
存储 前端开发 JavaScript
在线教育网课系统源码开发指南:功能设计与技术实现深度解析
在线教育网课系统是近年来发展迅猛的教育形式的核心载体,具备用户管理、课程管理、教学互动、学习评估等功能。本文从功能和技术两方面解析其源码开发,涵盖前端(HTML5、CSS3、JavaScript等)、后端(Java、Python等)、流媒体及云计算技术,并强调安全性、稳定性和用户体验的重要性。
|
9月前
|
机器学习/深度学习 自然语言处理 算法
生成式 AI 大语言模型(LLMs)核心算法及源码解析:预训练篇
生成式 AI 大语言模型(LLMs)核心算法及源码解析:预训练篇
2216 1
|
8月前
|
负载均衡 JavaScript 前端开发
分片上传技术全解析:原理、优势与应用(含简单实现源码)
分片上传通过将大文件分割成多个小的片段或块,然后并行或顺序地上传这些片段,从而提高上传效率和可靠性,特别适用于大文件的上传场景,尤其是在网络环境不佳时,分片上传能有效提高上传体验。 博客不应该只有代码和解决方案,重点应该在于给出解决方案的同时分享思维模式,只有思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
|
11月前
|
存储 设计模式 算法
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行为。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象行为模式比类行为模式具有更大的灵活性。 行为型模式分为: • 模板方法模式 • 策略模式 • 命令模式 • 职责链模式 • 状态模式 • 观察者模式 • 中介者模式 • 迭代器模式 • 访问者模式 • 备忘录模式 • 解释器模式
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
|
11月前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
结构型模式描述如何将类或对象按某种布局组成更大的结构。它分为类结构型模式和对象结构型模式,前者采用继承机制来组织接口和类,后者釆用组合或聚合来组合对象。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象结构型模式比类结构型模式具有更大的灵活性。 结构型模式分为以下 7 种: • 代理模式 • 适配器模式 • 装饰者模式 • 桥接模式 • 外观模式 • 组合模式 • 享元模式
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
10月前
|
自然语言处理 数据处理 索引
mindspeed-llm源码解析(一)preprocess_data
mindspeed-llm是昇腾模型套件代码仓,原来叫"modelLink"。这篇文章带大家阅读一下数据处理脚本preprocess_data.py(基于1.0.0分支),数据处理是模型训练的第一步,经常会用到。
307 0

推荐镜像

更多
  • DNS