基于文本挖掘的企业隐患排查质量分析模型(下)

简介: 基于文本挖掘的企业隐患排查质量分析模型(下)

5.2 划分数据集并创建生成器


df_train, df_test = train_test_split(train, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)
df_train.shape, df_val.shape, df_test.shape
((10800, 9), (600, 9), (600, 9))
def create_data_loader(df,tokenizer,max_len,batch_size):
    ds=EnterpriseDataset(
        texts=df['text'].values,
        labels=df['label'].values,
        tokenizer=tokenizer,
        max_len=max_len
    )
    return DataLoader(
        ds,
        batch_size=batch_size,
#         num_workers=4 # windows多线程
    )
BATCH_SIZE = 4
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
next(iter(train_data_loader))
F:\ProgramData\Anaconda3\lib\site-packages\transformers\tokenization_utils_base.py:2271: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
  warnings.warn(
{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。',
  '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'],
 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399,  679, 3926, 3504,  102, 2339,  689,  120,
          1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125,
          2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887,
          3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403,
          2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024,
          2347, 2128, 2961, 6579,  743, 4127, 4125, 1690, 3291, 2940,  102, 1555,
          6588, 3302, 1218, 3136, 3152, 1310, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 3466, 3389,  102, 4127, 4125, 1690, 3332, 6981, 5390,
          1350, 3300, 3126, 2658, 1105,  511,  102,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300,  671,  702, 3300,
          3125, 7397, 8024, 2347,  743, 1726, 2128, 6163, 3121, 3633,  511,  102,
          2339,  689,  120, 1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541,
          3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141,
          2900, 4850, 3403, 2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130,
          1962, 8039,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1843,  749, 3867, 7344, 6858, 6887,  102, 2339,  689,  120, 1314,
          1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125, 2337,
          3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887, 3221,
          1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562,
           510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([0, 0, 0, 0])}
data = next(iter(train_data_loader))
data.keys()
dict_keys(['texts', 'input_ids', 'attention_mask', 'labels'])
print(data['input_ids'].shape)
print(data['attention_mask'].shape)
print(data['labels'].shape)
torch.Size([4, 160])
torch.Size([4, 160])
torch.Size([4])


6 基于Huggingface 的企业隐患识别模型构建


# bert_model = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
bert_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
bert_model
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (2): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (3): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (4): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (5): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (6): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (7): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (8): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (9): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (10): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
encoding
{'input_ids': tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0]])}
last_hidden_state, pooled_output = bert_model(
    input_ids=encoding['input_ids'], 
    attention_mask=encoding['attention_mask'],
    return_dict = False
)
last_hidden_state # 每个token的向量表示
tensor([[[ 0.8880,  0.1987,  1.3610,  ..., -0.5096,  0.3742, -0.2368],
         [-0.0747,  0.3148,  1.4699,  ..., -1.0238, -0.0518, -0.0557],
         [ 1.0133, -0.6058,  1.0152,  ...,  0.3536,  1.1091, -0.1179],
         ...,
         [ 0.4613,  0.4155, -0.4329,  ...,  0.1605, -0.3617, -0.2294],
         [ 0.4403,  0.4763, -0.5568,  ...,  0.2216, -0.3297, -0.3064],
         [ 0.4437,  0.3844, -0.4880,  ...,  0.0670, -0.5105, -0.2472]]],
       grad_fn=<NativeLayerNormBackward>)
