训练你自己的自然语言处理深度学习模型,Bert预训练模型下游任务训练:情感二分类

简介: 训练你自己的自然语言处理深度学习模型,Bert预训练模型下游任务训练:情感二分类

基础介绍:

Bert模型是一个通用backbone,可以简单理解为一个句子的特征提取工具

更直观来看:我们的自然语言是用各种文字表示的,经过编码器,以及特征提取就可以变为计算机能理解的语言了

下游任务:

提取特征后,我们便可以自定义其他自然语言处理任务了,以下是一个简单的示例(效果可能不好,但算是一个基本流程)

数据格式:

模型训练:

我们来训练处理句子情感分类的模型,代码如下

import torch
from tqdm import tqdm  # 进度条库
from transformers import AdamW  # 优化器
import pandas as pd  # 文件读取
from transformers import BertTokenizer, BertModel  # 导入分词器和模型
# 导入数据
data = pd.read_csv("data/data.csv")
# 定义编码器
token = BertTokenizer.from_pretrained("bert-base-chinese")
# 加载预训练模型
pretrained = BertModel.from_pretrained("bert-base-chinese")
# 创建编码集
encode = []
# 编码句子
for i in tqdm(data["sentence"]):
    out = token.batch_encode_plus(
        batch_text_or_text_pairs=[i],
        truncation=True,
        padding='max_length',
        max_length=17,
        return_tensors='pt',
        return_length=True
    )
    encode.append(out)
# 定义模型
class MODEL(torch.nn.Module):
    def __init__(self):
        super().__init__()  # 确保调用父类构造函数
        self.linear1 = torch.nn.Linear(768, 2)
    def forward(self, input_ids, attention_mask, token_type_ids):
        result = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        result = self.linear1(result.last_hidden_state[:, 0])
        result = result.softmax(dim=1)
        return result
# 创建模型对象
model = MODEL()
# 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-4)
# 定义损失函数
criterion = torch.nn.CrossEntropyLoss()
# 模型训练
for i in range(len(encode)):
    out = model(encode[i]["input_ids"], encode[i]["attention_mask"], encode[i]["token_type_ids"])
    loss = criterion(out, torch.LongTensor([data["label"][i]]))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
# 模型权重保存
torch.save(model.state_dict(), 'model1_weights.pth')

运行后得到了训练后的模型权重文件

模型使用:

可用以下代码进行判断句子情感

import torch
from transformers import BertTokenizer, BertModel
token = BertTokenizer.from_pretrained('bert-base-chinese')
pretrained = BertModel.from_pretrained('bert-base-chinese')
# 定义模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)
    def forward(self, input_ids, attention_mask, token_type_ids):
        out = pretrained(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)
        return out
model = Model()
# 加载训练好的模型权重
model.load_state_dict(torch.load('model1_weights.pth'))
sentence = ["衣服一点也不好,差评"]
# 编码
o = token.batch_encode_plus(
        batch_text_or_text_pairs=sentence,
        truncation=True,
        padding='max_length',
        max_length=17,
        return_tensors='pt'
    )
out = model(o['input_ids'], o['attention_mask'], o['token_type_ids'])
if out[0][0] > out[0][1]:
    print("好评")
else:
    print("差评")


