Paddle2.0实现中文新闻文本标题分类

简介: Paddle2.0实现中文新闻文本标题分类

项目说明,本项目是李宏毅老师在飞桨授权课程的作业解析

课程 传送门

该项目AiStudio项目 传送门

数据集 传送门


本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!


中文新闻文本标题分类Paddle2.0版本基线(非官方)



非官方,三岁出品!(虽水必精)


调优小建议


本项目基线的值不会很高,需要自行调参来提高效果。


优化建议:

  • 修改模型 现在是线性模型可以尝试修改更为复杂的
    对于nlp项目更加友好的(具体的我也不是很清楚)
  • 调整学习率来调整我们最好效果的查找
  • 可以通过对已有模型进一步训练得到较好的效果


数据集地址


https://aistudio.baidu.com/aistudio/datasetdetail/75812


任务描述


基于THUCNews数据集的文本分类, THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档,参赛者需要根据新闻标题的内容用算法来判断该新闻属于哪一类别


数据说明


THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。


已将训练集按照“标签ID+\t+标签+\t+原文标题”的格式抽取出来,可以直接根据新闻标题进行文本分类任务,希望答题者能够给出自己的解决方案。


训练集格式 标签ID+\t+标签+\t+原文标题 测试集格式 原文标题


提交答案


考试提交,需要提交模型代码项目版本结果文件。结果文件为TXT文件格式,命名为result.txt,文件内的字段需要按照指定格式写入。


1.每个类别的行数和测试集原始数据行数应一一对应,不可乱序

2.输出结果应检查是否为83599行数据,否则成绩无效

3.输出结果文件命名为result.txt,一行一个类别,样例如下:


游戏

财经

时政

股票

家居

科技

社会

房产

教育

星座

科技

股票

游戏

财经

时政

股票

家居

科技

社会

房产

教育

···


代码思路说明

根据题目可以知道这个是一个经典的nlp任务。

根据nlp任务处理的一般流程,我们需要进行以下几个步骤:


  • 数据处理并转换成词向量
  • 模型的搭建
  • 数据的训练
  • 模型读取并推理数据得到结果


那么话不多说我们开始!


数据集解压


! pip install -U paddlepaddle==2.0.1
! unzip -oq /home/aistudio/data/data75812/新闻文本标签分类.zip
import paddle
import numpy as np
import matplotlib.pyplot as plt
import paddle.nn as nn
import os
import numpy as np
print(paddle.__version__)  # 查看当前版本
# cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
# device = paddle.set_device('gpu')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
2021-03-27 12:21:25,020 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-03-27 12:21:25,357 - INFO - generated new fontManager
2.0.1


数据处理

首先我们考虑词向量的书写方式。

我们先制作词典(此处词典已经制作完成,我们直接读取就好了,词典制作过程会放在留言中)

我们把词典和我们的数据集进行对应,制作完成一个纯数字的对应码

得到对应码以后进行输出测试是否正确。

数据无误进行填充,把数据码用特殊标签进行替代完成数据长度相同的内容

检验数据长度


数据读取(字典、数据集)


# 字典读取
def get_dict_len(d_path):
    with open(d_path, 'r', encoding='utf-8') as f:
        line = eval(f.readlines()[0])
    return line
word_dict = get_dict_len('新闻文本标签分类/dict.txt')
# 训练集和验证集读取
set = []
def dataset(datapath):  # 数据集读取代码
    with open(datapath)as f:
        for i in f.readlines():
            data = []
            dataset = i[:i.rfind('\t')].split(',')  # 获取文字内容
            dataset = np.array(dataset)
            data.append(dataset)
            label = np.array(i[i.rfind('\t')+1:-1])  # 获取标签
            data.append(label)
            set.append(data)
    return set
train_dataset = dataset('新闻文本标签分类/Train_IDs.txt')
val_dataset = dataset('新闻文本标签分类/Val_IDs.txt')


数据初始化


定义一些需要的值


# 初始数据准备 
vocab_size = len(word_dict) + 1  # 字典长度加1
print(vocab_size)
emb_size = 256  # 神经网络长度
seq_len = 30  # 数据集长度(需要扩充的长度)
batch_size = 32  # 批处理大小
epochs = 2  # 训练轮数
pad_id = word_dict['<unk>']  # 空的填充内容值
nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]
# 生成句子列表(数据码生成文本)
def ids_to_str(ids):
    # print(ids)
    words = []
    for k in ids:
        w = list(word_dict)[eval(k)]
        words.append(w if isinstance(w, str) else w.decode('ASCII'))
    return " ".join(words)
5308


数据查看


查看数据是否正确如有异常及时修改