pooled_output
tensor([[ 0.9999,  0.9998,  0.9989,  0.9629,  0.3075, -0.1866, -0.9904,  0.8628,
          0.9710, -0.9993,  1.0000,  1.0000,  0.9312, -0.9394,  0.9998, -0.9999,
          0.0417,  0.9999,  0.9458,  0.3190,  1.0000, -1.0000, -0.9062, -0.9048,
          0.1764,  0.9983,  0.9346, -0.8122, -0.9999,  0.9996,  0.7879,  0.9999,
          0.8475, -1.0000, -1.0000,  0.9413, -0.8260,  0.9889, -0.4976, -0.9857,
         -0.9955, -0.9580,  0.5833, -0.9996, -0.8932,  0.8563, -1.0000, -0.9999,
          0.9719,  0.9999, -0.7430, -0.9993,  0.9756, -0.9754,  0.2991,  0.8933,
         -0.9991,  0.9987,  1.0000,  0.4156,  0.9992, -0.9452, -0.8020, -0.9999,
          1.0000, -0.9964, -0.9900,  0.4365,  1.0000,  1.0000, -0.9400,  0.8794,
          1.0000,  0.9105, -0.6616,  1.0000, -0.9999,  0.6892, -1.0000, -0.9817,
          1.0000,  0.9957, -0.8844, -0.8248, -0.9921, -0.9999, -0.9998,  1.0000,
          0.5228,  0.1297,  0.9932, -0.9999, -1.0000,  0.9993, -0.9996, -0.9948,
         -0.9561,  0.9996, -0.5785, -0.9386, -0.2035,  0.9086, -0.9999, -0.9993,
          0.9959,  0.9984,  0.6953, -0.9995,  1.0000,  0.8610, -1.0000, -0.4507,
         -1.0000,  0.2384, -0.9812,  0.9998,  0.9504,  0.5421,  0.9995, -0.9998,
          0.9320, -0.9941, -0.9718, -0.9910,  0.9822,  1.0000,  0.9997, -0.9990,
          1.0000,  1.0000,  0.8608,  0.9964, -0.9997,  0.9799,  0.5985, -0.9098,
          0.5329, -0.6345,  1.0000,  0.9872,  0.9970, -0.9719,  0.9988, -0.9933,
          1.0000, -0.9999,  0.9973, -1.0000, -0.6550,  0.9996,  0.8899,  1.0000,
          0.2969,  0.9999, -0.9983, -0.9991,  0.9906, -0.6590,  0.9872, -1.0000,
          0.7658,  0.7876, -0.8556,  0.6304, -1.0000,  1.0000, -0.7938,  1.0000,
          0.9898,  0.2216, -0.9942, -0.9969,  0.8345, -0.9998, -0.9779,  0.9914,
          0.5227,  0.9992, -0.9893, -0.9889,  0.2325, -0.9887, -0.9999,  0.9885,
          0.0340,  0.9284,  0.5197,  0.4143,  0.8315,  0.1585, -0.5348,  1.0000,
          0.2361,  0.9985,  0.9999, -0.3446,  0.1012, -0.9924, -1.0000, -0.7542,
          0.9999, -0.2807, -0.9999,  0.9490, -1.0000,  0.9906, -0.7288, -0.5263,
         -0.9545, -0.9999,  0.9998, -0.9286, -0.9997, -0.5303,  0.8886,  0.5605,
         -0.9989, -0.3324,  0.9804, -0.9075,  0.9905, -0.9800, -0.9946,  0.6856,
         -0.9393,  0.9929,  0.9874,  1.0000,  0.9997, -0.0714, -0.9440,  1.0000,
          0.1676, -1.0000,  0.5573, -0.9611,  0.8835,  0.9999, -0.9980,  0.9294,
          1.0000,  0.7968,  1.0000, -0.7065, -0.9793, -0.9997,  1.0000,  0.9922,
          0.9999, -0.9984, -0.9995, -0.1701, -0.5426, -1.0000, -1.0000, -0.6334,
          0.9969,  0.9999, -0.1620, -0.9818, -0.9921, -0.9994,  1.0000, -0.9759,
          1.0000,  0.8570, -0.7434, -0.9164,  0.9438, -0.7311, -0.9986, -0.3936,
         -0.9997, -0.9650, -1.0000,  0.9433, -0.9999, -1.0000,  0.6913,  1.0000,
          0.8762, -1.0000,  0.9997,  0.9764,  0.7094, -0.9294,  0.9522, -1.0000,
          1.0000, -0.9965,  0.9428, -0.9972, -0.9897, -0.7680,  0.9922,  0.9999,
         -0.9999, -0.9597, -0.9922, -0.9807, -0.3632,  0.9936, -0.7280,  0.4117,
         -0.9498, -0.9666,  0.9545, -0.9957, -0.9970,  0.4028,  1.0000, -0.9798,
          1.0000,  0.9941,  1.0000,  0.9202, -0.9942,  0.9996,  0.5352, -0.5836,
         -0.8829, -0.9418,  0.9497, -0.0532,  0.6966, -0.9999,  0.9998,  0.9917,
          0.9612,  0.7289,  0.0167,  0.3179,  0.9627, -0.9911,  0.9995, -0.9996,
         -0.6737,  0.9991,  1.0000,  0.9932,  0.4880, -0.7488,  0.9986, -0.9961,
          0.9995, -1.0000,  0.9999, -0.9940,  0.9705, -0.9970, -0.9856,  1.0000,
          0.9846, -0.7932,  0.9997, -0.9386,  0.9938,  0.9738,  0.8173,  0.9913,
          0.9981,  1.0000, -0.9998, -0.9918, -0.9727, -0.9987, -0.9955, -1.0000,
         -0.1038, -1.0000, -0.9874, -0.9287,  0.5109, -0.9056,  0.1022,  0.7864,
         -0.8197,  0.5724, -0.5905,  0.2713, -0.7239, -0.9976, -0.9844, -1.0000,
         -0.9988,  0.8835,  0.9999, -0.9997,  0.9999, -0.9999, -0.9782,  0.9383,
         -0.5609,  0.7721,  0.9999, -1.0000,  0.9585,  0.9987,  1.0000,  0.9960,
          0.9993, -0.9741, -0.9999, -0.9989, -0.9999, -1.0000, -0.9998,  0.9343,
          0.6337, -1.0000,  0.0902,  0.8980,  1.0000,  0.9964, -0.9985, -0.6136,
         -0.9996, -0.8252,  0.9996, -0.0566, -1.0000,  0.9962, -0.8744,  1.0000,
         -0.8865,  0.9879,  0.8897,  0.9571,  0.9823, -1.0000,  0.9145,  1.0000,
          0.0365, -1.0000, -0.9985, -0.9075, -0.9998,  0.0369,  0.8120,  0.9999,
         -1.0000, -0.9155, -0.9975,  0.7988,  0.9922,  0.9998,  0.9982,  0.9267,
          0.9165,  0.5368,  0.1464,  0.9998,  0.4663, -0.9989,  0.9996, -0.7952,
          0.4527, -1.0000,  0.9998,  0.4073,  0.9999,  0.9159, -0.5480, -0.6822,
         -0.9904,  0.9938,  1.0000, -0.4229, -0.4845, -0.9981, -1.0000, -0.9861,
         -0.0950, -0.4625, -0.9629, -0.9998,  0.6675, -0.5244,  1.0000,  1.0000,
          0.9924, -0.9253, -0.9974,  0.9974, -0.9012,  0.9900, -0.2582, -1.0000,
         -0.9919, -0.9986,  1.0000, -0.9716, -0.9262, -0.9911, -0.2593,  0.5919,
         -0.9999, -0.4994, -0.9962,  0.9818,  1.0000, -0.9996,  0.9918, -0.9970,
          0.7085, -0.1369,  0.8077,  0.9955, -0.3394, -0.5860, -0.6887, -0.9841,
          0.9970,  0.9987, -0.9948, -0.8401,  0.9999,  0.0856,  0.9999,  0.5099,
          0.9466,  0.9567,  1.0000,  0.8771,  1.0000, -0.0815,  1.0000,  0.9999,
         -0.9392,  0.5744,  0.8723, -0.9686,  0.5958,  0.9822,  0.9997,  0.8854,
         -0.1952, -0.9967,  0.9994,  1.0000,  1.0000, -0.3391,  0.9883, -0.4452,
          0.9252,  0.4495,  0.9870,  0.3479,  0.2266,  0.9942,  0.9990, -0.9999,
         -0.9999, -1.0000,  1.0000,  0.9996, -0.6637, -1.0000,  0.9999,  0.4543,
          0.7471,  0.9983,  0.3772, -0.9812,  0.9853, -0.9995, -0.3404,  0.9788,
          0.9867,  0.7564,  0.9995, -0.9997,  0.7990,  1.0000,  0.0752,  0.9999,
          0.2912, -0.9941,  0.9970, -0.9935, -0.9995, -0.9743,  0.9991,  0.9981,
         -0.9273, -0.8402,  0.9996, -0.9999,  0.9999, -0.9998,  0.9724, -0.9939,
          1.0000, -0.9752, -0.9998, -0.3806,  0.8830,  0.8352, -0.8892,  1.0000,
         -0.8875, -0.8107,  0.7083, -0.8909, -0.9931, -0.9630,  0.0800, -1.0000,
          0.7777, -0.9611,  0.5867, -0.9947, -0.9999,  1.0000, -0.9084, -0.9414,
          0.9999, -0.8838, -1.0000,  0.9549, -0.9999, -0.6522,  0.7967, -0.6850,
          0.1524, -1.0000,  0.4800,  0.9999, -0.9998, -0.7089, -0.9129, -0.9864,
          0.6220,  0.8855,  0.9855, -0.8651,  0.3988, -0.2548,  0.9793, -0.7212,
         -0.2582, -0.9999, -0.8692, -0.6282, -0.9999, -0.9999, -1.0000,  1.0000,
          0.9996,  0.9999, -0.5600,  0.7442,  0.9460,  0.9927, -0.9999,  0.4407,
         -0.0461,  0.9937, -0.4887, -0.9994, -0.9198, -1.0000, -0.6905,  0.3538,
         -0.7728,  0.6622,  1.0000,  0.9999, -0.9999, -0.9994, -0.9995, -0.9979,
          0.9998,  0.9999,  0.9996, -0.9072, -0.5844,  0.9997,  0.9689,  0.5231,
         -0.9999, -0.9981, -0.9999,  0.7505, -0.9922, -0.9986,  0.9971,  1.0000,
          0.8730, -1.0000, -0.9533,  1.0000,  0.9997,  1.0000, -0.7768,  0.9999,
         -0.9838,  0.9819, -0.9993,  1.0000, -1.0000,  1.0000,  0.9999,  0.9809,
          0.9984, -0.9928,  0.9776, -0.9998, -0.7407,  0.9298, -0.4495, -0.9902,
          0.8053,  0.9996, -0.9952,  1.0000,  0.9243, -0.2028,  0.8002,  0.9873,
          0.9419, -0.6913, -0.9999,  0.8162,  0.9995,  0.9509,  1.0000,  0.9177,
          0.9996, -0.9839, -0.9998,  0.9914, -0.6991, -0.7821, -0.9998,  1.0000,
          1.0000, -0.9999, -0.9227,  0.7483,  0.1186,  1.0000,  0.9963,  0.9971,
          0.9857,  0.3887,  0.9996, -0.9999,  0.8526, -0.9980, -0.8613,  0.9999,
         -0.9899,  0.9999, -0.9981,  1.0000, -0.9858,  0.9944,  0.9989,  0.9684,
         -0.9968,  1.0000,  0.8246, -0.9956, -0.8348, -0.9374, -0.9999,  0.7827]],
       grad_fn=<TanhBackward>)
