【BERT-多标签文本分类实战】之七——训练-评估-测试与运行主程序

简介: 【BERT-多标签文本分类实战】之七——训练-评估-测试与运行主程序

·请参考本系列目录:【BERT-多标签文本分类实战】之一——实战项目总览

·下载本实战项目资源:>=点击此处=<

[1] 损失函数与评价指标


  多标签文本分类任务,用的损失函数是BCEWithLogitsLoss,不是交叉熵损失函数cross_entropy!!

BCEWithLogitsLosscross_entropy有什么区别?

+

1)cross_entropy它就是算单标签的损失的,大家去看一下它的公式,它对一个文本只取概率最大的那个标签;

+

2)BCEWithLogitsLoss对模型输出取的是sigmoid,而cross_entropy对模型的输出取的是softmax。sigmoid和softmax虽然都是把一组数据放缩到[0,1]区间,但是softmax具有排斥性,放缩后的一组数据之和为1,所以这样一组标签概率只会有一个较大值;而sigmoid也是把一组数据放缩到[0,1]区间,但它更类似于等比例缩放,原来大的数现在还大,可以有多个较大的概率存在,所以sigmoid更适合在多标签文本分类任务中。所以要使用BCEWithLogitsLoss。

  本次实战项目中使用的评价指标有:准确率accuracy、精确率precision、汉明损失hamming_loss。是基于sklearn库实现的。

# 计算多标签准确率、精确率、hm
def APH(y_true, y_pred):
    return metrics.accuracy_score(y_true, y_pred), \
           metrics.precision_score(y_true, y_pred, average='samples'), \
           metrics.hamming_loss(y_true, y_pred)

还有其他评价指标,召回率、F1等等,评价指标还分可为micro和macro,种类较多,可以参考地址:https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics

[2] 采样


  采样是指:把模型输出出来的概率,转化成独热数组,通常使用阈值为0.5的阈值函数,即概率大于0.5的标签采样为1,否则为0。本项目设置阈值为0.4、且只取2个标签。

# 预测多标签的输出,把概率值转化为独热数组
def Predict(outputs, alpha=0.4):
    predic = torch.sigmoid(outputs)
    zero = torch.zeros_like(predic)
    topk = torch.topk(predic, k=2, dim=1, largest=True)[1]
    for i, x in enumerate(topk):
        for y in x:
            if predic[i][y] > alpha:
                zero[i][y] = 1
    return zero.cpu()

[3] 训练


  训练代码如下:

def train(config, model, train_iter, dev_iter, test_iter, is_write):
    start_time = time.time()
    model.train()
    # 普通算法
    # optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    # bert算法
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
    # BertAdam implements weight decay fix,
    # BertAdam doesn't compensate for bias as in the regular Adam optimizer.
    optimizer = AdamW(optimizer_grouped_parameters,lr=config.learning_rate,eps=1e-8)
    # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps = 0,
                                            num_training_steps = len(train_iter) * config.num_epochs)
    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升
    if is_write:
        writer = SummaryWriter(
            log_dir="{0}/{1}__{2}__{3}__{4}".format(config.log_path, config.batch_size, config.pad_size,
                                                         config.learning_rate, time.strftime('%m-%d_%H.%M', time.localtime())))
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        for i, (trains, labels) in enumerate(train_iter):
            outputs = model(trains)
            model.zero_grad()
            loss = Loss(outputs, labels)
            loss.backward()
            optimizer.step()
            if total_batch % 100 == 0:
                # 每多少轮输出在训练集和验证集上的效果
                true = labels
                predic = Predict(outputs)
                train_oe = OneError(outputs, true)
                train_acc, train_pre, train_hl = APH(true.data.cpu().numpy(), predic.data.cpu().numpy())
                dev_acc, dev_pre, dev_hl, dev_oe, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6}, Train=== Loss: {1:>6.2}, Acc: {2:>6.2%}, Pre: {3:>6.2%}, HL: {4:>5.2} OE: {' \
                      '5:>6.2%}, Val=== Loss: {6:>5.2}, Acc: {7:>6.2%}, Pre: {8:>6.2%}, HL: {9:>5.2}, ' \
                      'OE: {10:>6.2%}, Time: {11} {12} '
                print(msg.format(total_batch, loss.item(), train_acc, train_pre, train_hl, train_oe,
                                 dev_loss, dev_acc, dev_pre, dev_hl, dev_oe, time_dif, improve))
                if is_write:
                    writer.add_scalar('loss/train', loss.item(), total_batch)
                    writer.add_scalar("acc/train", train_acc, total_batch)
                    writer.add_scalar("pre/train", train_pre, total_batch)
                    writer.add_scalar("oe/train", train_oe, total_batch)
                    writer.add_scalar("hamming loss/train", train_hl, total_batch)
                    writer.add_scalar("loss/dev", dev_loss, total_batch)
                    writer.add_scalar("acc/dev", dev_acc, total_batch)
                    writer.add_scalar("pre/dev", dev_pre, total_batch)
                    writer.add_scalar("oe/dev", dev_oe, total_batch)
                    writer.add_scalar("hamming loss/dev", dev_hl, total_batch)
                model.train()
            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        scheduler.step()  # 学习率衰减
        if flag:
            break
    if is_write:
        writer.close()
    return test(config, model, test_iter)

  需要解释的几点:

  1、bert模型采用AdamW做优化,不同层要设置不同的权重衰减值;

  2、writer这个变量主要是做数据可视化的,参考博客:【深度学习】pytorch使用tensorboard可视化实验数据

[4] 评估与测试


