完全解析!Bert & Transformer 阅读理解源码详解

简介: 完全解析!Bert & Transformer 阅读理解源码详解

接上一篇:

你所不知道的 Transformer!

超详细的 Bert 文本分类源码解读 | 附源码

中文情感分类单标签


参考论文:

https://arxiv.org/abs/1706.03762

https://arxiv.org/abs/1810.04805


在本文中,我将以run_squad.py以及SQuAD数据集为例介绍阅读理解的源码,官方代码基于tensorflow-gpu 1.x,若为tensorflow 2.x版本,会有各种错误,建议切换版本至1.14。


当然,注释好的源代码在这里:

https://github.com/sherlcok314159/ML/tree/main/nlp/code


章节

  • Demo传参
  • 数据篇
  • 番外句子分类
  • 创造实例
  • 实例转换
  • 模型构造
  • 写入预测


Demo传参

python bert/run_squad.py \
  --vocab_file=uncased_L-12_H-768_A-12/vocab.txt \
  --bert_config_file=uncased_L-12_H-768_A-12/bert_config.json \
  --init_checkpoint=uncased_L-12_H-768_A-12/bert_model.ckpt \
  --do_train=True \
  --train_file=SQUAD_DIR/train-v2.0.json \
  --train_batch_size=8 \
  --learning_rate=3e-5 \
  --num_train_epochs=1.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=/tmp/squad2.0_base/ \
  --version_2_with_negative=True


阅读源码最重要的一点不是拿到就读,而是跑通源码里面的小demo,因为你跑通demo就意味着你对代码的一些基础逻辑和参数有了一定的了解。


前面的参数都十分常规,如果不懂,建议看我的文本分类的讲解。这里讲一下比较特殊的最后一个参数,我们做的任务是阅读理解,如果有答案缺失,在SQuAD1.0是不可以的,但是在SQuAD允许,这也就是True的意思。


需要注意,不同人的文件路径都是不一样的,你不能照搬我的,要改成自己的路径。


数据篇


其实阅读理解任务模型是跟文本分类几乎是一样的,大的差异在于两者对于数据的处理,所以本篇文章重点在于如何将原生的数据转换为阅读理解任务所能接受的数据,至于模型构造篇,请看文本分类:


https://github.com/sherlcok314159/ML/blob/main/nlp/tasks/text.md


番外句子分类


想必很多人看到SquadExample类的_repr_方法都很疑惑,这里处理好一个example,为什么后面还要进行处理?看英文注释会发现这个类其实跟阅读理解没关系,它只是处理之后对于句子分类任务的,自然在run_squad.py里面没被调用。_repr_方法只是在有start_position的时候进行字符串的拼接。


image.png


创造实例


用于训练的数据集是json文件,需要用json库读入。


训练集的样式如下,可见data是最外层的

{
    "data": [
        {
            "title": "University_of_Notre_Dame",
            "paragraphs": [
                {
                    "context": "Architecturally, the school has a Catholic character.",
                    "qas": [
                        {
                            "answers": [
                                {
                                    "answer_start": 515,
                                    "text": "Saint Bernadette Soubirous"
                                }
                            ],
                            "question": "To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?",
                            "id": "5733be284776f41900661182"
                        }
                    ]
                }
            ]
        },
        {
            "title":"...",
            "paragraphs":[
                {
                    "context":"...",
                    "qas":[
                        {
                            "answers":[
                                {
                                    "answer_start":..,
                                    "text":"...",
                                }
                            ],
                            "question":"...",
                            "id":"..."
                        },
                    ]
                }
            ]
        }
    ]
}


image.png


input_data是一个大列表,然后每一个元素样式如下

{'paragraphs': [{...}, {...}, {...}, {...}, {...}, {...}, {...}, {...}, {...}, ...], 'title': 'University_of_Notre_Dame'}


is_whitespace方法是用来判断是否是一个空格,在切分字符然后加入doc_tokens会用到。



