首页> 标签> 机器学习/深度学习
"机器学习/深度学习"
共 31619 条结果
全部 问答 文章 公开课 课程 电子书 技术圈 体验
MagicBook打开虚拟机提示此主机支持 AMD-V,但 AMD-V 处于禁用状态。
环境:VMware Workstation Pro 12 +CentOS 7 64 位+win10点击开启此虚拟机,提示如下所示:此主机支持 AMD-V,但 AMD-V 处于禁用状态。如果已在 BIOS/固件设置中禁用 AMD-V,或主机自更改此设置后从未重新启动,则 AMD-V 可能被禁用。(1) 确认 BIOS/固件设置中启用了 AMD-V。(2) 如果此 BIOS/固件设置已更改,请重新启动主机。(3) 如果您在安装 VMware Workstation 之后从未重新启动主机,请重新启动。(4) 将主机的 BIOS/固件更新至最新版本。此主机不支持“AMD RVI”硬件辅助的 MMU 虚拟化。模块“CPUIDEarly”启动失败。未能启动虚拟机。解决方法:关机,开机,长按F2,进入BIOS界面,按照下面操作(不同电脑界面可能不同),记住按 F10 保存配置。再打开电脑,开启虚拟机即可。一、Intel平台家用笔记本。(Y、Z、G、N系列)开机按F2进入BIOS,选择Configuration的选项,Intel Virtual Technology的选项设置成Enable的状态。二、AMD平台家用笔记本。(Y、Z、G系列)开机按F2进入BIOS,选择Configuration的选项,SVM Support的选项设置成Enable的状态。三、其他情况可见联想官网:https://iknow.lenovo.com.cn/detail/dc_125894.html注意之后可能还会出现兼容性问题等其他问题均属正常,再进行修改即可
文章
机器学习/深度学习  ·  Linux  ·  虚拟化
2022-06-26
7-5 螺旋方阵
代码思路: 1.先尝试构造出外圈数/* 顺序: up-right-down-left 1 2 3 4 5 16 0 0 0 6 15 0 0 0 7 14 0 0 0 8 13 12 11 10 9 */ #include <stdio.h> #define N 10 int main() { int i,j,k,n,a[N][N]={0},value=1; scanf("%d",&n); //up for(j=0;j<n;j++) { a[0][j]=value++; } //right for(i=1;i<n;i++) { a[i][n-1]=value++; } //down for(j=n-2;j>=0;j--) { a[n-1][j]=value++; } //left for(i=n-2;i>0;i--) { a[i][0]=value++; } for(i=0;i<n;i++) { for(j=0;j<n;j++) printf("%3d ",a[i][j]); printf("\n"); } return 0; } 2.发现规律找出共性,修改代码提高代码的通用性外圈输出完外圈后里面是一个3*3的矩阵,里面的矩阵输出步骤重复。故可以使用一个变量k来控制重复的次数k=n , k=k-2; 循环结束条件为k>1分就奇偶讨论,当k为奇数时,矩阵中间点要单独计算内圈每次外圈循环完后,变成更小的矩阵,每条边的上下限不是常量,是一个变量且变化规律是内缩一格#include <stdio.h> #define N 10 int main() { int i, j, n, a[N][N] = { 0 }, value = 1; scanf("%d", &n); // 改造代码 int k, start, end; //分别表示外圈循环标记,起始和末尾 k = n; start = 0; end = n; while (k > 1) { //up 举列 把所有的与边界值有关替换为start和end控制 for (j = start; j < end; j++) { a[start][j] = value++; } //right for (i = start + 1; i < end; i++) { a[i][end - 1] = value++; } //down for (j = end - 2; j >= start; j--) { a[end - 1][j] = value++; } //left for (i = end - 2; i > start; i--) { a[i][start] = value++; } //控制外圈变量 5 3 1 结束 k = k - 2; //边界需要缩进一格 start = start + 1; //0+1 = 1 end = end - 1; //(n-1)-1 = n-2 } //如果n为奇数则,为矩阵中间数赋值 if (n % 2) a[start][end - 1] = value; for (i = 0; i < n; i++) { for (j = 0; j < n; j++) printf("%3d", a[i][j]); printf("\n"); } return 0; }
文章
机器学习/深度学习
2022-06-26
【笔记】用户指南—备份与恢复—将PolarDB-X与其他阿里云服务集成
与DTS集成DTS是一款集数据迁移、订阅及实时同步功能于一体的数据传输产品。PolarDB-X通过与DTS深度集成,提供了覆盖几乎所有常见数据库类型的数据导入和导出链路,详细内容请参见使用DTS导入和导出数据。与DMS集成DMS是一款集多种服务于一体的的数据管理服务产品。通过DMS可以对PolarDB-X实例进行数据管理、结构管理、用户授权等多种操作,详细内容请参见DMS官方文档。与DAS集成DAS是一款基于机器学习和专家经验实现数据库自感知、自修复、自优化、自运维及自安全的云服务。PolarDB-X目前已接入性能趋势、索引推荐、空间分析、SQL限流、实时性能、锁分析、SQL洞察等能力,详细内容请参见DAS官方文档。与云监控集成云监控是一款云资源监控报警产品。PolarDB-X已接入云监控并提供对计算资源和存储资源的监控告警能力,具体可参考监控与告警
文章
SQL  ·  机器学习/深度学习  ·  运维  ·  监控  ·  Cloud Native  ·  安全  ·  数据管理  ·  分布式数据库  ·  数据库  ·  数据库管理
2022-06-26
论文赏析[EMNLP19]如何在Transformer中融入句法树信息?这里给出了一种解决方案(二)
实验首先是在WSJ测试集上的无监督句法分析结果:可以看到Tree-Transformer效果还是好于之前的ON-LSTM和PRPN的,虽然比在NLI上训练的DIORA略差,但也情有可原,毕竟人家训练集大,而且是全局解码, 甚至还达到了URNNG的效果。而层数选择10层是效果最好的。然后是在WSJ10测试集上的无监督句法分析结果:可以看到,长度很短的时候Tree-Transformer效果就甚至不如PRPN了,和ON-LSTM相比其实也半斤八两。论文并没有分析原因,甚至都没有提这个。然后是采用不同的层做出来的无监督句法分析结果:可以看到,最小递归到第三层的时候结果最好,而看的层数越少,也就是只看高层的,效果非常的差。只看单独一层的效果也不大行,这都说明了高层的表示更加的抽象,其实不大适宜句法信息的表示。而低层又太接近单词层面了,都是表面信息。这其实和最近的一篇解释bert中attention含义的论文结果一致,中间层的attention表示的是句法信息。最后是语言模型的困惑度结果:这里就只和普通的Transformer相比了,结果还是更好的。因为这里得用masked LM做目标函数,所以没法和ON-LSTM、PRPN等语言模型相比。其他关于attention解释性等讨论详见论文,我觉得没有多大意思,attention的可解释性最近争论一直很大,强行解释没有意义。结论本文提出的Tree Transformer用成分先验表示两个单词属于同一个短语的概率,然后和self-attention联合决定两个单词之间的attention。并且提出了一种解码出句法树的算法,但是还存在着一些问题。文中说尝试过用Transformer预训练Tree Transformer,这样loss下降的更低了,拟合的更好,但是解码出的句法树效果更差了。这其实是有道理的,之前见过一篇分析论文,提到了语言模型训练的好,并不一定代表着句法树学的好,这两者不能划等号。所以今后如何选择更好更合适的损失函数,值得研究。这里面还有一些文章可以做,我总感觉本文模型的attention计算方式还是挺牵强的,特别是得分归一化那里,强行将单词左右邻居视为两种不同的角色。下一步工作我可以在上面进行改进,换一种全新的attention计算方式试试,另外损失函数上面考虑到前一篇文章提到的乱序问题,可以尝试用还原词序作为目标任务。
文章
机器学习/深度学习  ·  自然语言处理  ·  算法
2022-06-26
论文赏析[EMNLP19]如何在Transformer中融入句法树信息?这里给出了一种解决方案(一)
论文地址:https://www.aclweb.org/anthology/D19-1098.pdf介绍之前其实有很多工作将句法信息融入到了RNN中,例如ON-LSTM和PRPN,用来隐式建模句法结构信息,同时提升语言模型的准确率。本文尝试将句法信息融入到Transformer中,用来赋予attention更好的解释性。同时可以无监督的预测出句子的句法树,并且相比于一般的Transformer,语言模型的性能有所提高。模型结构上面这张是模型结构,最主要的区别就是在multi-head attention操作基础上新增了一个成分的attention,用来表示一段span能否构成一个短语。比如上图中,“cute dog”构成一个短语,所以第0层中这两个单词的attention较大。而“the cute dog”构成了一个更大的短语,所以第1层中“the”和“dog”的attention较大。回顾self-attention的操作,主要是计算两个单词的向量点积:这里 。但是在本文中,新增加了一个成分先验 C ,其中 表示 和 在一个短语内的概率。然后与原来的self-attention做元素乘即可:注意不同的head之间共享 C 。那么这个成分先验 C 怎么算呢?这里把它拆成若干相邻单词在同一短语内概率的乘积。也就是定义 在同一短语内的概率,那么 就可以表示为:这样只有 中所有单词都有较大概率在同一短语中, 取值才比较大。当然在实现中会取对数,来避免数值太小。那么问题又来了, a 怎么算?首先类似self-attention,计算相邻两个单词属于同一短语的得分:注意这里区分了方向,也就是还存在得分 ,并且两者虽然意义是一样的,但是分数不一定相同。为了防止出现一种问题,也就是所有得分全部相同,然后算出来概率全是1,那就没有意义了,所以要给得分加上限制,也就是归一化。这里选择归一化一个单词和左右邻居两者的得分:然后由于 值不一样,所以取平均:这样的话,如果两个相邻单词互相之间连接的概率很大,就会导致 很大,也就说明了这两个单词大概率属于同一个短语。从第一张模型图中可以看到,成分attention不只计算了一层。低层可以用来表示两两相邻单词之间属于同一短语的概率,而高层可以表示属于更大的短语的概率。注意还得满足一个性质,也就是如果两个单词在低层大概率属于同一个短语,那他们高层肯定更大概率属于一个更大的短语。所以计算方式如下:初始化的时候 都设为0。这样对于每一层都可以得到一个成分先验 。无监督句法分析
文章
机器学习/深度学习  ·  自然语言处理
2022-06-26
最全攻略:利用LightSeq加速你的深度学习模型
前言LightSeq是字节跳动火山翻译团队开源的一款Transformer系列模型加速引擎,分为训练和推理两个部分。其中推理加速引擎早在2019年12月就已经开源,而训练加速引擎也在2021年6月开源。项目地址:https://github.com/bytedance/lightseqLightSeq主要采用了CUDA算子融合、显存优化、参数连续化、层级式解码策略等技术,感兴趣的小伙伴可以阅读此前的文章:训练引擎:支持Transformer全流程训练加速,最高加速3倍!字节跳动LightSeq上新推理引擎:速度超快!字节跳动开源序列推理引擎LightSeq本文详细讲解一下如何使用LightSeq来改造你的PyTorch模型,实现1.5-3倍的训练加速和5-10倍的推理加速。至于TensorFlow模型的加速,目前也已经支持,这里不会详细讲解,可以参考下面NeurST的代码:https://github.com/bytedance/neurst/tree/lightseq整体流程使用LightSeq进行加速的整体流程依次为:接入训练引擎进行模型训练,并保存模型参数。加载模型参数,使用训练引擎的前向传播部分进行模型推理。为了更快的推理速度,还可以将模型参数导出为protobuf或者hdf5格式。使用推理引擎解析第3步中导出的模型,并进行模型推理。模型训练LightSeq提供了封装好的embedding、encoder、decoder、cross entropy和adam类,可以接入到你自己的模型中替换原有的模型。LightSeq还提供了现成的Fairseq、Hugging Face、DeepSpeed DeepSpeed可以用于大规模训练Speed、NeurST等样例。如果你用这几个训练库的话,就可以直接使用。如果你是自己的模型,那也可以手动接入LightSeq。这几个样例代码都在examples/training目录下。自定义模型首先引入所有可能用到的头文件:from lightseq.training import ( LSTransformer, LSTransformerEmbeddingLayer, LSTransformerEncoderLayer, LSTransformerDecoderLayer, LSCrossEntropyLayer, LSAdam, )以新建encoder层为例,主要分为两个步骤:使用LSTransformerEncoderLayer.get_config函数新建config。新建LightSeq的encoder层,即LSTransformerEncoderLayer类,使用config来初始化。一个典型的例子如下:config = LSTransformerEncoderLayer.get_config( model="bert-base", max_batch_tokens=4096, max_seq_len=512, fp16=True, local_rank=0, ) layer = LSTransformerEncoderLayer(config)其中max_batch_tokens指定了训练过程中一个batch最大可能的单词数,max_seq_len指定了句子的最长长度。model提供了四种现成的模型配置:transformer-base、transformer-big、bert-base和bert-big。当然如果你想用自己的模型配置,也可以手动补全所有的参数:config = LSTransformerEncoderLayer.get_config( max_batch_tokens=4096, max_seq_len=512, hidden_size=1024, intermediate_size=4096, nhead=16, attn_prob_dropout_ratio=0.1, activation_dropout_ratio=0.1, hidden_dropout_ratio=0.1, pre_layer_norm=False, activation_fn="gelu", fp16=True, local_rank=0, ) layer = LSTransformerEncoderLayer(config)除了encoder以外,embedding、decoder、cross entropy和adam也可以用同样的方法新建,最后和你自己写的模型一样进行训练即可。此外LightSeq还提供了完整的Transformer类LSTransformer,可以直接新建一整个Transformer:config = LSTransformer.get_config( model="transformer-base", max_batch_tokens=4096, max_seq_len=512, vocab_size=32000, padding_idx=0, num_encoder_layer=6, num_decoder_layer=6, fp16=True, local_rank=0, ) model = LSTransformer(config)示例代码在examples/training/custom中,可以直接运行python run.py查看效果。Hugging Face以Hugging Face官方提供的run_glue.py为例,一般首先都是用AutoModel.from_pretrained函数新建模型model,然后进行训练。为了接入LightSeq,需要将model中的所有encoder层替换为LightSeq版本的encoder层。替换过程分为三个步骤:使用LSTransformerEncoderLayer.get_config函数新建config。获取Hugging Face预训练好的BERT参数。新建LightSeq的encoder层,即LSTransformerEncoderLayer类,使用config和预训练好的参数来初始化。新建encoder层代码参见上一小节。注意在Hugging Face这个例子里,额外给LSTransformerEncoderLayer封装了一层LSHFTransformerEncoderLayer,主要是为了兼容原来的encoder输入形状。示例代码在examples/training/huggingface中,运行sh run_glue.sh和sh run_ner.sh分别可以查看LightSeq在GLUE和NER任务上的加速效果。注意Hugging Face BERT的fine-tune任务很不稳定,经常会不收敛,这时候可以尝试修改运行脚本中的--seed参数。FairseqFairseq主要用于一些生成任务,使用LightSeq加速的原理是一样的,都是需要将各自组件替换为LightSeq对应的组件。LightSeq对Fairseq做了非常完整的替换,将embedding、encoder、decoder、cross entropy和adam全部替换为了LightSeq对应的部分,来达到极致的加速效果。示例代码在examples/training/fairseq目录下,其中fs_cli目录存放着三个启动入口:train、validate和generate,fs_modules目录存放着用LightSeq封装好的几个Transformer组件。直接运行sh ls_fairseq_wmt14en2de.sh即可自动下载数据并运行WMT14英德机器翻译任务。脚本中主要的运行命令如下:lightseq-train /tmp/wmt14_en_de/ \ --task translation \ --arch ls_transformer_wmt_en_de_big_t2t --share-decoder-input-output-embed \ --optimizer ls_adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --weight-decay 0.0001 \ --criterion ls_label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 8192 \ --eval-bleu --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \ --eval-bleu-detok moses --eval-bleu-remove-bpe --eval-bleu-print-samples \ --best-checkpoint-metric bleu \ --maximize-best-checkpoint-metric --fp16注意到和一般运行Fairseq的命令不同的地方有这么几个:启动入口从fairseq-train替换为了lightseq-train,这是因为在根目录setup.py里封装了--user-dir用户模块目录。如果还想继续用fairseq-train的话,就需要手动指定--user-dir fs_modules参数。模型结构--arch需要在原来的基础上加上前缀ls_,用来指定使用LightSeq提供的Transformer模型。优化器--optimizer和损失函数--criterion都需要在原来的基础上加上前缀ls_,指定使用LightSeq对应的组件。DeepSpeedDeepSpeed主要用于大规模训练,也提供了Transformer的encoder层CUDA实现,不过效率没有LightSeq高。LightSeq提供了Fairseq+DeepSpeed分布式训练的使用样例,将启动器替换成了deepspeed,手动指定--user-dir目录,还需要指定DeepSpeed的配置文件deepspeed_config,其它参数和上一节Fairseq样例一模一样。使用时运行sh ds_fairseq_wmt14en2de.sh即可,和上一小节一样都是用Fairseq运行WMT14英德机器翻译任务。模型导出在模型训练完之后,直接load保存的checkpoint就可以继续fine-tune或者推理。但是这样调用的是训练引擎的推理部分,也就是模型的前向传播。这部分代码需要频繁在python和c++之间切换,并且前向过程中计算了很多反向传播才需要用到的变量。因此速度不如纯粹的推理引擎快。而要想使用LightSeq的推理引擎,就必须先将checkpoint转变为protobuf或者hdf5的格式。LightSeq提供了每个组件的导出接口,如果你使用了LightSeq的模型组件,那么导出将变得非常容易。只需要引入下面的头文件即可:from lightseq.training import ( export_ls_config, export_ls_embedding, export_ls_encoder, export_ls_decoder, )这四个函数分别可以导出推理引擎所需要的配置信息、embedding参数、encoder参数和decoder参数。而如果有其他部分的参数没包括在这里面(例如输出到词表的映射矩阵),则需要自己进行导出,详见下面的教程。LightSeq对Hugging Face的BERT、BART、GPT2三种模型,以及Fairseq+LightSeq、LightSeq的Transformer模型都提供了模型导出的样例,代码在examples/inference/python/export目录下。其中Hugging Face的模型都是没有采用LightSeq加速训练的预训练模型参数,所以导出更为复杂一些。模型导出的核心思想就是:首先创建一个protobuf对象Transformer或者hdf5的文件对象。然后在checkpoint中提取出参数值,将其赋值给Transformer或者hdf5文件对象中对应的参数。这个过程麻烦的就是提取并且对应赋值的过程,LightSeq提供了一系列方便的操作函数。Fairseq执行python ls_fs_transformer.py可以导出上一章节中Fairseq+LightSeq训练样例得到的模型。以protobuf导出为例,观察代码可以看到主体部分如下(省略了部分参数):file = Transformer() encoder_state_dict, decoder_state_dict = _extract_weight(state_dict) export_ls_embedding(file, encoder_state_dict, is_encoder=True) export_ls_embedding(file, encoder_state_dict, is_encoder=False) export_ls_encoder(file, encoder_state_dict) export_ls_decoder(file, decoder_state_dict) export_fs_weights(file, state_dict) export_ls_config(file)首先需要用户自己将state_dict拆分成encoder和decoder两部分,这主要是因为设计时考虑到有些用户只会用到encoder的导出(例如BERT)。并且LightSeq无法知道用户模型的最外层参数名叫啥,万一不叫encoder,而叫enc之类的呢?所以交给用户自己拆分更加合理。然后分别导出encoder的embedding、decoder的embedding、encoder和decoder参数,这几部分都直接调用LightSeq提供的接口就行了。LightSeq会自动帮你把解析出来的参数导出到定义的Transformer类里。接着需要处理一下Fairseq中与LightSeq无关的一些参数,例如encoder和decoder的layer norm参数等等。export_fs_weights函数需要用户自己实现,核心思想就是找到state_dict中的参数名,将其赋值给Transformer类里对应的变量就行了。最后设置一下Transformer类里所有的配置参数就行了。hdf5的用法类似,LightSeq都将其封装在同样的函数里了,只需要指定save_pb=False即可。Hugging Face执行python hf_bert_export.py、python hf_bart_export.py和python hf_gpt2_export.py三个文件分别可以导出BERT、BART和GPT2的预训练模型。因为Hugging Face的模型参数都是预训练得到的,所以LightSeq无法识别参数名是什么样的,只能用户自己编写导出规则,具体参考上面三个导出样例即可。LightSeq Transformer使用LightSeq提供的Transformer进行训练的话,参数名LightSeq都知道的一清二楚,因此可以直接使用LightSeq提供的导出接口进行转换。过程和上面的Fairseq+LightSeq类似。具体样例可以执行python ls_transformer_export.py,同时得到protobuf和hdf5格式的模型导出文件,并且对比两者生成的结果。这里的checkpoint可以使用上一章节中自定义模型小节中训练得到的模型。自定义模型因为自定义的模型参数LightSeq无法识别参数名,所以需要用户自己编写转换规则。举一个简单的例子,假设用户模型中有个encoder的输出部分的layer norm参数,state_dict中的参数名叫做encoder.layer_norm.weight。那么可以按如下方式进行转换:transformer = Transformer() enc_norm_w = state_dict["encoder.layer_norm.weight"].flatten().tolist() transformer.src_embedding.norm_scale[:] = enc_norm_w模型推理得到导出的protobuf或者hdf5模型后,推理就变得十分简单,核心代码就三行:import lightseq.inference as lsi model = lsi.Transformer("transformer.pb", 8) output = model.infer([[1, 2, 3], [4, 5, 6]])首先定义一个Transformer类用来加载模型参数,指定load的protobuf文路径和batch_size大小。然后调用infer函数进行推理,传入的输入参数必须是list或者numpy类型,且必须是二维。LightSeq在examples/inference/python/test目录下提供了三个Hugging Face模型推理的样例,此外上一小节中examples/inference/python/export中的ls_transformer_export.py代码也包含了导出后推理的过程。最佳实践总结一下,使用LightSeq加速你的深度学习模型,最佳方式无外乎三步:接入LightSeq训练引擎的模型组件,构建模型,进行训练,保存checkpoint。将checkpoint转换为protobuf或者hdf5格式,LightSeq的组件可以调用现成的转换接口,其它的需要自己手写转换规则。调用LightSeq推理引擎,加载上一步中导出的模型,进行快速推理。目前LightSeq已经被广泛应用在字节跳动公司内外各项业务和学术研究上,支持了标准的Transformer、BERT、BART、GPT2、ViT等多种Transformer系列模型。只要你的模型中包含有Transformer的部分组件,例如encoder层,就可以直接调用LightSeq进行加速。
文章
机器学习/深度学习  ·  自然语言处理  ·  并行计算  ·  Shell  ·  PyTorch  ·  TensorFlow  ·  算法框架/工具  ·  C++  ·  Python
2022-06-26
cuBLAS矩阵乘法性能分析(附代码示例)
使用教程矩阵乘法是神经网络中最基础、最重要的一个运算。在用CUDA实现矩阵乘法时,不需要我们手动写,cuBLAS库提供了现成的矩阵乘法算子,例如cublasGemmEx和cublasLtMatmul。其中后者是轻量级版本,API调用更灵活。例如对于整数乘法,cublasLtMatmul支持int8的输入输出,而cublasGemmEx只支持int8输入,int32输出。今天我只给大家讲解cublasGemmEx,主要使用起来相对更简洁一点。官方文档地址:https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmEx经过翻阅网上各种教程,我找到了一篇我认为写的最好的博客。例子举得非常好,写的很详细。地址如下:https://www.cnblogs.com/cuancuancuanhao/p/7763256.html具体的使用方法可以参见上面这篇博客,我这里就不再赘述了。今天我主要给大家演示一下,不同数据类型的矩阵乘法,速度和结果上到底有多大的差异?测试代码我写了一个简单的测试代码:#include <sys/time.h> #include <cuda_profiler_api.h> #include <cublas_v2.h> #include <cuda.h> #include <cuda_fp16.h> #include <cuda_runtime.h> #include <stdio.h> int8_t float2int8(float f, float scale) { int8_t i = int8_t(f * scale); if (i < -127) i = -127; if (i > 127) i = 127; return i; } template <typename T, typename S> void allocate_memory(int m, int n, int k, T **A, T **B, S **C) { cudaMallocManaged(A, m * k * sizeof(T)); cudaMallocManaged(B, k * n * sizeof(T)); cudaMallocManaged(C, m * n * sizeof(S)); } template <typename T, typename S> void free_memory(T *A, T *B, S *C) { cudaFree(A); cudaFree(B); cudaFree(C); } template <typename T, typename S> int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transA, cublasOperation_t transB, int m, int n, int k, T *A, T *B, S *C, int lda, int ldb, int ldc, S *alpha, S *beta, int algo) { cudaDataType_t AType, BType, CType, ComputeType; if (std::is_same<T, float>::value) { AType = BType = CType = ComputeType = CUDA_R_32F; } else if (std::is_same<T, __half>::value) { AType = BType = CType = ComputeType = CUDA_R_16F; } else if (std::is_same<T, int8_t>::value) { AType = BType = CUDA_R_8I; CType = ComputeType = CUDA_R_32I; } else { printf("Not supported data type."); return -1; } cublasStatus_t status; status = cublasGemmEx(handle, transA, transB, m, n, k, alpha, A, AType, lda, B, BType, ldb, beta, C, CType, ldc, ComputeType, static_cast<cublasGemmAlgo_t>(algo)); if (status == CUBLAS_STATUS_SUCCESS) return 1; else return -1; } template <typename T, typename S> void test_gemm(cublasHandle_t handle, int m, int n, int k, T *A, T *B, S *C, S *alpha, S *beta, int algo, int iteration) { float total_time = 0; for (int i = 0; i < iteration; ++i) { struct timeval start, end; cudaDeviceSynchronize(); cudaProfilerStart(); gettimeofday(&start, NULL); int success = cublas_gemm_ex(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, B, A, C, n, k, n, alpha, beta, static_cast<cublasGemmAlgo_t>(algo)); cudaDeviceSynchronize(); gettimeofday(&end, NULL); cudaProfilerStop(); if (success > 0 && i > 0) total_time += (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001; } if (total_time > 0) printf("algo %d: %.3f ms\n", algo, total_time / (iteration - 1)); } int main() { int m = 4096, n = 8192, k = 1024; printf("shape: (%d, %d) x (%d, %d)\n", m, k, k, n); int start_algo = CUBLAS_GEMM_DEFAULT; int end_algo = CUBLAS_GEMM_ALGO23; int start_algo_t_op = CUBLAS_GEMM_DEFAULT_TENSOR_OP; int end_algo_t_op = CUBLAS_GEMM_ALGO15_TENSOR_OP; int iteration = 10; float *fA, *fB, *fC; __half *hA, *hB, *hC; int8_t *iA, *iB; int32_t *iC; float f_alpha = 1, f_beta = 0; __half h_alpha = __float2half_rn(1.0), h_beta = __float2half_rn(0.0); int32_t i_alpha = 1, i_beta = 0; allocate_memory(m, n, k, &fA, &fB, &fC); allocate_memory(m, n, k, &hA, &hB, &hC); allocate_memory(m, n, k, &iA, &iB, &iC); for (int i = 0; i < m * k; ++i) { fA[i] = float(i % 255 - 127) / 127; hA[i] = __float2half_rn(fA[i]); iA[i] = float2int8(fA[i], 127); } for (int i = 0; i < k * n; ++i) { fB[i] = float(i % 255 - 127) / 127; hB[i] = __float2half_rn(fB[i]); iB[i] = float2int8(fB[i], 127); } cublasHandle_t handle; cublasCreate(&handle); printf(">>>>>>>>>>>>>>>>> test fp32 >>>>>>>>>>>>>>>>>\n"); for (int algo = start_algo; algo <= end_algo; ++algo) test_gemm(handle, m, n, k, fA, fB, fC, &f_alpha, &f_beta, algo, iteration); for (int algo = start_algo_t_op; algo <= end_algo_t_op; ++algo) test_gemm(handle, m, n, k, fA, fB, fC, &f_alpha, &f_beta, algo, iteration); printf(">>>>>>>>>>>>>>>>> test fp16 >>>>>>>>>>>>>>>>>\n"); for (int algo = start_algo; algo <= end_algo; ++algo) test_gemm(handle, m, n, k, hA, hB, hC, &h_alpha, &h_beta, algo, iteration); for (int algo = start_algo_t_op; algo <= end_algo_t_op; ++algo) test_gemm(handle, m, n, k, hA, hB, hC, &h_alpha, &h_beta, algo, iteration); printf(">>>>>>>>>>>>>>>>> test int8 >>>>>>>>>>>>>>>>>\n"); for (int algo = start_algo; algo <= end_algo; ++algo) test_gemm(handle, m, n, k, iA, iB, iC, &i_alpha, &i_beta, algo, iteration); for (int algo = start_algo_t_op; algo <= end_algo_t_op; ++algo) test_gemm(handle, m, n, k, iA, iB, iC, &i_alpha, &i_beta, algo, iteration); printf(">>>>>>>>>>>>>>>>> compare result >>>>>>>>>>>>>>>>>\n"); printf("fp32: "); for (int i = 0; i < 10; ++i) printf("%.5f%c", fC[i], " \n"[i==9]); printf("fp16: "); for (int i = 0; i < 10; ++i) printf("%.5f%c", float(hC[i]), " \n"[i==9]); printf("int8: "); for (int i = 0; i < 10; ++i) printf("%.5f%c", float(iC[i])/127/127, " \n"[i==9]); free_memory(iA, iB, iC); free_memory(fA, fB, fC); free_memory(hA, hB, hC); return 0; }代码保存为test_gemm.cpp,然后执行下面命令进行编译:nvcc test_gemm.cpp -o test_gemm -L/usr/local/cuda/lib64 -lcudart -lcuda -lcublas最后执行./test_gemm运行就行了。运行结果我对比了三种数据类型:fp32、fp16和int8,测试环境是V100显卡、CUDA 10.1。由于V100显卡没有int8的tensor core,所以速度并不能达到最快。要想全速进行int8的矩阵乘法,推荐使用sm75及以上的显卡,例如T4、A100等等。此外我还对比了不同的GEMM算法的效果。执行上面的运行命令后,会输出如下的结果:shape: (4096, 1024) x (1024, 8192) >>>>>>>>>>>>>>>>> test fp32 >>>>>>>>>>>>>>>>> algo -1: 4.831 ms algo 2: 5.293 ms algo 3: 5.406 ms algo 4: 5.297 ms algo 5: 5.098 ms algo 6: 4.874 ms algo 11: 4.870 ms algo 18: 7.219 ms algo 19: 6.061 ms algo 20: 5.631 ms algo 99: 1.110 ms algo 100: 1.159 ms algo 101: 1.688 ms algo 102: 4.944 ms algo 103: 4.744 ms algo 104: 4.700 ms algo 105: 4.679 ms algo 106: 4.679 ms algo 107: 4.675 ms algo 108: 4.676 ms algo 109: 4.677 ms algo 110: 4.676 ms algo 111: 4.676 ms algo 112: 4.678 ms algo 113: 4.675 ms algo 114: 4.676 ms algo 115: 4.689 ms >>>>>>>>>>>>>>>>> test fp16 >>>>>>>>>>>>>>>>> algo -1: 2.423 ms algo 1: 2.460 ms algo 2: 2.565 ms algo 3: 2.518 ms algo 5: 2.398 ms algo 6: 2.416 ms algo 99: 0.737 ms algo 100: 1.581 ms algo 101: 1.032 ms algo 102: 0.978 ms algo 103: 0.767 ms algo 104: 0.790 ms algo 105: 0.803 ms algo 106: 0.774 ms algo 107: 2.656 ms algo 108: 2.577 ms algo 109: 2.518 ms algo 110: 0.925 ms algo 111: 0.951 ms algo 112: 0.935 ms algo 113: 0.909 ms algo 114: 2.549 ms algo 115: 2.532 ms >>>>>>>>>>>>>>>>> test int8 >>>>>>>>>>>>>>>>> algo -1: 1.232 ms algo 0: 7.544 ms algo 1: 1.217 ms algo 2: 1.294 ms algo 3: 2.362 ms algo 99: 1.243 ms algo 100: 1.244 ms algo 101: 1.237 ms algo 102: 1.232 ms algo 103: 1.230 ms algo 104: 1.224 ms algo 105: 1.222 ms algo 106: 1.224 ms algo 107: 1.225 ms algo 108: 1.224 ms algo 109: 1.218 ms algo 110: 1.217 ms algo 111: 1.217 ms algo 112: 1.218 ms algo 113: 1.218 ms algo 114: 1.216 ms algo 115: 1.217 ms >>>>>>>>>>>>>>>>> compare result >>>>>>>>>>>>>>>>> fp32: 52.38629 44.76633 37.65229 31.04420 24.94203 19.34578 14.25543 9.67102 5.59253 2.01996 fp16: 52.46875 44.84375 37.40625 31.21875 24.95312 19.39062 14.28125 9.69531 5.61328 2.05078 int8: 52.38626 44.76632 37.65230 31.04421 24.94203 19.34577 14.25544 9.67103 5.59254 2.01996这里简单解释一下,algo -1到23表示不使用tensor core算法的结果,algo 99到115表示使用tensor core算法的结果。可以看到图中缺失了一部分算法的结果,因为那些算法可能不适用于当前的矩阵乘法,因此报错了。汇总一下各自最快的结果(不使用vs使用tensor core):fp32: 4.83 1.11fp16: 2.41 0.73int8: 1.21 1.21由于V100显卡没有int8的tensor core,所以int8的两个结果是相同的。结果也符合我们的预期,速度上fp32慢于fp16慢于int8。所以在实际的深度学习应用中,流行使用混合精度,也就是用fp16来进行训练和推理。而int8是速度最快的,所以如果训练和推理也都能使用int8的话,速度上将会迈上一个新的台阶。那么一个浮点数的矩阵乘法怎么转变为整数的矩阵乘法呢?这里我不会详细讲,后续会出一个详细的量化教程。那么由于这里有个类型转换的操作,所以会产生误差。但是在我们的样例中,int8的误差竟然比fp16还要小很多,结果和fp32几乎一模一样。这主要由于是我构造的矩阵数据分布非常均匀有规律,因此计算误差会很小,实际深度网络中int8的误差会较大。结语int8甚至更低比特的量化的实际收益非常大,提速可以达到将近2倍。虽然现在有很多现成的自动量化工具,但是效果上或多或少都有一定的损失,速度上也没有达到极致。因此今后量化是一个不错的方向,值得一试。
文章
机器学习/深度学习  ·  并行计算  ·  算法  ·  测试技术  ·  API  ·  C++
2022-06-26
养成女友?我训练出了一个“杨超越”聊天机器人
在上一期教程中,我演示了如何从零开始训练一个比较智障的聊天机器人。从零开始训练一个人工智障女友但是当时数据量太少,模型简单,完全没法用,只能回复训练集中出现过的句子。而现在,完全体的聊天机器人终于诞生了,我把它命名为“杨超越”。这次可以随你提问什么了,她都能对答如流!首先来看看回复的效果怎么样:聊天开始!(按q退出) 我:很高兴认识你 杨超越:我也很开心哦 我:我喜欢你 杨超越:我也喜欢你 我:做我女朋友好不好? 杨超越:哈哈,可以呢! 我:我会一直支持超越妹妹的! 杨超越:谢谢我们会一直努力的 我:我什么时候有女朋友? 杨超越:女朋友。。 我:我什么时候脱单? 杨超越:脱你妹啊!! 我:q 聊天结束!可以看到超越妹妹的回复还是非常流畅的,那她究竟是怎么诞生的呢?完整的项目我都放在了下面,欢迎大家点个star,支持一波:https://github.com/godweiyang/chatbot介绍这里我才用的是网上公开的小黄鸡聊天语料,大概有100万条左右,但是质量不是很高,都放在了data目录下。模型采用标准的Transformer-big模型,输入你的提问句子,预测超越妹妹回复的句子,config目录下是训练和预测的配置文件。模型训练采用NeurST训练库,主要基于TensorFlow,也支持PyTorch训练。模型快速推理采用LightSeq,可加速推理10倍以上,同时还能加速NeurST的训练,最高加速3倍。两者都是字节跳动AI Lab自研的,都已开源。安装环境我们需要安装三样东西:SentencePiece的命令行版本和python版本,用来对句子进行分词。NeurST深度学习训练库,用来训练Transformer模型。LightSeq,用来加速模型推理。安装命令都很简单:git clone https://github.com/google/sentencepiece.git & cd sentencepiece mkdir build & cd build cmake .. make -j $(nproc) sudo make install sudo ldconfig -v pip3 install lightseq neurst sentencepiece开始养成生成词表首先我们需要从训练语料库中抽取出词表,为了方便,直接用SentencePiece来分词,生成大小为32k的词表。spm_train --input=./data/train/train.src,./data/train/train.trg \ --model_prefix=./data/spm \ --vocab_size=32000 \ --character_coverage=0.9995这里需要指定训练语料路径--input、词表保存的路径前缀--model_prefix和词表大小--vocab_size。运行结束后会在data目录下生成spm.model和spm.vocab两个词表文件。一个是训练好的分词模型,一个是词表。不过我也上传了生成好的TFRecord,大家也可以直接使用,跳过这一步。「我上传了生成好的词表文件,大家可以直接使用,跳过这一步。」生成TFRecord为了加快TensorFlow的训练速度,可以预先将训练语料用上面的词表处理成id,然后保存为TFRecord格式。这样模型训练时就可以直接读取id进行训练了,不需要做前面的分词操作。能大大加快训练速度,提升显卡利用率。python3 -m neurst.cli.create_tfrecords \ --config_paths configs/task_args.yml \ --dataset ParallelTextDataset \ --src_file ./data/train/train.src \ --trg_file ./data/train/train.trg \ --processor_id 0 \ --num_processors 1 \ --num_output_shards 32 \ --output_range_begin 0 \ --output_range_end 32 \ --output_template ./data/tfrecords/train.tfrecords-%5.5d-of-%5.5d这里主要需要指定训练集的路径--src_file和--trg_file,其它参数保持默认即可。生成完毕后会在data/tfrecords下面生成32个二进制文件,这就是处理好的训练数据了。「我上传了生成好的TFRecord,大家可以直接使用,跳过这一步。」模型训练有了词表,有了处理好的训练数据,接下来就是训练模型了。这里开启了XLA优化,使用Horovod分布式训练,加快训练速度。如果报错,可以去掉最后两行。python3 -m neurst.cli.run_exp \ --entry trainer \ --task translation \ --hparams_set transformer_big \ --model_dir ./models/transformer_big \ --config_paths ./configs/task_args.yml,./configs/train_args.yml,./configs/valid_args.yml --distribution_strategy horovod \ --enable_xla这里需要指定的参数就是模型保存路径model_dir,其他都保持默认。训练好的模型会保存在models/transformer_big下,里面还细分为了best、best_avg等文件夹,用来存最好的模型、模型的平均值等等。我在8张V100 32G显卡上训练了8个小时左右,如果你们自己训练的话还是比较耗时的。「由于模型文件过大,之后我会找地方上传我训练好的模型文件,省去大家训练的时间。」模型预测训练好的模型会保存在models/transformer_big目录下,然后我们就可以开始预测啦。python3 -m neurst.cli.run_exp \ --entry predict \ --model_dir ./models/transformer_big \ --config_paths ./configs/predict_args.yml \ --output output.txt但是这时候还没有交互功能,只能指定一个测试集文件,写在了模型预测的配置文件里configs/predict_args.yml。还可以指定--output,将回复结果输出到文件中。「如果想直接体验交互式的对话聊天,可以跳过这一步。」模型导出为PB格式如果直接用TensorFlow进行推理的话,速度非常慢,你就会感觉你和超越妹妹之间存在延时。所以可以将训练得到的ckpt模型导出为PB格式,然后就可以用LightSeq训练加速引擎进行快速推理了。python3 export/export.py \ --model_dir ./models/transformer_big \ --output_file ./models/transformer_big/model.pb \ --beam_size 4 \ --length_penalty 0.6这里需要指定模型路径--model_dir和导出PB文件的路径--output_file,其它参数保持默认。最后会得到models/transformer_big/model.pb这个PB文件。「由于模型文件过大,之后我会找地方上传我导出好的PB模型文件,这样大家就可以直接跳到最后一步了。」开始交互式聊天!有了PB模型文件,就可以和超越妹妹开始聊天啦!python3 chat.py \ --spm_model ./data/spm.model \ --model_file ./models/transformer_big/model.pb这里需要指定两个路径。一是最开始训练好的分词模型--spm_model,用来将你输入的句子变成整数id。二是--model_file,也就是上一步中的PB格式模型文件。聊天过程中随时可以按q退出聊天,你每说一句话,超越妹妹就会回复你一句。欢迎关注这次用到的NeurST训练库和LightSeq加速库都非常好用,从上面使用教程中也可以看出,几乎不需要你写什么代码就能使用起来。「聊天机器人:」https://github.com/godweiyang/chatbot「NeurST训练库:」https://github.com/bytedance/neurst「LightSeq加速库:」https://github.com/bytedance/lightseq
文章
机器学习/深度学习  ·  人工智能  ·  自然语言处理  ·  机器人  ·  PyTorch  ·  TensorFlow  ·  算法框架/工具  ·  Python
2022-06-25
从零开始训练一个人工智障女友
很多人工智能小白可能不知道那些高大上的语音助理、机器翻译或者聊天机器人都是怎么被创造出来的,也不知道一个深度学习模型是怎么从零开始搭建并运行起来的。今天我就简单教大家如何从零开始搭建一个Transformer模型,并在自己的数据上训练起来。这个教程非常基础,所以训练出来的模型也很傻瓜,适合零基础小白长知识用。首先整个训练流程可以分为下面几步,我们在后面章节依次介绍:处理数据创建模型创建损失函数创建参数优化器进行训练进行预测安装环境这里我们需要使用到的有三样东西:训练深度学习模型需要用PyTorch。对句子进行分词处理需要用Hugging Face的分词器。搭建Transformer模型需要用LightSeq的快速模型、损失函数以及参数优化器。所以运行下面安装命令即可:pip3 install torch transformers git clone https://github.com/bytedance/lightseq.git cd lightseq pip3 install -e .然后导入必要的一些文件:import torch from transformers import BertTokenizer from lightseq.training import LSTransformer, LSCrossEntropyLayer, LSAdam处理数据因为深度学习模型擅长和数字打交道,所以你需要将你说的话或者写的句子变成一串整数id,用来表示每个单词在词表中的序号。这里我们使用到的是Hugging Face的分词器,它能帮你把输入的句子直接变成一串整数id,非常便捷。def create_data(): # 创建Hugging Face分词器 tokenizer = BertTokenizer.from_pretrained("bert-base-cased") vocab_size = tokenizer.vocab_size sep_id = tokenizer.encode( tokenizer.special_tokens_map["sep_token"], add_special_tokens=False )[0] # 将源文本映射成整数id src_text = [ "What is the fastest library in the world?", "You are so pretty!", "What do you love me for?", "The sparrow outside the window hovering on the telephone pole.", ] src_tokens = tokenizer.batch_encode_plus( src_text, padding=True, return_tensors="pt" ) src_tokens = src_tokens["input_ids"].to(torch.device("cuda:0")) batch_size, src_seq_len = src_tokens.size(0), src_tokens.size(1) # 将目标文本映射成整数id trg_text = [ "I guess it must be LightSeq, because ByteDance is the fastest.", "Thanks very much and you are pretty too.", "Love your beauty, smart, virtuous and kind.", "You said all this is very summery.", ] trg_tokens = tokenizer.batch_encode_plus( trg_text, padding=True, return_tensors="pt" ) trg_tokens = trg_tokens["input_ids"].to(torch.device("cuda:0")) trg_seq_len = trg_tokens.size(1) # 将目标文本左移1个单词位置,用来作为解码端输出 target = trg_tokens.clone()[:, 1:] trg_tokens = trg_tokens[:, :-1] return ( tokenizer, src_text, src_tokens, trg_text, trg_tokens, target, sep_id, vocab_size, batch_size, src_seq_len, trg_seq_len, )代码中注释写的非常清楚了,只需要创建输入文本和输出文本即可,而标准的解码端输出就是输出文本左移一个单词,也就是每个单词输入后预测下一个单词是什么。创建模型这里我们使用Transformer-base模型进行训练,使用LightSeq来创建Transformer模型非常简单,只需要创建一个配置,然后用它就能创建Transformer模型了。def create_model(vocab_size): transformer_config = LSTransformer.get_config( model="transformer-base", max_batch_tokens=2048, max_seq_len=512, vocab_size=vocab_size, padding_idx=0, num_encoder_layer=6, num_decoder_layer=6, fp16=True, local_rank=0, ) model = LSTransformer(transformer_config) model.to(dtype=torch.half, device=torch.device("cuda:0")) return model创建损失函数这里我们使用交叉熵损失函数,使用LightSeq来创建同样非常简单,只需要创建一个配置。def create_criterion(): ce_config = LSCrossEntropyLayer.get_config( max_batch_tokens=2048, padding_idx=0, epsilon=0.0, fp16=True, local_rank=0, ) loss_fn = LSCrossEntropyLayer(ce_config) loss_fn.to(dtype=torch.half, device=torch.device("cuda:0")) return loss_fn创建参数优化器使用LightSeq来创建参数优化器的过程和平常使用PyTorch创建一模一样,只要一行代码就行了。opt = LSAdam(model.parameters(), lr=1e-5)进行训练模型训练过程也和平常一模一样,这里我们训练2000轮。因为训练过程中需要知道目标端的文本是什么,所以需要输入源端和目标端两个文本。print("========================TRAIN========================") model.train() for epoch in range(2000): output = model(src_tokens, trg_tokens) loss, _ = loss_fn(output, target) if epoch % 200 == 0: print("epoch {:03d}: {:.3f}".format(epoch, loss.item())) loss.backward() opt.step()进行预测在模型训练好之后,我们用它进行预测。这时候你就不知道目标端的文本是什么了,你只能输入源端文本,然后目标端输入一个句子开始标记,后面的目标端文本都得通过模型预测得到。print("========================TEST========================") model.eval() # 获得编码器的输出和掩码表示 encoder_out, encoder_padding_mask = model.encoder(src_tokens) # 使用目标端文本的第一个单词作为解码器的初始输入,预测后面单词 predict_tokens = trg_tokens[:, :1] cache = {} for _ in range(trg_seq_len - 1): # 使用缓存来加速解码速度 output = model.decoder( predict_tokens[:, -1:], encoder_out, encoder_padding_mask, cache ) # 预测下一个单词 output = torch.reshape(torch.argmax(output, dim=-1), (batch_size, -1)) # 将预测得到的单词和历史预测拼接,作为最终预测结果 predict_tokens = torch.cat([predict_tokens, output], dim=-1) # 将结束符后的单词都标记为结束符 mask = torch.cumsum(torch.eq(predict_tokens, sep_id).int(), dim=1) predict_tokens = predict_tokens.masked_fill(mask > 0, sep_id) # 将预测结果的id还原为文本 predict_text = tokenizer.batch_decode(predict_tokens, skip_special_tokens=True) print(">>>>> source text") print("\n".join(src_text)) print(">>>>> target text") print("\n".join(trg_text)) print(">>>>> predict text") print("\n".join(predict_text))完整代码完整代码如下,保存在run.py里面,然后运行下面命令就行了:python3 run.pyimport torch from transformers import BertTokenizer from lightseq.training import LSTransformer, LSCrossEntropyLayer, LSAdam def create_data(): # 创建Hugging Face分词器 tokenizer = BertTokenizer.from_pretrained("bert-base-cased") vocab_size = tokenizer.vocab_size sep_id = tokenizer.encode( tokenizer.special_tokens_map["sep_token"], add_special_tokens=False )[0] # 将源文本映射成整数id src_text = [ "What is the fastest library in the world?", "You are so pretty!", "What do you love me for?", "The sparrow outside the window hovering on the telephone pole.", ] src_tokens = tokenizer.batch_encode_plus( src_text, padding=True, return_tensors="pt" ) src_tokens = src_tokens["input_ids"].to(torch.device("cuda:0")) batch_size, src_seq_len = src_tokens.size(0), src_tokens.size(1) # 将目标文本映射成整数id trg_text = [ "I guess it must be LightSeq, because ByteDance is the fastest.", "Thanks very much and you are pretty too.", "Love your beauty, smart, virtuous and kind.", "You said all this is very summery.", ] trg_tokens = tokenizer.batch_encode_plus( trg_text, padding=True, return_tensors="pt" ) trg_tokens = trg_tokens["input_ids"].to(torch.device("cuda:0")) trg_seq_len = trg_tokens.size(1) # 将目标文本左移1个单词位置,用来作为解码端输出 target = trg_tokens.clone()[:, 1:] trg_tokens = trg_tokens[:, :-1] return ( tokenizer, src_text, src_tokens, trg_text, trg_tokens, target, sep_id, vocab_size, batch_size, src_seq_len, trg_seq_len, ) def create_model(vocab_size): transformer_config = LSTransformer.get_config( model="transformer-base", max_batch_tokens=2048, max_seq_len=512, vocab_size=vocab_size, padding_idx=0, num_encoder_layer=6, num_decoder_layer=6, fp16=True, local_rank=0, ) model = LSTransformer(transformer_config) model.to(dtype=torch.half, device=torch.device("cuda:0")) return model def create_criterion(): ce_config = LSCrossEntropyLayer.get_config( max_batch_tokens=2048, padding_idx=0, epsilon=0.0, fp16=True, local_rank=0, ) loss_fn = LSCrossEntropyLayer(ce_config) loss_fn.to(dtype=torch.half, device=torch.device("cuda:0")) return loss_fn if __name__ == "__main__": ( tokenizer, src_text, src_tokens, trg_text, trg_tokens, target, sep_id, vocab_size, batch_size, src_seq_len, trg_seq_len, ) = create_data() model = create_model(vocab_size) loss_fn = create_criterion() opt = LSAdam(model.parameters(), lr=1e-5) print("========================TRAIN========================") model.train() for epoch in range(2000): output = model(src_tokens, trg_tokens) loss, _ = loss_fn(output, target) if epoch % 200 == 0: print("epoch {:03d}: {:.3f}".format(epoch, loss.item())) loss.backward() opt.step() print("========================TEST========================") model.eval() # 获得编码器的输出和掩码表示 encoder_out, encoder_padding_mask = model.encoder(src_tokens) # 使用目标端文本的第一个单词作为解码器的初始输入,预测后面单词 predict_tokens = trg_tokens[:, :1] cache = {} for _ in range(trg_seq_len - 1): # 使用缓存来加速解码速度 output = model.decoder( predict_tokens[:, -1:], encoder_out, encoder_padding_mask, cache ) # 预测下一个单词 output = torch.reshape(torch.argmax(output, dim=-1), (batch_size, -1)) # 将预测得到的单词和历史预测拼接,作为最终预测结果 predict_tokens = torch.cat([predict_tokens, output], dim=-1) # 将结束符后的单词都标记为结束符 mask = torch.cumsum(torch.eq(predict_tokens, sep_id).int(), dim=1) predict_tokens = predict_tokens.masked_fill(mask > 0, sep_id) # 将预测结果的id还原为文本 predict_text = tokenizer.batch_decode(predict_tokens, skip_special_tokens=True) print(">>>>> source text") print("\n".join(src_text)) print(">>>>> target text") print("\n".join(trg_text)) print(">>>>> predict text") print("\n".join(predict_text))如果运行顺利的话,你会看到下面的输出信息:========================TRAIN======================== TransformerEmbeddingLayer #0 bind weights and grads. TransformerEncoderLayer #0 bind weights and grads. TransformerEncoderLayer #1 bind weights and grads. TransformerEncoderLayer #2 bind weights and grads. TransformerEncoderLayer #3 bind weights and grads. TransformerEncoderLayer #4 bind weights and grads. TransformerEncoderLayer #5 bind weights and grads. TransformerEmbeddingLayer #1 bind weights and grads. TransformerDecoderLayer #0 bind weights and grads. Decoder layer #0 allocate encdec_kv memory TransformerDecoderLayer #1 bind weights and grads. TransformerDecoderLayer #2 bind weights and grads. TransformerDecoderLayer #3 bind weights and grads. TransformerDecoderLayer #4 bind weights and grads. TransformerDecoderLayer #5 bind weights and grads. epoch 000: 725.560 epoch 200: 96.252 epoch 400: 15.151 epoch 600: 5.770 epoch 800: 3.212 epoch 1000: 1.748 epoch 1200: 0.930 epoch 1400: 0.457 epoch 1600: 0.366 epoch 1800: 0.299 ========================TEST======================== >>>>> source text What is the fastest library in the world? You are so pretty! What do you love me for? The sparrow outside the window hovering on the telephone pole. >>>>> target text I guess it must be LightSeq, because ByteDance is the fastest. Thanks very much and you are pretty too. Love your beauty, smart, virtuous and kind. You said all this is very summery. >>>>> predict text I guess it must be LightSeq, because ByteDance is the fastest. Thanks very much and you are pretty too. Love your beauty, smart, virtuous and kind. You said all this is very summery.可以看到,最后的预测文本和真实的目标端文本完全一致。当然这里的例子非常简单,输入输出只有4句话。如果你有大量的对话数据集的话,你就可以训练出一个非常完美的聊天机器人啦,还愁啥没有女朋友呢?如果觉得LightSeq比较好用,别忘了给个star,是给我们最大的支持。https://github.com/bytedance/lightseq
文章
机器学习/深度学习  ·  人工智能  ·  自然语言处理  ·  机器人  ·  PyTorch  ·  算法框架/工具
2022-06-25
LeetCode 21-25 题 详解 Java版 ( 万字 图文详解 LeetCode 算法题21-25 =====>>> <建议收藏>)
目录第21题. Merge Two Sorted Lists解法一 迭代解法二 递归总第22题 . Generate Parentheses1. 题目描述(中等难度)解法一 暴力破解解法二解法三扩展 卡塔兰数总第23题: Merge k Sorted Lists解法一 暴力破解解法二 一列一列比较解法三 优先队列解法四 两两合并解法五 两两合并优化总第24题: Swap Nodes in Pairs解法一 迭代解法二 递归总第25题 : Reverse Nodes in k-Group解法一 迭代解法二递归总喜欢 请点个 + 关注第21题. Merge Two Sorted Lists题目描述(简单难度)合并两个有序链表。解法一 迭代遍历两个链表。public ListNode mergeTwoLists(ListNode l1, ListNode l2) { ListNode h = new ListNode(0); ListNode ans=h; while (l1 != null && l2 != null) { if (l1.val < l2.val) { h.next = l1; h = h.next; l1 = l1.next; } else { h.next = l2; h = h.next; l2 = l2.next; } } if(l1==null){ h.next=l2; } if(l2==null){ h.next=l1; } return ans.next; } 时间复杂度:O(m + n)。空间复杂度:O(1)。解法二 递归参考[这里](https://leetcode.wang/Merge Two Sorted Lists)ListNode mergeTwoLists(ListNode l1, ListNode l2) { if(l1 == null) return l2; if(l2 == null) return l1; if(l1.val < l2.val) { l1.next = mergeTwoLists(l1.next, l2); return l1; } else { l2.next = mergeTwoLists(l2.next, l1); return l2; } } 时间复杂度:空间复杂度:总递归看起来,两个字,优雅!但是关于递归的时间复杂度,空间复杂度的求法,先留个坑吧。第22题 . Generate Parentheses1. 题目描述(中等难度)给一个数字 n ,返回所有合法的括号匹配,刚好和[20题](https://leetcode.wang/leetCode-20-Valid Parentheses.html)相反。自己没想出来,全部参考 LeetCode 给出的 Solution。解法一 暴力破解列举所有的情况,每一位有左括号和右括号两种情况,总共 2n 位,所以总共 22n2^{2n}22n 种情况。public List<String> generateParenthesis(int n) { List<String> combinations = new ArrayList(); generateAll(new char[2 * n], 0, combinations); return combinations; } public void generateAll(char[] current, int pos, List<String> result) { if (pos == current.length) { if (valid(current)) result.add(new String(current)); } else { current[pos] = '('; generateAll(current, pos+1, result); current[pos] = ')'; generateAll(current, pos+1, result); } } public boolean valid(char[] current) { int balance = 0; for (char c: current) { if (c == '(') balance++; else balance--; if (balance < 0) return false; } return (balance == 0); } 时间复杂度:对每种情况判断是否合法需要 O(n),所以时间复杂度是 O(22nn)O(2^{2n}n)O(22nn) 。空间复杂度:O(22nn)O(2^{2n}n)O(22nn),乘以 n 是因为每个串的长度是 2n。此外这是假设所有情况都符合的时候,但其实不可能都符合,后边会给出更精确的情况。解法二解法一中,我们不停的加左括号,其实如果左括号超过 n 的时候,它肯定不是合法序列了。因为合法序列一定是 n 个左括号和 n 个右括号。还有一种情况就是如果添加括号的过程中,如果右括号的总数量大于左括号的总数量了,后边不论再添加什么,它都不可能是合法序列了。因为每个右括号必须和之前的某个左括号匹配,如果右括号数量多于左括号,那么一定有一个右括号没有与之匹配的左括号,后边不论加多少左括号都没有用了。例如 n = 3 ,总共会有 6 个括号,我们加到 ( ) ) 3 个括号的情况的时候,有 1 个左括号,2 个右括号,此时后边 3 个括号无论是什么,已经注定它不会是合法序列了。基于上边的两点,我们只要避免它们,就可以保证我们生成的括号一定是合法的了。public List<String> generateParenthesis(int n) { List<String> ans = new ArrayList(); backtrack(ans, "", 0, 0, n); return ans; } public void backtrack(List<String> ans, String cur, int left, int right, int n){ if (cur.length() == n * 2) { ans.add(cur); return; } //左括号不要超过 n if (left < n) backtrack(ans, cur+"(", left+1, right, n); //右括号不要超过左括号 if (right < left) backtrack(ans, cur+")", left, right+1, n); } 时间复杂度:空间复杂度:递归的复杂度分析,继续留坑 =.=。解法三解法二中是用列举的方法,仔细想想,我们每次用递归的时候,都是把大问题换成小问题然后去解决,这道题有没有这个思路呢?我们想一下之前的列举过程,第 0 个位置一定会是左括号,然后接着添加左括号或右括号,过程中左括号数一定大于或等于右括号数,当第一次出现左括号数等于右括号数的时候,假如此时的位置是 c 。那么位置 1 到 c - 1 之间一定是合法序列,此外 c + 1 到最后的 2n -1 也是合法序列。而假设总共是 n 组括号,1 到 c - 1 是 a 组括号, c + 1 到 2n - 1 之间则是 n - 1 - a 组括号,如下图a = 1,b = 1,对应 (())(()) 一种情况,此时 c = 3。a = 2,b = 0 对应 ((())), (()()) 两种情况,此时 c = 5。所以我们如果要想求 n 组括号,只需要知道 a 组和 b 组的情况,然后组合起来就可以了。看起来我们在迭代 a ,其实本质上是在迭代 c ,c = 2a + 1,迭代 a 从 0 到 n - 1 ,就是迭代 c 从 1 到 2n - 1。看起来 c 都是奇数,其实是可以理解的,因为 0 到 c 间都是一组组的括号, 所以 c 一定是奇数。为什么可以迭代 c ,因为上边说到每一个合法序列都对应着一个 c ,遍历 c 的话,就能得到所有的情况了,看一下代码吧。public List<String> generateParenthesis(int n) { List<String> ans = new ArrayList(); if (n == 0) { ans.add(""); } else { for (int a = 0; a < n; a++) for (String left: generateParenthesis(a)) for (String right: generateParenthesis(n-1-a)) ans.add("(" + left + ")" + right); } return ans; } 时间复杂度:空间复杂度:留坑。扩展 卡塔兰数如果这道题不是让你列举所有的情况, 而是仅仅让你输出 n 对应下有多少种合法序列,该怎么做呢?答案就是 1n+1C2nn\frac{1}{n+1}C^n_{2n}n+11C2nn,也可以写成1n+1(2nn)\frac{1}{n+1}\binom{2n}{n}n+11(n2n)。怎么证明呢?我主要参考了这里,说一下。我们假设不考虑是不是合法序列,那么就一共有C2nnC^n_{2n}C2nn种情况,然后我们只需要把里边的非法情况减去就可以了,一共有多少种非法情况呢?首先我们用C2nnC^n_{2n}C2nn就保证了一定是有 n 个左括号,n 个右括号,那么为什么出现了非法序列?为了方便论述,我们把左括号记为 +1,右括号记为 -1.ps:下边的 和 都是指两个数的和,不是你和我中的和。我们假设非法序列的集合是 M ,而非法序列就是列举过程中右括号数比左括号数多了,也就是和小于 0 了,变成 -1 了。这种情况一旦出现,后边无论是什么括号都改变不了它是非法序列的命了。我们将第一次和等于 -1 的时候的位置记为 d 。每一个非法序列一定存在这样一个 d 。然后关键的地方到了!此时我们把 0 到 d 所有的 -1 变成 1,1 变成 -1,我们将每一个非法序列都这样做,就构成了一个新的集合 N ,并且这个集合 N 一定和 M 中的元素一一对应( N -> M,在集合 N 中第一次出现和为 1 的位置也就是 d ,把 0 到 d 中所有的 -1 变成 1,1 变成 -1 就回到了 M),从而集合 M 的数量就等于集合 N 的数量。集合 N 的数量是多少呢?我们来分析下集合 N 是什么样的,集合 N 对应的集合 M 原来的序列本来是这样的,在 0 到 d 之间和是 -1 ,也就是 -1 比 +1 多一个,d + 1 到最后的和一定是 1(因为 n 个 +1 和 n 个 -1 的和一定是 0 ,由于 0 到 d 和是 -1,后边的和一定是 1),也就意味着 +1 比 -1 多一个。而在集合 N 中,我们把 0 到 d 的 -1 变成了 +1 ,+1 变成了 -1 ,所以也变成了 +1 比 -1 多一个,所以集合 N 总共就是 +1 比 -1 多 2 个的集合,也就是 n + 1 个 +1 和 n - 1 个 -1 。所以集合 N 就是 2n 个位置中选 n - 1 个位置放 -1,其他位置放 +1,总共就有 C2nn−1C^{n - 1}{2n}C2nn−1,所以集合 M 也有 C2nn−1C^{n - 1}{2n}C2nn−1种。所有合法序列就有 C2nn−C2nn−1=1n+1C2nnCn_{2n}-C{n-1}{2n}=\frac{1}{n+1}C^n{2n}C2nn−C2nn−1=n+11C2nn 。将集合 M 和集合 N 建立了一一映射,从而解决了问题,神奇!!!!!!!!!!其实,这个数列就是卡塔兰数,可以看下维基百科的定义。而这个数列,其实除了括号匹配,还有很多类似的问题,其本质是一样的,例如,2n 个人排队买票,其中 n 个人持 50 元,n 个人持 100 元。每张票 50 元,且一人只买一张票。初始时售票处没有零钱找零。请问这 2n 个人一共有多少种排队顺序,不至于使售票处找不开钱?对于一个无限大的栈,一共n个元素,请问有几种合法的入栈出栈形式?P = a1 a2 a3 … an,其中 ai 是矩阵。根据乘法结合律,不改变矩阵的相互顺序,只用括号表示成对的乘积,试问一共有几种括号化方案?n 个结点可构造多少个不同的二叉树?… …更多例子可以看维基百科和这里。而 Solutin 给出的时间复杂度,其实就是卡特兰数。维基百科的给出的性质。总本以为这道题挺常规的,然后自己一直卡在解法三的理解上,查来查去,竟然查出了卡塔兰数,虽然似乎和解法三也没什么关系,但又开阔了很多思路。解法三分析出来的迭代方法,以及用映射证明卡塔兰数的求法,棒!第23题: Merge k Sorted Lists题目描述(困难难度)k 个有序链表的合并。我们用 N 表示链表的总长度,考虑最坏情况,k 个链表的长度相等,都为 n 。解法一 暴力破解简单粗暴,遍历所有的链表,将数字存到一个数组里,然后用快速排序,最后再将排序好的数组存到一个链表里。public ListNode mergeKLists(ListNode[] lists) { List<Integer> l = new ArrayList<Integer>(); //存到数组 for (ListNode ln : lists) { while (ln != null) { l.add(ln.val); ln = ln.next; } } //数组排序 Collections.sort(l); //存到链表 ListNode head = new ListNode(0); ListNode h = head; for (int i : l) { ListNode t = new ListNode(i); h.next = t; h = h.next; } h.next = null; return head.next; } 时间复杂度:假设 N 是所有的数字个数,存到数组是 O(N),排序如果是用快速排序就是 O(NlogN)O(Nlog_N)O(NlogN) ,存到链表是 O(N),所以取个最大的,就是 O(NlogN)O(Nlog_N)O(NlogN)。空间复杂度:新建了一个链表,O(N)。解法二 一列一列比较我们可以一列一列的比较,将最小的一个存到一个新的链表里。public ListNode mergeKLists(ListNode[] lists) { int min_index = 0; ListNode head = new ListNode(0); ListNode h = head; while (true) { boolean isBreak = true;//标记是否遍历完所有链表 int min = Integer.MAX_VALUE; for (int i = 0; i < lists.length; i++) { if (lists[i] != null) { //找出最小下标 if (lists[i].val < min) { min_index = i; min = lists[i].val; } //存在一个链表不为空,标记改完 false isBreak = false; } } if (isBreak) { break; } //加到新链表中 ListNode a = new ListNode(lists[min_index].val); h.next = a; h = h.next; //链表后移一个元素 lists[min_index] = lists[min_index].next; } h.next = null; return head.next; } 时间复杂度:假设最长的链表长度是 n ,那么 while 循环将循环 n 次。假设链表列表里有 k 个链表,for 循环执行 k 次,所以时间复杂度是 O(kn)。空间复杂度:N 表示最终链表的长度,则为 O(N)。其实我们不需要创建一个新链表保存,我们只需要改变得到的最小结点的指向就可以了。public ListNode mergeKLists(ListNode[] lists) { int min_index = 0; ListNode head = new ListNode(0); ListNode h = head; while (true) { boolean isBreak = true; int min = Integer.MAX_VALUE; for (int i = 0; i < lists.length; i++) { if (lists[i] != null) { if (lists[i].val < min) { min_index = i; min = lists[i].val; } isBreak = false; } } if (isBreak) { break; } //最小的节点接过来 h.next = lists[min_index]; h = h.next; lists[min_index] = lists[min_index].next; } h.next = null; return head.next; } 时间复杂度:假设最长的链表长度是 n ,那么 while 循环将循环 n 次。假设链表列表里有 k 个链表,for 循环执行 k 次,所以时间复杂度是 O(kn)。空间复杂度:O(1)。解法三 优先队列解法二中,我们每次都是取出一个最小的,然后加入一个新的, O(1)的复杂度,再找最小的,O(k) 的复杂度。我们完全可以用一个优先队列。我们将优先级定义为数越小优先级越高,如果用堆实现优先队列,这样我们每次找最小不再需要 O(k),而是 O(log(k)),当然这样的话,我们加入新的话不再是 O(1),也需要 O(log(k))。可以看看这里和这里。public ListNode mergeKLists(ListNode[] lists) { //定义优先队列的比较器 Comparator<ListNode> cmp; cmp = new Comparator<ListNode>() { @Override public int compare(ListNode o1, ListNode o2) { // TODO Auto-generated method stub return o1.val-o2.val; } }; //建立队列 Queue<ListNode> q = new PriorityQueue<ListNode>(cmp); for(ListNode l : lists){ if(l!=null){ q.add(l); } } ListNode head = new ListNode(0); ListNode point = head; while(!q.isEmpty()){ //出队列 point.next = q.poll(); point = point.next; //判断当前链表是否为空,不为空就将新元素入队 ListNode next = point.next; if(next!=null){ q.add(next); } } return head.next; } 时间复杂度:假如总共有 N 个节点,每个节点入队出队都需要 log(k),所有时间复杂度是 O(N log(k))。空间复杂度:优先队列需要 O(k)的复杂度。解法四 两两合并利用之前合并两个链表的算法,我们直接两两合并,第 0 个和第 1 个链表合并,新生成的再和第 2 个链表合并,新生成的再和第 3 个链表合并…直到全部合并完。public ListNode mergeTwoLists(ListNode l1, ListNode l2) { ListNode h = new ListNode(0); ListNode ans=h; while (l1 != null && l2 != null) { if (l1.val < l2.val) { h.next = l1; h = h.next; l1 = l1.next; } else { h.next = l2; h = h.next; l2 = l2.next; } } if(l1==null){ h.next=l2; } if(l2==null){ h.next=l1; } return ans.next; } public ListNode mergeKLists(ListNode[] lists) { if(lists.length==1){ return lists[0]; } if(lists.length==0){ return null; } ListNode head = mergeTwoLists(lists[0],lists[1]); for (int i = 2; i < lists.length; i++) { head = mergeTwoLists(head,lists[i]); } return head; } 时间复杂度:不妨假设是 k 个链表并且长度相同,链表总长度为 N,那么第一次合并就是 N/k 和 N/k ,第二次合并就是 2 * N/k 和 N/k,第三次合并就是 3 * N/k 和 N / k,总共进行 n - 1 次合并,每次合并的时间复杂度是 O(n),所以总时间复杂度就是O(∑i=1k−1(i∗Nk+Nk))=O(kN)O(\sum_{i=1}^{k-1}(i*\frac{N}{k}+\frac{N}{k}))=O(kN)O(∑i=1k−1(i∗kN+kN))=O(kN),可以将两项分开,N/k 其实是常数,分开的第一项是等差数列。空间复杂度:O(1)。解法五 两两合并优化依旧假设是 k 个链表,合并的过程优化下,使得只需要合并 log(k)次。public ListNode mergeTwoLists(ListNode l1, ListNode l2) { ListNode h = new ListNode(0); ListNode ans=h; while (l1 != null && l2 != null) { if (l1.val < l2.val) { h.next = l1; h = h.next; l1 = l1.next; } else { h.next = l2; h = h.next; l2 = l2.next; } } if(l1==null){ h.next=l2; } if(l2==null){ h.next=l1; } return ans.next; } public ListNode mergeKLists(ListNode[] lists) { if(lists.length==0){ return null; } int interval = 1; while(interval<lists.length){ System.out.println(lists.length); for (int i = 0; i + interval< lists.length; i=i+interval*2) { lists[i]=mergeTwoLists(lists[i],lists[i+interval]); } interval*=2; } return lists[0]; } 时间复杂度:假设每个链表的长度都是 n ,有 k 个链表,记总结点数是 N = n * k,那么时间复杂度就是O(∑i=1log2kN)=O(Nlogk)O(\sum_{i=1}^{log_2k}N)=O(Nlogk)O(∑i=1log2kN)=O(Nlogk)。空间复杂度:O(1)。总优先队列的运用印象深刻,此外对两两链表的合并,我们仅仅改变了合并的方式就将时间复杂度降低了很多,美妙!第24题: Swap Nodes in Pairs题目描述(中等难度)给定一个链表,然后两两交换链表的位置。解法一 迭代首先为了避免单独讨论头结点的情况,一般先申请一个空结点指向头结点,然后再用一个指针来遍历整个链表。先来看一下图示:point 是两个要交换结点前边的一个位置。public ListNode swapPairs(ListNode head) { ListNode dummy = new ListNode(0); dummy.next = head; ListNode point = dummy; while (point.next != null && point.next.next != null) { ListNode swap1 = point.next; ListNode swap2 = point.next.next; point.next = swap2; swap1.next = swap2.next; swap2.next = swap1; point = swap1; } return dummy.next; } 时间复杂度:O(n)。空间复杂度:O(1)。解法二 递归参考这里。自己画了个参考图。public ListNode swapPairs(ListNode head) { if ((head == null)||(head.next == null)) return head; ListNode n = head.next; head.next = swapPairs(head.next.next); n.next = head; return n; } 递归时间复杂度留坑。总自己开始没有想出递归的算法,每次都会被递归的简洁吸引。另外,感觉链表的一些题,只要画图打打草稿,搞清指向关系,一般不难。第25题 : Reverse Nodes in k-Group题目描述(困难难度)将一个链表,每 k 个倒置,最后一组不足 k 个就不倒置。解法一 迭代关于单链表倒置,我们在第 2 题就讨论过。有了单链表倒置,这道题无非就是用一个循环,每次将 k 个结点取下来,倒置后再接回去,然后再取 k 个,以此循环,到了最后一组如果不足 k 个,不做处理,直接返回头结点就可以了。所以关键就是,指针指来指去,大家不晕掉就好,我做了图示,大家参考一下。为了将头结点也一般化,我们创建一个 dummy 结点,然后整个过程主要运用三个指针, tail 指针表示已经倒置后的链表的尾部,subhead 指针表示要进行倒置的子链表,toNull 指针为了将子链表从原来链表中取下来。一个 while 循环,让 toNull 指针走 k - 1 步使其指向子链表的尾部。中间的 if 语句就是判断当前节点数够不够 k 个了,不够的话直接返回结果就可以了将子链表指向 null ,脱离出来。并且用 temp 保存下一个结点的位置。然后调用倒置函数,将子链表倒置。接下来四步分别是,新链表接到 tail(注意下边的图 tail 是更新后的位置,之前 tail 在 dummy 的位置) 的后边;更新 tail 到新链表的尾部,也就是之前的 subhead (下图 subhead 也是更新后的位置,之前的位置参见上边的图);sub_head 更新到 temp 的位置;toNull 到 sub_head 的位置;然后将新的尾部 tail 把之前断开的链表连起来,接到 sub_head 上。整理下其实就是下边的样子和初始的时候(下边的图)对比一下,发现 tail,subhead 和 toNull 三个指针已经就位,可以愉快的重复上边的步骤了。看下代码吧。public ListNode reverseKGroup(ListNode head, int k) { if (head == null) return null; ListNode sub_head = head; ListNode dummy = new ListNode(0); dummy.next = head; ListNode tail = dummy; ListNode toNull = head; while (sub_head != null) { int i = k; //找到子链表的尾部 while (i - 1 > 0) { toNull = toNull.next; if (toNull == null) { return dummy.next; } i--; } ListNode temp = toNull.next; //将子链表断开 toNull.next = null; ListNode new_sub_head = reverse(sub_head); //将倒置后的链表接到 tail 后边 tail.next = new_sub_head; //更新 tail tail = sub_head; //sub_head 由于倒置其实是新链表的尾部 sub_head = temp; toNull = sub_head; //将后边断开的链表接回来 tail.next = sub_head; } return dummy.next; } public ListNode reverse(ListNode head) { ListNode current_head = null; while (head != null) { ListNode next = head.next; head.next = current_head; current_head = head; head = next; } return current_head; } 时间复杂度:while 循环中本质上我们只是将每个结点访问了一次,加上结点倒置访问的一次,所以总共加起来每个结点其实只访问了 2 次。所以时间复杂度是 O(n)。空间复杂度:O(1)。解法二递归有没有被解法一的各种指针绕晕呢,我们有一个更好的选择,递归,这样看起来就会简洁很多。public ListNode reverseKGroup(ListNode head, int k) { if (head == null) return null; ListNode point = head; //找到子链表的尾部 int i = k; while(i - 1 >0){ point = point.next; if (point == null) { return head; } i--; } ListNode temp = point.next; //将子链表断开 point.next = null; //倒置子链表,并接受新的头结点 ListNode new_head = reverse(head); //head 其实是倒置链表的尾部,然后我们将后边的倒置结果接过来就可以了 //temp 是链表断开后的头指针,可以参考解法一的图示 head.next = reverseKGroup(temp,k); return new_head; } public ListNode reverse(ListNode head) { ListNode current_head = null; while (head != null) { ListNode next = head.next; head.next = current_head; current_head = head; head = next; } return current_head; } 复杂度:递归留坑。总还是那句话,涉及到链表的,我们就画下图,把各个指针的移动理清楚,一般没啥问题。今天我们一起学习了LeetCode 题的算法分析,感谢大家阅读,觉得不错记得收藏哦!喜欢 请点个 + 关注
文章
机器学习/深度学习  ·  人工智能  ·  算法  ·  Java
2022-06-25
1 2 3 4 5 6 7 8 9
...
20
跳转至:
人工智能
2623 人关注 | 9276 讨论 | 68556 内容
+ 订阅
  • React系列八 - 深入理解setState
  • 前端面试 | 18个常见HTML问题与答案
  • 构建WEB项目的 25 个HTML建议
查看更多 >
开发与运维
5243 人关注 | 125834 讨论 | 202512 内容
+ 订阅
  • CSS3——CSS3 新增选择器
  • JDK1.8新特性(七):默认方法,真香,开动!接口?我要升级!!
  • React系列九 - 受控非受控组件
查看更多 >
云计算
21619 人关注 | 57900 讨论 | 39092 内容
+ 订阅
  • 构建WEB项目的 25 个HTML建议
  • React系列二 - 核心JSX语法二
  • React系列(一) -邂逅React开发
查看更多 >
云原生
230264 人关注 | 9544 讨论 | 29730 内容
+ 订阅
  • 页面制作的15个CSS技巧
  • 前端面试 | 18个常见HTML问题与答案
  • React系列五 - 组件化开发(一)
查看更多 >
数据库
249308 人关注 | 44495 讨论 | 62466 内容
+ 订阅
  • PolarDB for PostgreSQL 开源创造营课后练习
  • JDK1.8新特性(六):Stream的终极操作,轻松解决集合分组、汇总等复杂操作
  • JDK1.8新特性(五):Stream,集合操作利器,让你好用到飞起来
查看更多 >