Marker 源码解析(二)(2)

简介: Marker 源码解析(二)

Marker 源码解析(二)(1)https://developer.aliyun.com/article/1483801

.\marker\marker\ordering.py

# 导入必要的模块
from copy import deepcopy
from typing import List
import torch
import sys, os
from marker.extract_text import convert_single_page
from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3Processor
from PIL import Image
import io
from marker.schema import Page
from marker.settings import settings
# 从设置中加载 LayoutLMv3Processor 模型
processor = LayoutLMv3Processor.from_pretrained(settings.ORDERER_MODEL_NAME)
# 加载 LayoutLMv3ForSequenceClassification 模型
def load_ordering_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(
        settings.ORDERER_MODEL_NAME,
        torch_dtype=settings.MODEL_DTYPE,
    ).to(settings.TORCH_DEVICE_MODEL)
    model.eval()
    return model
# 获取推理数据
def get_inference_data(page, page_blocks: Page):
    # 深拷贝页面块的边界框
    bboxes = deepcopy([block.bbox for block in page_blocks.blocks])
    # 初始化单词列表
    words = ["."] * len(bboxes)
    # 获取页面的像素图像
    pix = page.get_pixmap(dpi=settings.LAYOUT_DPI, annots=False, clip=page_blocks.bbox)
    # 将像素图像转换为 PNG 格式
    png = pix.pil_tobytes(format="PNG")
    # 将 PNG 数据转换为 RGB 图像
    rgb_image = Image.open(io.BytesIO(png)).convert("RGB")
    # 获取页面块的边界框和宽高
    page_box = page_blocks.bbox
    pwidth = page_blocks.width
    pheight = page_blocks.height
    # 调整边界框的值
    for box in bboxes:
        if box[0] < page_box[0]:
            box[0] = page_box[0]
        if box[1] < page_box[1]:
            box[1] = page_box[1]
        if box[2] > page_box[2]:
            box[2] = page_box[2]
        if box[3] > page_box[3]:
            box[3] = page_box[3]
        # 将边界框的值转换为相对于页面宽高的比例
        box[0] = int(box[0] / pwidth * 1000)
        box[1] = int(box[1] / pheight * 1000)
        box[2] = int(box[2] / pwidth * 1000)
        box[3] = int(box[3] / pheight * 1000)
    return rgb_image, bboxes, words
# 批量推理
def batch_inference(rgb_images, bboxes, words, model):
    # 对 RGB 图像、单词和边界框进行编码
    encoding = processor(
        rgb_images,
        text=words,
        boxes=bboxes,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=128
    )
    # 将像素值转换为模型的数据类型
    encoding["pixel_values"] = encoding["pixel_values"].to(model.dtype)
    # 进入推断模式,不进行梯度计算
    with torch.inference_mode():
        # 将指定的键对应的值移动到模型所在设备上
        for k in ["bbox", "input_ids", "pixel_values", "attention_mask"]:
            encoding[k] = encoding[k].to(model.device)
        # 使用模型进行推理,获取输出
        outputs = model(**encoding)
        # 获取模型输出的预测结果
        logits = outputs.logits
    # 获取预测结果中概率最大的类别索引,并转换为列表
    predictions = logits.argmax(-1).squeeze().tolist()
    # 如果预测结果是整数,则转换为列表
    if isinstance(predictions, int):
        predictions = [predictions]
    # 将预测结果转换为类别标签
    predictions = [model.config.id2label[p] for p in predictions]
    # 返回预测结果
    return predictions
