Marker 源码解析(一)(2)

简介: Marker 源码解析(一)

Marker 源码解析(一)(1)https://developer.aliyun.com/article/1483776

.\marker\marker\cleaners\code.py

# 导入所需的模块和类
from marker.schema import Span, Line, Page
import re
from typing import List
import fitz as pymupdf
# 判断代码行的长度是否符合阈值
def is_code_linelen(lines, thresh=60):
    # 计算所有代码行中的字母数字字符总数
    total_alnum_chars = sum(len(re.findall(r'\w', line.prelim_text)) for line in lines)
    # 计算总行数
    total_newlines = max(len(lines) - 1, 1)
    # 如果没有字母数字字符,则返回 False
    if total_alnum_chars == 0:
        return False
    # 计算字母数字字符与行数的比率
    ratio = total_alnum_chars / total_newlines
    return ratio < thresh
# 统计代码行中包含注释的行数
def comment_count(lines):
    # 定义匹配注释的正则表达式模式
    pattern = re.compile(r"^(//|#|'|--|/\*|'''|\"\"\"|--\[\[|<!--|%|%{|\(\*)")
    # 统计匹配到的注释行数
    return sum([1 for line in lines if pattern.match(line)])
# 识别代码块
def identify_code_blocks(blocks: List[Page]):
    # 初始化代码块计数和字体信息
    code_block_count = 0
    font_info = None
    # 遍历每个页面
    for p in blocks:
        # 获取页面的字体统计信息
        stats = p.get_font_stats()
        # 如果是第一页,则将字体信息初始化为当前页面的字体信息
        if font_info is None:
            font_info = stats
        else:
            # 否则将当前页面的字体信息与之前页面的字体信息相加
            font_info += stats
    try:
        # 获取最常见的字体
        most_common_font = font_info.most_common(1)[0][0]
    except IndexError:
        # 如果找不到最常见的字体,则打印提示信息
        print(f"Could not find most common font")
        most_common_font = None
    # 初始化最后一个代码块
    last_block = None
    # 遍历每一页的文本块
    for page in blocks:
        try:
            # 获取当前页最小行的起始位置
            min_start = page.get_min_line_start()
        except IndexError:
            # 如果出现索引错误,则跳过当前页
            continue
        # 遍历当前页的文本块
        for block in page.blocks:
            # 如果当前文本块的类型不是"Text",则跳过
            if block.most_common_block_type() != "Text":
                last_block = block
                continue
            # 初始化用于判断是否为代码的变量
            is_indent = []
            line_fonts = []
            # 遍历当前文本块的每一行
            for line in block.lines:
                # 获取每行中的字体信息
                fonts = [span.font for span in line.spans]
                line_fonts += fonts
                # 获取每行的起始位置
                line_start = line.bbox[0]
                # 判断当前行是否缩进
                if line_start > min_start:
                    is_indent.append(True)
                else:
                    is_indent.append(False)
            # 统计每个文本块中的注释行数
            comment_lines = comment_count([line.prelim_text for line in block.lines])
            # 判断当前文本块是否为代码块
            is_code = [
                len(block.lines) > 3,  # 文本块行数大于3
                sum([f != most_common_font for f in line_fonts]) > len(line_fonts) * .8,  # 至少80%的字体不是最常见的字体,因为代码通常使用与主体文本不同的字体
                is_code_linelen(block.lines),  # 判断代码行长度是否符合规范
                (
                    sum(is_indent) > len(block.lines) * .2  # 20%的行有缩进
                    or
                    comment_lines > len(block.lines) * .2  # 20%的行是注释
                 ), 
            ]
            # 检查前一个文本块是否为代码块,当前文本块是否有缩进
            is_code_prev = [
                last_block and last_block.most_common_block_type() == "Code",  # 前一个文本块是代码块
                sum(is_indent) >= len(block.lines) * .8  # 至少80%的行有缩进
            ]
            # 如果当前文本块被判断为代码块,增加代码块计数并设置文本块类型为"Code"
            if all(is_code) or all(is_code_prev):
                code_block_count += 1
                block.set_block_type("Code")
            last_block = block
    # 返回代码块计数
    return code_block_count