# 查看数据内容
for i in  train_dataset:
    sent = i[0]
    label = int(i[1])
    print('sentence list id is:', sent)  # 数据内容
    print('sentence label id is:', label)  # 对应标签
    print('--------------------------')  # 分隔线
    print('sentence list is: ', ids_to_str(sent))  # 转换后的数据
    print('sentence label is: ', nu[label])  # 转换后的标签
    break


sentence list id is: ['2976' '385' '2050' '3757' '1147' '3296' '1585' '688' '1180' '2608'
 '4280' '1887']
sentence label id is: 0
--------------------------
sentence list is:  上 证 5 0 E T F 净 申 购 突 增
sentence label is:  财经


数据扩充


把数据扩充成一样的长度

# 数据扩充并查看
def create_padded_dataset(dataset):
    padded_sents = []
    labels = []
    for batch_id, data in enumerate(dataset):  # 读取数据
        sent, label = data[0], data[1]  # 标签和数据拆分
        padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32')  # 数据拼接
        # print(padded_sent)
        padded_sents.append(padded_sent)  # 写入数据
        labels.append(label)  # 写入标签
    # print(padded_sents)
    return np.array(padded_sents), np.array(labels).astype('int64')  # 转换成数组并返回
# 对train、val数据进行实例化
train_sents, train_labels = create_padded_dataset(train_dataset)  # 实例化训练集
val_sents, val_labels = create_padded_dataset(val_dataset)  # 实例化测试集
train_labels = train_labels.reshape(832475,1)  # 标签数据大小转换
val_labels = val_labels.reshape(832475,1)
# 查看数据大小及举例内容
print(train_sents.shape)
print(train_labels.shape)
print(val_sents.shape)
print(val_labels.shape)


(832475, 30)
(832475, 1)
(832475, 30)
(832475, 1)


数据封装


通过继承paddle.io.Dataset类,把数据封装然后生成可以训练的数据格式

# 继承paddle.io.Dataset对数据进行处理
class IMDBDataset(paddle.io.Dataset):
    '''
    继承paddle.io.Dataset类进行封装数据
    '''
    def __init__(self, sents, labels):
        # 数据读取
        self.sents = sents
        self.labels = labels
    def __getitem__(self, index):
        # 数据处理
        data = self.sents[index]
        label = self.labels[index]
        return data, label
    def __len__(self):
        # 返回大小数据
        return len(self.sents)
# 数据实例化 
train_dataset = IMDBDataset(train_sents, train_labels)
val_dataset = IMDBDataset(val_sents, val_labels)
# 封装成生成器
train_loader = paddle.io.DataLoader(train_dataset, return_list=True,
                                    shuffle=True, batch_size=batch_size, drop_last=True)
val_loader = paddle.io.DataLoader(val_dataset, return_list=True,
                                    shuffle=True, batch_size=batch_size, drop_last=True)


# 查看生成器内的数据内容及大小
for i in train_loader:
    print(i)
    break
for j in val_loader:
    print(j)
    break
[Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True,
       [[4041, 4370, 3449, 3536, 103 , 2896, 4133, 312 , 1974, 3933, 2380, 805 , 3956, 4805, 3129, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1440, 3740, 1169, 2663, 4401, 4591, 4874, 2734, 989 , 1980, 5016, 450 , 335 , 1562, 2543, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [580 , 3844, 3513, 1231, 4111, 1894, 737 , 1318, 3536, 4805, 3956, 4075, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2573, 536 , 1230, 3757, 610 , 2018, 1974, 39  , 1629, 121 , 4625, 294 , 450 , 1991, 3149, 4389, 1146, 1736, 588 , 3388, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4829, 4419, 3415, 1230, 4910, 3814, 1876, 3509, 1592, 5059, 2207, 2139, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1546, 1221, 1117, 4386, 3449, 1562, 2088, 4770, 1299, 4500, 41  , 2976, 725 , 1006, 2053, 897 , 2315, 3786, 2559, 828 , 3682, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2185, 4673, 1546, 2991, 1120, 5025, 782 , 5025, 1674, 3717, 1006, 2099, 4807, 78  , 4749, 1932, 5283, 1375, 4725, 3185, 2358, 2100, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2140, 4935, 3388, 278 , 3287, 4059, 775 , 1304, 4315, 698 , 3375, 3966, 3980, 1472, 1472, 2140, 4935, 3388, 5303, 939 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5072, 2886, 4647, 3957, 5276, 2139, 4646, 5053, 4073, 4954, 1006, 4038, 2896, 3886, 756 , 4289, 2700, 4242, 4954, 2018, 2336, 2412, 2764, 4711, 5306, 5306, 5306, 5306, 5306, 5306],
        [1546, 1231, 1230, 385 , 4774, 5269, 939 , 2845, 1147, 2358, 3947, 4774, 872 , 1592, 2896, 123 , 5059, 1177, 3947, 4191, 4841, 754 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1180, 2646, 2155, 2776, 2886, 1257, 2302, 2748, 39  , 1230, 478 , 1006, 1425, 2263, 1278, 5078, 959 , 5102, 4578, 671 , 3430, 4954, 4910, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2891, 5257, 4426, 4932, 189 , 1695, 1347, 1724, 4328, 3344, 1688, 3449, 5115, 379 , 1347, 2244, 5216, 3070, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [67  , 2788, 2873, 898 , 4207, 1347, 12  , 372 , 1737, 1006, 3468, 383 , 1836, 5115, 4608, 4790, 1620, 760 , 3313, 2244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2099, 4807, 3379, 200 , 3933, 472 , 4415, 312 , 2078, 3222, 44  , 3222, 3924, 2373, 3398, 643 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2311, 3967, 720 , 2014, 2873, 311 , 4346, 2961, 4401, 725 , 1425, 1006, 1505, 3430, 4647, 926 , 4554, 4702, 4246, 2358, 3115, 5279, 123 , 1230, 679 , 5306, 5306, 5306, 5306, 5306],
        [1521, 2571, 1079, 4554, 1070, 534 , 2088, 2140, 5229, 1425, 3242, 846 , 3933, 3714, 99  , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2916, 123 , 1844, 5059, 123 , 1747, 3040, 1006, 5205, 1688, 1347, 601 , 3041, 3144, 3269, 4059, 2986, 4863, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3888, 2153, 4813, 3053, 1741, 1648, 2757, 1177, 2033, 2991, 5283, 123 , 2779, 2651, 1053, 1522, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4444, 5283, 1138, 3114, 3890, 3489, 1028, 3717, 936 , 389 , 2886, 2031, 316 , 3187, 2031, 2623, 643 , 4911, 3468, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3740, 2925, 3023, 2851, 4389, 3092, 3576, 725 , 1736, 2300, 3114, 1006, 2122, 1076, 3973, 3092, 3951, 2664, 1059, 3440, 415 , 3099, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1441, 312 , 134 , 4697, 1896, 1449, 3973, 4955, 3449, 1498, 1199, 2032, 2359, 4822, 1006, 4883, 4389, 4038, 4552, 4509, 2347, 690 , 1094, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2777, 422 , 1902, 2428, 621 , 3313, 3973, 5014, 5140, 3086, 4822, 1006, 3809, 3305, 3343, 5161, 1230, 1995, 3684, 954 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [813 , 903 , 4554, 3449, 1195, 3790, 4067, 1932, 2347, 3082, 4625, 2061, 3191, 992 , 1006, 1819, 3040, 4650, 1395, 729 , 5125, 5202, 2939, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3952, 3493, 385 , 225 , 3449, 1613, 4822, 3534, 3191, 2896, 3927, 698 , 3375, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1230, 831 , 1347, 2244, 1588, 3813, 2044, 3094, 1076, 4626, 1006, 1231, 1230, 3853, 4366, 2511, 2605, 3726, 5303, 939 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5052, 2293, 3449, 3446, 1094, 2976, 4922, 2099, 1221, 4034, 1290, 3323, 3430, 3099, 4109, 4579, 1006, 1713, 3058, 4370, 1613, 4191, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1359, 4922, 2748, 3933, 2099, 397 , 2858, 1006, 4438, 221 , 611 , 4159, 2642, 939 , 4784, 664 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4554, 1667, 477 , 2891, 1819, 2354, 1819, 3040, 1006, 2873, 898 , 3740, 1408, 2176, 3371, 123 , 5151, 2886, 3040, 1275, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4746, 3242, 5010, 3430, 2401, 4426, 4373, 1695, 2776, 775 , 1006, 1502, 3952, 2428, 1935, 3687, 809 , 416 , 1503, 4500, 1854, 2352, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2351, 3287, 3813, 2032, 4554, 1519, 1655, 4038, 3951, 2958, 2886, 2140, 1006, 4246, 4536, 3449, 1476, 2572, 4207, 4401, 1505, 2953, 3468, 377 , 5306, 5306, 5306, 5306, 5306, 5306],
        [3712, 3583, 3973, 2312, 4426, 3305, 2979, 1897, 3513, 4059, 1695, 1006, 5293, 4382, 2199, 1076, 4412, 3559, 1215, 2640, 1343, 4785, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2830, 2567, 1472, 134 , 3040, 1275, 3951, 377 , 420 , 1753, 1598, 690 , 3682, 4500, 1006, 3135, 3853, 4862, 3253, 377 , 2263, 5105, 3060, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[3 ],
        [6 ],
        [3 ],
        [5 ],
        [6 ],
        [3 ],
        [6 ],
        [2 ],
        [10],
        [3 ],
        [10],
        [6 ],
        [13],
        [6 ],
        [10],
        [6 ],
        [6 ],
        [4 ],
        [9 ],
        [10],
        [10],
        [13],
        [10],
        [3 ],
        [0 ],
        [3 ],
        [3 ],
        [13],
        [13],
        [10],
        [13],
        [10]])]
[Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True,
       [[2607, 5278, 1979, 2932, 40  , 2813, 2361, 3114, 4111, 3099, 1221, 103 , 2079, 3951, 2050, 3757, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2050, 2751, 3403, 1214, 516 , 1006, 4059, 2125, 2380, 233 , 1521, 805 , 366 , 2336, 2176, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2312, 487 , 2185, 4832, 4426, 2099, 1811, 1695, 1413, 4813, 3053, 3222, 4523, 3820, 2143, 1020, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3740, 4334, 377 , 1299, 4062, 4442, 536 , 3487, 3398, 4863, 1850, 4480, 1006, 2896, 4673, 2776, 1230, 3114, 3786, 4442, 3507, 1902, 2428, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1445, 4075, 1006, 610 , 805 , 3757, 3634, 2453, 1521, 736 , 1661, 1394, 4874, 3822, 1006, 421 , 3424, 3296, 610 , 610 , 3757, 316 , 4863, 3702, 2192, 5306, 5306, 5306, 5306, 5306],
        [2185, 2685, 4863, 5257, 3430, 2813, 2233, 684 , 846 , 892 , 1006, 3593, 3966, 3951, 4343, 2079, 892 , 4352, 4242, 3091, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3131, 1809, 2052, 4359, 3449, 1199, 2401, 1441, 2768, 4073, 1724, 4191, 1301, 3956, 3757, 2050, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5032, 410 , 4835, 3449, 2099, 44  , 989 , 4073, 1724, 4191, 1521, 1521, 3642, 2751, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2050, 3757, 3880, 4945, 2515, 1112, 4224, 1282, 3379, 4477, 834 , 2013, 4874, 3823, 617 , 1090, 4060, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3886, 2843, 2412, 1722, 1230, 3092, 4197, 1006, 699 , 1839, 380 , 1834, 1521, 3757, 1631, 4237, 518 , 3813, 2768, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1440, 3740, 3520, 2832, 3888, 2886, 1993, 3952, 2427, 1215, 2550, 4248, 4328, 4099, 5103, 2337, 3468, 4456, 3191, 4062, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [802 , 97  , 1876, 2768, 4191, 4785, 1318, 3991, 1006, 3165, 4191, 3509, 1318, 4504, 736 , 3757, 3757, 3757, 3757, 3375, 5019, 4959, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2233, 4370, 1347, 726 , 3886, 3142, 3259, 260 , 1445, 746 , 3238, 1025, 332 , 993 , 1006, 1301, 1661, 2845, 1836, 5115, 3738, 2199, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5267, 4953, 1472, 1876, 2873, 3951, 377 , 4841, 754 , 125 , 3224, 1006, 3951, 3967, 2983, 2886, 4038, 5135, 684 , 123 , 1521, 1301, 2846, 389 , 4841, 5306, 5306, 5306, 5306, 5306],
        [1993, 1837, 5281, 3992, 1425, 3740, 224 , 804 , 3534, 3191, 2099, 4807, 3735, 5067, 1006, 4449, 2375, 2375, 4945, 2515, 2436, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [224 , 4382, 3379, 200 , 3449, 1230, 3996, 805 , 141 , 1006, 3379, 200 , 4576, 2680, 3430, 3042, 1081, 3537, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2820, 682 , 2825, 2759, 1230, 294 , 4389, 3069, 3355, 2896, 1215, 2825, 4222, 3244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4274, 3740, 3413, 3449, 134 , 377 , 603 , 3886, 2873, 123 , 4289, 3020, 1230, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1293, 3714, 40  , 1472, 1094, 1440, 1669, 3966, 2756, 4432, 1521, 4591, 1094, 4591, 2052, 1006, 4951, 1418, 4019, 1425, 3740, 4775, 1839, 3430, 738 , 5306, 5306, 5306, 5306, 5306],
        [1164, 2453, 1185, 4162, 3430, 1546, 3740, 3398, 2052, 3559, 1221, 2050, 3956, 3757, 805 , 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2312, 536 , 1747, 3164, 2986, 542 , 3023, 3907, 1006, 4456, 4009, 3296, 3634, 1521, 3757, 5059, 736 , 736 , 3757, 3757, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3142, 4886, 3430, 4954, 5177, 4242, 4382, 3952, 4931, 795 , 1006, 2099, 2886, 4651, 1562, 2986, 2155, 1521, 4591, 3966, 601 , 3041, 2151, 377 , 5306, 5306, 5306, 5306, 5306, 5306],
        [3400, 872 , 1893, 3016, 3933, 2263, 2781, 3114, 692 , 3222, 1620, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2176, 2830, 3449, 3668, 3131, 3402, 2727, 224 , 264 , 4370, 4389, 1318, 1641, 2932, 1940, 4805, 2886, 4207, 4225, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2549, 3101, 2099, 690 , 4111, 3682, 3537, 534 , 4167, 3137, 4954, 1006, 1785, 3869, 823 , 3924, 3473, 3881, 927 , 730 , 592 , 476 , 3207, 241 , 5306, 5306, 5306, 5306, 5306, 5306],
        [4227, 1562, 4027, 4954, 1521, 610 , 375 , 3889, 2896, 2239, 4370, 4141, 3000, 56  , 1006, 4697, 200 , 269 , 926 , 1413, 4540, 5238, 1017, 3468, 2014, 964 , 5306, 5306, 5306, 5306],
        [2358, 1025, 1708, 993 , 332 , 1006, 1862, 1006, 2358, 1025, 3956, 2079, 1709, 720 , 3676, 4050, 3357, 1472, 2941, 2254, 2412, 1029, 3222, 1725, 1028, 3165, 5306, 5306, 5306, 5306],
        [2820, 682 , 3191, 3440, 1146, 3174, 4328, 2982, 2825, 2759, 1117, 3069, 3355, 617 , 2813, 2742, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2612, 3069, 1214, 3951, 1521, 3956, 4805, 3115, 1314, 2050, 3757, 3757, 366 , 1006, 1534, 2401, 5202, 1521, 4805, 366 , 3157, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306],
        [4057, 62  , 1765, 531 , 1991, 3149, 5269, 736 , 3757, 1521, 736 , 1991, 3149, 4389, 2018, 4389, 2253, 1694, 4073, 1200, 5116, 4073, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4746, 4283, 5295, 3449, 2099, 1794, 376 , 2040, 1663, 3564, 3187, 2986, 1006, 4111, 2896, 690 , 1117, 2776, 4500, 2078, 2040, 698 , 1214, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4289, 2166, 698 , 2100, 1006, 1343, 1681, 1094, 4863, 123 , 5162, 384 , 61  , 2380, 1645, 3388, 2336, 736 , 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[6 ],
        [8 ],
        [4 ],
        [10],
        [6 ],
        [5 ],
        [3 ],
        [3 ],
        [6 ],
        [6 ],
        [13],
        [0 ],
        [12],
        [10],
        [13],
        [3 ],
        [7 ],
        [9 ],
        [10],
        [3 ],
        [6 ],
        [10],
        [4 ],
        [9 ],
        [10],
        [10],
        [3 ],
        [7 ],
        [7 ],
        [5 ],
        [10],
        [8 ]])]


网络定义


定义网络情况,用于训练,这一块是提高成绩的关键之一

# 定义网络
class MyNet(paddle.nn.Layer):
    def __init__(self):
        super(MyNet, self).__init__() 
        self.emb = paddle.nn.Embedding(vocab_size, emb_size)  # 嵌入层用于自动构造一个二维embedding矩阵
        self.fc = paddle.nn.Linear(in_features=emb_size, out_features=96)  # 线性变换层 
        self.fc1 = paddle.nn.Linear(in_features=96, out_features=14)  # 分类器
        self.dropout = paddle.nn.Dropout(0.5)  # 正则化
    def forward(self, x):
        x = self.emb(x)
        x = paddle.mean(x, axis=1)  # 获取平均值
        x = self.dropout(x)
        x = self.fc(x)
        x = self.dropout(x)
        x = self.fc1(x)
        return x


# 画图
def draw_process(title,color,iters,data,label):
    plt.title(title, fontsize=24)  # 标题
    plt.xlabel("iter", fontsize=20)  # x轴
    plt.ylabel(label, fontsize=20)  # y轴
    plt.plot(iters, data,color=color,label=label)   # 画图
    plt.legend()
    plt.grid()
    plt.show()


模型训练


训练的重要环节,可以调节学习率,优化器等,有可能有奇效


# 训练模型
def train(model):
    model.train()
    opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())  # 优化器学习率等
    # 初始值设置
    steps = 0
    Iters, total_loss, total_acc = [], [], []
    for epoch in range(epochs):  # 训练循环
        for batch_id, data in enumerate(train_loader):  # 数据循环
            steps += 1
            sent = data[0]  # 获取数据
            label = data[1]  # 获取标签
            logits = model(sent)  # 输入数据
            loss = paddle.nn.functional.cross_entropy(logits, label)  # loss获取
            acc = paddle.metric.accuracy(logits, label)  # acc获取
            if batch_id % 500 == 0:  # 每500次输出一次结果
                Iters.append(steps)  # 保存训练轮数
                total_loss.append(loss.numpy()[0])  # 保存loss
                total_acc.append(acc.numpy()[0])  # 保存acc
                print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))  # 输出结果
            # 数据更新
            loss.backward()  
            opt.step()  
            opt.clear_grad()  
        # 每一个epochs进行一次评估
        model.eval()
        accuracies = []
        losses = []
        for batch_id, data in enumerate(val_loader):  # 数据循环读取
            sent = data[0]  # 训练内容读取
            label = data[1]  # 标签读取
            logits = model(sent)  # 训练数据
            loss = paddle.nn.functional.cross_entropy(logits, label)  # loss获取
            acc = paddle.metric.accuracy(logits, label)  # acc获取
            accuracies.append(acc.numpy())  # 添加数据
            losses.append(loss.numpy())  
        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)  # 获取loss、acc平均值
        print("[validation] accuracy: {}, loss: {}".format(avg_acc, avg_loss))  # 输出值
        model.train()
        paddle.save(model.state_dict(),str(epoch)+"_model_final.pdparams")  # 保存训练文件
    draw_process("trainning loss","red",Iters,total_loss,"trainning loss")  # 画处loss图
    draw_process("trainning acc","green",Iters,total_acc,"trainning acc")  # 画出caa图
model = MyNet()  # 模型实例化
train(model)  # 开始训练
epoch: 0, batch_id: 0, loss is: [2.6477456]
epoch: 0, batch_id: 500, loss is: [1.8056118]
epoch: 0, batch_id: 1000, loss is: [1.1092072]
epoch: 0, batch_id: 1500, loss is: [1.0716103]
epoch: 0, batch_id: 2000, loss is: [0.6794955]
epoch: 0, batch_id: 2500, loss is: [0.54738545]
epoch: 0, batch_id: 3000, loss is: [0.9065808]
epoch: 0, batch_id: 3500, loss is: [0.63474274]
epoch: 0, batch_id: 4000, loss is: [0.68158776]
epoch: 0, batch_id: 4500, loss is: [1.0516238]
epoch: 0, batch_id: 5000, loss is: [0.9118046]
epoch: 0, batch_id: 5500, loss is: [0.65075576]
epoch: 0, batch_id: 6000, loss is: [0.5605841]
epoch: 0, batch_id: 6500, loss is: [0.56175774]
epoch: 0, batch_id: 7000, loss is: [0.95122683]
epoch: 0, batch_id: 7500, loss is: [0.38649452]
epoch: 0, batch_id: 8000, loss is: [0.2205698]
epoch: 0, batch_id: 8500, loss is: [0.40474647]
epoch: 0, batch_id: 9000, loss is: [0.5931748]
epoch: 0, batch_id: 9500, loss is: [0.3922717]
epoch: 0, batch_id: 10000, loss is: [0.6130478]
epoch: 0, batch_id: 10500, loss is: [0.5300909]
epoch: 0, batch_id: 11000, loss is: [0.6114788]
epoch: 0, batch_id: 11500, loss is: [0.24966809]
epoch: 0, batch_id: 12000, loss is: [0.45669073]
epoch: 0, batch_id: 12500, loss is: [0.29746443]
epoch: 0, batch_id: 13000, loss is: [0.6775298]
epoch: 0, batch_id: 13500, loss is: [0.8836371]
epoch: 0, batch_id: 14000, loss is: [0.27501673]
epoch: 0, batch_id: 14500, loss is: [0.46843478]
epoch: 0, batch_id: 15000, loss is: [0.49367175]
epoch: 0, batch_id: 15500, loss is: [0.500063]
epoch: 0, batch_id: 16000, loss is: [0.31290954]
epoch: 0, batch_id: 16500, loss is: [0.30774388]
epoch: 0, batch_id: 17000, loss is: [0.21738727]
epoch: 0, batch_id: 17500, loss is: [0.2860858]
epoch: 0, batch_id: 18000, loss is: [0.2766972]
epoch: 0, batch_id: 18500, loss is: [0.36017033]
epoch: 0, batch_id: 19000, loss is: [0.43986273]
epoch: 0, batch_id: 19500, loss is: [0.4210134]
epoch: 0, batch_id: 20000, loss is: [0.579644]
epoch: 0, batch_id: 20500, loss is: [0.23016676]
epoch: 0, batch_id: 21000, loss is: [0.21913218]
epoch: 0, batch_id: 21500, loss is: [0.18669227]
epoch: 0, batch_id: 22000, loss is: [0.31480896]
epoch: 0, batch_id: 22500, loss is: [0.37621552]
epoch: 0, batch_id: 23000, loss is: [0.54980826]
epoch: 0, batch_id: 23500, loss is: [0.6016808]
epoch: 0, batch_id: 24000, loss is: [0.25056183]
epoch: 0, batch_id: 24500, loss is: [0.2916811]
epoch: 0, batch_id: 25000, loss is: [0.33430776]
epoch: 0, batch_id: 25500, loss is: [0.74600095]
epoch: 0, batch_id: 26000, loss is: [0.35165167]
[validation] accuracy: 0.884321928024292, loss: 0.3713749647140503
epoch: 1, batch_id: 0, loss is: [0.47405708]
epoch: 1, batch_id: 500, loss is: [0.4443894]
epoch: 1, batch_id: 1000, loss is: [0.35416052]
epoch: 1, batch_id: 1500, loss is: [0.3004715]
epoch: 1, batch_id: 2000, loss is: [0.59477925]
epoch: 1, batch_id: 2500, loss is: [0.5639044]
epoch: 1, batch_id: 3000, loss is: [0.40286714]
epoch: 1, batch_id: 3500, loss is: [0.5387965]
epoch: 1, batch_id: 4000, loss is: [0.11766122]
epoch: 1, batch_id: 4500, loss is: [0.68849707]
epoch: 1, batch_id: 5000, loss is: [0.83928466]
epoch: 1, batch_id: 5500, loss is: [0.2867105]
epoch: 1, batch_id: 6000, loss is: [0.20924558]
epoch: 1, batch_id: 6500, loss is: [0.5582311]
epoch: 1, batch_id: 7000, loss is: [0.63174886]
epoch: 1, batch_id: 7500, loss is: [0.318484]
epoch: 1, batch_id: 8000, loss is: [0.5406461]
epoch: 1, batch_id: 8500, loss is: [0.4790561]
epoch: 1, batch_id: 9000, loss is: [0.52266514]
epoch: 1, batch_id: 9500, loss is: [0.51126254]
epoch: 1, batch_id: 10000, loss is: [0.27308795]
epoch: 1, batch_id: 10500, loss is: [0.22041513]
epoch: 1, batch_id: 11000, loss is: [0.32234907]
epoch: 1, batch_id: 11500, loss is: [0.6857507]
epoch: 1, batch_id: 12000, loss is: [0.40997463]
epoch: 1, batch_id: 12500, loss is: [0.53966033]
epoch: 1, batch_id: 13000, loss is: [0.2620927]
epoch: 1, batch_id: 13500, loss is: [0.21417136]
epoch: 1, batch_id: 14000, loss is: [0.5232475]
epoch: 1, batch_id: 14500, loss is: [0.37579858]
epoch: 1, batch_id: 15000, loss is: [0.3611152]
epoch: 1, batch_id: 15500, loss is: [0.336707]
epoch: 1, batch_id: 16000, loss is: [0.2795578]
epoch: 1, batch_id: 16500, loss is: [0.54298353]
epoch: 1, batch_id: 17000, loss is: [0.26425135]
epoch: 1, batch_id: 17500, loss is: [0.52595145]
epoch: 1, batch_id: 18000, loss is: [0.24938256]
epoch: 1, batch_id: 18500, loss is: [0.30653632]
epoch: 1, batch_id: 19000, loss is: [0.58400965]
epoch: 1, batch_id: 19500, loss is: [0.18243803]
epoch: 1, batch_id: 20000, loss is: [0.28917578]
epoch: 1, batch_id: 20500, loss is: [1.0765818]
epoch: 1, batch_id: 21000, loss is: [0.32550114]
epoch: 1, batch_id: 21500, loss is: [0.16792971]
epoch: 1, batch_id: 22000, loss is: [0.65214527]
epoch: 1, batch_id: 22500, loss is: [0.58119446]
epoch: 1, batch_id: 23000, loss is: [0.43643892]
epoch: 1, batch_id: 23500, loss is: [0.47376677]
epoch: 1, batch_id: 24000, loss is: [0.3279624]
epoch: 1, batch_id: 24500, loss is: [0.50899947]
epoch: 1, batch_id: 25000, loss is: [0.61989105]
epoch: 1, batch_id: 25500, loss is: [0.42433214]
epoch: 1, batch_id: 26000, loss is: [0.26673254]
[validation] accuracy: 0.8882260322570801, loss: 0.35311153531074524
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data


推理数据读取


# 比赛数据读取
set = []
def dataset(datapath):
    with open(datapath)as f:  # 读取文件
        for i in f.readlines():  # 逐行读取数据
            dataset = np.array(i.split(','))  # 分割数据
            set.append(dataset)  # 存入数据
    return set
# 比赛数据扩充
def create_padded_dataset(dataset):
    padded_sents = []
    labels = []
    for batch_id, data in enumerate(dataset):  # 循环
        # print(data)
        sent = data  # 读取数据
        padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32')  # 拼接填充
        # print(padded_sent)
        padded_sents.append(padded_sent)  # 输入数据
    # print(padded_sents)
    return np.array(padded_sents)  # 转换成数组并返回
test_data = dataset('新闻文本标签分类/Test_IDs.txt')  # 读取数据
# print()
# 对train、val数据进行实例化
test_data = create_padded_dataset(test_data)  # 数据填充
# 查看数据大小及举例内容
print(test_data)


[[4057 1902 1475 ... 5306 5306 5306]
 [2805 5242 3593 ... 5306 5306 5306]
 [1836 3222 4641 ... 5306 5306 5306]
 ...
 [4838 1202 1490 ... 5306 5306 5306]
 [ 805 3757 3757 ... 5306 5306 5306]
 [2805 5242 3593 ... 5306 5306 5306]]


开始推理


这里可以选择效果好的模型然后进行预测

nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]  # 标签列表
# 导入模型
model_state_dict = paddle.load('0_model_final.pdparams')  # 模型读取
model = MyNet()  # 读取网络
model.set_state_dict(model_state_dict)  
model.eval()
# print(type(test_data[0]))
count = 0  # 初始值
with open('./result.txt', 'w', encoding='utf-8') as f_train:  # 生成文件
    for batch_id, data in enumerate(test_data):  # 循环数据
        results = model(paddle.to_tensor(data.reshape(30,1)))  # 开始训练
        for probs in results:
            # 映射分类label
            idx = np.argmax(probs)  # 获取结果值
            labels = nu[idx]  # 通过结果值获取标签
            f_train.write(labels+"\n")  # 写入数据
            count +=1
            break
        if count%500==0:  # 查看推理情况
            print(count)