last_hidden_state.shape # 每个token的向量表示
torch.Size([1, 32, 768])
pooled_output.shape # CLS的向量表示
torch.Size([1, 768])
bert_model.config.hidden_size
768
pooled_output.shape
# 整体句子表示
torch.Size([1, 768])
class EnterpriseDangerClassifier(nn.Module):
    def __init__(self, n_classes):
        super(EnterpriseDangerClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes) # 两个类别
    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict = False
        )
        output = self.drop(pooled_output) # dropout
        return self.out(output)
class_names=[0,1]
model = EnterpriseDangerClassifier(len(class_names))
model = model.to(device)
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
data
{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。',
  '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'],
 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399,  679, 3926, 3504,  102, 2339,  689,  120,
          1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125,
          2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887,
          3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403,
          2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024,
          2347, 2128, 2961, 6579,  743, 4127, 4125, 1690, 3291, 2940,  102, 1555,
          6588, 3302, 1218, 3136, 3152, 1310, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 3466, 3389,  102, 4127, 4125, 1690, 3332, 6981, 5390,
          1350, 3300, 3126, 2658, 1105,  511,  102,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300,  671,  702, 3300,
          3125, 7397, 8024, 2347,  743, 1726, 2128, 6163, 3121, 3633,  511,  102,
          2339,  689,  120, 1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541,
          3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141,
          2900, 4850, 3403, 2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130,
          1962, 8039,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1843,  749, 3867, 7344, 6858, 6887,  102, 2339,  689,  120, 1314,
          1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125, 2337,
          3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887, 3221,
          1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562,
           510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([0, 0, 0, 0])}
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)
print(input_ids.shape) # batch size x seq length
print(attention_mask.shape) # batch size x seq length
torch.Size([4, 160])
torch.Size([4, 160])
model(input_ids, attention_mask)
tensor([[-0.3011, -0.3009],
        [ 0.2871,  0.1841],
        [ 0.2703, -0.0926],
        [-0.3193, -0.1487]], device='cuda:0', grad_fn=<AddmmBackward>)