# 缩进代码块,将每个代码块的内容整理成一个新的 Span 对象
def indent_blocks(blocks: List[Page]):
    # 计数器,用于生成新的 Span 对象的 ID
    span_counter = 0
    # 遍历每一页的代码块
    for page in blocks:
        for block in page.blocks:
            # 获取当前代码块中所有行的块类型
            block_types = [span.block_type for line in block.lines for span in line.spans]
            # 如果当前代码块不是代码块,则跳过
            if "Code" not in block_types:
                continue
            # 初始化空列表用于存储处理后的行数据
            lines = []
            # 初始化最左边界和字符宽度
            min_left = 1000  # will contain x- coord of column 0
            col_width = 0  # width of 1 char
            # 遍历当前代码块的每一行
            for line in block.lines:
                text = ""
                # 更新最左边界
                min_left = min(line.bbox[0], min_left)
                # 拼接每行的文本内容
                for span in line.spans:
                    if col_width == 0 and len(span.text) > 0:
                        col_width = (span.bbox[2] - span.bbox[0]) / len(span.text)
                    text += span.text
                lines.append((pymupdf.Rect(line.bbox), text))
            # 初始化空字符串用于存储处理后的代码块文本
            block_text = ""
            blank_line = False
            # 遍历处理后的每一行
            for line in lines:
                text = line[1]
                prefix = " " * int((line[0].x0 - min_left) / col_width)
                current_line_blank = len(text.strip()) == 0
                # 如果当前行和上一行都是空行,则跳过
                if blank_line and current_line_blank:
                    continue
                # 拼接处理后的代码块文本
                block_text += prefix + text + "\n"
                blank_line = current_line_blank
            # 创建新的 Span 对象,用于替换原有的代码块
            new_span = Span(
                text=block_text,
                bbox=block.bbox,
                color=block.lines[0].spans[0].color,
                span_id=f"{span_counter}_fix_code",
                font=block.lines[0].spans[0].font,
                block_type="Code"
            )
            span_counter += 1
            # 替换原有的代码块内容为新的 Span 对象
            block.lines = [Line(spans=[new_span], bbox=block.bbox)]

.\marker\marker\cleaners\equations.py

# 导入所需的库
import io
from copy import deepcopy
from functools import partial
from typing import List
import torch
from texify.inference import batch_inference
from texify.model.model import load_model
from texify.model.processor import load_processor
import re
from PIL import Image, ImageDraw
# 导入自定义模块
from marker.bbox import should_merge_blocks, merge_boxes
from marker.debug.data import dump_equation_debug_data
from marker.settings import settings
from marker.schema import Page, Span, Line, Block, BlockType
import os
# 设置环境变量,禁用 tokenizers 的并行处理
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 加载处理器
processor = load_processor()
# 加载 Texify 模型
def load_texify_model():
    texify_model = load_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
    return texify_model
# 创建遮罩区域
def mask_bbox(png_image, bbox, selected_bboxes):
    # 创建一个与图片大小相同的灰度图像
    mask = Image.new('L', png_image.size, 0)  # 'L' mode for grayscale
    draw = ImageDraw.Draw(mask)
    first_x = bbox[0]
    first_y = bbox[1]
    bbox_height = bbox[3] - bbox[1]
    bbox_width = bbox[2] - bbox[0]
    for box in selected_bboxes:
        # 将框适配到选定区域
        new_box = (box[0] - first_x, box[1] - first_y, box[2] - first_x, box[3] - first_y)
        # 将遮罩适配到图像边界与 PDF 边界
        resized = (
           new_box[0] / bbox_width * png_image.size[0],
           new_box[1] / bbox_height * png_image.size[1],
           new_box[2] / bbox_width * png_image.size[0],
           new_box[3] / bbox_height * png_image.size[1]
        )
        draw.rectangle(resized, fill=255)
    # 通过遮罩创建结果图像
    result = Image.composite(png_image, Image.new('RGBA', png_image.size, 'white'), mask)
    return result
# 获取遮罩后的图像
def get_masked_image(page, bbox, selected_bboxes):
    # 获取页面的像素图
    pix = page.get_pixmap(dpi=settings.TEXIFY_DPI, clip=bbox)
    png = pix.pil_tobytes(format="PNG")
    png_image = Image.open(io.BytesIO(png))
    # 创建遮罩后的图像
    png_image = mask_bbox(png_image, bbox, selected_bboxes)
    png_image = png_image.convert("RGB")
    return png_image