def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path))
    model.eval()
    start_time = time.time()
    test_acc, test_pre, test_rec, test_hl, test_loss, test_report = evaluate(config, model, test_iter,
                                                                             test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}, Test Pre: {2:>6.2%}, Test HL: {3:>5.2}, Test OE: {4:>6.2%}'
    print(msg.format(test_loss, test_acc, test_pre, test_rec, test_hl))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)
    return test_loss, test_acc, test_pre, test_rec, test_hl
def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = []
    labels_all = []
    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = model(texts)
            oe = OneError(outputs.data.cpu(), labels.data.cpu())
            loss = Loss(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = Predict(outputs.data)
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic.numpy())
    labels_all = labels_all.reshape(-1, config.num_classes)
    predict_all = predict_all.reshape(-1, config.num_classes)
    acc, pre, hl = APH(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=3)
        return acc, pre, hl, oe, loss_total / len(data_iter), report
    return acc, pre, hl, oe, loss_total / len(data_iter)

[5] 运行主程序run.py


if __name__ == '__main__':
    """配置参数
        dataSet     : 数据集名称. required.
        model_name  : 模型名称. required. 可选值['bert']
        is_write    : 是否开启tensorboard的记录绘图模式. 可选值[False, True]
    """
    M = ['bert','bert_RNN','bert_RCNN','bert_DPCNN']
    I = [False, True]
    dataSet = 'Reuters-21578'
    is_write = I[0]
    for model_name in M:
        x = import_module('models.' + model_name)
        config = x.Config(dataSet)
        # 设置numpy的随机种子,以使得结果是确定的
        np.random.seed(1)
        # 为CPU设置种子用于生成随机数,以使得结果是确定的
        torch.manual_seed(1)
        # 为当前GPU设置随机种子,以使得结果是确定的
        torch.cuda.manual_seed_all(1)
        # 保证每次结果一样
        torch.backends.cudnn.deterministic = True
        start_time = time.time()
        print("Loading data...")
        train_data, dev_data, test_data = build_dataset(config)
        train_iter = build_iterator(train_data, config)
        dev_iter = build_iterator(dev_data, config)
        test_iter = build_iterator(test_data, config)
        time_dif = get_time_dif(start_time)
        print("Time usage:", time_dif)
        # train
        model = x.Model(config).to(config.device)
        print(model.parameters)
        print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')
        train(config, model, train_iter, dev_iter, test_iter, is_write)

  代码还是比较好懂的,但是还是有一个整体能运行起来的项目体验更佳。

相关文章
|
1月前
|
存储 监控 网络协议
服务器压力测试是一种评估系统在极端条件下的表现和稳定性的技术
【10月更文挑战第11天】服务器压力测试是一种评估系统在极端条件下的表现和稳定性的技术
109 32
|
1月前
|
机器学习/深度学习 编解码 监控
目标检测实战(六): 使用YOLOv8完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
这篇文章详细介绍了如何使用YOLOv8进行目标检测任务,包括环境搭建、数据准备、模型训练、验证测试以及模型转换等完整流程。
1225 1
目标检测实战(六): 使用YOLOv8完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
|
1月前
|
机器学习/深度学习 监控 计算机视觉
目标检测实战(八): 使用YOLOv7完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
本文介绍了如何使用YOLOv7进行目标检测,包括环境搭建、数据集准备、模型训练、验证、测试以及常见错误的解决方法。YOLOv7以其高效性能和准确率在目标检测领域受到关注,适用于自动驾驶、安防监控等场景。文中提供了源码和论文链接,以及详细的步骤说明,适合深度学习实践者参考。
312 0
目标检测实战(八): 使用YOLOv7完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
|
1月前
|
机器学习/深度学习 并行计算 数据可视化
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
这篇文章介绍了如何使用PaddleClas框架完成多标签分类任务,包括数据准备、环境搭建、模型训练、预测、评估等完整流程。
86 0
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
|
1月前
|
机器学习/深度学习 JSON 算法
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
本文介绍了DeepLab V3在语义分割中的应用,包括数据集准备、模型训练、测试和评估,提供了代码和资源链接。
186 0
语义分割笔记(二):DeepLab V3对图像进行分割(自定义数据集从零到一进行训练、验证和测试)
|
1月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
47 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
1月前
|
机器学习/深度学习 XML 并行计算
目标检测实战(七): 使用YOLOX完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
这篇文章介绍了如何使用YOLOX完成图像目标检测任务的完整流程,包括数据准备、模型训练、验证和测试。
157 0
目标检测实战(七): 使用YOLOX完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)
|
1月前
|
弹性计算 网络协议 Linux
云服务器评估迁移时间与测试传输速度
云服务器评估迁移时间与测试传输速度
|
8天前
|
JSON Java 测试技术
SpringCloud2023实战之接口服务测试工具SpringBootTest
SpringBootTest同时集成了JUnit Jupiter、AssertJ、Hamcrest测试辅助库,使得更容易编写但愿测试代码。
37 3
|
1月前
|
JSON 算法 数据可视化
测试专项笔记(一): 通过算法能力接口返回的检测结果完成相关指标的计算(目标检测)
这篇文章是关于如何通过算法接口返回的目标检测结果来计算性能指标的笔记。它涵盖了任务描述、指标分析(包括TP、FP、FN、TN、精准率和召回率),接口处理,数据集处理,以及如何使用实用工具进行文件操作和数据可视化。文章还提供了一些Python代码示例,用于处理图像文件、转换数据格式以及计算目标检测的性能指标。
59 0
测试专项笔记(一): 通过算法能力接口返回的检测结果完成相关指标的计算(目标检测)