Transformers 4.37 中文文档(五)(6)https://developer.aliyun.com/article/1565246
文档问答
原始文本:
huggingface.co/docs/transformers/v4.37.2/en/tasks/document_question_answering
文档问答,也称为文档视觉问答,是一个涉及提供关于文档图像的问题的答案的任务。支持此任务的模型的输入通常是图像和问题的组合,输出是用自然语言表达的答案。这些模型利用多种模态,包括文本、单词的位置(边界框)和图像本身。
本指南说明了如何:
- 在 DocVQA 数据集 上对 LayoutLMv2 进行微调。
- 使用您微调的模型进行推断。
本教程中演示的任务由以下模型架构支持:
LayoutLM, LayoutLMv2, LayoutLMv3
LayoutLMv2 通过在标记的最终隐藏状态之上添加一个问题-回答头来解决文档问答任务,以预测答案的开始和结束标记的位置。换句话说,这个问题被视为抽取式问答:在给定上下文的情况下,提取哪个信息片段回答问题。上下文来自 OCR 引擎的输出,这里是 Google 的 Tesseract。
在开始之前,请确保您已安装所有必要的库。LayoutLMv2 依赖于 detectron2、torchvision 和 tesseract。
pip install -q transformers datasets
pip install 'git+https://github.com/facebookresearch/detectron2.git' pip install torchvision
sudo apt install tesseract-ocr pip install -q pytesseract
安装完所有依赖项后,请重新启动您的运行时。
我们鼓励您与社区分享您的模型。登录到您的 Hugging Face 账户,将其上传到 🤗 Hub。在提示时,输入您的令牌以登录:
>>> from huggingface_hub import notebook_login >>> notebook_login()
让我们定义一些全局变量。
>>> model_checkpoint = "microsoft/layoutlmv2-base-uncased" >>> batch_size = 4
加载数据
在本指南中,我们使用了一个小样本的预处理 DocVQA,您可以在 🤗 Hub 上找到。如果您想使用完整的 DocVQA 数据集,您可以在 DocVQA 主页 上注册并下载。如果您这样做了,要继续本指南,请查看 如何将文件加载到 🤗 数据集。
>>> from datasets import load_dataset >>> dataset = load_dataset("nielsr/docvqa_1200_examples") >>> dataset DatasetDict({ train: Dataset({ features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'], num_rows: 1000 }) test: Dataset({ features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'], num_rows: 200 }) })
如您所见,数据集已经分为训练集和测试集。查看一个随机示例,以熟悉特征。
>>> dataset["train"].features
这里是各个字段代表的含义:
id
:示例的 idimage
:包含文档图像的 PIL.Image.Image 对象query
:问题字符串 - 自然语言提出的问题,可以是多种语言answers
:人类注释者提供的正确答案列表words
和bounding_boxes
:OCR 的结果,我们这里不会使用answer
:由另一个模型匹配的答案,我们这里不会使用
让我们只保留英文问题,并且删除包含另一个模型预测的 answer
特征。我们还将从注释者提供的答案集中取第一个答案。或者,您可以随机抽样。
>>> updated_dataset = dataset.map(lambda example: {"question": example["query"]["en"]}, remove_columns=["query"]) >>> updated_dataset = updated_dataset.map( ... lambda example: {"answer": example["answers"][0]}, remove_columns=["answer", "answers"] ... )
请注意,本指南中使用的 LayoutLMv2 检查点已经训练了 max_position_embeddings = 512
(您可以在 检查点的 config.json
文件 中找到此信息)。我们可以截断示例,但为了避免答案可能在大型文档的末尾并最终被截断的情况,这里我们将删除几个示例,其中嵌入可能会超过 512。如果您的数据集中大多数文档很长,您可以实现一个滑动窗口策略 - 详细信息请查看 此笔记本。
>>> updated_dataset = updated_dataset.filter(lambda x: len(x["words"]) + len(x["question"].split()) < 512)
此时让我们还从数据集中删除 OCR 特征。这些是 OCR 的结果,用于微调不同的模型。如果我们想要使用它们,它们仍然需要一些处理,因为它们不符合我们在本指南中使用的模型的输入要求。相反,我们可以在原始数据上同时使用 LayoutLMv2Processor 进行 OCR 和标记化。这样我们将得到与模型预期输入匹配的输入。如果您想手动处理图像,请查看LayoutLMv2
模型文档以了解模型期望的输入格式。
>>> updated_dataset = updated_dataset.remove_columns("words") >>> updated_dataset = updated_dataset.remove_columns("bounding_boxes")
最后,如果我们不查看一个图像示例,数据探索就不会完成。
>>> updated_dataset["train"][11]["image"]
预处理数据
文档问答任务是一个多模态任务,您需要确保每种模态的输入都按照模型的期望进行预处理。让我们从加载 LayoutLMv2Processor 开始,它内部结合了一个可以处理图像数据的图像处理器和一个可以编码文本数据的标记器。
>>> from transformers import AutoProcessor >>> processor = AutoProcessor.from_pretrained(model_checkpoint)
预处理文档图像
首先,让我们通过处理器中的image_processor
来为模型准备文档图像。默认情况下,图像处理器将图像调整大小为 224x224,确保它们具有正确的颜色通道顺序,应用 OCR 与 tesseract 获取单词和归一化边界框。在本教程中,所有这些默认设置正是我们所需要的。编写一个函数,将默认图像处理应用于一批图像,并返回 OCR 的结果。
>>> image_processor = processor.image_processor >>> def get_ocr_words_and_boxes(examples): ... images = [image.convert("RGB") for image in examples["image"]] ... encoded_inputs = image_processor(images) ... examples["image"] = encoded_inputs.pixel_values ... examples["words"] = encoded_inputs.words ... examples["boxes"] = encoded_inputs.boxes ... return examples
为了以快速的方式将此预处理应用于整个数据集,请使用map。
>>> dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2)
预处理文本数据
一旦我们对图像应用了 OCR,我们需要对数据集的文本部分进行编码,以准备模型使用。这涉及将我们在上一步中获得的单词和框转换为标记级别的input_ids
、attention_mask
、token_type_ids
和bbox
。对于文本预处理,我们将需要处理器中的tokenizer
。
>>> tokenizer = processor.tokenizer
除了上述预处理之外,我们还需要为模型添加标签。对于🤗 Transformers 中的xxxForQuestionAnswering
模型,标签包括start_positions
和end_positions
,指示答案的起始和结束的标记在哪里。
让我们从这里开始。定义一个辅助函数,可以在较大的列表(单词列表)中找到一个子列表(答案拆分为单词)。
此函数将接受两个列表作为输入,words_list
和answer_list
。然后,它将遍历words_list
,检查words_list
中当前单词(words_list[i])是否等于answer_list
的第一个单词(answer_list[0]),以及从当前单词开始且与answer_list
相同长度的words_list
子列表是否等于answer_list
。如果这个条件为真,表示找到了匹配,函数将记录匹配及其起始索引(idx)和结束索引(idx + len(answer_list) - 1)。如果找到了多个匹配,函数将仅返回第一个。如果没有找到匹配,函数将返回(None
,0 和 0)。
>>> def subfinder(words_list, answer_list): ... matches = [] ... start_indices = [] ... end_indices = [] ... for idx, i in enumerate(range(len(words_list))): ... if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list: ... matches.append(answer_list) ... start_indices.append(idx) ... end_indices.append(idx + len(answer_list) - 1) ... if matches: ... return matches[0], start_indices[0], end_indices[0] ... else: ... return None, 0, 0
为了说明此函数如何找到答案的位置,让我们在一个示例上使用它:
>>> example = dataset_with_ocr["train"][1] >>> words = [word.lower() for word in example["words"]] >>> match, word_idx_start, word_idx_end = subfinder(words, example["answer"].lower().split()) >>> print("Question: ", example["question"]) >>> print("Words:", words) >>> print("Answer: ", example["answer"]) >>> print("start_index", word_idx_start) >>> print("end_index", word_idx_end) Question: Who is in cc in this letter? Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', '«short', 'cigarette,', 'tobacco', 'section', '30', 'mm.', '«extremely', 'fast', 'buming', 'cigarette.', '«novel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', '«more', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498'] Answer: T.F. Riehl start_index 17 end_index 18
然而,一旦示例被编码,它们将看起来像这样:
>>> encoding = tokenizer(example["question"], example["words"], example["boxes"]) >>> tokenizer.decode(encoding["input_ids"]) [CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ...
我们需要找到编码输入中答案的位置。
token_type_ids
告诉我们哪些标记属于问题,哪些属于文档的单词。tokenizer.cls_token_id
将帮助找到输入开头的特殊标记。word_ids
将帮助将原始words
中找到的答案与完全编码输入中的相同答案进行匹配,并确定编码输入中答案的起始/结束位置。
有了这个想法,让我们创建一个函数来对数据集中的一批示例进行编码:
>>> def encode_dataset(examples, max_length=512): ... questions = examples["question"] ... words = examples["words"] ... boxes = examples["boxes"] ... answers = examples["answer"] ... # encode the batch of examples and initialize the start_positions and end_positions ... encoding = tokenizer(questions, words, boxes, max_length=max_length, padding="max_length", truncation=True) ... start_positions = [] ... end_positions = [] ... # loop through the examples in the batch ... for i in range(len(questions)): ... cls_index = encoding["input_ids"][i].index(tokenizer.cls_token_id) ... # find the position of the answer in example's words ... words_example = [word.lower() for word in words[i]] ... answer = answers[i] ... match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split()) ... if match: ... # if match is found, use `token_type_ids` to find where words start in the encoding ... token_type_ids = encoding["token_type_ids"][i] ... token_start_index = 0 ... while token_type_ids[token_start_index] != 1: ... token_start_index += 1 ... token_end_index = len(encoding["input_ids"][i]) - 1 ... while token_type_ids[token_end_index] != 1: ... token_end_index -= 1 ... word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1] ... start_position = cls_index ... end_position = cls_index ... # loop over word_ids and increase `token_start_index` until it matches the answer position in words ... # once it matches, save the `token_start_index` as the `start_position` of the answer in the encoding ... for id in word_ids: ... if id == word_idx_start: ... start_position = token_start_index ... else: ... token_start_index += 1 ... # similarly loop over `word_ids` starting from the end to find the `end_position` of the answer ... for id in word_ids[::-1]: ... if id == word_idx_end: ... end_position = token_end_index ... else: ... token_end_index -= 1 ... start_positions.append(start_position) ... end_positions.append(end_position) ... else: ... start_positions.append(cls_index) ... end_positions.append(cls_index) ... encoding["image"] = examples["image"] ... encoding["start_positions"] = start_positions ... encoding["end_positions"] = end_positions ... return encoding
现在我们有了这个预处理函数,我们可以对整个数据集进行编码:
>>> encoded_train_dataset = dataset_with_ocr["train"].map( ... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["train"].column_names ... ) >>> encoded_test_dataset = dataset_with_ocr["test"].map( ... encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr["test"].column_names ... )
让我们看看编码数据集的特征是什么样子的:
>>> encoded_train_dataset.features {'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None), 'start_positions': Value(dtype='int64', id=None), 'end_positions': Value(dtype='int64', id=None)}
Transformers 4.37 中文文档(五)(8)https://developer.aliyun.com/article/1565248