六、源码分析
1.train.py
import functools import json import os import shutil from dataclasses import dataclass, field from pathlib import Path from typing import Optional import numpy as np import paddle from sklearn.metrics import ( accuracy_score, classification_report, precision_recall_fscore_support, ) from utils import log_metrics_debug, preprocess_function, read_local_dataset from paddlenlp.data import DataCollatorWithPadding from paddlenlp.datasets import load_dataset from paddlenlp.trainer import ( CompressionArguments, EarlyStoppingCallback, PdArgumentParser, Trainer, ) from paddlenlp.transformers import ( AutoModelForSequenceClassification, AutoTokenizer, export_model, ) from paddlenlp.utils.log import logger # 支持的模型列表 SUPPORTED_MODELS = [ "ernie-1.0-large-zh-cw", "ernie-1.0-base-zh-cw", "ernie-3.0-xbase-zh", "ernie-3.0-base-zh", "ernie-3.0-medium-zh", "ernie-3.0-micro-zh", "ernie-3.0-mini-zh", "ernie-3.0-nano-zh", "ernie-3.0-tiny-base-v2-zh", "ernie-3.0-tiny-medium-v2-zh", "ernie-3.0-tiny-micro-v2-zh", "ernie-3.0-tiny-mini-v2-zh", "ernie-3.0-tiny-nano-v2-zh ", "ernie-3.0-tiny-pico-v2-zh", "ernie-2.0-large-en", "ernie-2.0-base-en", "ernie-3.0-tiny-mini-v2-en", "ernie-m-base", "ernie-m-large", ] # 默认参数 # yapf: disable @dataclass class DataArguments: max_length: int = field(default=128, metadata={"help": "Maximum number of tokens for the model."}) early_stopping: bool = field(default=False, metadata={"help": "Whether apply early stopping strategy."}) early_stopping_patience: int = field(default=4, metadata={"help": "Stop training when the specified metric worsens for early_stopping_patience evaluation calls"}) debug: bool = field(default=False, metadata={"help": "Whether choose debug mode."}) train_path: str = field(default='./data/train.txt', metadata={"help": "Train dataset file path."}) dev_path: str = field(default='./data/dev.txt', metadata={"help": "Dev dataset file path."}) test_path: str = field(default='./data/dev.txt', metadata={"help": "Test dataset file path."}) label_path: str = field(default='./data/label.txt', metadata={"help": "Label file path."}) bad_case_path: str = field(default='./data/bad_case.txt', metadata={"help": "Bad case file path."}) @dataclass class ModelArguments: model_name_or_path: str = field(default="ernie-3.0-tiny-medium-v2-zh", metadata={"help": "Build-in pretrained model name or the path to local model."}) export_model_dir: Optional[str] = field(default=None, metadata={"help": "Path to directory to store the exported inference model."}) # yapf: enable def main(): """ Training a binary or multi classification model """ parser = PdArgumentParser((ModelArguments, DataArguments, CompressionArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() if training_args.do_compress: training_args.strategy = "dynabert" if training_args.do_train or training_args.do_compress: training_args.print_config(model_args, "Model") training_args.print_config(data_args, "Data") paddle.set_device(training_args.device) # Define id2label id2label = {} label2id = {} with open(data_args.label_path, "r", encoding="utf-8") as f: for i, line in enumerate(f): l = line.strip() id2label[i] = l label2id[l] = i # Define model & tokenizer if os.path.isdir(model_args.model_name_or_path): model = AutoModelForSequenceClassification.from_pretrained( model_args.model_name_or_path, label2id=label2id, id2label=id2label ) elif model_args.model_name_or_path in SUPPORTED_MODELS: model = AutoModelForSequenceClassification.from_pretrained( model_args.model_name_or_path, num_classes=len(label2id), label2id=label2id, id2label=id2label ) else: raise ValueError( f"{model_args.model_name_or_path} is not a supported model type. Either use a local model path or select a model from {SUPPORTED_MODELS}" ) tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) # load and preprocess dataset train_ds = load_dataset(read_local_dataset, path=data_args.train_path, label2id=label2id, lazy=False) dev_ds = load_dataset(read_local_dataset, path=data_args.dev_path, label2id=label2id, lazy=False) trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=data_args.max_length) train_ds = train_ds.map(trans_func) dev_ds = dev_ds.map(trans_func) if data_args.debug: test_ds = load_dataset(read_local_dataset, path=data_args.test_path, label2id=label2id, lazy=False) test_ds = test_ds.map(trans_func) # Define the metric function. def compute_metrics(eval_preds): pred_ids = np.argmax(eval_preds.predictions, axis=-1) metrics = {} metrics["accuracy"] = accuracy_score(y_true=eval_preds.label_ids, y_pred=pred_ids) for average in ["micro", "macro"]: precision, recall, f1, _ = precision_recall_fscore_support( y_true=eval_preds.label_ids, y_pred=pred_ids, average=average ) metrics[f"{average}_precision"] = precision metrics[f"{average}_recall"] = recall metrics[f"{average}_f1"] = f1 return metrics def compute_metrics_debug(eval_preds): pred_ids = np.argmax(eval_preds.predictions, axis=-1) metrics = classification_report(eval_preds.label_ids, pred_ids, output_dict=True) return metrics # Define the early-stopping callback. if data_args.early_stopping: callbacks = [EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience)] else: callbacks = None # 定义 Trainer trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, criterion=paddle.nn.loss.CrossEntropyLoss(), train_dataset=train_ds, eval_dataset=dev_ds, callbacks=callbacks, data_collator=DataCollatorWithPadding(tokenizer), compute_metrics=compute_metrics_debug if data_args.debug else compute_metrics, ) # 训练 if training_args.do_train: train_result = trainer.train() metrics = train_result.metrics trainer.save_model() trainer.log_metrics("train", metrics) for checkpoint_path in Path(training_args.output_dir).glob("checkpoint-*"): shutil.rmtree(checkpoint_path) # 测试、预测 if training_args.do_eval: if data_args.debug: output = trainer.predict(test_ds) log_metrics_debug(output, id2label, test_ds, data_args.bad_case_path) else: eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) # 模型导出 if training_args.do_export: if model.init_config["init_class"] in ["ErnieMForSequenceClassification"]: input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")] else: input_spec = [ paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"), ] if model_args.export_model_dir is None: model_args.export_model_dir = os.path.join(training_args.output_dir, "export") export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir) tokenizer.save_pretrained(model_args.export_model_dir) id2label_file = os.path.join(model_args.export_model_dir, "id2label.json") with open(id2label_file, "w", encoding="utf-8") as f: json.dump(id2label, f, ensure_ascii=False) logger.info(f"id2label file saved in {id2label_file}") # 模型压缩 if training_args.do_compress: trainer.compress() for width_mult in training_args.width_mult_list: pruned_infer_model_dir = os.path.join(training_args.output_dir, "width_mult_" + str(round(width_mult, 2))) tokenizer.save_pretrained(pruned_infer_model_dir) id2label_file = os.path.join(pruned_infer_model_dir, "id2label.json") with open(id2label_file, "w", encoding="utf-8") as f: json.dump(id2label, f, ensure_ascii=False) logger.info(f"id2label file saved in {id2label_file}") for path in Path(training_args.output_dir).glob("runs"): shutil.rmtree(path) if __name__ == "__main__": main()
2.utils.py
import numpy as np from paddlenlp.utils.log import logger # 预处理 def preprocess_function(examples, tokenizer, max_length, is_test=False): """ Builds model inputs from a sequence for sequence classification tasks by concatenating and adding special tokens. """ result = tokenizer(examples["text"], max_length=max_length, truncation=True) if not is_test: result["labels"] = np.array([examples["label"]], dtype="int64") return result # 读取数据集 def read_local_dataset(path, label2id=None, is_test=False): """ Read dataset. """ with open(path, "r", encoding="utf-8") as f: for line in f: if is_test: sentence = line.strip() yield {"text": sentence} else: items = line.strip().split("\t") yield {"text": items[0], "label": label2id[items[1]]} # 打印日志 def log_metrics_debug(output, id2label, dev_ds, bad_case_path): """ Log metrics in debug mode. """ predictions, label_ids, metrics = output pred_ids = np.argmax(predictions, axis=-1) logger.info("-----Evaluate model-------") logger.info("Dev dataset size: {}".format(len(dev_ds))) logger.info("Accuracy in dev dataset: {:.2f}%".format(metrics["test_accuracy"] * 100)) logger.info( "Macro average | precision: {:.2f} | recall: {:.2f} | F1 score {:.2f}".format( metrics["test_macro avg"]["precision"] * 100, metrics["test_macro avg"]["recall"] * 100, metrics["test_macro avg"]["f1-score"] * 100, ) ) for i in id2label: l = id2label[i] logger.info("Class name: {}".format(l)) i = "test_" + str(i) if i in metrics: logger.info( "Evaluation examples in dev dataset: {}({:.1f}%) | precision: {:.2f} | recall: {:.2f} | F1 score {:.2f}".format( metrics[i]["support"], 100 * metrics[i]["support"] / len(dev_ds), metrics[i]["precision"] * 100, metrics[i]["recall"] * 100, metrics[i]["f1-score"] * 100, ) ) else: logger.info("Evaluation examples in dev dataset: 0 (0%)") logger.info("----------------------------") with open(bad_case_path, "w", encoding="utf-8") as f: f.write("Text\tLabel\tPrediction\n") for i, (p, l) in enumerate(zip(pred_ids, label_ids)): p, l = int(p), int(l) if p != l: f.write(dev_ds.data[i]["text"] + "\t" + id2label[l] + "\t" + id2label[p] + "\n") logger.info("Bad case in dev dataset saved in {}".format(bad_case_path))
七、模型预测
使用taskflow进行模型预测
- 加载模型
- 加载数据
- 进行预测
1.加载模型进行单个预测
from paddlenlp import Taskflow # 模型预测 cls = Taskflow("text_classification", task_path='checkpoint/export', is_static_model=True) cls(["回放CCTV2的消费主张"])
[2023-04-11 17:42:26,315] [ INFO] - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'checkpoint/export'. W0411 17:42:26.472223 349 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2 W0411 17:42:26.475904 349 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2. [2023-04-11 17:42:29,395] [ INFO] - Load id2label from checkpoint/export/id2label.json. [{'predictions': [{'label': 'TVProgram-Play', 'score': 0.9521104350237317}], 'text': '回放CCTV2的消费主张'}]
2.读取待预测数据
读取待预测数据到列表
with open('data/test.txt', 'r') as file: mytests = file.readlines() print(mytests[:3])
['回放CCTV2的消费主张\n', '给我打开玩具房的灯\n', '循环播放赵本山的小品相亲来听\n']
3.整体预测
result = cls(mytests)
print(result[:3])
[{'predictions': [{'label': 'TVProgram-Play', 'score': 0.9521104350237317}], 'text': '回放CCTV2的消费主张\n'}, {'predictions': [{'label': 'HomeAppliance-Control', 'score': 0.9970951493859599}], 'text': '给我打开玩具房的灯\n'}, {'predictions': [{'label': 'Audio-Play', 'score': 0.9710607817649783}], 'text': '循环播放赵本山的小品相亲来听\n'}]
4.按格式保存
f=open('/home/aistudio/result.txt', 'w') f.write("ID,Target\n") for i in range(len(result)): f.write(f"{i+1},{result[i]['predictions'][0]['label']}\n") f.close()
!head -n10 /home/aistudio/result.txt
ID,Target 1,TVProgram-Play 2,HomeAppliance-Control 3,Audio-Play 4,Alarm-Update 5,HomeAppliance-Control 6,FilmTele-Play 7,FilmTele-Play 8,Music-Play 9,Calendar-Query
八、提交结果
- 项目地址: 基于PaddleNLP的端到端智能家居对话意图识别 - 飞桨AI Studio
- github地址: livingbody/Conversational_intention_recognition: 基于PaddleNLP的对话意图识别