# 批量处理 LaTeX 图像,根据指定的区域长度重新格式化,使用指定的模型进行转换
def get_latex_batched(images, reformat_region_lens, texify_model, batch_size):
    # 如果图像列表为空,则返回空列表
    if len(images) == 0:
        return []
    # 初始化预测结果列表
    predictions = [""] * len(images)
    # 按批次处理图像
    for i in range(0, len(images), batch_size):
        # 动态设置最大长度以节省推理时间
        min_idx = i
        max_idx = min(min_idx + batch_size, len(images))
        max_length = max(reformat_region_lens[min_idx:max_idx])
        max_length = min(max_length, settings.TEXIFY_MODEL_MAX)
        max_length += settings.TEXIFY_TOKEN_BUFFER
        # 对图像批次进行推理
        model_output = batch_inference(images[min_idx:max_idx], texify_model, processor, max_tokens=max_length)
        # 遍历模型输出
        for j, output in enumerate(model_output):
            token_count = get_total_texify_tokens(output)
            # 如果 token 数量超过最大长度减一,则将输出置为空字符串
            if token_count >= max_length - 1:
                output = ""
            # 计算图像索引
            image_idx = i + j
            predictions[image_idx] = output
    return predictions
# 获取文本中的总 LaTeX token 数量
def get_total_texify_tokens(text):
    tokenizer = processor.tokenizer
    tokens = tokenizer(text)
    return len(tokens["input_ids"])
# 查找页面中的数学公式区域
def find_page_equation_regions(pnum, page, block_types):
    i = 0
    # 提取数学公式区域的边界框
    equation_boxes = [b.bbox for b in block_types[pnum] if b.block_type == "Formula"]
    reformatted_blocks = set()
    reformat_regions = []
    block_lens = []
    return reformat_regions, block_lens
# 获取区域内的边界框
def get_bboxes_for_region(page, region):
    bboxes = []
    merged_box = None
    for idx in region:
        block = page.blocks[idx]
        bbox = block.bbox
        if merged_box is None:
            merged_box = bbox
        else:
            merged_box = merge_boxes(merged_box, bbox)
        bboxes.append(bbox)
    return bboxes, merged_box
# 替换页面块中的文本块为 LaTeX
def replace_blocks_with_latex(page_blocks: Page, merged_boxes, reformat_regions, predictions, pnum):
    new_blocks = []
    converted_spans = []
    current_region = 0
    idx = 0
    success_count = 0
    fail_count = 0
    # 当索引小于页面块列表的长度时,继续循环
    while idx < len(page_blocks.blocks):
        # 获取当前索引对应的页面块
        block = page_blocks.blocks[idx]
        # 如果当前区域索引超过重新格式化区域列表的长度,或者当前索引小于重新格式化区域的起始索引
        if current_region >= len(reformat_regions) or idx < reformat_regions[current_region][0]:
            # 将当前页面块添加到新的块列表中
            new_blocks.append(block)
            # 索引加一
            idx += 1
            # 继续下一次循环
            continue
        # 获取重新格式化区域的原始文本
        orig_block_text = " ".join([page_blocks.blocks[i].prelim_text for i in reformat_regions[current_region]])
        # 获取预测的 LaTeX 文本
        latex_text = predictions[current_region]
        # 定义条件列表
        conditions = [
            len(latex_text) > 0,
            get_total_texify_tokens(latex_text) < settings.TEXIFY_MODEL_MAX,  # 确保没有达到总体令牌最大值
            len(latex_text) > len(orig_block_text) * .8,
            len(latex_text.strip()) > 0
        ]
        # 更新索引为重新格式化区域的结束索引加一
        idx = reformat_regions[current_region][-1] + 1
        # 如果条件不满足
        if not all(conditions):
            # 失败计数加一
            fail_count += 1
            # 将转换后的区域添加为 None
            converted_spans.append(None)
            # 将重新格式化区域中的页面块添加到新的块列表中
            for i in reformat_regions[current_region]:
                new_blocks.append(page_blocks.blocks[i])
        else:
            # 成功计数加一
            success_count += 1
            # 创建一个包含 LaTeX 文本的行对象
            block_line = Line(
                spans=[
                    Span(
                        text=latex_text,
                        bbox=merged_boxes[current_region],
                        span_id=f"{pnum}_{idx}_fixeq",
                        font="Latex",
                        color=0,
                        block_type="Formula"
                    )
                ],
                bbox=merged_boxes[current_region]
            )
            # 深拷贝第一个 span 对象并添加到转换后的区域列表中
            converted_spans.append(deepcopy(block_line.spans[0]))
            # 创建一个新的块对象,包含上述行对象
            new_blocks.append(Block(
                lines=[block_line],
                bbox=merged_boxes[current_region],
                pnum=pnum
            ))
        # 更新当前区域索引
        current_region += 1
    # 返回新的块列表、成功计数、失败计数和转换后的区域列表
    return new_blocks, success_count, fail_count, converted_spans