F.softmax(model(input_ids, attention_mask), dim=1)
tensor([[0.6495, 0.3505],
        [0.6752, 0.3248],
        [0.7261, 0.2739],
        [0.4528, 0.5472]], device='cuda:0', grad_fn=<SoftmaxBackward>)


7 模型训练


EPOCHS = 10 # 训练轮数
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
  optimizer,
  num_warmup_steps=0,
  num_training_steps=total_steps
)
loss_fn = nn.CrossEntropyLoss().to(device)
F:\ProgramData\Anaconda3\lib\site-packages\transformers\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
def train_epoch(
  model, 
  data_loader, 
  loss_fn, 
  optimizer, 
  device, 
  scheduler, 
  n_examples
):
    model = model.train() # train模式
    losses = []
    correct_predictions = 0
    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        targets = d["labels"].to(device)
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, targets)
        correct_predictions += torch.sum(preds == targets)
        losses.append(loss.item())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
    return correct_predictions.double() / n_examples, np.mean(losses)
def eval_model(model, data_loader, loss_fn, device, n_examples):
    model = model.eval() # 验证预测模式
    losses = []
    correct_predictions = 0
    with torch.no_grad():
        for d in data_loader:
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs, dim=1)
            loss = loss_fn(outputs, targets)
            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())
    return correct_predictions.double() / n_examples, np.mean(losses)
