手把手教你搭建Bert文本分类模型,快点看过来吧!

简介: 手把手教你搭建Bert文本分类模型,快点看过来吧!

1 赛题名称


基于文本挖掘的企业隐患排查质量分析模型


90.png


2 赛题背景


企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义。企业在填报隐患时,往往存在不认真填报的情况,“虚报、假报”隐患内容,增大了企业监管的难度。采用大数据手段分析隐患内容,找出不切实履行主体责任的企业,向监管部门进行推送,实现精准执法,能够提高监管手段的有效性,增强企业安全责任意识。


3 赛题任务


本赛题提供企业填报隐患数据,参赛选手需通过智能化手段识别其中是否存在“虚报、假报”的情况

看清赛题很关键,大家需要好好理解赛题目标之后,再去做题,可以避免很多弯路。

数据简介

本赛题数据集为脱敏后的企业填报自查隐患记录。


4 数据说明


训练集数据包含“【id、level_1(一级标准)、level_2(二级标准)、level_3(三级标准)、level_4(四级标准)、content(隐患内容)和label(标签)】”共7个字段。

其中“id”为主键,无业务意义;“一级标准、二级标准、三级标准、四级标准”为《深圳市安全隐患自查和巡查基本指引(2016年修订版)》规定的排查指引,一级标准对应不同隐患类型,二至四级标准是对一级标准的细化,企业自主上报隐患时,根据不同类型隐患的四级标准开展隐患自查工作;“隐患内容”为企业上报的具体隐患;“标签”标识的是该条隐患的合格性,“1”表示隐患填报不合格,“0”表示隐患填报合格。

预测结果文件results.csv


列名 说明
id 企业号
label 正负样本分类


  • 文件名:results.csv,utf-8编码
  • 参赛者以csv/json等文件格式,提交模型结果,平台进行在线评分,实时排名。


5 评测标准


本赛题采用F1 -score作为模型评判标准。


91.png


精确率P、召回率 R和 F1-score计算公式如下所示:


92.png


6 数据分析


  • 查看数据集

    93.png


训练集数据包含“【id、level_1(一级标准)、level_2(二级标准)、level_3(三级标准)、level_4(四级标准)、content(隐患内容)和label(标签)】”共7个字段。测试集没有label字段


  • 标签分布
    我们看下数据标签数量分布,看看有多少在划水哈哈_

sns.countplot(train.label)
plt.xlabel('label count')


94.png


在训练集12000数据中,其中隐患填报合格的有10712条,隐患填报不合格的有1288条,差不多是9:1的比例,说明我们分类任务标签分布式极其不均衡的。

  • 文本长度分布
    我们将level_content的文本拼接在一起

train['text']=train['content']+' '+train['level_1']+' '+train['level_2']+' '+train['level_3']+' '+train['level_4']
test['text']=test['content']+' '+test['level_1']+' '+test['level_2']+' '+test['level_3']+' '+test['level_4']
train['text_len']=train['text'].map(len)
test['text'].map(len).describe()


然后查看下文本最大长度分布

count    18000.000000
mean        64.762167
std         22.720117
min         27.000000
25%         50.000000
50%         60.000000
75%         76.000000
max        504.000000
Name: text, dtype: float64

train['text_len'].plot(kind='kde')


95.png


7 基于BERT的企业隐患排查质量分析模型


完整代码可以联系作者获取


7.1 导入工具包

import random
import numpy as np
import pandas as pd
from bert4keras.backend import keras, set_gelu
from bert4keras.tokenizers import Tokenizer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
from keras.layers import Lambda, Dense

Using TensorFlow backend.


7.2 设置参数

set_gelu('tanh')  # 切换gelu版本

num_classes = 2
maxlen = 128
batch_size = 32
config_path = '../model/albert_small_zh_google/albert_config_small_google.json'
checkpoint_path = '../model/albert_small_zh_google/albert_model.ckpt'
dict_path = '../model/albert_small_zh_google/vocab.txt'
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)


7.3 定义模型