def replace_equations(doc, blocks: List[Page], block_types: List[List[BlockType]], texify_model, batch_size=settings.TEXIFY_BATCH_SIZE):
    # 初始化未成功 OCR 的计数和成功 OCR 的计数
    unsuccessful_ocr = 0
    successful_ocr = 0
    # 查找潜在的方程区域,并计算每个区域中文本的长度
    reformat_regions = []
    reformat_region_lens = []
    for pnum, page in enumerate(blocks):
        regions, region_lens = find_page_equation_regions(pnum, page, block_types)
        reformat_regions.append(regions)
        reformat_region_lens.append(region_lens)
    # 计算方程的总数
    eq_count = sum([len(x) for x in reformat_regions])
    # 获取每个区域的图像
    flat_reformat_region_lens = [item for sublist in reformat_region_lens for item in sublist]
    images = []
    merged_boxes = []
    for page_idx, reformat_regions_page in enumerate(reformat_regions):
        page_obj = doc[page_idx]
        for reformat_region in reformat_regions_page:
            bboxes, merged_box = get_bboxes_for_region(blocks[page_idx], reformat_region)
            png_image = get_masked_image(page_obj, merged_box, bboxes)
            images.append(png_image)
            merged_boxes.append(merged_box)
    # 进行批量预测
    predictions = get_latex_batched(images, flat_reformat_region_lens, texify_model, batch_size)
    # 替换区域中的文本块为预测结果
    page_start = 0
    converted_spans = []
    # 遍历重排后的区域列表,获取每一页的预测结果和合并后的框
    for page_idx, reformat_regions_page in enumerate(reformat_regions):
        # 获取当前页的预测结果和合并后的框
        page_predictions = predictions[page_start:page_start + len(reformat_regions_page)]
        page_boxes = merged_boxes[page_start:page_start + len(reformat_regions_page)]
        # 替换块内容为 LaTeX,并返回新的块列表、成功计数、失败计数和转换的跨度
        new_page_blocks, success_count, fail_count, converted_span = replace_blocks_with_latex(
            blocks[page_idx],
            page_boxes,
            reformat_regions_page,
            page_predictions,
            page_idx
        )
        # 将转换的跨度添加到列表中
        converted_spans.extend(converted_span)
        # 更新当前页的块列表
        blocks[page_idx].blocks = new_page_blocks
        # 更新页起始位置
        page_start += len(reformat_regions_page)
        # 更新成功 OCR 计数和失败 OCR 计数
        successful_ocr += success_count
        unsuccessful_ocr += fail_count
    # 如果调试模式开启,输出转换结果以供比较
    dump_equation_debug_data(doc, images, converted_spans)
    # 返回更新后的块列表和 OCR 结果统计信息
    return blocks, {"successful_ocr": successful_ocr, "unsuccessful_ocr": unsuccessful_ocr, "equations": eq_count}

.\marker\marker\cleaners\headers.py

# 导入所需的模块
import re
from collections import Counter, defaultdict
from itertools import chain
from thefuzz import fuzz
from sklearn.cluster import DBSCAN
import numpy as np
from marker.schema import Page, FullyMergedBlock
from typing import List, Tuple
# 过滤出现频率高于给定阈值的文本块
def filter_common_elements(lines, page_count):
    # 提取所有文本内容
    text = [s.text for line in lines for s in line.spans if len(s.text) > 4]
    # 统计文本内容出现的次数
    counter = Counter(text)
    # 选取出现频率高于阈值的文本内容
    common = [k for k, v in counter.items() if v > page_count * .6]
    # 获取包含常见文本内容的文本块的 span_id
    bad_span_ids = [s.span_id for line in lines for s in line.spans if s.text in common]
    return bad_span_ids
# 过滤页眉页脚文本块
def filter_header_footer(all_page_blocks, max_selected_lines=2):
    first_lines = []
    last_lines = []
    for page in all_page_blocks:
        nonblank_lines = page.get_nonblank_lines()
        first_lines.extend(nonblank_lines[:max_selected_lines])
        last_lines.extend(nonblank_lines[-max_selected_lines:])
    # 获取页眉页脚文本块的 span_id
    bad_span_ids = filter_common_elements(first_lines, len(all_page_blocks))
    bad_span_ids += filter_common_elements(last_lines, len(all_page_blocks))
    return bad_span_ids
