使用数据集中文诗词数据集对预训练GPT-3预训练生成模型-中文-base 二次训练生成的模型进行评估,代码如下:
#!/usr/bin/python3
# -*- coding=utf-8 -*-
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.metainfo import Trainers
def cfg_modify_fn(cfg):
cfg.evaluation.metrics = {
'type': 'text-gen-metric',
# 'target_text': 'input_ids',
}
return cfg
dataset_dict = MsDataset.load('chinese-poetry-collection')
eval_dataset = dataset_dict['test'].remap_columns({'text1': 'src_txt'})
print(eval_dataset)
max_epochs = 10
tmp_dir = "./gpt3_poetry"
# 构造 trainer 并进行评估
kwargs = dict(
# 由于使用的模型训练后的目录,因此不需要传入cfg_modify_fn
model=f'{tmp_dir}/output',
eval_dataset=eval_dataset,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.evaluate()
报错如下:
Traceback (most recent call last):
File "test2.py", line 81, in <module>
result = trainer.evaluate()
File "/root/anaconda3/envs/modelscope/lib/python3.7/site-packages/modelscope/trainers/trainer.py", line 623, in evaluate
metric_classes)
File "/root/anaconda3/envs/modelscope/lib/python3.7/site-packages/modelscope/trainers/trainer.py", line 1091, in evaluation_loop
data_loader_iters=self._eval_iters_per_epoch)
File "/root/anaconda3/envs/modelscope/lib/python3.7/site-packages/modelscope/trainers/utils/inference.py", line 56, in single_gpu_test
evaluate_batch(trainer, data, metric_classes, vis_closure)
File "/root/anaconda3/envs/modelscope/lib/python3.7/site-packages/modelscope/trainers/utils/inference.py", line 182, in evaluate_batch
metric_cls.add(batch_result, data)
File "/root/anaconda3/envs/modelscope/lib/python3.7/site-packages/modelscope/metrics/text_generation_metric.py", line 36, in add
ground_truths = inputs[self.target_text]
KeyError: 'tgts'