# 加载预训练模型
bert = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    model='albert',
    return_keras_model=False,
)

output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
output = Dense(
    units=num_classes,
    activation='softmax',
    kernel_initializer=bert.initializer
)(output)
model = keras.models.Model(bert.model.input, output)
model.summary()

Model: "model_2"
    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    Input-Token (InputLayer)        (None, None)         0                                            
    __________________________________________________________________________________________________
    Input-Segment (InputLayer)      (None, None)         0                                            
    __________________________________________________________________________________________________
    Embedding-Token (Embedding)     (None, None, 128)    2704384     Input-Token[0][0]                
    __________________________________________________________________________________________________
    Embedding-Segment (Embedding)   (None, None, 128)    256         Input-Segment[0][0]              
    __________________________________________________________________________________________________
    Embedding-Token-Segment (Add)   (None, None, 128)    0           Embedding-Token[0][0]            
                                                                     Embedding-Segment[0][0]          
    __________________________________________________________________________________________________
    Embedding-Position (PositionEmb (None, None, 128)    65536       Embedding-Token-Segment[0][0]    
    __________________________________________________________________________________________________
    Embedding-Norm (LayerNormalizat (None, None, 128)    256         Embedding-Position[0][0]         
    __________________________________________________________________________________________________
    Embedding-Mapping (Dense)       (None, None, 384)    49536       Embedding-Norm[0][0]             
    __________________________________________________________________________________________________
    Transformer-MultiHeadSelfAttent (None, None, 384)    591360      Embedding-Mapping[0][0]          
                                                                     Embedding-Mapping[0][0]          
                                                                     Embedding-Mapping[0][0]          
                                                                     Transformer-FeedForward-Norm[0][0
                                                                     Transformer-FeedForward-Norm[0][0
                                                                     Transformer-FeedForward-Norm[0][0
                                                                     Transformer-FeedForward-Norm[1][0
                                                                     Transformer-FeedForward-Norm[1][0
                                                                     Transformer-FeedForward-Norm[1][0
                                                                     Transformer-FeedForward-Norm[2][0
                                                                     Transformer-FeedForward-Norm[2][0
                                                                     Transformer-FeedForward-Norm[2][0
                                                                     Transformer-FeedForward-Norm[3][0
                                                                     Transformer-FeedForward-Norm[3][0
                                                                     Transformer-FeedForward-Norm[3][0
                                                                     Transformer-FeedForward-Norm[4][0
                                                                     Transformer-FeedForward-Norm[4][0
                                                                     Transformer-FeedForward-Norm[4][0
    __________________________________________________________________________________________________
    Transformer-MultiHeadSelfAttent (None, None, 384)    0           Embedding-Mapping[0][0]          
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward-Norm[0][0
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward-Norm[1][0
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward-Norm[2][0
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward-Norm[3][0
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward-Norm[4][0
                                                                     Transformer-MultiHeadSelfAttentio
    __________________________________________________________________________________________________
    Transformer-MultiHeadSelfAttent (None, None, 384)    768         Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
    __________________________________________________________________________________________________
    Transformer-FeedForward (FeedFo (None, None, 384)    1181568     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-MultiHeadSelfAttentio
    __________________________________________________________________________________________________
    Transformer-FeedForward-Add (Ad (None, None, 384)    0           Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward[0][0]    
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward[1][0]    
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward[2][0]    
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward[3][0]    
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward[4][0]    
                                                                     Transformer-MultiHeadSelfAttentio
                                                                     Transformer-FeedForward[5][0]    
    __________________________________________________________________________________________________
    Transformer-FeedForward-Norm (L (None, None, 384)    768         Transformer-FeedForward-Add[0][0]
                                                                     Transformer-FeedForward-Add[1][0]
                                                                     Transformer-FeedForward-Add[2][0]
                                                                     Transformer-FeedForward-Add[3][0]
                                                                     Transformer-FeedForward-Add[4][0]
                                                                     Transformer-FeedForward-Add[5][0]
    __________________________________________________________________________________________________
    CLS-token (Lambda)              (None, 384)          0           Transformer-FeedForward-Norm[5][0
    __________________________________________________________________________________________________
    dense_7 (Dense)                 (None, 2)            770         CLS-token[0][0]                  
    ==================================================================================================
    Total params: 4,595,202
    Trainable params: 4,595,202
    Non-trainable params: 0
    __________________________________________________________________________________________________

# 派生为带分段线性学习率的优化器。
# 其中name参数可选,但最好填入,以区分不同的派生优化器。
# AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=Adam(1e-5),  # 用足够小的学习率
#     optimizer=AdamLR(learning_rate=1e-4, lr_schedule={
#         1000: 1,
#         2000: 0.1
#     }),
    metrics=['accuracy'],
)


7.4 生成数据

def load_data(valid_rate=0.3):
    train_file = "../data/train.csv"
    test_file = "../data/test.csv"
    df_train_data = pd.read_csv("../data/train.csv")
    df_test_data = pd.read_csv("../data/test.csv")
    train_data, valid_data, test_data = [], [], []
    for row_i, data in df_train_data.iterrows():
        id, level_1, level_2, level_3, level_4, content, label = data
        id, text, label = id, str(level_1) + '\t' + str(level_2) + '\t' + \
        str(level_3) + '\t' + str(level_4) + '\t' + str(content), label
        if random.random() > valid_rate:
            train_data.append( (id, text, int(label)) )
        else:
            valid_data.append( (id, text, int(label)) )
    for row_i, data in df_test_data.iterrows():
        id, level_1, level_2, level_3, level_4, content = data
        id, text, label = id, str(level_1) + '\t' + str(level_2) + '\t' + \
        str(level_3) + '\t' + str(level_4) + '\t' + str(content), 0
        test_data.append( (id, text, int(label)) )
    return train_data, valid_data, test_data

train_data, valid_data, test_data = load_data(valid_rate=0.3)

valid_data

[(5,
      '工业/危化品类(现场)—2016版\t(一)消防检查\t2、防火检查\t8、易燃易爆危险物品和场所防火防爆措施的落实情况以及其他重要物资的防火安全情况;\t防爆柜里面稀释剂,机油费混装',
      0),
     (3365,
      '三小场所(现场)—2016版\t(一)消防安全\t2、消防通道和疏散\t2、疏散通道、安全出口设置应急照明灯和疏散指示标志。\t4楼消防楼梯安全出口指示牌坏',
      0),
     ...]

len(train_data)

8403

class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_labels = [], [], []
        for is_end, (id, text, label) in self.sample(random):
            token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append([label])
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_labels = sequence_padding(batch_labels)
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids, batch_segment_ids, batch_labels = [], [], []

# 转换数据集
train_generator = data_generator(train_data, batch_size)
valid_generator = data_generator(valid_data, batch_size)

valid_data

[(5,
      '工业/危化品类(现场)—2016版\t(一)消防检查\t2、防火检查\t8、易燃易爆危险物品和场所防火防爆措施的落实情况以及其他重要物资的防火安全情况;\t防爆柜里面稀释剂,机油费混装',
      0),
     (8,
      '工业/危化品类(现场)—2016版\t(一)消防检查\t2、防火检查\t2、安全疏散通道、疏散指示标志、应急照明和安全出口情况;\t已整改',
      1),
     (3365,
      '三小场所(现场)—2016版\t(一)消防安全\t2、消防通道和疏散\t2、疏散通道、安全出口设置应急照明灯和疏散指示标志。\t4楼消防楼梯安全出口指示牌坏',
      0),
     ...]


7.5 训练和验证

evaluator = Evaluator()
model.fit(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=2,
        callbacks=[evaluator]
    )

model.load_weights('best_model.weights')
# print(u'final test acc: %05f\n' % (evaluate(test_generator)))
print(u'final test acc: %05f\n' % (evaluate(valid_generator)))

final test acc: 0.981651

print(u'final test acc: %05f\n' % (evaluate(train_generator)))


完整代码可以联系作者获取


相关文章
|
1月前
|
自然语言处理 PyTorch 算法框架/工具
掌握从零到一的进阶攻略:让你轻松成为BERT微调高手——详解模型微调全流程,含实战代码与最佳实践秘籍,助你应对各类NLP挑战!
【10月更文挑战第1天】随着深度学习技术的进步,预训练模型已成为自然语言处理(NLP)领域的常见实践。这些模型通过大规模数据集训练获得通用语言表示,但需进一步微调以适应特定任务。本文通过简化流程和示例代码,介绍了如何选择预训练模型(如BERT),并利用Python库(如Transformers和PyTorch)进行微调。文章详细说明了数据准备、模型初始化、损失函数定义及训练循环等关键步骤,并提供了评估模型性能的方法。希望本文能帮助读者更好地理解和实现模型微调。
70 2
掌握从零到一的进阶攻略:让你轻松成为BERT微调高手——详解模型微调全流程,含实战代码与最佳实践秘籍,助你应对各类NLP挑战!
|
1月前
|
机器学习/深度学习 自然语言处理 知识图谱
|
1月前
|
机器学习/深度学习 自然语言处理 算法
[大语言模型-工程实践] 手把手教你-基于BERT模型提取商品标题关键词及优化改进
[大语言模型-工程实践] 手把手教你-基于BERT模型提取商品标题关键词及优化改进
109 0
|
2月前
|
搜索推荐 算法
模型小,还高效!港大最新推荐系统EasyRec:零样本文本推荐能力超越OpenAI、Bert
【9月更文挑战第21天】香港大学研究者开发了一种名为EasyRec的新推荐系统,利用语言模型的强大文本理解和生成能力,解决了传统推荐算法在零样本学习场景中的局限。EasyRec通过文本-行为对齐框架,结合对比学习和协同语言模型调优,提升了推荐准确性。实验表明,EasyRec在多个真实世界数据集上的表现优于现有模型,但其性能依赖高质量文本数据且计算复杂度较高。论文详见:http://arxiv.org/abs/2408.08821
63 7
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
【AI大模型】BERT模型:揭秘LLM主要类别架构(上)
【AI大模型】BERT模型:揭秘LLM主要类别架构(上)
|
3月前
|
机器学习/深度学习 存储 自然语言处理
【NLP-新闻文本分类】3 Bert模型的对抗训练
详细介绍了使用BERT模型进行新闻文本分类的过程,包括数据集预处理、使用预处理数据训练BERT语料库、加载语料库和词典后用原始数据训练BERT模型,以及模型测试。
68 1
|
3月前
|
算法 异构计算
自研分布式训练框架EPL问题之帮助加速Bert Large模型的训练如何解决
自研分布式训练框架EPL问题之帮助加速Bert Large模型的训练如何解决
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
算法金 | 秒懂 AI - 深度学习五大模型:RNN、CNN、Transformer、BERT、GPT 简介
**RNN**,1986年提出,用于序列数据,如语言模型和语音识别,但原始模型有梯度消失问题。**LSTM**和**GRU**通过门控解决了此问题。 **CNN**,1989年引入,擅长图像处理,卷积层和池化层提取特征,经典应用包括图像分类和物体检测,如LeNet-5。 **Transformer**,2017年由Google推出,自注意力机制实现并行计算,优化了NLP效率,如机器翻译。 **BERT**,2018年Google的双向预训练模型,通过掩码语言模型改进上下文理解,适用于问答和文本分类。
157 9
|
3月前
|
数据采集 人工智能 数据挖掘
2021 第五届“达观杯” 基于大规模预训练模型的风险事件标签识别】3 Bert和Nezha方案
2021第五届“达观杯”基于大规模预训练模型的风险事件标签识别比赛中使用的NEZHA和Bert方案,包括预训练、微调、模型融合、TTA测试集数据增强以及总结和反思。
43 0
|
4月前
|
数据采集 自然语言处理 PyTorch
AIGC之BERT模型
7月更文挑战第5天