image.png


然后我们层层剥开,然后遍历context的内容,它是一个字符串,所以遍历的时候会遍历每一个字母,字符会被进行判断,如果是空格,则加入doc_tokens,char_to_word_offset表示切分后的索引列表,每一个元素表示一个词有几个字符组成。

image.png

切分后的doc_tokens会去掉空白部分,同时会包括英文逗号。一个单词会有很多字符,每个字符对应的索引会存在char_to_word_offset,例如,前面都是0,代表这些字符都是第一个单词的,所以都是0,换句话说就是第一个单词很长。

doc_tokens = ['Architecturally,', 'the', 'school', 'has', 'a', 'Catholic', 'character.', 'Atop', 'the',"..."]
char_to_word_offset = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]


接下来进行qas内容的遍历,每个元素称为qa,进行id和question内容的分配,后面都是初始化一些参数


image.png


qa里面还有一个is_impossible,用于判断是否有答案


image.png


确保有答案之后,刚刚读入了问题,现在读入与答案相关的部分,读入的时候注意start_position和end_position是相对于doc_tokens的


image.png


接下来对答案部分进行双重检验,actual_text是根据doc_tokens和始末位置拼接好的内容,然后对orig_answer_text进行空格切分,最后用find方法判断orig_answer_text是否被包含在actual_text里面。


image.png


这个是针对is_impossible来说的,如果没有答案,则把始末位置全部变成-1。


image.png


然后将example变成SquadExample的实例化对象,将example加入大列表——examples并返回,至此实例创建完成。


image.png


实例转换


把json文件变成实例之后,我们还差一步便可以把数据塞进模型进行训练了,那就是将实例转化为变量。


先对question_text进行简单的空格切分变为query_tokens


image.png


如果问题过长,就进行截断操作


image.png


接下来对doc_tokens进行空格切分以及词切分,变成all_doc_tokens,需要注意的是orig_to_tok_index代表的是doc_tokens在all_doc_tokens的索引,取最近的一个,而tok_to_orig_index代表的是all_doc_tokens在doc_tokens索引


image.png


对tok_start_position和tok_end_position进行初始化,记住,这两个是相对于all_doc_tokens来说的,一定要与start_position和end_position区分开来,它们是相对于doc_tokens来说的


image.png


接下来先介绍_improve_answer_span方法,这个方法是用来处理特殊的情况的,举个例子,假如说你的文本是"The Japanese electronics industry is the lagest in the world.",你的问题是"What country is the top exporter of electornics?" 那答案其实应该是Japan,可是呢,你用空格和词切分的时候会发现Japanese已经在词表中可查,这意味着不会对它进行再切分,会直接将它返回,这种情况下可能需要这个方法救场。


image.png


因为是监督学习,答案已经给出,所以呢,这个方法干的事情就是词切分后的tokens进行再一次切分,如果发现切分之后会有更好的答案,就返回新的始末点,否则就返回原来的。


对tok_start_position和tok_end_position进行进一步赋值


image.png


计算max_tokens_for_doc,与文本分类类似,需要减去[CLS]和两个[SEP]的位置,这里不同的是还要减去问题的长度,因为这里算的是文本的长度。


tokens = [CLS] query tokens [SEP] context [SEP]


image.png


很多时候文章长度大于maximum_sequence_length的时候,这个时候我们要对文章进行切片处理,把它按照一定长度进行切分,每一个切片称为一个doc_span,start代表从哪开始,length代表一个的长度。


image.png


doc_spans储存很多个doc_span。这里对窗口的长度有所限制,规定了start_offset不能比doc_stride大,这是第二个窗口的起点,从这个角度或许可以理解doc_stride代表平滑的长度。


image.png


接下来的操作跟文本分类有些类似,添加[CLS],然后添加问题和[SEP],这些在segment_ids里面都为0。


image.png


下面讲_check_is_max_context方法,这个方法是用来判断某个词是否具有完备的上下文关系,源代码给了一个例子:


Span A: the man went to the


Span B: to the store and bought


Span C: and bought a gallon of ...


那么对于bought来说,它在Span B和Span C中都有出现,那么,哪一个上下文关系最全呢?其实我们凭直觉应该可以猜到应该是Span C,因为Span B中bought出现在句末,没有下文。当然了,我们还是得用公式计算一下


score = min(num_left_context, num_right_context) + 0.01 * doc_span.length


score_B = min(4, 0) + 0.05 = 0.05


score_C = min(1,3) + 0.05 = 1.05


所以,在Span C中,bought的上下文语义最全,最终该方法会返回True or False,在滑动窗口这个方法中,一个词很可能出现在多个span里面,所以用这个方法判断当前这个词在当前span里面是否具有最完整的上下文


image.png


回到上面,token_to_orig_map是用来记录文章部分在all_doc_tokens的索引,而token_is_max_context是记录文章每一个词在当前span里面是否具有最完整的上下文关系,因为一开始只有一个span,那么一开始每个词肯定都是True。split_token_index用于切分成每一个token,这样可以进行上下文关系判断,至于后面添[SEP]和segment_ids添1这种操作文本分类也有。


image.png


接下来将tokens(精细化切分后的)按照词表转化为id,另外若不足,则把0填充进去这种操作也是很常见的。


image.png


前面是进行判断,如果切了之后答案并不在span里面就直接舍弃,若在里面,因为一开始all_doc_tokens里面没有问题和[CLS],[SEP]时正文的索引是tok_start_position,然后转换为input_ids又有问题以及[CLS],[SEP],所以要得到正文索引需要跳过它们。


image.png


接下来大量的tf.logging只是写入日志信息,同时也是你终端或输出那里看到的。


最终用这些参数实例化InputFeatures对象,然后不断重复,每一个feature对应着一个特殊的id,即为unique_id。



模型构建


这里大致与文本分类差不多,只是文本分类在模型里面直接进行了softmax处理,然后进行最小交叉熵损失,而这次我们没有直接这样做,得到了开头和结尾处的未归一化的概率logits,之后我们直接返回。


然后这次我们是在model_fn_builder方法里面的子方法model_fn里定义compute_loss,其实这里也是经过softmax进行归一化,然后再计算交叉熵损失,最终返回均方误差。


image.png


然后我们计算开头和结尾处的损失,总损失为二者和的平均。


最终我们进行优化。


image.png



写入预测


start_logit & end_logit 代表着未经过softmax的概率,start_logit表示tokens里面以每一个token作为开头的概率,后者类似的。还有一对null_start_logit & null_end_logit,它们两个代表的是SQuAD2.0没有答案的那些,默认全为0。


首先,简单介绍一下_get_best_indexes,这个方法是用来输出由高到低前n_best_size个的概率的索引。


image.png


遍历start_indexes,end_indexes(都是分别经过_get_best_indexes得到),对于答案未缺失的,以具体的logit填入,另外,feature_index代表第几个feature。


image.png


如果答案缺失,则全都为0


image.png


接下来我们进一步转换为具体的文本


image.png


然后进一步清洗数据


image.png


这样还有个问题,词切分会自动小写,与答案还存在一定的偏移,这里介绍get_final_text方法来解决这一问题,比如:


pred_text = steve smith


orig_text = Steve Smith's


这个方法通俗来讲就是获得orig_text(未经过词切分)上正确的截取片段。


然后将其添加到nbest中


image.png


同样会存在没有答案的情况


image.png


接下来会有一个total_scores,它的元素是start_logit和end_logit相加,注意,它们不是数值,是数组,之后就计算total_scores的交叉熵损失作为概率。


image.png


剩下的部分跟文本分类差不多,这里就此略过。