# 为文档中的每个块添加列数计数
def add_column_counts(doc, doc_blocks, model, batch_size):
    # 按照批量大小遍历文档块
    for i in range(0, len(doc_blocks), batch_size):
        # 创建当前批量的索引范围
        batch = range(i, min(i + batch_size, len(doc_blocks)))
        # 初始化空列表用于存储 RGB 图像、边界框和单词
        rgb_images = []
        bboxes = []
        words = []
        # 遍历当前批量的页码
        for pnum in batch:
            # 获取推理数据:RGB 图像、页边界框和页单词
            page = doc[pnum]
            rgb_image, page_bboxes, page_words = get_inference_data(page, doc_blocks[pnum])
            rgb_images.append(rgb_image)
            bboxes.append(page_bboxes)
            words.append(page_words)
        # 进行批量推理,获取预测结果
        predictions = batch_inference(rgb_images, bboxes, words, model)
        # 将预测结果与页码对应,更新文档块的列数计数
        for pnum, prediction in zip(batch, predictions):
            doc_blocks[pnum].column_count = prediction
# 对文档块进行排序
def order_blocks(doc, doc_blocks: List[Page], model, batch_size=settings.ORDERER_BATCH_SIZE):
    # 添加列数计数
    add_column_counts(doc, doc_blocks, model, batch_size)
    # 遍历文档块中的每一页
    for page_blocks in doc_blocks:
        # 如果该页的列数大于1
        if page_blocks.column_count > 1:
            # 根据位置重新排序块
            split_pos = page_blocks.x_start + page_blocks.width / 2
            left_blocks = []
            right_blocks = []
            # 遍历该页的每个块
            for block in page_blocks.blocks:
                # 根据位置将块分为左右两部分
                if block.x_start <= split_pos:
                    left_blocks.append(block)
                else:
                    right_blocks.append(block)
            # 更新该页的块顺序
            page_blocks.blocks = left_blocks + right_blocks
    # 返回排序后的文档块
    return doc_blocks

.\marker\marker\postprocessors\editor.py

# 导入必要的库
from collections import defaultdict, Counter
from itertools import chain
from typing import Optional
# 导入 transformers 库中的 AutoTokenizer 类
from transformers import AutoTokenizer
# 导入 settings 模块中的 settings 变量
from marker.settings import settings
# 导入 torch 库
import torch
import torch.nn.functional as F
# 导入 marker.postprocessors.t5 模块中的 T5ForTokenClassification 类和 byt5_tokenize 函数
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize
# 定义加载编辑模型的函数
def load_editing_model():
    # 如果未启用编辑模型,则返回 None
    if not settings.ENABLE_EDITOR_MODEL:
        return None
    # 从预训练模型中加载 T5ForTokenClassification 模型
    model = T5ForTokenClassification.from_pretrained(
            settings.EDITOR_MODEL_NAME,
            torch_dtype=settings.MODEL_DTYPE,
        ).to(settings.TORCH_DEVICE_MODEL)
    model.eval()
    # 配置模型的标签映射
    model.config.label2id = {
        "equal": 0,
        "delete": 1,
        "newline-1": 2,
        "space-1": 3,
    }
    model.config.id2label = {v: k for k, v in model.config.label2id.items()}
    return model
