小样本学习在文心ERNIE3.0多分类任务应用--提示学习

简介: 小样本学习在文心ERNIE3.0多分类任务应用--提示学习

小样本学习在文心ERNIE3.0多分类任务应用(提示学习)

项目链接:
https://aistudio.baidu.com/aistudio/projectdetail/4438610?contributionType=1

0.小样本学习简介

二分类/多分类任务在商品分类、网页分类、新闻分类、医疗文本分类等现实场景中有着广泛应用。现有的主流解决方案是在大规模预训练语言模型进行微调,因为下游任务和预训练任务训练目标不同,想要取得较好的分类效果往往需要大量标注数据,因此学界和业界开始研究如何在小样本学习(Few-shot Learning)场景下取得更好的学习效果。

提示学习(Prompt Learning) 的主要思想是通过任务转换使得下游任务和预训练任务尽可能相似,充分利用预训练语言模型学习到的特征,从而降低样本需求量。除此之外,我们往往还需要在原有的输入文本上拼接一段“提示”,来引导预训练模型输出期望的结果。

我们以Ernie为例,回顾一下这类预训练语言模型的训练任务。 与考试中的完形填空相似,给定一句文本,遮盖掉其中的部分字词,要求语言模型预测出这些遮盖位置原本的字词。

因此,我们也将多分类任务转换为与完形填空相似的形式。例如影评情感分类任务,标签分为1-正向,0-负向两类。

  • 在经典的微调方式中,需要学习的参数是以[CLS]向量为输入,以负向/正向为输出的随机初始化的分类器。
  • 在提示学习中,我们通过构造提示,将原有的分类任务转化为完形填空。如下图所示,通过提示我[MASK]喜欢。,原有1-正向,0-负向的标签被转化为了预测空格是很还是不。此时的分类器也不再是随机初始化,而是利用了这两个字的预训练向量来初始化,充分利用了预训练模型学习到的参数。

在这里插入图片描述

对于标注样本充足的场景可以直接使用预训练模型微调实现文本多分类,对于尚无标注或者标注样本较少的任务场景我们推荐使用小样本学习,以取得比微调方法更好的效果。

下边通过新闻分类的例子展示如何使用小样本学习来进行文本分类。

0.1 环境要求

python >= 3.6

paddlepaddle >= 2.3

paddlenlp >= 2.4.0 【预计9月份上线】

0.1.1提前尝鲜获取最新版本

!pip install git+

pip 从 git 源码仓库直接 install 。要求是这个github仓库内要有setup.py文件

查看url:

在这里插入图片描述

安装git仓库中的包

pip install git+<git仓库地址>
pip install git+<git仓库地址>@<分支名称>

用到它的场景就是比如你有一个代码已经上传到github了,要分发给别人用,你就懒得再下载下来再导出成tar.gz或者whl文件,是可以直接让他从github网址安装

pip install git+http://127.0.0.1/xxx/demo.git --user
pip install git+https://github.com/shadowsocks/shadowsocks.git@master

等价于:

# 两步走的安装(安装完还需要自己删除git文件)
git clone http://127.0.0.1/XXX/demo.git
#change dir
cd demo
# install
python setup.py install --user
# windows环境下加--user 不然容易报错
#直接pip git 无法安装成功,那就采用第二种方案
# !pip install git+https://github.com/PaddlePaddle/PaddleNLP.git@develop

先把paddlenlp develop分支下下载到本地,进行压缩上传到aistudio。

$ git clone https://github.com/PaddlePaddle/PaddleNLP.git
Cloning into 'PaddleNLP'...
remote: Enumerating objects: 27619, done.
remote: Counting objects: 100% (125/125), done.
remote: Compressing objects: 100% (106/106), done.
remote: Total 27619 (delta 46), reused 78 (delta 18), pack-reused 27494
Receiving objects: 100% (27619/27619), 73.69 MiB | 3.03 MiB/s, done.
Resolving deltas: 100% (18182/18182), done.
Updating files: 100% (3115/3115), done.

获得的paddlenl文件夹在C:\Users\admin路径下,然后进行压缩上传到aistudio
win10压缩参考下面链接:(或者直接用压缩软件压缩)
https://blog.csdn.net/sinat_39620217/article/details/126290315

