重写transformers.Trainer的compute_metrics方法计算评价指标时,形参如何包含自定义的数据

简介:   这个问题苦恼我几个月,之前一直用替代方案。这次实在没替代方案了,transformers源码和文档看了一整天,终于在晚上12点找到了。。。

这个问题苦恼我几个月,之前一直用替代方案。这次实在没替代方案了,transformers源码和文档看了一整天,终于在晚上12点找到了。。。


解决方案:在TrainingArguments中指定label_names、remove_unused_columns、include_inputs_for_metrics三个参数


问题描述


 使用transformers.Trainer就图个快和优雅,它包装了一整套的训练逻辑,让我们不用从数据加载、模型训练、评估、预测、保存模型、计算评价指标等等一整套写完。


 但是显然,模型和任务一复杂的时候,loss的计算、评价指标的实现,我们还是需要重写的。于是问题就来了:


 在计算评价指标时,即重写compute_metrics方法时,形参pred是EvalPrediction类,然而它只提供三个变量:(predictions=all_preds, label_ids=all_labels, inputs=all_inputs),但是我们还需要其他的数据来算评价指标怎么办??


解决方案


  在TrainingArguments中添加以下三个参数


args = TrainingArguments(
  ...
    label_names=['labels','自定义数据名'],
    remove_unused_columns= False, # 在compute_loss 时需要额外输入
    include_inputs_for_metrics= True # compute_metrics 时需要原始输出来计算评价指标
)


 然后你会发现,compute_metrics的形参的label_ids存的就不知原始标签了,现在存的是元组,就是你指定的label_names里面的数据。


详细使用方法


  1、在构建输入的时候,除了PLM模型本身就需要的数据,还要有我们想使用的自定义数据,格式如下:


feature_dict = {
// PLM规定的输入
'input_ids': torch.tensor([f.input_ids for f in features], dtype=torch.long),
'attention_mask': torch.tensor([f.attention_mask for f in features], dtype=torch.long),
'token_type_ids': torch.tensor([f.token_type_ids for f in features], dtype=torch.long),
'labels': torch.tensor([f.labels for f in features], dtype=torch.long),
// 自定义的输入
'自定义名': torch.tensor([f.自定义名 for f in features], dtype=torch.long)
}


 2、在TrainingArguments中设置remove_unused_columns= False,意思是在重写compute_loos方法时,不会删除我们自定义的列。


 这样,在compute_loos方法中,我们就可以使用自定义的列的数据了。但是要注意在把输入喂给model的时候,要把自定义列摘出来,不然会报错:


def compute_loss(self, model, inputs, return_outputs=False):
    # 运行模型
    new_inputs = {k:v for k, v in inputs.items() if k not in ['自定义名']}
    outputs = model(**new_inputs)
    ...


  3、在TrainingArguments中设置label_names=['labels','自定义数据名'],意思是在重写compute_metrics方法时,形参的label_ids属性会存入我们设置的那些列。使用方法:


# 重写评价指标计算
def compute_metrics(pred):
    labels, 自定义数据名 = pred.label_ids
    ...


Prompt最近这么火,一个方向的朋友一定会出现和我一样的问题,看到这篇帖子麻烦评论个1,哈哈哈哈

相关文章
|
3月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
6月前
|
机器学习/深度学习 计算机视觉 Python
`GridSearchCV` 是一种穷举搜索方法,它会对指定的参数网格中的每一个参数组合进行交叉验证,并返回最优的参数组合。
`GridSearchCV` 是一种穷举搜索方法,它会对指定的参数网格中的每一个参数组合进行交叉验证,并返回最优的参数组合。
YOLOv8打印模型结构配置信息并查看网络模型详细参数:参数量、计算量(GFLOPS)
YOLOv8打印模型结构配置信息并查看网络模型详细参数:参数量、计算量(GFLOPS)
yolov8在进行目标追踪时,model.track()中persist参数的含义
yolov8在进行目标追踪时,model.track()中persist参数的含义
|
8月前
|
数据库 Python
定义模型
定义模型。
77 1
|
8月前
|
JavaScript
函数形状有几种定义方式;操作符infer的作用
函数形状有几种定义方式;操作符infer的作用
55 3
|
机器学习/深度学习 并行计算 异构计算
在torch中,x变量数据经过处理后,变成y变量数据,再传入神经网络,数据是在最开始x上传给gpu还是将y传给gpu?
数据应该在经过处理后变成y变量数据后再传入神经网络,并将其上传到GPU。这样可以确保在传递数据时只传输必要的信息,从而减少内存使用和计算时间,并且在处理后的数据上进行操作可以更好地利用GPU的并行计算能力。
140 0
|
机器学习/深度学习 人工智能 PyTorch
|
PyTorch 数据处理 算法框架/工具
pytorch中自定义数据集加载对象重写Dataset
pytorch中自定义数据集加载对象重写Dataset
329 0
pytorch中自定义数据集加载对象重写Dataset
|
缓存 Python