5.2 划分数据集并创建生成器
df_train, df_test = train_test_split(train, test_size=0.1, random_state=RANDOM_SEED) df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED) df_train.shape, df_val.shape, df_test.shape
((10800, 9), (600, 9), (600, 9))
def create_data_loader(df,tokenizer,max_len,batch_size): ds=EnterpriseDataset( texts=df['text'].values, labels=df['label'].values, tokenizer=tokenizer, max_len=max_len ) return DataLoader( ds, batch_size=batch_size, # num_workers=4 # windows多线程 )
BATCH_SIZE = 4 train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE) val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE) test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
next(iter(train_data_loader))
F:\ProgramData\Anaconda3\lib\site-packages\transformers\tokenization_utils_base.py:2271: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert). warnings.warn( {'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;', '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。', '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;', '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'], 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399, 679, 3926, 3504, 102, 2339, 689, 120, 1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024, 2347, 2128, 2961, 6579, 743, 4127, 4125, 1690, 3291, 2940, 102, 1555, 6588, 3302, 1218, 3136, 3152, 1310, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 3466, 3389, 102, 4127, 4125, 1690, 3332, 6981, 5390, 1350, 3300, 3126, 2658, 1105, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300, 671, 702, 3300, 3125, 7397, 8024, 2347, 743, 1726, 2128, 6163, 3121, 3633, 511, 102, 2339, 689, 120, 1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 101, 1843, 749, 3867, 7344, 6858, 6887, 102, 2339, 689, 120, 1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([0, 0, 0, 0])}
data = next(iter(train_data_loader)) data.keys()
dict_keys(['texts', 'input_ids', 'attention_mask', 'labels'])
print(data['input_ids'].shape) print(data['attention_mask'].shape) print(data['labels'].shape)
torch.Size([4, 160]) torch.Size([4, 160]) torch.Size([4])
6 基于Huggingface 的企业隐患识别模型构建
# bert_model = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME) bert_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
bert_model
BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(21128, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) )
encoding
{'input_ids': tensor([[ 101, 791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769, 1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119, 102, 0, 0, 0, 0, 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]])}
last_hidden_state, pooled_output = bert_model( input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], return_dict = False )
last_hidden_state # 每个token的向量表示
tensor([[[ 0.8880, 0.1987, 1.3610, ..., -0.5096, 0.3742, -0.2368], [-0.0747, 0.3148, 1.4699, ..., -1.0238, -0.0518, -0.0557], [ 1.0133, -0.6058, 1.0152, ..., 0.3536, 1.1091, -0.1179], ..., [ 0.4613, 0.4155, -0.4329, ..., 0.1605, -0.3617, -0.2294], [ 0.4403, 0.4763, -0.5568, ..., 0.2216, -0.3297, -0.3064], [ 0.4437, 0.3844, -0.4880, ..., 0.0670, -0.5105, -0.2472]]], grad_fn=<NativeLayerNormBackward>)
pooled_output
tensor([[ 0.9999, 0.9998, 0.9989, 0.9629, 0.3075, -0.1866, -0.9904, 0.8628, 0.9710, -0.9993, 1.0000, 1.0000, 0.9312, -0.9394, 0.9998, -0.9999, 0.0417, 0.9999, 0.9458, 0.3190, 1.0000, -1.0000, -0.9062, -0.9048, 0.1764, 0.9983, 0.9346, -0.8122, -0.9999, 0.9996, 0.7879, 0.9999, 0.8475, -1.0000, -1.0000, 0.9413, -0.8260, 0.9889, -0.4976, -0.9857, -0.9955, -0.9580, 0.5833, -0.9996, -0.8932, 0.8563, -1.0000, -0.9999, 0.9719, 0.9999, -0.7430, -0.9993, 0.9756, -0.9754, 0.2991, 0.8933, -0.9991, 0.9987, 1.0000, 0.4156, 0.9992, -0.9452, -0.8020, -0.9999, 1.0000, -0.9964, -0.9900, 0.4365, 1.0000, 1.0000, -0.9400, 0.8794, 1.0000, 0.9105, -0.6616, 1.0000, -0.9999, 0.6892, -1.0000, -0.9817, 1.0000, 0.9957, -0.8844, -0.8248, -0.9921, -0.9999, -0.9998, 1.0000, 0.5228, 0.1297, 0.9932, -0.9999, -1.0000, 0.9993, -0.9996, -0.9948, -0.9561, 0.9996, -0.5785, -0.9386, -0.2035, 0.9086, -0.9999, -0.9993, 0.9959, 0.9984, 0.6953, -0.9995, 1.0000, 0.8610, -1.0000, -0.4507, -1.0000, 0.2384, -0.9812, 0.9998, 0.9504, 0.5421, 0.9995, -0.9998, 0.9320, -0.9941, -0.9718, -0.9910, 0.9822, 1.0000, 0.9997, -0.9990, 1.0000, 1.0000, 0.8608, 0.9964, -0.9997, 0.9799, 0.5985, -0.9098, 0.5329, -0.6345, 1.0000, 0.9872, 0.9970, -0.9719, 0.9988, -0.9933, 1.0000, -0.9999, 0.9973, -1.0000, -0.6550, 0.9996, 0.8899, 1.0000, 0.2969, 0.9999, -0.9983, -0.9991, 0.9906, -0.6590, 0.9872, -1.0000, 0.7658, 0.7876, -0.8556, 0.6304, -1.0000, 1.0000, -0.7938, 1.0000, 0.9898, 0.2216, -0.9942, -0.9969, 0.8345, -0.9998, -0.9779, 0.9914, 0.5227, 0.9992, -0.9893, -0.9889, 0.2325, -0.9887, -0.9999, 0.9885, 0.0340, 0.9284, 0.5197, 0.4143, 0.8315, 0.1585, -0.5348, 1.0000, 0.2361, 0.9985, 0.9999, -0.3446, 0.1012, -0.9924, -1.0000, -0.7542, 0.9999, -0.2807, -0.9999, 0.9490, -1.0000, 0.9906, -0.7288, -0.5263, -0.9545, -0.9999, 0.9998, -0.9286, -0.9997, -0.5303, 0.8886, 0.5605, -0.9989, -0.3324, 0.9804, -0.9075, 0.9905, -0.9800, -0.9946, 0.6856, -0.9393, 0.9929, 0.9874, 1.0000, 0.9997, -0.0714, -0.9440, 1.0000, 0.1676, -1.0000, 0.5573, -0.9611, 0.8835, 0.9999, -0.9980, 0.9294, 1.0000, 0.7968, 1.0000, -0.7065, -0.9793, -0.9997, 1.0000, 0.9922, 0.9999, -0.9984, -0.9995, -0.1701, -0.5426, -1.0000, -1.0000, -0.6334, 0.9969, 0.9999, -0.1620, -0.9818, -0.9921, -0.9994, 1.0000, -0.9759, 1.0000, 0.8570, -0.7434, -0.9164, 0.9438, -0.7311, -0.9986, -0.3936, -0.9997, -0.9650, -1.0000, 0.9433, -0.9999, -1.0000, 0.6913, 1.0000, 0.8762, -1.0000, 0.9997, 0.9764, 0.7094, -0.9294, 0.9522, -1.0000, 1.0000, -0.9965, 0.9428, -0.9972, -0.9897, -0.7680, 0.9922, 0.9999, -0.9999, -0.9597, -0.9922, -0.9807, -0.3632, 0.9936, -0.7280, 0.4117, -0.9498, -0.9666, 0.9545, -0.9957, -0.9970, 0.4028, 1.0000, -0.9798, 1.0000, 0.9941, 1.0000, 0.9202, -0.9942, 0.9996, 0.5352, -0.5836, -0.8829, -0.9418, 0.9497, -0.0532, 0.6966, -0.9999, 0.9998, 0.9917, 0.9612, 0.7289, 0.0167, 0.3179, 0.9627, -0.9911, 0.9995, -0.9996, -0.6737, 0.9991, 1.0000, 0.9932, 0.4880, -0.7488, 0.9986, -0.9961, 0.9995, -1.0000, 0.9999, -0.9940, 0.9705, -0.9970, -0.9856, 1.0000, 0.9846, -0.7932, 0.9997, -0.9386, 0.9938, 0.9738, 0.8173, 0.9913, 0.9981, 1.0000, -0.9998, -0.9918, -0.9727, -0.9987, -0.9955, -1.0000, -0.1038, -1.0000, -0.9874, -0.9287, 0.5109, -0.9056, 0.1022, 0.7864, -0.8197, 0.5724, -0.5905, 0.2713, -0.7239, -0.9976, -0.9844, -1.0000, -0.9988, 0.8835, 0.9999, -0.9997, 0.9999, -0.9999, -0.9782, 0.9383, -0.5609, 0.7721, 0.9999, -1.0000, 0.9585, 0.9987, 1.0000, 0.9960, 0.9993, -0.9741, -0.9999, -0.9989, -0.9999, -1.0000, -0.9998, 0.9343, 0.6337, -1.0000, 0.0902, 0.8980, 1.0000, 0.9964, -0.9985, -0.6136, -0.9996, -0.8252, 0.9996, -0.0566, -1.0000, 0.9962, -0.8744, 1.0000, -0.8865, 0.9879, 0.8897, 0.9571, 0.9823, -1.0000, 0.9145, 1.0000, 0.0365, -1.0000, -0.9985, -0.9075, -0.9998, 0.0369, 0.8120, 0.9999, -1.0000, -0.9155, -0.9975, 0.7988, 0.9922, 0.9998, 0.9982, 0.9267, 0.9165, 0.5368, 0.1464, 0.9998, 0.4663, -0.9989, 0.9996, -0.7952, 0.4527, -1.0000, 0.9998, 0.4073, 0.9999, 0.9159, -0.5480, -0.6822, -0.9904, 0.9938, 1.0000, -0.4229, -0.4845, -0.9981, -1.0000, -0.9861, -0.0950, -0.4625, -0.9629, -0.9998, 0.6675, -0.5244, 1.0000, 1.0000, 0.9924, -0.9253, -0.9974, 0.9974, -0.9012, 0.9900, -0.2582, -1.0000, -0.9919, -0.9986, 1.0000, -0.9716, -0.9262, -0.9911, -0.2593, 0.5919, -0.9999, -0.4994, -0.9962, 0.9818, 1.0000, -0.9996, 0.9918, -0.9970, 0.7085, -0.1369, 0.8077, 0.9955, -0.3394, -0.5860, -0.6887, -0.9841, 0.9970, 0.9987, -0.9948, -0.8401, 0.9999, 0.0856, 0.9999, 0.5099, 0.9466, 0.9567, 1.0000, 0.8771, 1.0000, -0.0815, 1.0000, 0.9999, -0.9392, 0.5744, 0.8723, -0.9686, 0.5958, 0.9822, 0.9997, 0.8854, -0.1952, -0.9967, 0.9994, 1.0000, 1.0000, -0.3391, 0.9883, -0.4452, 0.9252, 0.4495, 0.9870, 0.3479, 0.2266, 0.9942, 0.9990, -0.9999, -0.9999, -1.0000, 1.0000, 0.9996, -0.6637, -1.0000, 0.9999, 0.4543, 0.7471, 0.9983, 0.3772, -0.9812, 0.9853, -0.9995, -0.3404, 0.9788, 0.9867, 0.7564, 0.9995, -0.9997, 0.7990, 1.0000, 0.0752, 0.9999, 0.2912, -0.9941, 0.9970, -0.9935, -0.9995, -0.9743, 0.9991, 0.9981, -0.9273, -0.8402, 0.9996, -0.9999, 0.9999, -0.9998, 0.9724, -0.9939, 1.0000, -0.9752, -0.9998, -0.3806, 0.8830, 0.8352, -0.8892, 1.0000, -0.8875, -0.8107, 0.7083, -0.8909, -0.9931, -0.9630, 0.0800, -1.0000, 0.7777, -0.9611, 0.5867, -0.9947, -0.9999, 1.0000, -0.9084, -0.9414, 0.9999, -0.8838, -1.0000, 0.9549, -0.9999, -0.6522, 0.7967, -0.6850, 0.1524, -1.0000, 0.4800, 0.9999, -0.9998, -0.7089, -0.9129, -0.9864, 0.6220, 0.8855, 0.9855, -0.8651, 0.3988, -0.2548, 0.9793, -0.7212, -0.2582, -0.9999, -0.8692, -0.6282, -0.9999, -0.9999, -1.0000, 1.0000, 0.9996, 0.9999, -0.5600, 0.7442, 0.9460, 0.9927, -0.9999, 0.4407, -0.0461, 0.9937, -0.4887, -0.9994, -0.9198, -1.0000, -0.6905, 0.3538, -0.7728, 0.6622, 1.0000, 0.9999, -0.9999, -0.9994, -0.9995, -0.9979, 0.9998, 0.9999, 0.9996, -0.9072, -0.5844, 0.9997, 0.9689, 0.5231, -0.9999, -0.9981, -0.9999, 0.7505, -0.9922, -0.9986, 0.9971, 1.0000, 0.8730, -1.0000, -0.9533, 1.0000, 0.9997, 1.0000, -0.7768, 0.9999, -0.9838, 0.9819, -0.9993, 1.0000, -1.0000, 1.0000, 0.9999, 0.9809, 0.9984, -0.9928, 0.9776, -0.9998, -0.7407, 0.9298, -0.4495, -0.9902, 0.8053, 0.9996, -0.9952, 1.0000, 0.9243, -0.2028, 0.8002, 0.9873, 0.9419, -0.6913, -0.9999, 0.8162, 0.9995, 0.9509, 1.0000, 0.9177, 0.9996, -0.9839, -0.9998, 0.9914, -0.6991, -0.7821, -0.9998, 1.0000, 1.0000, -0.9999, -0.9227, 0.7483, 0.1186, 1.0000, 0.9963, 0.9971, 0.9857, 0.3887, 0.9996, -0.9999, 0.8526, -0.9980, -0.8613, 0.9999, -0.9899, 0.9999, -0.9981, 1.0000, -0.9858, 0.9944, 0.9989, 0.9684, -0.9968, 1.0000, 0.8246, -0.9956, -0.8348, -0.9374, -0.9999, 0.7827]], grad_fn=<TanhBackward>)
last_hidden_state.shape # 每个token的向量表示
torch.Size([1, 32, 768])
pooled_output.shape # CLS的向量表示
torch.Size([1, 768])
bert_model.config.hidden_size
768
pooled_output.shape # 整体句子表示
torch.Size([1, 768])
class EnterpriseDangerClassifier(nn.Module): def __init__(self, n_classes): super(EnterpriseDangerClassifier, self).__init__() self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME) self.drop = nn.Dropout(p=0.3) self.out = nn.Linear(self.bert.config.hidden_size, n_classes) # 两个类别 def forward(self, input_ids, attention_mask): _, pooled_output = self.bert( input_ids=input_ids, attention_mask=attention_mask, return_dict = False ) output = self.drop(pooled_output) # dropout return self.out(output)
class_names=[0,1]
model = EnterpriseDangerClassifier(len(class_names)) model = model.to(device)
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
data
{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;', '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。', '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;', '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'], 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399, 679, 3926, 3504, 102, 2339, 689, 120, 1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024, 2347, 2128, 2961, 6579, 743, 4127, 4125, 1690, 3291, 2940, 102, 1555, 6588, 3302, 1218, 3136, 3152, 1310, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 3466, 3389, 102, 4127, 4125, 1690, 3332, 6981, 5390, 1350, 3300, 3126, 2658, 1105, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300, 671, 702, 3300, 3125, 7397, 8024, 2347, 743, 1726, 2128, 6163, 3121, 3633, 511, 102, 2339, 689, 120, 1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 101, 1843, 749, 3867, 7344, 6858, 6887, 102, 2339, 689, 120, 1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([0, 0, 0, 0])}
input_ids = data['input_ids'].to(device) attention_mask = data['attention_mask'].to(device) print(input_ids.shape) # batch size x seq length print(attention_mask.shape) # batch size x seq length
torch.Size([4, 160]) torch.Size([4, 160])
model(input_ids, attention_mask)
tensor([[-0.3011, -0.3009], [ 0.2871, 0.1841], [ 0.2703, -0.0926], [-0.3193, -0.1487]], device='cuda:0', grad_fn=<AddmmBackward>)
F.softmax(model(input_ids, attention_mask), dim=1)
tensor([[0.6495, 0.3505], [0.6752, 0.3248], [0.7261, 0.2739], [0.4528, 0.5472]], device='cuda:0', grad_fn=<SoftmaxBackward>)
7 模型训练
EPOCHS = 10 # 训练轮数 optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False) total_steps = len(train_data_loader) * EPOCHS scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps ) loss_fn = nn.CrossEntropyLoss().to(device)
F:\ProgramData\Anaconda3\lib\site-packages\transformers\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning warnings.warn(
def train_epoch( model, data_loader, loss_fn, optimizer, device, scheduler, n_examples ): model = model.train() # train模式 losses = [] correct_predictions = 0 for d in data_loader: input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() optimizer.zero_grad() return correct_predictions.double() / n_examples, np.mean(losses)
def eval_model(model, data_loader, loss_fn, device, n_examples): model = model.eval() # 验证预测模式 losses = [] correct_predictions = 0 with torch.no_grad(): for d in data_loader: input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) return correct_predictions.double() / n_examples, np.mean(losses)
history = defaultdict(list) # 记录10轮loss和acc best_accuracy = 0 for epoch in range(EPOCHS): print(f'Epoch {epoch + 1}/{EPOCHS}') print('-' * 10) train_acc, train_loss = train_epoch( model, train_data_loader, loss_fn, optimizer, device, scheduler, len(df_train) ) print(f'Train loss {train_loss} accuracy {train_acc}') val_acc, val_loss = eval_model( model, val_data_loader, loss_fn, device, len(df_val) ) print(f'Val loss {val_loss} accuracy {val_acc}') print() history['train_acc'].append(train_acc) history['train_loss'].append(train_loss) history['val_acc'].append(val_acc) history['val_loss'].append(val_loss) if val_acc > best_accuracy: torch.save(model.state_dict(), 'best_model_state.bin') best_accuracy = val_acc
Epoch 1/10 ---------- Train loss 0.49691140200114914 accuracy 0.8901851851851852 Val loss 0.40999091763049367 accuracy 0.9 Epoch 2/10 ---------- Train loss 0.3062430267758383 accuracy 0.9349999999999999 Val loss 0.20030112245275328 accuracy 0.9650000000000001 Epoch 3/10 ---------- Train loss 0.18264216477097728 accuracy 0.9603703703703703 Val loss 0.18755523634143173 accuracy 0.9650000000000001 Epoch 4/10 ---------- Train loss 0.15700688022613543 accuracy 0.9693518518518518 Val loss 0.20371213133369262 accuracy 0.9633333333333334 Epoch 5/10 ---------- Train loss 0.1627817107436756 accuracy 0.9668518518518519 Val loss 0.16456402061972766 accuracy 0.9683333333333334 Epoch 6/10 ---------- Train loss 0.15311389193888453 accuracy 0.9721296296296296 Val loss 0.1188539441426595 accuracy 0.9783333333333334 Epoch 7/10 ---------- Train loss 0.13947947008179012 accuracy 0.9734259259259259 Val loss 0.12033098526764661 accuracy 0.9783333333333334 Epoch 8/10 ---------- Train loss 0.12078767392419482 accuracy 0.9781481481481481 Val loss 0.12014915000802527 accuracy 0.9733333333333334 Epoch 9/10 ---------- Train loss 0.11557375699952967 accuracy 0.9751851851851852 Val loss 0.12187736847476724 accuracy 0.9766666666666667 Epoch 10/10 ---------- Train loss 0.10247013699765645 accuracy 0.977037037037037 Val loss 0.11501088156461871 accuracy 0.9766666666666667
plt.plot(history['train_acc'], label='train accuracy') plt.plot(history['val_acc'], label='validation accuracy') plt.title('Training history') plt.ylabel('Accuracy') plt.xlabel('Epoch') plt.legend() plt.ylim([0, 1]);
# model = EnterpriseDangerClassifier(len(class_names)) # model.load_state_dict(torch.load('best_model_state.bin')) # model = model.to(device)
8 模型评估
test_acc, _ = eval_model( model, test_data_loader, loss_fn, device, len(df_test) ) test_acc.item()
0.9716666666666667
def get_predictions(model, data_loader): model = model.eval() raw_texts = [] predictions = [] prediction_probs = [] real_values = [] with torch.no_grad(): for d in data_loader: texts = d["texts"] input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) # 类别 probs = F.softmax(outputs, dim=1) # 概率 raw_texts.extend(texts) predictions.extend(preds) prediction_probs.extend(probs) real_values.extend(targets) predictions = torch.stack(predictions).cpu() prediction_probs = torch.stack(prediction_probs).cpu() real_values = torch.stack(real_values).cpu() return raw_texts, predictions, prediction_probs, real_values
y_texts, y_pred, y_pred_probs, y_test = get_predictions( model, test_data_loader )
print(classification_report(y_test, y_pred, target_names=[str(label) for label in class_names])) # 分类报告
precision recall f1-score support 0 0.99 0.98 0.98 554 1 0.81 0.83 0.82 46 accuracy 0.97 600 macro avg 0.90 0.90 0.90 600 weighted avg 0.97 0.97 0.97 600
def show_confusion_matrix(confusion_matrix): hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues") hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right') hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right') plt.ylabel('True label') plt.xlabel('Predicted label'); cm = confusion_matrix(y_test, y_pred) df_cm = pd.DataFrame(cm, index=class_names, columns=class_names) show_confusion_matrix(df_cm)
idx = 2 sample_text = y_texts[idx] true_label = y_test[idx] pred_df = pd.DataFrame({ 'class_names': class_names, 'values': y_pred_probs[idx] })
print("\n".join(wrap(sample_text))) print() print(f'True label: {class_names[true_label]}')
焊锡员工未佩戴防护口罩 工业/危化品类 主要负责人、分管负责人及管理人员履职情况 分管负责人履职情况 分管负责人依法履行安全管理职责(存在职业健康危害的单位需自查职业卫生履职情况)。 True label: 0
sns.barplot(x='values', y='class_names', data=pred_df, orient='h') plt.ylabel('sentiment') plt.xlabel('probability') plt.xlim([0, 1]);
9 测试集预测
sample_text = "电源插头应按规定正确接线"
encoded_text = tokenizer.encode_plus( sample_text, max_length=MAX_LEN, add_special_tokens=True, return_token_type_ids=False, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt', )
input_ids = encoded_text['input_ids'].to(device) attention_mask = encoded_text['attention_mask'].to(device) output = model(input_ids, attention_mask) _, prediction = torch.max(output, dim=1) print(f'Sample text: {sample_text}') print(f'Danger label : {class_names[prediction]}')
Sample text: 电源插头应按规定正确接线 Danger label : 1