zip -q -r paddlenlp.zip PaddleNLP
# !unzip paddlenlp.zip

#这里采用先装老库再覆盖的方案,确保依赖都安装时,不然直接解压执行setup.py会卡主
!pip install --upgrade paddlenlp

%cd PaddleNLP
!python setup.py install --user
%cd ..

0.2数据集格式要求

#获取数据
# !wget https://paddlenlp.bj.bcebos.com/datasets/few-shot/tnews.tar.gz
!tar zxvf tnews.tar.gz
!mv tnews data

数据集格式
对于训练/验证/测试数据集文件,每行数据表示一条样本,包括文本和标签两部分,由tab符\t分隔。格式如下

文登区这些公路及危桥将进入封闭施工,请注意绕行! news_car
普洱茶要如何醒茶? news_culture
...

对于待预测数据文件,每行包含一条待预测样本,无标签。格式如下

互联网时代如何保护个人信息
清秋暮雨读柳词:忍把浮名,换了浅斟低唱丨周末读诗
...

对于分类标签集文件,存储了数据集中所有的标签集合,每行为一个标签名。如果需要自定义标签映射用于分类器初始化,则每行需要包括标签名和相应的映射词,由==分隔。格式如下

news_car'=='汽车
news_culture'=='文化
...

Note 这里的标签映射词定义遵循的规则是,不同映射词尽可能长度一致,映射词和提示需要尽可能构成通顺的语句。越接近自然语句,小样本下模型训练效果越好。如果原标签名已经可以构成通顺语句,也可以不构造映射词,每行一个标签即可。

1.模型训练与预测

这里提示一下:

如果运行程序报错:

