transformers+huggingface训练模型

本文涉及的产品
实时计算 Flink 版,5000CU*H 3个月
检索分析服务 Elasticsearch 版,2核4GB开发者规格 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
简介: 本教程介绍了如何使用 Hugging Face 的 `transformers` 库训练一个 BERT 模型进行情感分析。主要内容包括:导入必要库、下载 Yelp 评论数据集、数据预处理、模型加载与配置、定义训练参数、评估指标、实例化训练器并开始训练,最后保存模型和训练状态。整个过程详细展示了如何利用预训练模型进行微调,以适应特定任务。

[TOC]

transformers+huggingface训练模型

导入必要的库

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import evaluate

CopyInsert

  • 导入 datasets 用于加载数据集。
  • 导入 transformers 中的组件,以便使用预训练的 BERT 模型和 tokenizer。
  • 导入 numpy 用于数值计算。
  • 导入 evaluate 用于计算模型预测的指标(这里是准确率)。

数据集下载

dataset = load_dataset("yelp_review_full")

CopyInsert

  • 从 Hugging Face 的数据集中下载 Yelp 评论数据集,该数据集包含各种评论和意见。

数据预处理

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

CopyInsert

  • 使用预训练的 BERT model("bert-base-cased")初始化 tokenizer。
  • 定义 tokenize_function 函数,将评论文本编码成模型可接受的格式,设置填充和截断。

应用数据预处理

tokenized_datasets = dataset.map(tokenize_function, batched=True)

CopyInsert

  • 对下载的数据集应用 tokenize_function,批量处理文本数据。

数据抽样

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

CopyInsert

  • 从训练和测试集中各随机抽取 1000 条样本,以加快训练速度和验证模型性能。

模型加载与训练配置

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

CopyInsert

  • 加载预训练的 BERT 模型,并指定输出标签数(5个分类)。
model_dir = "models/bert-base-cased-finetune-yelp"

training_args = TrainingArguments(
    output_dir=model_dir,
    per_device_train_batch_size=16,
    num_train_epochs=5,
    logging_steps=100
)

CopyInsert

  • 定义模型保存路径和训练参数,如每个设备的训练批大小、训练轮数和日志记录的频率。

指标评估

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

CopyInsert

  • 加载准确率评估指标。
  • 定义 compute_metrics 函数,通过计算预测标签和真实标签的比较来评估模型性能。

实例化 Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics
)

CopyInsert

  • 创建训练器 Trainer 的实例,用于处理模型的训练过程和评估。

开始训练

trainer.train()

CopyInsert

  • 运行训练过程。

监控 GPU 使用

# 使用命令行工具: watch -n 1 nvidia-smi

CopyInsert

  • 提供了一个命令行工具提示,以监控 GPU 的使用情况。

保存模型和训练状态

trainer.save_model(model_dir)
trainer.save_state()

CopyInsert

  • 保存训练完成后的模型和状态,以便后续使用。
目录
相关文章
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch CIFAR10图像分类 Swin Transformer篇(一)
Pytorch CIFAR10图像分类 Swin Transformer篇(一)
|
6月前
|
机器学习/深度学习 数据可视化 算法
Pytorch CIFAR10图像分类 Swin Transformer篇(二)
Pytorch CIFAR10图像分类 Swin Transformer篇(二)
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
关于Tensorflow!目标检测预训练模型的迁移学习
这篇文章主要介绍了使用Tensorflow进行目标检测的迁移学习过程。关于使用Tensorflow进行目标检测模型训练的实战教程,涵盖了从数据准备到模型应用的全过程,特别适合对此领域感兴趣的开发者参考。
72 3
关于Tensorflow!目标检测预训练模型的迁移学习
|
6月前
|
机器学习/深度学习 数据采集 TensorFlow
TensorFlow与迁移学习:利用预训练模型
【4月更文挑战第17天】本文介绍了如何在TensorFlow中运用迁移学习,特别是利用预训练模型提升深度学习任务的性能和效率。迁移学习通过将源任务学到的知识应用于目标任务,减少数据需求、加速收敛并提高泛化能力。TensorFlow Hub提供预训练模型接口,可加载模型进行特征提取或微调。通过示例代码展示了如何加载InceptionV3模型、创建特征提取模型以及进行微调。在实践中,注意源任务与目标任务的相关性、数据预处理和模型调整。迁移学习是提升模型性能的有效方法,TensorFlow的工具使其变得更加便捷。
|
6月前
|
数据采集 自然语言处理
在ModelScope中进行情感分析模型的微调
在ModelScope中进行情感分析模型的微调
140 4
|
存储 缓存 自然语言处理
几个常见的小技巧加快Pytorch训练速度
几个常见的小技巧加快Pytorch训练速度
597 0
几个常见的小技巧加快Pytorch训练速度
|
机器学习/深度学习 存储 PyTorch
Huggingface:导出transformers模型到onnx
上一篇的初体验之后,本篇我们继续探索,将transformers模型导出到onnx。这里主要参考huggingface的官方文档:https://huggingface.co/docs/transformers/v4.20.1/en/serialization#exporting-a-model-to-onnx。
1219 0
YOLOV5模型转onnx并推理
YOLOV5模型转onnx并推理
899 1
|
机器学习/深度学习 算法 PyTorch
pytorch模型转ONNX、并进行比较推理
pytorch模型转ONNX、并进行比较推理
711 0
|
机器学习/深度学习 人工智能 编解码
Transformers回顾 :从BERT到GPT4
人工智能已成为近年来最受关注的话题之一,由于神经网络的发展,曾经被认为纯粹是科幻小说中的服务现在正在成为现实。从对话代理到媒体内容生成,人工智能正在改变我们与技术互动的方式。特别是机器学习 (ML) 模型在自然语言处理 (NLP) 领域取得了重大进展。一个关键的突破是引入了“自注意力”和用于序列处理的Transformers架构,这使得之前主导该领域的几个关键问题得以解决。
5053 0