print(count)


效果不一定好,但是可以跑通 ,如果有其他的需求可以联系我(留言或群里面at我)我也会进一步改进项目

感谢大家的支持!


传说中的飞桨社区最菜代码人,让我们一起努力!

记住:三岁出品必是精品 (不要脸系列

目录
相关文章
|
6月前
|
Python
ChatGPT 调教指南:从 PDF 提取标题并保存
ChatGPT 调教指南:从 PDF 提取标题并保存
134 0
|
机器学习/深度学习 JSON 自然语言处理
bert中文文本摘要代码(2)
bert中文文本摘要代码(2)
301 0
|
6月前
|
机器学习/深度学习 自然语言处理 搜索推荐
探索文本向量化的新高峰:合合信息acge_text_embedding 模型原创
文本向量化方法包括词袋模型、TF-IDF、词嵌入和预训练模型(如BERT、GPT)。词嵌入如Word2Vec、GloVe和FastText捕捉单词语义,预训练模型则保留上下文信息。C-MTEB是中文文本嵌入评估平台,测试模型在检索、相似性、分类等任务的性能。合合信息的acge_text_embedding模型在C-MTEB中表现优秀,适用于情感分析、文本生成等任务,具有高分类聚类准确性、资源效率和场景适应性。技术突破涉及数据集优化、模型训练策略和持续学习,提供Demo展示如何使用acge模型计算句子相似度。acge_text_embedding是提升文本处理效率和智能化的有力工具。
656 2
探索文本向量化的新高峰:合合信息acge_text_embedding 模型原创
|
机器学习/深度学习 数据采集 算法
2021-4月Python 机器学习——中文新闻文本标题分类
2021-4月Python 机器学习——中文新闻文本标题分类
285 0
|
存储 自然语言处理 PyTorch
bert中文文本摘要代码(1)
bert中文文本摘要代码(1)
129 0
|
存储 自然语言处理 并行计算
bert中文文本摘要代码(3)
bert中文文本摘要代码(3)
131 0
bert中文文本摘要代码(3)
|
机器学习/深度学习 自然语言处理 算法
2021-4月Python 机器学习——中文新闻文本标题分类(简单容易版)
2021-4月Python 机器学习——中文新闻文本标题分类(简单容易版)
394 0
|
自然语言处理
【ModelScope News】文本生成大模型专题内容上线啦!!
文本生成大模型专题内容上线,最新技术趋势、模型详解、应用体验、相关论文,尽在其中。
157 0
|
机器学习/深度学习 自然语言处理 算法
使用Python和GloVe词嵌入模型提取新闻和文章的文本摘要
使用Python和GloVe词嵌入模型提取新闻和文章的文本摘要
264 0
使用Python和GloVe词嵌入模型提取新闻和文章的文本摘要
|
缓存 监控 数据挖掘
基于【PaddleNLP】中文新闻文本标题分类
基于【PaddleNLP】中文新闻文本标题分类
423 0