history = defaultdict(list) # 记录10轮loss和acc
best_accuracy = 0
for epoch in range(EPOCHS):
    print(f'Epoch {epoch + 1}/{EPOCHS}')
    print('-' * 10)
    train_acc, train_loss = train_epoch(
        model,
        train_data_loader,
        loss_fn,
        optimizer,
        device,
        scheduler,
        len(df_train)
    )
    print(f'Train loss {train_loss} accuracy {train_acc}')
    val_acc, val_loss = eval_model(
        model,
        val_data_loader,
        loss_fn,
        device,
        len(df_val)
    )
    print(f'Val   loss {val_loss} accuracy {val_acc}')
    print()
    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_acc'].append(val_acc)
    history['val_loss'].append(val_loss)
    if val_acc > best_accuracy:
        torch.save(model.state_dict(), 'best_model_state.bin')
        best_accuracy = val_acc
Epoch 1/10
----------
Train loss 0.49691140200114914 accuracy 0.8901851851851852
Val   loss 0.40999091763049367 accuracy 0.9
Epoch 2/10
----------
Train loss 0.3062430267758383 accuracy 0.9349999999999999
Val   loss 0.20030112245275328 accuracy 0.9650000000000001
Epoch 3/10
----------
Train loss 0.18264216477097728 accuracy 0.9603703703703703
Val   loss 0.18755523634143173 accuracy 0.9650000000000001
Epoch 4/10
----------
Train loss 0.15700688022613543 accuracy 0.9693518518518518
Val   loss 0.20371213133369262 accuracy 0.9633333333333334
Epoch 5/10
----------
Train loss 0.1627817107436756 accuracy 0.9668518518518519
Val   loss 0.16456402061972766 accuracy 0.9683333333333334
Epoch 6/10
----------
Train loss 0.15311389193888453 accuracy 0.9721296296296296
Val   loss 0.1188539441426595 accuracy 0.9783333333333334
Epoch 7/10
----------
Train loss 0.13947947008179012 accuracy 0.9734259259259259
Val   loss 0.12033098526764661 accuracy 0.9783333333333334
Epoch 8/10
----------
Train loss 0.12078767392419482 accuracy 0.9781481481481481
Val   loss 0.12014915000802527 accuracy 0.9733333333333334
Epoch 9/10
----------
Train loss 0.11557375699952967 accuracy 0.9751851851851852
Val   loss 0.12187736847476724 accuracy 0.9766666666666667
Epoch 10/10
----------
Train loss 0.10247013699765645 accuracy 0.977037037037037
Val   loss 0.11501088156461871 accuracy 0.9766666666666667
plt.plot(history['train_acc'], label='train accuracy')
plt.plot(history['val_acc'], label='validation accuracy')
plt.title('Training history')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0, 1]);