# 对文本块进行分类
def categorize_blocks(all_page_blocks: List[Page]):
    # 提取所有非空文本块的 span
    spans = list(chain.from_iterable([p.get_nonblank_spans() for p in all_page_blocks]))
    # 构建特征矩阵
    X = np.array(
        [(*s.bbox, len(s.text)) for s in spans]
    )
    # 使用 DBSCAN 进行聚类
    dbscan = DBSCAN(eps=.1, min_samples=5)
    dbscan.fit(X)
    labels = dbscan.labels_
    label_chars = defaultdict(int)
    for i, label in enumerate(labels):
        label_chars[label] += len(spans[i].text)
    # 选择出现次数最多的类别作为主要类别
    most_common_label = None
    most_chars = 0
    for i in label_chars.keys():
        if label_chars[i] > most_chars:
            most_common_label = i
            most_chars = label_chars[i]
    # 将非主要类别标记为 1
    labels = [0 if label == most_common_label else 1 for label in labels]
    # 获取非主要类别的文本块的 span_id
    bad_span_ids = [spans[i].span_id for i in range(len(spans)) if labels[i] == 1]
    return bad_span_ids
# 替换字符串开头的数字
def replace_leading_trailing_digits(string, replacement):
    string = re.sub(r'^\d+', replacement, string)
    # 使用正则表达式替换字符串中最后的数字
    string = re.sub(r'\d+$', replacement, string)
    # 返回替换后的字符串
    return string
# 定义一个函数,用于查找重叠元素
def find_overlap_elements(lst: List[Tuple[str, int]], string_match_thresh=.9, min_overlap=.05) -> List[int]:
    # 初始化一个列表,用于存储符合条件的元素
    result = []
    # 从输入列表中提取所有元组的第一个元素,即标题
    titles = [l[0] for l in lst]
    # 遍历输入列表中的元素
    for i, (str1, id_num) in enumerate(lst):
        overlap_count = 0  # 计算至少80%重叠的元素数量
        # 再次遍历标题列表,检查元素之间的相似度
        for j, str2 in enumerate(titles):
            if i != j and fuzz.ratio(str1, str2) >= string_match_thresh * 100:
                overlap_count += 1
        # 检查元素是否与至少50%的其他元素重叠
        if overlap_count >= max(3.0, len(lst) * min_overlap):
            result.append(id_num)
    return result
# 定义一个函数,用于过滤常见标题
def filter_common_titles(merged_blocks: List[FullyMergedBlock]) -> List[FullyMergedBlock]:
    titles = []
    # 遍历合并块列表中的块
    for i, block in enumerate(merged_blocks):
        # 如果块类型为"Title"或"Section-header"
        if block.block_type in ["Title", "Section-header"]:
            text = block.text
            # 如果文本以"#"开头,则去除所有"#"
            if text.strip().startswith("#"):
                text = re.sub(r'#+', '', text)
            text = text.strip()
            # 去除文本开头和结尾的页码
            text = replace_leading_trailing_digits(text, "").strip()
            titles.append((text, i))
    # 查找重叠标题的块的索引
    bad_block_ids = find_overlap_elements(titles)
    new_blocks = []
    # 遍历合并块列表中的块
    for i, block in enumerate(merged_blocks):
        # 如果块的索引在重叠块的索引列表中,则跳过该块
        if i in bad_block_ids:
            continue
        new_blocks.append(block)
    return new_blocks

Marker 源码解析(一)(3)https://developer.aliyun.com/article/1483782

相关文章
|
1天前
|
缓存 Java 开发者
10个点介绍SpringBoot3工作流程与核心组件源码解析
Spring Boot 是Java开发中100%会使用到的框架,开发者不仅要熟练使用,对其中的核心源码也要了解,正所谓知其然知其所以然,V 哥建议小伙伴们在学习的过程中,一定要去研读一下源码,这有助于你在开发中游刃有余。欢迎一起交流学习心得,一起成长。
|
3天前
|
消息中间件 缓存 前端开发
Netty消息编码及发送源码解析
Netty消息编码及发送源码解析
6 0
|
6天前
|
XML 人工智能 Java
Spring Bean名称生成规则(含源码解析、自定义Spring Bean名称方式)
Spring Bean名称生成规则(含源码解析、自定义Spring Bean名称方式)
|
14天前
yolo-world 源码解析(六)(2)
yolo-world 源码解析(六)
45 0
|
14天前
yolo-world 源码解析(六)(1)
yolo-world 源码解析(六)
43 0
|
14天前
yolo-world 源码解析(五)(4)
yolo-world 源码解析(五)
47 0
|
14天前
yolo-world 源码解析(五)(1)
yolo-world 源码解析(五)
61 0
|
14天前
yolo-world 源码解析(二)(2)
yolo-world 源码解析(二)
58 0
|
14天前
Marker 源码解析(二)(3)
Marker 源码解析(二)
19 0

推荐镜像

更多