# 定义编辑全文的函数
def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_size: int = settings.EDITOR_BATCH_SIZE):
    # 如果模型为空,则直接返回原始文本和空字典
    if not model:
        return text, {}
    # 对文本进行 tokenization
    tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
    input_ids = tokenized["input_ids"]
    char_token_lengths = tokenized["char_token_lengths"]
    # 准备 token_masks 列表
    token_masks = []
    # 遍历输入的 input_ids,按照 batch_size 进行分批处理
    for i in range(0, len(input_ids), batch_size):
        # 从 tokenized 中获取当前 batch 的 input_ids
        batch_input_ids = tokenized["input_ids"][i: i + batch_size]
        # 将 batch_input_ids 转换为 torch 张量,并指定设备为 model 的设备
        batch_input_ids = torch.tensor(batch_input_ids, device=model.device)
        # 从 tokenized 中获取当前 batch 的 attention_mask
        batch_attention_mask = tokenized["attention_mask"][i: i + batch_size]
        # 将 batch_attention_mask 转换为 torch 张量,并指定设备为 model 的设备
        batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device)
        
        # 进入推理模式
        with torch.inference_mode():
            # 使用模型进行预测
            predictions = model(batch_input_ids, attention_mask=batch_attention_mask)
        # 将预测结果 logits 移动到 CPU 上
        logits = predictions.logits.cpu()
        # 如果最大概率小于阈值,则假设为不良预测
        # 我们希望保守一点,不要对文本进行过多编辑
        probs = F.softmax(logits, dim=-1)
        max_prob = torch.max(probs, dim=-1)
        cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
        labels = logits.argmax(-1)
        labels[cutoff_prob] = model.config.label2id["equal"]
        labels = labels.squeeze().tolist()
        if len(labels) == settings.EDITOR_MAX_LENGTH:
            labels = [labels]
        labels = list(chain.from_iterable(labels))
        token_masks.extend(labels)
    # 文本中的字符列表
    flat_input_ids = list(chain.from_iterable(input_ids)
    # 去除特殊标记 0,1。保留未知标记,尽管它不应该被使用
    assert len(token_masks) == len(flat_input_ids)
    token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
    # 确保 token_masks 的长度与文本编码后的长度相等
    assert len(token_masks) == len(list(text.encode("utf-8")))
    # 统计编辑次数的字典
    edit_stats = defaultdict(int)
    # 输出文本列表
    out_text = []
    # 起始位置
    start = 0
    # 遍历文本中的每个字符及其索引
    for i, char in enumerate(text):
        # 获取当前字符对应的 token 长度
        char_token_length = char_token_lengths[i]
        # 获取当前字符对应的 token 的 mask
        masks = token_masks[start: start + char_token_length]
        # 将 mask 转换为标签
        labels = [model.config.id2label[mask] for mask in masks]
        # 如果所有标签都是 "delete",则执行删除操作
        if all(l == "delete" for l in labels):
            # 如果删除的是空格,则保留,否则忽略
            if char.strip():
                out_text.append(char)
            else:
                edit_stats["delete"] += 1
        # 如果标签为 "newline-1",则添加换行符
        elif labels[0] == "newline-1":
            out_text.append("\n")
            out_text.append(char)
            edit_stats["newline-1"] += 1
        # 如果标签为 "space-1",则添加空格
        elif labels[0] == "space-1":
            out_text.append(" ")
            out_text.append(char)
            edit_stats["space-1"] += 1
        # 如果标签为其他情况,则保留字符
        else:
            out_text.append(char)
            edit_stats["equal"] += 1
        # 更新下一个字符的起始位置
        start += char_token_length
    # 将处理后的文本列表转换为字符串
    out_text = "".join(out_text)
    # 返回处理后的文本及编辑统计信息
    return out_text, edit_stats

Marker 源码解析(二)(3)https://developer.aliyun.com/article/1483806

相关文章
|
4天前
|
Linux 网络安全 Windows
网络安全笔记-day8,DHCP部署_dhcp搭建部署,源码解析
网络安全笔记-day8,DHCP部署_dhcp搭建部署,源码解析
|
5天前
HuggingFace Tranformers 源码解析(4)
HuggingFace Tranformers 源码解析
6 0
|
5天前
HuggingFace Tranformers 源码解析(3)
HuggingFace Tranformers 源码解析
7 0
|
5天前
|
开发工具 git
HuggingFace Tranformers 源码解析(2)
HuggingFace Tranformers 源码解析
8 0
|
5天前
|
并行计算
HuggingFace Tranformers 源码解析(1)
HuggingFace Tranformers 源码解析
11 0
|
6天前
PandasTA 源码解析(二十三)
PandasTA 源码解析(二十三)
43 0
|
6天前
PandasTA 源码解析(二十二)(3)
PandasTA 源码解析(二十二)
35 0
|
6天前
PandasTA 源码解析(二十二)(2)
PandasTA 源码解析(二十二)
42 2
|
6天前
PandasTA 源码解析(二十二)(1)
PandasTA 源码解析(二十二)
33 0
|
6天前
PandasTA 源码解析(二十一)(4)
PandasTA 源码解析(二十一)
24 1

推荐镜像

更多