相关文章
|
10月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
267 2
|
6月前
|
算法 测试技术 C语言
深入理解HTTP/2:nghttp2库源码解析及客户端实现示例
通过解析nghttp2库的源码和实现一个简单的HTTP/2客户端示例,本文详细介绍了HTTP/2的关键特性和nghttp2的核心实现。了解这些内容可以帮助开发者更好地理解HTTP/2协议,提高Web应用的性能和用户体验。对于实际开发中的应用,可以根据需要进一步优化和扩展代码,以满足具体需求。
638 29
|
6月前
|
前端开发 数据安全/隐私保护 CDN
二次元聚合短视频解析去水印系统源码
二次元聚合短视频解析去水印系统源码
185 4
|
6月前
|
JavaScript 算法 前端开发
JS数组操作方法全景图,全网最全构建完整知识网络!js数组操作方法全集(实现筛选转换、随机排序洗牌算法、复杂数据处理统计等情景详解,附大量源码和易错点解析)
这些方法提供了对数组的全面操作,包括搜索、遍历、转换和聚合等。通过分为原地操作方法、非原地操作方法和其他方法便于您理解和记忆,并熟悉他们各自的使用方法与使用范围。详细的案例与进阶使用,方便您理解数组操作的底层原理。链式调用的几个案例,让您玩转数组操作。 只有锻炼思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
|
6月前
|
移动开发 前端开发 JavaScript
从入门到精通:H5游戏源码开发技术全解析与未来趋势洞察
H5游戏凭借其跨平台、易传播和开发成本低的优势,近年来发展迅猛。接下来,让我们深入了解 H5 游戏源码开发的技术教程以及未来的发展趋势。
|
6月前
|
存储 前端开发 JavaScript
在线教育网课系统源码开发指南:功能设计与技术实现深度解析
在线教育网课系统是近年来发展迅猛的教育形式的核心载体,具备用户管理、课程管理、教学互动、学习评估等功能。本文从功能和技术两方面解析其源码开发,涵盖前端(HTML5、CSS3、JavaScript等)、后端(Java、Python等)、流媒体及云计算技术,并强调安全性、稳定性和用户体验的重要性。
|
6月前
|
负载均衡 JavaScript 前端开发
分片上传技术全解析:原理、优势与应用(含简单实现源码)
分片上传通过将大文件分割成多个小的片段或块,然后并行或顺序地上传这些片段,从而提高上传效率和可靠性,特别适用于大文件的上传场景,尤其是在网络环境不佳时,分片上传能有效提高上传体验。 博客不应该只有代码和解决方案,重点应该在于给出解决方案的同时分享思维模式,只有思维才能可持续地解决问题,只有思维才是真正值得学习和分享的核心要素。如果这篇博客能给您带来一点帮助,麻烦您点个赞支持一下,还可以收藏起来以备不时之需,有疑问和错误欢迎在评论区指出~
|
9月前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
创建型模式的主要关注点是“怎样创建对象?”,它的主要特点是"将对象的创建与使用分离”。这样可以降低系统的耦合度,使用者不需要关注对象的创建细节。创建型模式分为5种:单例模式、工厂方法模式抽象工厂式、原型模式、建造者模式。
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
9月前
|
存储 设计模式 算法
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行为。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象行为模式比类行为模式具有更大的灵活性。 行为型模式分为: • 模板方法模式 • 策略模式 • 命令模式 • 职责链模式 • 状态模式 • 观察者模式 • 中介者模式 • 迭代器模式 • 访问者模式 • 备忘录模式 • 解释器模式
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
|
9月前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
结构型模式描述如何将类或对象按某种布局组成更大的结构。它分为类结构型模式和对象结构型模式,前者采用继承机制来组织接口和类,后者釆用组合或聚合来组合对象。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象结构型模式比类结构型模式具有更大的灵活性。 结构型模式分为以下 7 种: • 代理模式 • 适配器模式 • 装饰者模式 • 桥接模式 • 外观模式 • 组合模式 • 享元模式
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析

热门文章

最新文章

推荐镜像

更多
  • DNS