image.png

# model = EnterpriseDangerClassifier(len(class_names))
# model.load_state_dict(torch.load('best_model_state.bin'))
# model = model.to(device)


8 模型评估


test_acc, _ = eval_model(
  model,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)
test_acc.item()
0.9716666666666667
def get_predictions(model, data_loader):
    model = model.eval()
    raw_texts = []
    predictions = []
    prediction_probs = []
    real_values = []
    with torch.no_grad():
        for d in data_loader:
            texts = d["texts"]
            input_ids = d["input_ids"].to(device)
            attention_mask = d["attention_mask"].to(device)
            targets = d["labels"].to(device)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            _, preds = torch.max(outputs, dim=1) # 类别
            probs = F.softmax(outputs, dim=1) # 概率
            raw_texts.extend(texts)
            predictions.extend(preds)
            prediction_probs.extend(probs)
            real_values.extend(targets)
    predictions = torch.stack(predictions).cpu()
    prediction_probs = torch.stack(prediction_probs).cpu()
    real_values = torch.stack(real_values).cpu()
    return raw_texts, predictions, prediction_probs, real_values
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)
print(classification_report(y_test, y_pred, target_names=[str(label) for label in class_names])) # 分类报告
precision    recall  f1-score   support
           0       0.99      0.98      0.98       554
           1       0.81      0.83      0.82        46
    accuracy                           0.97       600
   macro avg       0.90      0.90      0.90       600
weighted avg       0.97      0.97      0.97       600
def show_confusion_matrix(confusion_matrix):
    hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
    hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
    hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
    plt.ylabel('True label')
    plt.xlabel('Predicted label');
cm = confusion_matrix(y_test, y_pred)
df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
show_confusion_matrix(df_cm)

image.png

idx = 2
sample_text = y_texts[idx]
true_label = y_test[idx]
pred_df = pd.DataFrame({
  'class_names': class_names,
  'values': y_pred_probs[idx]
})
print("\n".join(wrap(sample_text)))
print()
print(f'True label: {class_names[true_label]}')
焊锡员工未佩戴防护口罩 工业/危化品类 主要负责人、分管负责人及管理人员履职情况 分管负责人履职情况
分管负责人依法履行安全管理职责(存在职业健康危害的单位需自查职业卫生履职情况)。
True label: 0
sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
plt.ylabel('sentiment')
plt.xlabel('probability')
plt.xlim([0, 1]);