相关文章
|
机器学习/深度学习 编解码 人工智能
人脸表情[七种表情]数据集(15500张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
本数据集包含15,500张已划分、已标注的人脸表情图像,覆盖惊讶、恐惧、厌恶、高兴、悲伤、愤怒和中性七类表情,适用于YOLO系列等深度学习模型的分类与检测任务。数据集结构清晰,分为训练集与测试集,支持多种标注格式转换,适用于人机交互、心理健康、驾驶监测等多个领域。
|
3月前
|
机器学习/深度学习 人工智能 文字识别
中药材图像识别数据集(100类,9200张)|适用于YOLO系列深度学习分类检测任务
本数据集包含9200张中药材图像,覆盖100种常见品类,已标注并划分为训练集与验证集,支持YOLO等深度学习模型。适用于中药分类、目标检测、AI辅助识别及教学应用,助力中医药智能化发展。
|
5月前
|
机器学习/深度学习 人工智能 监控
河道塑料瓶识别标准数据集 | 科研与项目必备(图片已划分、已标注)| 适用于YOLO系列深度学习分类检测任务【数据集分享】
随着城市化进程加快和塑料制品使用量增加,河道中的塑料垃圾问题日益严重。塑料瓶作为河道漂浮垃圾的主要类型,不仅破坏水体景观,还威胁水生生态系统的健康。传统的人工巡查方式效率低、成本高,难以满足实时监控与治理的需求。
|
5月前
|
机器学习/深度学习 传感器 人工智能
火灾火焰识别数据集(2200张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
在人工智能和计算机视觉的快速发展中,火灾检测与火焰识别逐渐成为智慧城市、公共安全和智能监控的重要研究方向。一个高质量的数据集往往是推动相关研究的核心基础。本文将详细介绍一个火灾火焰识别数据集,该数据集共包含 2200 张图片,并已按照 训练集(train)、验证集(val)、测试集(test) 划分,同时配有对应的标注文件,方便研究者快速上手模型训练与评估。
火灾火焰识别数据集(2200张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
|
5月前
|
机器学习/深度学习 人工智能 自动驾驶
7种交通场景数据集(千张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
在智能交通与自动驾驶技术快速发展的今天,如何高效、准确地感知道路环境已经成为研究与应用的核心问题。车辆、行人和交通信号灯作为城市交通系统的关键元素,对道路安全与交通效率具有直接影响。然而,真实道路场景往往伴随 复杂光照、遮挡、多目标混杂以及交通信号状态多样化 等挑战,使得视觉识别与检测任务难度显著增加。
|
5月前
|
机器学习/深度学习 人工智能 监控
坐姿标准好坏姿态数据集(图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
坐姿标准好坏姿态数据集的发布,填补了计算机视觉领域在“细分健康行为识别”上的空白。它不仅具有研究价值,更在实际应用层面具备广阔前景。从青少年的健康教育,到办公室的智能提醒,再到驾驶员的安全监控和康复训练,本数据集都能发挥巨大的作用。
坐姿标准好坏姿态数据集(图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
|
5月前
|
机器学习/深度学习 数据采集 算法
PCB电路板缺陷检测数据集(近千张图片已划分、已标注)| 适用于YOLO系列深度学习检测任务【数据集分享】
在现代电子制造中,印刷电路板(PCB)是几乎所有电子设备的核心组成部分。随着PCB设计复杂度不断增加,人工检测PCB缺陷不仅效率低,而且容易漏检或误判。因此,利用计算机视觉和深度学习技术对PCB缺陷进行自动检测成为行业发展的必然趋势。
PCB电路板缺陷检测数据集(近千张图片已划分、已标注)| 适用于YOLO系列深度学习检测任务【数据集分享】
|
5月前
|
机器学习/深度学习 编解码 人工智能
102类农业害虫数据集(20000张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
在现代农业发展中,病虫害监测与防治 始终是保障粮食安全和提高农作物产量的关键环节。传统的害虫识别主要依赖人工观察与统计,不仅效率低下,而且容易受到主观经验、环境条件等因素的影响,导致识别准确率不足。
|
机器学习/深度学习 人工智能 监控
单车、共享单车已标注数据集(图片已划分、已标注)|适用于深度学习检测任务【数据集分享】
数据是人工智能的“燃料”。一个高质量、标注精准的单车与共享单车数据集,不仅能够推动学术研究的进步,还能为智慧交通、智慧城市的建设提供有力支撑。 在计算机视觉领域,研究者们常常会遇到“数据鸿沟”问题:公开数据集与真实业务需求之间存在不匹配。本次分享的数据集正是为了弥补这一不足,使得研究人员与工程师能够快速切入单车检测领域,加速模型从实验室走向真实应用场景。
|
5月前
|
机器学习/深度学习 自动驾驶 算法
道路表面缺陷数据集(裂缝/井盖/坑洼)(6000张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
随着城市化与交通运输业的快速发展,道路基础设施的健康状况直接关系到出行安全与城市运行效率。长期高强度的使用、气候变化以及施工质量差异,都会导致道路表面出现裂缝、坑洼、井盖下沉及修补不良等缺陷。这些问题不仅影响驾驶舒适度,还可能引发交通事故,增加道路养护成本。
道路表面缺陷数据集(裂缝/井盖/坑洼)(6000张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】

热门文章

最新文章