三、环境构建
ERNIE 通过建模海量数据中的词、实体及实体关系,学习真实世界的语义知识。相较于 BERT 学习原始语言信号,ERNIE 直接对先验语义知识单元进行建模,增强了模型语义表示能力,以 Transformer 为网络基本组件,以Masked Bi-Language Model和 Next Sentence Prediction 为训练目标,通过预训练得到通用语义表示,再结合简单的输出层,应用到下游的 NLP 任务。本示例展示利用ERNIE进行文本分类任务。
1.paddlehub安装
!pip install -U -q paddlehub
2.模型下载
PaddleHub还提供BERT等模型可供选择, 当前支持文本分类任务的模型对应的加载示例如下:
模型名 | PaddleHub Module |
ERNIE, Chinese | hub.Module(name='ernie') |
ERNIE tiny, Chinese | hub.Module(name='ernie_tiny') |
ERNIE 2.0 Base, English | hub.Module(name='ernie_v2_eng_base') |
ERNIE 2.0 Large, English | hub.Module(name='ernie_v2_eng_large') |
BERT-Base, English Cased | hub.Module(name='bert-base-cased') |
BERT-Base, English Uncased | hub.Module(name='bert-base-uncased') |
BERT-Large, English Cased | hub.Module(name='bert-large-cased') |
BERT-Large, English Uncased | hub.Module(name='bert-large-uncased') |
BERT-Base, Multilingual Cased | hub.Module(nane='bert-base-multilingual-cased') |
BERT-Base, Multilingual Uncased | hub.Module(nane='bert-base-multilingual-uncased') |
BERT-Base, Chinese | hub.Module(name='bert-base-chinese') |
BERT-wwm, Chinese | hub.Module(name='chinese-bert-wwm') |
BERT-wwm-ext, Chinese | hub.Module(name='chinese-bert-wwm-ext') |
RoBERTa-wwm-ext, Chinese | hub.Module(name='roberta-wwm-ext') |
RoBERTa-wwm-ext-large, Chinese | hub.Module(name='roberta-wwm-ext-large') |
RBT3, Chinese | hub.Module(name='rbt3') |
RBTL3, Chinese | hub.Module(name='rbtl3') |
ELECTRA-Small, English | hub.Module(name='electra-small') |
ELECTRA-Base, English | hub.Module(name='electra-base') |
ELECTRA-Large, English | hub.Module(name='electra-large') |
ELECTRA-Base, Chinese | hub.Module(name='chinese-electra-base') |
ELECTRA-Small, Chinese | hub.Module(name='chinese-electra-small') |
!hub install roberta-wwm-ext-large
[32m[2022-09-10 14:56:17,391] [ INFO][0m - Module roberta-wwm-ext-large already installed in /home/aistudio/.paddlehub/modules/roberta_wwm_ext_large[0m[0m复制代码
四、模型Fine-tune
# 设置使用的GPU卡号 !export CUDA_VISIBLE_DEVICES=0 import paddlehub as hub
model = hub.Module(name='roberta-wwm-ext-large', task='seq-cls', num_classes=10)
[2022-09-10 14:56:20,673] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/roberta-wwm-ext-large/roberta_chn_large.pdparams W0910 14:56:20.677435 25230 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2 W0910 14:56:20.681991 25230 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
1.自定义数据集
from typing import Dict, List, Optional, Union, Tuple import os from paddlehub.env import DATA_HOME from paddlehub.utils.download import download_data from paddlehub.datasets.base_nlp_dataset import TextClassificationDataset from paddlehub.text.bert_tokenizer import BertTokenizer from paddlehub.text.tokenizer import CustomTokenizer class MyDataset(TextClassificationDataset): def __init__(self, tokenizer: Union[BertTokenizer, CustomTokenizer], max_seq_len: int = 128, mode: str = 'train'): base_path = "yt" if mode == 'train': data_file = 'train_ok.csv' elif mode=="val": data_file = 'val_ok.csv' super().__init__( base_path=base_path, tokenizer=tokenizer, max_seq_len=max_seq_len, mode=mode, data_file=data_file, label_list=['多问', '临床表现(病症表现)', '用法' ,'适用症', '定义', '病因' ,'治疗方法', '无法确定', '作用', '方法'], is_file_with_header=True)
2.Reader定义
接着生成一个文本分类的reader,reader负责将dataset的数据进行预处理,首先对文本进行切词,接着以特定格式组织并输入给模型进行训练。
ClassifyReader
的参数有以下三个:
dataset
: 传入PaddleHub Dataset;vocab_path
: 传入ERNIE/BERT模型对应的词表文件路径;max_seq_len
: ERNIE模型的最大序列长度,若序列长度不足,会通过padding方式补到max_seq_len
, 若序列长度大于该值,则会以截断方式让序列长度为max_seq_len
;
train_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=221, mode='train') val_dataset = MyDataset(tokenizer=model.get_tokenizer(), max_seq_len=221, mode='val')
[2022-09-10 14:56:26,339] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/roberta-wwm-ext-large/vocab.txt [2022-09-10 14:56:26,739] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/roberta-wwm-ext-large/vocab.txt
3.选择优化策略和运行配置
# 设置使用的GPU卡号 !export CUDA_VISIBLE_DEVICES=0
import paddle optimizer = paddle.optimizer.Adam(learning_rate=5e-5, parameters=model.parameters()) trainer = hub.Trainer(model, optimizer, checkpoint_dir='test_ernie_text_cls', use_gpu=True) trainer.train(train_dataset=train_dataset, eval_dataset=val_dataset,epochs=5, batch_size=32, save_interval=1)
五、模型预测
当完成Fine-tune后,Fine-tune过程在验证集上表现最优的模型会被保存在CHECKPOINTDIR/bestmodel目录下,其中{CHECKPOINT_DIR}/best_model目录下,其中CHECKPOINTDIR/bestmodel目录下,其中{CHECKPOINT_DIR}目录为Fine-tune时所选择的保存checkpoint的目录。
1.读取数据
import numpy as np import pandas as pd test=pd.read_csv('yt/test.csv',sep='\t') test_list=np.array(test.text).tolist() print(len(test_list))
2000
test_data=[] for item in test_list: test_data.append([item]) print(test_data[:10])
[['四磨汤口服液的成分是什么,可以改善腹痛腹泻的情况吗?'], ['阴茎上长泡是怎么回事?'], ['21金维他多维元素片适用于怎样的营养补给'], ['臀部护理是不是应该有什么步骤和用品呢?都应该怎么进行?'], ['您的答复我收到了,谢谢!我还想在问一下,如果胸腔积液继续增加有什么好的办法解决吗?今天下午感觉胸闷又排了五百多毫升我观察比前几次浓度增加了。'], ['晕痛定胶囊这个药物能不能长期吃的啊,不知道治疗偏头痛怎么样呢'], ['医师你好,我女儿右腿左侧胎生血管瘤,现在还显瘀青色,有人建议激光治疗,激光对她太疼,有什么更好的治疗?内蒙古婴儿的血管瘤长手上危害大不大?'], ['月见草油胶丸会导致大便稀?需要停止用药吗?'], ['宝宝经常性便秘怎么办?'], ['有没有人知道在深圳哪个药房才可以买到益肾蠲痹丸这个药物呢?']]
label_list=['多问', '临床表现(病症表现)', '用法' ,'适用症', '定义', '病因' ,'治疗方法', '无法确定', '作用', '方法'] label_map = {} for i in range(10): label_map[i]=label_list[i] print(label_map)
{0: '多问', 1: '临床表现(病症表现)', 2: '用法', 3: '适用症', 4: '定义', 5: '病因', 6: '治疗方法', 7: '无法确定', 8: '作用', 9: '方法'}
2.模型预测
import paddlehub as hub model = hub.Module( name='ernie_tiny', version='2.0.1', task='seq-cls', load_checkpoint='./test_ernie_text_cls/best_model/model.pdparams', label_map=label_map) results = model.predict(test_data, max_seq_len=221, batch_size=1, use_gpu=True)
[2022-09-10 15:03:03,990] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/ernie_tiny.pdparams [2022-09-10 15:03:07,737] [ INFO] - Loaded parameters from /home/aistudio/test_ernie_text_cls/best_model/model.pdparams [2022-09-10 15:03:07,749] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/vocab.txt [2022-09-10 15:03:07,751] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/spm_cased_simp_sampled.model [2022-09-10 15:03:07,753] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-tiny/dict.wordseg.pickle
print(results[0]) print(len(results))
定义 2000
3.保存结果
f=open('result.csv','w') f.write('label\n') for idx in range(len(results)): f.write(results[idx]+'\n') f.close()
六、提交
下载提交,即可出分数
- 另:可划分 train 和 eval ,选取最佳模型