image.png


9 测试集预测


sample_text = "电源插头应按规定正确接线"
encoded_text = tokenizer.encode_plus(
  sample_text,
  max_length=MAX_LEN,
  add_special_tokens=True,
  return_token_type_ids=False,
  pad_to_max_length=True,
  return_attention_mask=True,
  return_tensors='pt',
)
input_ids = encoded_text['input_ids'].to(device)
attention_mask = encoded_text['attention_mask'].to(device)
output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)
print(f'Sample text: {sample_text}')
print(f'Danger label  : {class_names[prediction]}')
Sample text: 电源插头应按规定正确接线
Danger label  : 1


目录
相关文章
|
2月前
|
架构师 测试技术
缺陷趋势分析
本文详细解析了累积缺陷发现统计及其在软件测试中的应用,探讨了理想情况下的凹凸曲线变化规律以及不同拐点出现时可能的问题,并提出了相应的调整策略。此外,还讨论了如何判断缺陷收敛及不收敛的情况,并给出了具体对策。这对于软件测试人员来说具有很高的参考价值。
44 3
|
5月前
|
人工智能 安全 定位技术
IT风险管理:识别、评估与缓解的艺术
【6月更文挑战第22天】面对数字化时代的挑战,企业需精通识别、评估与缓解IT风险以保障数字资产安全。本文聚焦风险识别的关键性,使用头脑风暴等工具发现潜在风险;通过概率-影响矩阵等评估风险严重性;并采取加强安全防护、完善制度等措施缓解风险,确保企业稳定运营。持续提升风险管理能力至关重要。
|
5月前
|
机器学习/深度学习 数据采集 算法
基于机器学习的糖尿病风险预警分析系统是一个非常有用的应用
基于机器学习的糖尿病风险预警分析系统是一个非常有用的应用
65 1
|
4月前
|
传感器 数据采集 存储
在环境治理领域,污染治理系统工程旨在通过系统的方法来解决环境污染问题。这通常包括污染源的识别、污染物的监测、治理技术的选择、治理效果的评估等多个环节。
在环境治理领域,污染治理系统工程旨在通过系统的方法来解决环境污染问题。这通常包括污染源的识别、污染物的监测、治理技术的选择、治理效果的评估等多个环节。
|
6月前
|
监控 测试技术
深入分析软件测试中的风险评估与管理
【5月更文挑战第30天】 在软件开发生命周期中,风险无处不在,特别是在软件测试阶段。本文旨在探讨软件测试过程中如何有效地进行风险评估和管理,以确保软件质量和项目成功。文中将介绍风险评估的基本概念,提出一个结构化的风险识别和评估框架,并详细讨论如何通过定性和定量方法来管理测试风险。此外,文章还将展示一个案例研究,以说明所提策略在实际中的应用效果。
|
测试技术
如何评估软件测试的质量风险?记住这5个核心关键点
如何评估软件测试的质量风险?记住这5个核心关键点
320 0
|
机器学习/深度学习 数据采集 数据处理
如何做一个诊断系统
如何做一个诊断系统
142 0
|
运维
《智能运维里的时间序列:异常检测、根源分析、预测》电子版地址
智能运维里的时间序列:异常检测、根源分析、预测
240 0
《智能运维里的时间序列:异常检测、根源分析、预测》电子版地址
|
数据采集 JSON 移动开发
基于文本挖掘的企业隐患排查质量分析模型(上)
基于文本挖掘的企业隐患排查质量分析模型(上)
605 0
基于文本挖掘的企业隐患排查质量分析模型(上)
|
人工智能 大数据 云计算
测试-风险甄别
测试-风险甄别