重写transformers.Trainer的compute_metrics方法计算评价指标时,形参如何包含自定义的数据

简介:   这个问题苦恼我几个月,之前一直用替代方案。这次实在没替代方案了,transformers源码和文档看了一整天,终于在晚上12点找到了。。。

这个问题苦恼我几个月,之前一直用替代方案。这次实在没替代方案了,transformers源码和文档看了一整天,终于在晚上12点找到了。。。


解决方案:在TrainingArguments中指定label_names、remove_unused_columns、include_inputs_for_metrics三个参数


问题描述


 使用transformers.Trainer就图个快和优雅,它包装了一整套的训练逻辑,让我们不用从数据加载、模型训练、评估、预测、保存模型、计算评价指标等等一整套写完。


 但是显然,模型和任务一复杂的时候,loss的计算、评价指标的实现,我们还是需要重写的。于是问题就来了:


 在计算评价指标时,即重写compute_metrics方法时,形参pred是EvalPrediction类,然而它只提供三个变量:(predictions=all_preds, label_ids=all_labels, inputs=all_inputs),但是我们还需要其他的数据来算评价指标怎么办??


解决方案


  在TrainingArguments中添加以下三个参数


args = TrainingArguments(
  ...
    label_names=['labels','自定义数据名'],
    remove_unused_columns= False, # 在compute_loss 时需要额外输入
    include_inputs_for_metrics= True # compute_metrics 时需要原始输出来计算评价指标
)


 然后你会发现,compute_metrics的形参的label_ids存的就不知原始标签了,现在存的是元组,就是你指定的label_names里面的数据。


详细使用方法


  1、在构建输入的时候,除了PLM模型本身就需要的数据,还要有我们想使用的自定义数据,格式如下:


feature_dict = {
// PLM规定的输入
'input_ids': torch.tensor([f.input_ids for f in features], dtype=torch.long),
'attention_mask': torch.tensor([f.attention_mask for f in features], dtype=torch.long),
'token_type_ids': torch.tensor([f.token_type_ids for f in features], dtype=torch.long),
'labels': torch.tensor([f.labels for f in features], dtype=torch.long),
// 自定义的输入
'自定义名': torch.tensor([f.自定义名 for f in features], dtype=torch.long)
}


 2、在TrainingArguments中设置remove_unused_columns= False,意思是在重写compute_loos方法时,不会删除我们自定义的列。


 这样,在compute_loos方法中,我们就可以使用自定义的列的数据了。但是要注意在把输入喂给model的时候,要把自定义列摘出来,不然会报错:


def compute_loss(self, model, inputs, return_outputs=False):
    # 运行模型
    new_inputs = {k:v for k, v in inputs.items() if k not in ['自定义名']}
    outputs = model(**new_inputs)
    ...


  3、在TrainingArguments中设置label_names=['labels','自定义数据名'],意思是在重写compute_metrics方法时,形参的label_ids属性会存入我们设置的那些列。使用方法:


# 重写评价指标计算
def compute_metrics(pred):
    labels, 自定义数据名 = pred.label_ids
    ...


Prompt最近这么火,一个方向的朋友一定会出现和我一样的问题,看到这篇帖子麻烦评论个1,哈哈哈哈

相关文章
|
2月前
|
机器学习/深度学习 IDE API
【Tensorflow+keras】Keras 用Class类封装的模型如何调试call子函数的模型内部变量
该文章介绍了一种调试Keras中自定义Layer类的call方法的方法,通过直接调用call方法并传递输入参数来进行调试。
22 4
|
3月前
|
算法
创建一个训练函数
【7月更文挑战第22天】创建一个训练函数。
26 4
|
2月前
|
XML JSON 计算机视觉
LangChain 构建问题之智能代理类型中的“预期模型类型”的定义如何解决
LangChain 构建问题之智能代理类型中的“预期模型类型”的定义如何解决
23 0
|
3月前
|
机器学习/深度学习 计算机视觉 Python
`GridSearchCV` 是一种穷举搜索方法,它会对指定的参数网格中的每一个参数组合进行交叉验证,并返回最优的参数组合。
`GridSearchCV` 是一种穷举搜索方法,它会对指定的参数网格中的每一个参数组合进行交叉验证,并返回最优的参数组合。
YOLOv8打印模型结构配置信息并查看网络模型详细参数:参数量、计算量(GFLOPS)
YOLOv8打印模型结构配置信息并查看网络模型详细参数:参数量、计算量(GFLOPS)
|
5月前
|
关系型数据库 MySQL 数据库
定义模型和模型配置
定义模型和模型配置。
36 1
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
base model初始化large model,造成的参数矩阵对不上权重不匹配问题+修改预训练权重形状和上采样
base model初始化large model,造成的参数矩阵对不上权重不匹配问题+修改预训练权重形状和上采样
186 0
|
机器学习/深度学习 编解码 人工智能
ATC 模型转换动态 shape 问题案例
ATC(Ascend Tensor Compiler)是异构计算架构 CANN 体系下的模型转换工具:它可以将开源框架的网络模型(如 TensorFlow 等)以及 Ascend IR 定义的单算子描述文件转换为昇腾 AI 处理器支持的离线模型;模型转换过程中,ATC 会进行算子调度优化、权重数据重排、内存使用优化等具体操作,对原始的深度学习模型进行进一步的调优,从而满足部署场景下的高性能需求,使其能够高效执行在昇腾 AI 处理器上。
201 0
|
机器学习/深度学习 人工智能 PyTorch