Traceback (most recent call last):
  File "train.py", line 23, in <module>
    from paddlenlp.prompt import (
ModuleNotFoundError: No module named 'paddlenlp.prompt'

是因为paddlenlp.prompt在2.4.0版本才会有,请检查上面步骤是否有遗漏,

!export CUDA_VISIBLE_DEVICES=0
!python train.py \
--data_dir ./data/tnews  \
--output_dir ./checkpoints/ \
--prompt "这条新闻标题的主题是" \
--max_seq_length 128  \
--learning_rate 3e-5 \
--ppt_learning_rate 3e-4 \
--do_train \
--do_eval \
--max_steps 1000 \
--eval_steps 100 \
--logging_steps 10 \
--per_device_eval_batch_size 32 \
--per_device_train_batch_size 32 

结果部分展示:

Training Configuration Arguments

[2022-08-18 11:42:58,983] [    INFO] - paddle commit id              :3cc6ae69ed93388b2648bcc819d593130dede752
[2022-08-18 11:42:58,983] [    INFO] - _no_sync_in_gradient_accumulation:True
[2022-08-18 11:42:58,983] [    INFO] - adam_beta1                    :0.9
[2022-08-18 11:42:58,983] [    INFO] - adam_beta2                    :0.999
[2022-08-18 11:42:58,983] [    INFO] - adam_epsilon                  :1e-08
[2022-08-18 11:42:58,983] [    INFO] - alpha_rdrop                   :5.0
[2022-08-18 11:42:58,983] [    INFO] - alpha_rgl                     :0.5
[2022-08-18 11:42:58,983] [    INFO] - current_device                :gpu:0
[2022-08-18 11:42:58,984] [    INFO] - dataloader_drop_last          :False
[2022-08-18 11:42:58,984] [    INFO] - dataloader_num_workers        :0
[2022-08-18 11:42:58,984] [    INFO] - device                        :gpu
[2022-08-18 11:42:58,984] [    INFO] - disable_tqdm                  :False
[2022-08-18 11:42:58,984] [    INFO] - do_eval                       :True
[2022-08-18 11:42:58,984] [    INFO] - do_export                     :False
[2022-08-18 11:42:58,984] [    INFO] - do_predict                    :False
[2022-08-18 11:42:58,984] [    INFO] - do_train                      :True
[2022-08-18 11:42:58,984] [    INFO] - eval_batch_size               :32
[2022-08-18 11:42:58,984] [    INFO] - eval_steps                    :100
[2022-08-18 11:42:58,984] [    INFO] - evaluation_strategy           :IntervalStrategy.STEPS
[2022-08-18 11:42:58,984] [    INFO] - first_max_length              :None
[2022-08-18 11:42:58,984] [    INFO] - fp16                          :False
[2022-08-18 11:42:58,984] [    INFO] - fp16_opt_level                :O1
[2022-08-18 11:42:58,984] [    INFO] - freeze_dropout                :False
[2022-08-18 11:42:58,984] [    INFO] - freeze_plm                    :False
[2022-08-18 11:42:58,984] [    INFO] - gradient_accumulation_steps   :1
[2022-08-18 11:42:58,984] [    INFO] - greater_is_better             :None
[2022-08-18 11:42:58,984] [    INFO] - ignore_data_skip              :False
[2022-08-18 11:42:58,984] [    INFO] - label_names                   :None
[2022-08-18 11:42:58,984] [    INFO] - learning_rate                 :3e-05
[2022-08-18 11:42:58,984] [    INFO] - load_best_model_at_end        :False
[2022-08-18 11:42:58,984] [    INFO] - local_process_index           :0
[2022-08-18 11:42:58,984] [    INFO] - local_rank                    :-1
[2022-08-18 11:42:58,984] [    INFO] - log_level                     :-1
[2022-08-18 11:42:58,984] [    INFO] - log_level_replica             :-1
[2022-08-18 11:42:58,985] [    INFO] - log_on_each_node              :True
[2022-08-18 11:42:58,985] [    INFO] - logging_dir                   :./checkpoints/runs/Aug18_11-42-56_jupyter-691158-4438610
[2022-08-18 11:42:58,985] [    INFO] - logging_first_step            :False
[2022-08-18 11:42:58,985] [    INFO] - logging_steps                 :10
[2022-08-18 11:42:58,985] [    INFO] - logging_strategy              :IntervalStrategy.STEPS
[2022-08-18 11:42:58,985] [    INFO] - lr_scheduler_type             :SchedulerType.LINEAR
[2022-08-18 11:42:58,985] [    INFO] - max_grad_norm                 :1.0
[2022-08-18 11:42:58,985] [    INFO] - max_seq_length                :128
[2022-08-18 11:42:58,985] [    INFO] - max_steps                     :5000
[2022-08-18 11:42:58,985] [    INFO] - metric_for_best_model         :None
[2022-08-18 11:42:58,985] [    INFO] - minimum_eval_times            :None
[2022-08-18 11:42:58,985] [    INFO] - no_cuda                       :False
[2022-08-18 11:42:58,985] [    INFO] - num_train_epochs              :3.0
[2022-08-18 11:42:58,985] [    INFO] - optim                         :OptimizerNames.ADAMW
[2022-08-18 11:42:58,985] [    INFO] - other_max_length              :None
[2022-08-18 11:42:58,985] [    INFO] - output_dir                    :./checkpoints/
[2022-08-18 11:42:58,985] [    INFO] - overwrite_output_dir          :False
[2022-08-18 11:42:58,985] [    INFO] - past_index                    :-1
[2022-08-18 11:42:58,985] [    INFO] - per_device_eval_batch_size    :32
[2022-08-18 11:42:58,985] [    INFO] - per_device_train_batch_size   :32
[2022-08-18 11:42:58,985] [    INFO] - ppt_adam_beta1                :0.9
[2022-08-18 11:42:58,985] [    INFO] - ppt_adam_beta2                :0.999
[2022-08-18 11:42:58,985] [    INFO] - ppt_adam_epsilon              :1e-08
[2022-08-18 11:42:58,985] [    INFO] - ppt_learning_rate             :0.0003
[2022-08-18 11:42:58,985] [    INFO] - ppt_weight_decay              :0.0
[2022-08-18 11:42:58,985] [    INFO] - prediction_loss_only          :False
[2022-08-18 11:42:58,985] [    INFO] - process_index                 :0
[2022-08-18 11:42:58,986] [    INFO] - remove_unused_columns         :True
[2022-08-18 11:42:58,986] [    INFO] - report_to                     :['visualdl']
[2022-08-18 11:42:58,986] [    INFO] - resume_from_checkpoint        :None
[2022-08-18 11:42:58,986] [    INFO] - run_name                      :./checkpoints/
[2022-08-18 11:42:58,986] [    INFO] - save_on_each_node             :False
[2022-08-18 11:42:58,986] [    INFO] - save_steps                    :500
[2022-08-18 11:42:58,986] [    INFO] - save_strategy                 :IntervalStrategy.STEPS
[2022-08-18 11:42:58,986] [    INFO] - save_total_limit              :None
[2022-08-18 11:42:58,986] [    INFO] - scale_loss                    :32768
[2022-08-18 11:42:58,986] [    INFO] - seed                          :42
[2022-08-18 11:42:58,986] [    INFO] - should_log                    :True
[2022-08-18 11:42:58,986] [    INFO] - should_save                   :True
[2022-08-18 11:42:58,986] [    INFO] - task_type                     :multi-class
[2022-08-18 11:42:58,986] [    INFO] - train_batch_size              :32
[2022-08-18 11:42:58,986] [    INFO] - truncate_mode                 :tail
[2022-08-18 11:42:58,986] [    INFO] - use_rdrop                     :False
[2022-08-18 11:42:58,986] [    INFO] - use_rgl                       :False
[2022-08-18 11:42:58,986] [    INFO] - warmup_ratio                  :0.0
[2022-08-18 11:42:58,986] [    INFO] - warmup_steps                  :0
[2022-08-18 11:42:58,986] [    INFO] - weight_decay                  :0.0
[2022-08-18 11:42:58,986] [    INFO] - world_size                    :1
[2022-08-18 11:42:58,989] [    INFO] - ***** Running training *****
[2022-08-18 11:42:58,989] [    INFO] -   Num examples = 240
[2022-08-18 11:42:58,989] [    INFO] -   Num Epochs = 625
[2022-08-18 11:42:58,989] [    INFO] -   Instantaneous batch size per device = 32
[2022-08-18 11:42:58,989] [    INFO] -   Total train batch size (w. parallel, distributed & accumulation) = 32
[2022-08-18 11:42:58,989] [    INFO] -   Gradient Accumulation steps = 1
[2022-08-18 11:42:58,989] [    INFO] -   Total optimization steps = 5000
[2022-08-18 11:42:58,989] [    INFO] -   Total num train samples = 160000

模型保存,以及指标性能

eval_loss: 3.820039987564087, eval_accuracy: 0.5625, eval_runtime: 3.6311, eval_samples_per_second: 66.095, eval_steps_per_second: 2.203, epoch: 125.0


 20%|███████▊                               | 1000/5000 [08:21<22:29,  2.96it/s]
100%|█████████████████████████████████████████████| 8/8 [00:01<00:00,  7.33it/s]
                                                                                [2022-08-18 11:51:20,044] [    INFO] - Saving model checkpoint to ./checkpoints/checkpoint-1000
[2022-08-18 11:51:20,045] [    INFO] - Trainer.model is not a `PretrainedModel`, only saving its state dict.
[2022-08-18 11:51:26,610] [    INFO] - tokenizer config file saved in ./checkpoints/checkpoint-1000/tokenizer_config.json
[2022-08-18 11:51:26,610] [    INFO] - Special tokens file saved in ./checkpoints/checkpoint-1000/special_tokens_map.json
#多卡训练
# !unset CUDA_VISIBLE_DEVICES
# !python -u -m paddle.distributed.launch --gpus 0,1,2,3 train.py \
# --data_dir ./data \
# --output_dir ./checkpoints/ \
# --prompt "这条新闻标题的主题是" \
# --max_seq_length 128  \
# --learning_rate 3e-5 \
# --ppt_learning_rate 3e-4 \
# --do_train \
# --do_eval \
# --max_steps 1000 \
# --eval_steps 100 \
# --logging_steps 10 \
# --per_device_eval_batch_size 32 \
# --per_device_train_batch_size 8 \
# --do_predict \
# --do_export

1.2 预测

在模型训练时开启--do_predict,训练结束后直接进行预测,也可以在训练结束后,通过运行以下命令加载模型参数进行预测:

可配置参数说明:

data_dir: 测试数据路径。数据格式要求详见数据准备,数据应存放在该目录下test.txt文件中,每行一条待预测文本。

output_dir: 日志的保存目录。

resume_from_checkpoint: 训练时模型参数的保存目录,用于加载模型参数。

do_predict: 是否进行预测。

max_seq_length: 最大句子长度,超过该长度的文本将被截断,不足的以Pad补全。提示文本不会被截断。

!python train.py --do_predict --data_dir ./data/tnews --output_dir ./predict_ckpt --resume_from_checkpoint ./checkpoints --max_seq_length 128
[2022-08-18 13:04:45,520] [    INFO] - ***** Running Prediction *****
[2022-08-18 13:04:45,520] [    INFO] -   Num examples = 2010
[2022-08-18 13:04:45,520] [    INFO] -   Pre device batch size = 8
[2022-08-18 13:04:45,520] [    INFO] -   Total Batch size = 8
[2022-08-18 13:04:45,520] [    INFO] -   Total prediction steps = 252
 99%|████████████████████████████████████████▋| 250/252 [00:15<00:00, 16.99it/s]***** test metrics *****
  test_accuracy           =     0.5468
  test_loss               =     3.4095
  test_runtime            = 0:00:16.65
  test_samples_per_second =    120.664
  test_steps_per_second   =     15.128
100%|█████████████████████████████████████████| 252/252 [00:15<00:00, 16.44it/s]

2. 模型导出与部署

2.1 导出

在训练结束后,需要将动态图模型导出为静态图参数用于部署推理。可以在模型训练时开启--do_export在训练结束后直接导出,也可以运行以下命令加载并导出训练后的模型参数,默认导出到在output_dir指定的目录下。

python train.py --do_predict --data_dir ./data --output_dir ./predict_ckpt --resume_from_checkpoint ./ckpt/ --max_seq_length 128

可配置参数说明:

data_dir: 标签数据路径。数据格式要求详见数据准备。

output_dir: 静态图模型参数和日志的保存目录。

resume_from_checkpoint: 训练时模型参数的保存目录,用于加载模型参数。

do_export: 是否将模型导出为静态图,保存路径为output_dir/export。

2.2 模型部署

模型转换与ONNXRuntime预测部署依赖Paddle2ONNX和ONNXRuntime,Paddle2ONNX支持将Paddle静态图模型转化为ONNX模型格式,算子目前稳定支持导出ONNX Opset 7~15,更多细节可参考:Paddle2ONNX。

https://github.com/PaddlePaddle/Paddle2ONNX

如果基于GPU部署,请先确保机器已正确安装NVIDIA相关驱动和基础软件,确保CUDA >= 11.2,CuDNN >= 8.2,并使用以下命令安装所需依赖:

pip install paddle2onnx==1.0.0rc3
python -m pip install onnxruntime-gpu onnx onnxconverter-common

如果基于CPU部署,请使用如下命令安装所需依赖:

pip install paddle2onnx==1.0.0rc3
python -m pip install onnxruntime
CPU端推理样例

python infer.py --model_path_prefix ckpt/export/model --data_dir ./data --batch_size 32 --device cpu

GPU端推理样例

python infer.py --model_path_prefix ckpt/export/model --data_dir ./data --batch_size 32 --device gpu --device_id 0

可配置参数说明:

model_path_prefix: 导出的静态图模型路径及文件前缀。

model_name_or_path: 内置预训练模型名,或者模型参数配置目录路径,用于加载tokenizer。默认为ernie-3.0-base-zh。

data_dir: 待推理数据所在路径,数据应存放在该目录下的data.txt文件。

max_seq_length: 最大句子长度,超过该长度的文本将被截断,不足的以Pad补全。提示文本不会被截断。

batch_size: 每次预测的样本数量。

device: 选择推理设备,包括cpu和gpu。默认为gpu。

device_id: 指定GPU设备ID。

num_threads: 设置CPU使用的线程数。默认为机器上的物理内核数。

3.总结

预训练语言模型的参数空间比较大,如果在下游任务上直接对这些模型进行微调,为了达到较好的模型泛化性,需要较多的训练数据。在实际业务场景中,特别是垂直领域、特定行业中,训练样本数量不足的问题广泛存在,极大地影响这些模型在下游任务的准确度,因此,预训练语言模型学习到的大量知识无法充分地发挥出来。本项目实现基于预训练语言模型的小样本数据调优,从而解决大模型与小训练集不相匹配的问题。

小样本学习是机器学习领域未来很有前景的一个发展方向,它要解决的问题很有挑战性、也很有意义。小样本学习中最重要的一点就是先验知识的利用,如果我们妥善解决了先验知识的利用,能够做到很好的迁移性,想必那时我们距离通用AI也不远了。

最后也可以看出目前在新闻数据做的小样本demo性能结果上还有所欠缺,后续将进行改进。

展望: 后续将完成模型融合环节提升性能,并做可解释性分析。

本人博客:https://blog.csdn.net/sinat_39620217?type=blog

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
lda模型和bert模型的文本主题情感分类实战
lda模型和bert模型的文本主题情感分类实战
244 0
|
机器学习/深度学习 自然语言处理 PyTorch
PyTorch应用实战六:利用LSTM实现文本情感分类
PyTorch应用实战六:利用LSTM实现文本情感分类
316 0
|
2月前
|
机器学习/深度学习 自然语言处理 算法
[大语言模型-论文精读] 大语言模型是单样本URL分类器和解释器
[大语言模型-论文精读] 大语言模型是单样本URL分类器和解释器
34 0
|
6月前
|
机器学习/深度学习 自然语言处理 前端开发
深度学习-[源码+数据集]基于LSTM神经网络黄金价格预测实战
深度学习-[源码+数据集]基于LSTM神经网络黄金价格预测实战
170 0
|
机器学习/深度学习 数据采集 自然语言处理
【Deep Learning A情感文本分类实战】2023 Pytorch+Bert、Roberta+TextCNN、BiLstm、Lstm等实现IMDB情感文本分类完整项目(项目已开源)
亮点:代码开源+结构清晰+准确率高+保姆级解析 🍊本项目使用Pytorch框架,使用上游语言模型+下游网络模型的结构实现IMDB情感分析 🍊语言模型可选择Bert、Roberta 🍊神经网络模型可选择BiLstm、LSTM、TextCNN、Rnn、Gru、Fnn共6种 🍊语言模型和网络模型扩展性较好,方便读者自己对模型进行修改
641 0
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch搭建循环神经网络(RNN)进行文本分类、预测及损失分析(对不同国家的语言单词和姓氏进行分类,附源码和数据集)
PyTorch搭建循环神经网络(RNN)进行文本分类、预测及损失分析(对不同国家的语言单词和姓氏进行分类,附源码和数据集)
341 1
|
7月前
|
机器学习/深度学习 自然语言处理 数据挖掘
预训练语言模型中Transfomer模型、自监督学习、BERT模型概述(图文解释)
预训练语言模型中Transfomer模型、自监督学习、BERT模型概述(图文解释)
209 0
|
7月前
|
机器学习/深度学习 数据采集 自然语言处理
PyTorch搭建LSTM神经网络实现文本情感分析实战(附源码和数据集)
PyTorch搭建LSTM神经网络实现文本情感分析实战(附源码和数据集)
805 0
|
机器学习/深度学习 自然语言处理
【文本分类】《基于提示学习的小样本文本分类方法》
使用P-turning提示学习,进行小样本文本分类。本文值得学习。
206 0
|
人工智能 自然语言处理 PyTorch
NLP文本匹配任务Text Matching [有监督训练]:PointWise(单塔)、DSSM(双塔)、Sentence BERT(双塔)项目实践
NLP文本匹配任务Text Matching [有监督训练]:PointWise(单塔)、DSSM(双塔)、Sentence BERT(双塔)项目实践
NLP文本匹配任务Text Matching [有监督训练]:PointWise(单塔)、DSSM(双塔)、Sentence BERT(双塔)项目实践
下一篇
DataWorks