项目说明,本项目是李宏毅老师在飞桨授权课程的作业解析
课程 传送门
该项目AiStudio项目 传送门
数据集 传送门
本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!
中文新闻文本标题分类Paddle2.0版本基线(非官方)
非官方,三岁出品!(虽水必精)
调优小建议
本项目基线的值不会很高,需要自行调参来提高效果。
优化建议:
- 修改模型 现在是线性模型可以尝试修改更为复杂的
对于nlp项目更加友好的(具体的我也不是很清楚) - 调整学习率来调整我们最好效果的查找
- 可以通过对已有模型进一步训练得到较好的效果
数据集地址
https://aistudio.baidu.com/aistudio/datasetdetail/75812
任务描述
基于THUCNews数据集的文本分类, THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档,参赛者需要根据新闻标题的内容用算法来判断该新闻属于哪一类别
数据说明
THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。
已将训练集按照“标签ID+\t+标签+\t+原文标题”的格式抽取出来,可以直接根据新闻标题进行文本分类任务,希望答题者能够给出自己的解决方案。
训练集格式 标签ID+\t+标签+\t+原文标题 测试集格式 原文标题
提交答案
考试提交,需要提交模型代码项目版本和结果文件。结果文件为TXT文件格式,命名为result.txt,文件内的字段需要按照指定格式写入。
1.每个类别的行数和测试集原始数据行数应一一对应,不可乱序
2.输出结果应检查是否为83599行数据,否则成绩无效
3.输出结果文件命名为result.txt,一行一个类别,样例如下:
游戏
财经
时政
股票
家居
科技
社会
房产
教育
星座
科技
股票
游戏
财经
时政
股票
家居
科技
社会
房产
教育
···
代码思路说明
根据题目可以知道这个是一个经典的nlp
任务。
根据nlp
任务处理的一般流程,我们需要进行以下几个步骤:
- 数据处理并转换成词向量
- 模型的搭建
- 数据的训练
- 模型读取并推理数据得到结果
那么话不多说我们开始!
数据集解压
! pip install -U paddlepaddle==2.0.1
! unzip -oq /home/aistudio/data/data75812/新闻文本标签分类.zip
import paddle import numpy as np import matplotlib.pyplot as plt import paddle.nn as nn import os import numpy as np print(paddle.__version__) # 查看当前版本 # cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。 # device = paddle.set_device('gpu')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized 2021-03-27 12:21:25,020 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts'] 2021-03-27 12:21:25,357 - INFO - generated new fontManager 2.0.1
数据处理
首先我们考虑词向量的书写方式。
我们先制作词典(此处词典已经制作完成,我们直接读取就好了,词典制作过程会放在留言中)
我们把词典和我们的数据集进行对应,制作完成一个纯数字的对应码
得到对应码以后进行输出测试是否正确。
数据无误进行填充,把数据码用特殊标签进行替代完成数据长度相同的内容
检验数据长度
数据读取(字典、数据集)
# 字典读取 def get_dict_len(d_path): with open(d_path, 'r', encoding='utf-8') as f: line = eval(f.readlines()[0]) return line word_dict = get_dict_len('新闻文本标签分类/dict.txt')
# 训练集和验证集读取 set = [] def dataset(datapath): # 数据集读取代码 with open(datapath)as f: for i in f.readlines(): data = [] dataset = i[:i.rfind('\t')].split(',') # 获取文字内容 dataset = np.array(dataset) data.append(dataset) label = np.array(i[i.rfind('\t')+1:-1]) # 获取标签 data.append(label) set.append(data) return set train_dataset = dataset('新闻文本标签分类/Train_IDs.txt') val_dataset = dataset('新闻文本标签分类/Val_IDs.txt')
数据初始化
定义一些需要的值
# 初始数据准备 vocab_size = len(word_dict) + 1 # 字典长度加1 print(vocab_size) emb_size = 256 # 神经网络长度 seq_len = 30 # 数据集长度(需要扩充的长度) batch_size = 32 # 批处理大小 epochs = 2 # 训练轮数 pad_id = word_dict['<unk>'] # 空的填充内容值 nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"] # 生成句子列表(数据码生成文本) def ids_to_str(ids): # print(ids) words = [] for k in ids: w = list(word_dict)[eval(k)] words.append(w if isinstance(w, str) else w.decode('ASCII')) return " ".join(words)
5308
数据查看
查看数据是否正确如有异常及时修改
# 查看数据内容 for i in train_dataset: sent = i[0] label = int(i[1]) print('sentence list id is:', sent) # 数据内容 print('sentence label id is:', label) # 对应标签 print('--------------------------') # 分隔线 print('sentence list is: ', ids_to_str(sent)) # 转换后的数据 print('sentence label is: ', nu[label]) # 转换后的标签 break
sentence list id is: ['2976' '385' '2050' '3757' '1147' '3296' '1585' '688' '1180' '2608' '4280' '1887'] sentence label id is: 0 -------------------------- sentence list is: 上 证 5 0 E T F 净 申 购 突 增 sentence label is: 财经
数据扩充
把数据扩充成一样的长度
# 数据扩充并查看 def create_padded_dataset(dataset): padded_sents = [] labels = [] for batch_id, data in enumerate(dataset): # 读取数据 sent, label = data[0], data[1] # 标签和数据拆分 padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32') # 数据拼接 # print(padded_sent) padded_sents.append(padded_sent) # 写入数据 labels.append(label) # 写入标签 # print(padded_sents) return np.array(padded_sents), np.array(labels).astype('int64') # 转换成数组并返回 # 对train、val数据进行实例化 train_sents, train_labels = create_padded_dataset(train_dataset) # 实例化训练集 val_sents, val_labels = create_padded_dataset(val_dataset) # 实例化测试集 train_labels = train_labels.reshape(832475,1) # 标签数据大小转换 val_labels = val_labels.reshape(832475,1) # 查看数据大小及举例内容 print(train_sents.shape) print(train_labels.shape) print(val_sents.shape) print(val_labels.shape)
(832475, 30) (832475, 1) (832475, 30) (832475, 1)
数据封装
通过继承paddle.io.Dataset
类,把数据封装然后生成可以训练的数据格式
# 继承paddle.io.Dataset对数据进行处理 class IMDBDataset(paddle.io.Dataset): ''' 继承paddle.io.Dataset类进行封装数据 ''' def __init__(self, sents, labels): # 数据读取 self.sents = sents self.labels = labels def __getitem__(self, index): # 数据处理 data = self.sents[index] label = self.labels[index] return data, label def __len__(self): # 返回大小数据 return len(self.sents) # 数据实例化 train_dataset = IMDBDataset(train_sents, train_labels) val_dataset = IMDBDataset(val_sents, val_labels) # 封装成生成器 train_loader = paddle.io.DataLoader(train_dataset, return_list=True, shuffle=True, batch_size=batch_size, drop_last=True) val_loader = paddle.io.DataLoader(val_dataset, return_list=True, shuffle=True, batch_size=batch_size, drop_last=True)
# 查看生成器内的数据内容及大小 for i in train_loader: print(i) break for j in val_loader: print(j) break
[Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True, [[4041, 4370, 3449, 3536, 103 , 2896, 4133, 312 , 1974, 3933, 2380, 805 , 3956, 4805, 3129, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1440, 3740, 1169, 2663, 4401, 4591, 4874, 2734, 989 , 1980, 5016, 450 , 335 , 1562, 2543, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [580 , 3844, 3513, 1231, 4111, 1894, 737 , 1318, 3536, 4805, 3956, 4075, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2573, 536 , 1230, 3757, 610 , 2018, 1974, 39 , 1629, 121 , 4625, 294 , 450 , 1991, 3149, 4389, 1146, 1736, 588 , 3388, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [4829, 4419, 3415, 1230, 4910, 3814, 1876, 3509, 1592, 5059, 2207, 2139, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1546, 1221, 1117, 4386, 3449, 1562, 2088, 4770, 1299, 4500, 41 , 2976, 725 , 1006, 2053, 897 , 2315, 3786, 2559, 828 , 3682, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2185, 4673, 1546, 2991, 1120, 5025, 782 , 5025, 1674, 3717, 1006, 2099, 4807, 78 , 4749, 1932, 5283, 1375, 4725, 3185, 2358, 2100, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2140, 4935, 3388, 278 , 3287, 4059, 775 , 1304, 4315, 698 , 3375, 3966, 3980, 1472, 1472, 2140, 4935, 3388, 5303, 939 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [5072, 2886, 4647, 3957, 5276, 2139, 4646, 5053, 4073, 4954, 1006, 4038, 2896, 3886, 756 , 4289, 2700, 4242, 4954, 2018, 2336, 2412, 2764, 4711, 5306, 5306, 5306, 5306, 5306, 5306], [1546, 1231, 1230, 385 , 4774, 5269, 939 , 2845, 1147, 2358, 3947, 4774, 872 , 1592, 2896, 123 , 5059, 1177, 3947, 4191, 4841, 754 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1180, 2646, 2155, 2776, 2886, 1257, 2302, 2748, 39 , 1230, 478 , 1006, 1425, 2263, 1278, 5078, 959 , 5102, 4578, 671 , 3430, 4954, 4910, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2891, 5257, 4426, 4932, 189 , 1695, 1347, 1724, 4328, 3344, 1688, 3449, 5115, 379 , 1347, 2244, 5216, 3070, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [67 , 2788, 2873, 898 , 4207, 1347, 12 , 372 , 1737, 1006, 3468, 383 , 1836, 5115, 4608, 4790, 1620, 760 , 3313, 2244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2099, 4807, 3379, 200 , 3933, 472 , 4415, 312 , 2078, 3222, 44 , 3222, 3924, 2373, 3398, 643 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2311, 3967, 720 , 2014, 2873, 311 , 4346, 2961, 4401, 725 , 1425, 1006, 1505, 3430, 4647, 926 , 4554, 4702, 4246, 2358, 3115, 5279, 123 , 1230, 679 , 5306, 5306, 5306, 5306, 5306], [1521, 2571, 1079, 4554, 1070, 534 , 2088, 2140, 5229, 1425, 3242, 846 , 3933, 3714, 99 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2916, 123 , 1844, 5059, 123 , 1747, 3040, 1006, 5205, 1688, 1347, 601 , 3041, 3144, 3269, 4059, 2986, 4863, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [3888, 2153, 4813, 3053, 1741, 1648, 2757, 1177, 2033, 2991, 5283, 123 , 2779, 2651, 1053, 1522, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [4444, 5283, 1138, 3114, 3890, 3489, 1028, 3717, 936 , 389 , 2886, 2031, 316 , 3187, 2031, 2623, 643 , 4911, 3468, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [3740, 2925, 3023, 2851, 4389, 3092, 3576, 725 , 1736, 2300, 3114, 1006, 2122, 1076, 3973, 3092, 3951, 2664, 1059, 3440, 415 , 3099, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1441, 312 , 134 , 4697, 1896, 1449, 3973, 4955, 3449, 1498, 1199, 2032, 2359, 4822, 1006, 4883, 4389, 4038, 4552, 4509, 2347, 690 , 1094, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2777, 422 , 1902, 2428, 621 , 3313, 3973, 5014, 5140, 3086, 4822, 1006, 3809, 3305, 3343, 5161, 1230, 1995, 3684, 954 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [813 , 903 , 4554, 3449, 1195, 3790, 4067, 1932, 2347, 3082, 4625, 2061, 3191, 992 , 1006, 1819, 3040, 4650, 1395, 729 , 5125, 5202, 2939, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [3952, 3493, 385 , 225 , 3449, 1613, 4822, 3534, 3191, 2896, 3927, 698 , 3375, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1230, 831 , 1347, 2244, 1588, 3813, 2044, 3094, 1076, 4626, 1006, 1231, 1230, 3853, 4366, 2511, 2605, 3726, 5303, 939 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [5052, 2293, 3449, 3446, 1094, 2976, 4922, 2099, 1221, 4034, 1290, 3323, 3430, 3099, 4109, 4579, 1006, 1713, 3058, 4370, 1613, 4191, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1359, 4922, 2748, 3933, 2099, 397 , 2858, 1006, 4438, 221 , 611 , 4159, 2642, 939 , 4784, 664 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [4554, 1667, 477 , 2891, 1819, 2354, 1819, 3040, 1006, 2873, 898 , 3740, 1408, 2176, 3371, 123 , 5151, 2886, 3040, 1275, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [4746, 3242, 5010, 3430, 2401, 4426, 4373, 1695, 2776, 775 , 1006, 1502, 3952, 2428, 1935, 3687, 809 , 416 , 1503, 4500, 1854, 2352, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2351, 3287, 3813, 2032, 4554, 1519, 1655, 4038, 3951, 2958, 2886, 2140, 1006, 4246, 4536, 3449, 1476, 2572, 4207, 4401, 1505, 2953, 3468, 377 , 5306, 5306, 5306, 5306, 5306, 5306], [3712, 3583, 3973, 2312, 4426, 3305, 2979, 1897, 3513, 4059, 1695, 1006, 5293, 4382, 2199, 1076, 4412, 3559, 1215, 2640, 1343, 4785, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2830, 2567, 1472, 134 , 3040, 1275, 3951, 377 , 420 , 1753, 1598, 690 , 3682, 4500, 1006, 3135, 3853, 4862, 3253, 377 , 2263, 5105, 3060, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True, [[3 ], [6 ], [3 ], [5 ], [6 ], [3 ], [6 ], [2 ], [10], [3 ], [10], [6 ], [13], [6 ], [10], [6 ], [6 ], [4 ], [9 ], [10], [10], [13], [10], [3 ], [0 ], [3 ], [3 ], [13], [13], [10], [13], [10]])] [Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True, [[2607, 5278, 1979, 2932, 40 , 2813, 2361, 3114, 4111, 3099, 1221, 103 , 2079, 3951, 2050, 3757, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2050, 2751, 3403, 1214, 516 , 1006, 4059, 2125, 2380, 233 , 1521, 805 , 366 , 2336, 2176, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2312, 487 , 2185, 4832, 4426, 2099, 1811, 1695, 1413, 4813, 3053, 3222, 4523, 3820, 2143, 1020, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [3740, 4334, 377 , 1299, 4062, 4442, 536 , 3487, 3398, 4863, 1850, 4480, 1006, 2896, 4673, 2776, 1230, 3114, 3786, 4442, 3507, 1902, 2428, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1445, 4075, 1006, 610 , 805 , 3757, 3634, 2453, 1521, 736 , 1661, 1394, 4874, 3822, 1006, 421 , 3424, 3296, 610 , 610 , 3757, 316 , 4863, 3702, 2192, 5306, 5306, 5306, 5306, 5306], [2185, 2685, 4863, 5257, 3430, 2813, 2233, 684 , 846 , 892 , 1006, 3593, 3966, 3951, 4343, 2079, 892 , 4352, 4242, 3091, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [3131, 1809, 2052, 4359, 3449, 1199, 2401, 1441, 2768, 4073, 1724, 4191, 1301, 3956, 3757, 2050, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [5032, 410 , 4835, 3449, 2099, 44 , 989 , 4073, 1724, 4191, 1521, 1521, 3642, 2751, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2050, 3757, 3880, 4945, 2515, 1112, 4224, 1282, 3379, 4477, 834 , 2013, 4874, 3823, 617 , 1090, 4060, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [3886, 2843, 2412, 1722, 1230, 3092, 4197, 1006, 699 , 1839, 380 , 1834, 1521, 3757, 1631, 4237, 518 , 3813, 2768, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1440, 3740, 3520, 2832, 3888, 2886, 1993, 3952, 2427, 1215, 2550, 4248, 4328, 4099, 5103, 2337, 3468, 4456, 3191, 4062, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [802 , 97 , 1876, 2768, 4191, 4785, 1318, 3991, 1006, 3165, 4191, 3509, 1318, 4504, 736 , 3757, 3757, 3757, 3757, 3375, 5019, 4959, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2233, 4370, 1347, 726 , 3886, 3142, 3259, 260 , 1445, 746 , 3238, 1025, 332 , 993 , 1006, 1301, 1661, 2845, 1836, 5115, 3738, 2199, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [5267, 4953, 1472, 1876, 2873, 3951, 377 , 4841, 754 , 125 , 3224, 1006, 3951, 3967, 2983, 2886, 4038, 5135, 684 , 123 , 1521, 1301, 2846, 389 , 4841, 5306, 5306, 5306, 5306, 5306], [1993, 1837, 5281, 3992, 1425, 3740, 224 , 804 , 3534, 3191, 2099, 4807, 3735, 5067, 1006, 4449, 2375, 2375, 4945, 2515, 2436, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [224 , 4382, 3379, 200 , 3449, 1230, 3996, 805 , 141 , 1006, 3379, 200 , 4576, 2680, 3430, 3042, 1081, 3537, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2820, 682 , 2825, 2759, 1230, 294 , 4389, 3069, 3355, 2896, 1215, 2825, 4222, 3244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [4274, 3740, 3413, 3449, 134 , 377 , 603 , 3886, 2873, 123 , 4289, 3020, 1230, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [1293, 3714, 40 , 1472, 1094, 1440, 1669, 3966, 2756, 4432, 1521, 4591, 1094, 4591, 2052, 1006, 4951, 1418, 4019, 1425, 3740, 4775, 1839, 3430, 738 , 5306, 5306, 5306, 5306, 5306], [1164, 2453, 1185, 4162, 3430, 1546, 3740, 3398, 2052, 3559, 1221, 2050, 3956, 3757, 805 , 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2312, 536 , 1747, 3164, 2986, 542 , 3023, 3907, 1006, 4456, 4009, 3296, 3634, 1521, 3757, 5059, 736 , 736 , 3757, 3757, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [3142, 4886, 3430, 4954, 5177, 4242, 4382, 3952, 4931, 795 , 1006, 2099, 2886, 4651, 1562, 2986, 2155, 1521, 4591, 3966, 601 , 3041, 2151, 377 , 5306, 5306, 5306, 5306, 5306, 5306], [3400, 872 , 1893, 3016, 3933, 2263, 2781, 3114, 692 , 3222, 1620, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2176, 2830, 3449, 3668, 3131, 3402, 2727, 224 , 264 , 4370, 4389, 1318, 1641, 2932, 1940, 4805, 2886, 4207, 4225, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2549, 3101, 2099, 690 , 4111, 3682, 3537, 534 , 4167, 3137, 4954, 1006, 1785, 3869, 823 , 3924, 3473, 3881, 927 , 730 , 592 , 476 , 3207, 241 , 5306, 5306, 5306, 5306, 5306, 5306], [4227, 1562, 4027, 4954, 1521, 610 , 375 , 3889, 2896, 2239, 4370, 4141, 3000, 56 , 1006, 4697, 200 , 269 , 926 , 1413, 4540, 5238, 1017, 3468, 2014, 964 , 5306, 5306, 5306, 5306], [2358, 1025, 1708, 993 , 332 , 1006, 1862, 1006, 2358, 1025, 3956, 2079, 1709, 720 , 3676, 4050, 3357, 1472, 2941, 2254, 2412, 1029, 3222, 1725, 1028, 3165, 5306, 5306, 5306, 5306], [2820, 682 , 3191, 3440, 1146, 3174, 4328, 2982, 2825, 2759, 1117, 3069, 3355, 617 , 2813, 2742, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [2612, 3069, 1214, 3951, 1521, 3956, 4805, 3115, 1314, 2050, 3757, 3757, 366 , 1006, 1534, 2401, 5202, 1521, 4805, 366 , 3157, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306], [4057, 62 , 1765, 531 , 1991, 3149, 5269, 736 , 3757, 1521, 736 , 1991, 3149, 4389, 2018, 4389, 2253, 1694, 4073, 1200, 5116, 4073, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [4746, 4283, 5295, 3449, 2099, 1794, 376 , 2040, 1663, 3564, 3187, 2986, 1006, 4111, 2896, 690 , 1117, 2776, 4500, 2078, 2040, 698 , 1214, 5306, 5306, 5306, 5306, 5306, 5306, 5306], [4289, 2166, 698 , 2100, 1006, 1343, 1681, 1094, 4863, 123 , 5162, 384 , 61 , 2380, 1645, 3388, 2336, 736 , 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True, [[6 ], [8 ], [4 ], [10], [6 ], [5 ], [3 ], [3 ], [6 ], [6 ], [13], [0 ], [12], [10], [13], [3 ], [7 ], [9 ], [10], [3 ], [6 ], [10], [4 ], [9 ], [10], [10], [3 ], [7 ], [7 ], [5 ], [10], [8 ]])]
网络定义
定义网络情况,用于训练,这一块是提高成绩的关键之一
# 定义网络 class MyNet(paddle.nn.Layer): def __init__(self): super(MyNet, self).__init__() self.emb = paddle.nn.Embedding(vocab_size, emb_size) # 嵌入层用于自动构造一个二维embedding矩阵 self.fc = paddle.nn.Linear(in_features=emb_size, out_features=96) # 线性变换层 self.fc1 = paddle.nn.Linear(in_features=96, out_features=14) # 分类器 self.dropout = paddle.nn.Dropout(0.5) # 正则化 def forward(self, x): x = self.emb(x) x = paddle.mean(x, axis=1) # 获取平均值 x = self.dropout(x) x = self.fc(x) x = self.dropout(x) x = self.fc1(x) return x
# 画图 def draw_process(title,color,iters,data,label): plt.title(title, fontsize=24) # 标题 plt.xlabel("iter", fontsize=20) # x轴 plt.ylabel(label, fontsize=20) # y轴 plt.plot(iters, data,color=color,label=label) # 画图 plt.legend() plt.grid() plt.show()
模型训练
训练的重要环节,可以调节学习率,优化器等,有可能有奇效
# 训练模型 def train(model): model.train() opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) # 优化器学习率等 # 初始值设置 steps = 0 Iters, total_loss, total_acc = [], [], [] for epoch in range(epochs): # 训练循环 for batch_id, data in enumerate(train_loader): # 数据循环 steps += 1 sent = data[0] # 获取数据 label = data[1] # 获取标签 logits = model(sent) # 输入数据 loss = paddle.nn.functional.cross_entropy(logits, label) # loss获取 acc = paddle.metric.accuracy(logits, label) # acc获取 if batch_id % 500 == 0: # 每500次输出一次结果 Iters.append(steps) # 保存训练轮数 total_loss.append(loss.numpy()[0]) # 保存loss total_acc.append(acc.numpy()[0]) # 保存acc print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy())) # 输出结果 # 数据更新 loss.backward() opt.step() opt.clear_grad() # 每一个epochs进行一次评估 model.eval() accuracies = [] losses = [] for batch_id, data in enumerate(val_loader): # 数据循环读取 sent = data[0] # 训练内容读取 label = data[1] # 标签读取 logits = model(sent) # 训练数据 loss = paddle.nn.functional.cross_entropy(logits, label) # loss获取 acc = paddle.metric.accuracy(logits, label) # acc获取 accuracies.append(acc.numpy()) # 添加数据 losses.append(loss.numpy()) avg_acc, avg_loss = np.mean(accuracies), np.mean(losses) # 获取loss、acc平均值 print("[validation] accuracy: {}, loss: {}".format(avg_acc, avg_loss)) # 输出值 model.train() paddle.save(model.state_dict(),str(epoch)+"_model_final.pdparams") # 保存训练文件 draw_process("trainning loss","red",Iters,total_loss,"trainning loss") # 画处loss图 draw_process("trainning acc","green",Iters,total_acc,"trainning acc") # 画出caa图 model = MyNet() # 模型实例化 train(model) # 开始训练
epoch: 0, batch_id: 0, loss is: [2.6477456] epoch: 0, batch_id: 500, loss is: [1.8056118] epoch: 0, batch_id: 1000, loss is: [1.1092072] epoch: 0, batch_id: 1500, loss is: [1.0716103] epoch: 0, batch_id: 2000, loss is: [0.6794955] epoch: 0, batch_id: 2500, loss is: [0.54738545] epoch: 0, batch_id: 3000, loss is: [0.9065808] epoch: 0, batch_id: 3500, loss is: [0.63474274] epoch: 0, batch_id: 4000, loss is: [0.68158776] epoch: 0, batch_id: 4500, loss is: [1.0516238] epoch: 0, batch_id: 5000, loss is: [0.9118046] epoch: 0, batch_id: 5500, loss is: [0.65075576] epoch: 0, batch_id: 6000, loss is: [0.5605841] epoch: 0, batch_id: 6500, loss is: [0.56175774] epoch: 0, batch_id: 7000, loss is: [0.95122683] epoch: 0, batch_id: 7500, loss is: [0.38649452] epoch: 0, batch_id: 8000, loss is: [0.2205698] epoch: 0, batch_id: 8500, loss is: [0.40474647] epoch: 0, batch_id: 9000, loss is: [0.5931748] epoch: 0, batch_id: 9500, loss is: [0.3922717] epoch: 0, batch_id: 10000, loss is: [0.6130478] epoch: 0, batch_id: 10500, loss is: [0.5300909] epoch: 0, batch_id: 11000, loss is: [0.6114788] epoch: 0, batch_id: 11500, loss is: [0.24966809] epoch: 0, batch_id: 12000, loss is: [0.45669073] epoch: 0, batch_id: 12500, loss is: [0.29746443] epoch: 0, batch_id: 13000, loss is: [0.6775298] epoch: 0, batch_id: 13500, loss is: [0.8836371] epoch: 0, batch_id: 14000, loss is: [0.27501673] epoch: 0, batch_id: 14500, loss is: [0.46843478] epoch: 0, batch_id: 15000, loss is: [0.49367175] epoch: 0, batch_id: 15500, loss is: [0.500063] epoch: 0, batch_id: 16000, loss is: [0.31290954] epoch: 0, batch_id: 16500, loss is: [0.30774388] epoch: 0, batch_id: 17000, loss is: [0.21738727] epoch: 0, batch_id: 17500, loss is: [0.2860858] epoch: 0, batch_id: 18000, loss is: [0.2766972] epoch: 0, batch_id: 18500, loss is: [0.36017033] epoch: 0, batch_id: 19000, loss is: [0.43986273] epoch: 0, batch_id: 19500, loss is: [0.4210134] epoch: 0, batch_id: 20000, loss is: [0.579644] epoch: 0, batch_id: 20500, loss is: [0.23016676] epoch: 0, batch_id: 21000, loss is: [0.21913218] epoch: 0, batch_id: 21500, loss is: [0.18669227] epoch: 0, batch_id: 22000, loss is: [0.31480896] epoch: 0, batch_id: 22500, loss is: [0.37621552] epoch: 0, batch_id: 23000, loss is: [0.54980826] epoch: 0, batch_id: 23500, loss is: [0.6016808] epoch: 0, batch_id: 24000, loss is: [0.25056183] epoch: 0, batch_id: 24500, loss is: [0.2916811] epoch: 0, batch_id: 25000, loss is: [0.33430776] epoch: 0, batch_id: 25500, loss is: [0.74600095] epoch: 0, batch_id: 26000, loss is: [0.35165167] [validation] accuracy: 0.884321928024292, loss: 0.3713749647140503 epoch: 1, batch_id: 0, loss is: [0.47405708] epoch: 1, batch_id: 500, loss is: [0.4443894] epoch: 1, batch_id: 1000, loss is: [0.35416052] epoch: 1, batch_id: 1500, loss is: [0.3004715] epoch: 1, batch_id: 2000, loss is: [0.59477925] epoch: 1, batch_id: 2500, loss is: [0.5639044] epoch: 1, batch_id: 3000, loss is: [0.40286714] epoch: 1, batch_id: 3500, loss is: [0.5387965] epoch: 1, batch_id: 4000, loss is: [0.11766122] epoch: 1, batch_id: 4500, loss is: [0.68849707] epoch: 1, batch_id: 5000, loss is: [0.83928466] epoch: 1, batch_id: 5500, loss is: [0.2867105] epoch: 1, batch_id: 6000, loss is: [0.20924558] epoch: 1, batch_id: 6500, loss is: [0.5582311] epoch: 1, batch_id: 7000, loss is: [0.63174886] epoch: 1, batch_id: 7500, loss is: [0.318484] epoch: 1, batch_id: 8000, loss is: [0.5406461] epoch: 1, batch_id: 8500, loss is: [0.4790561] epoch: 1, batch_id: 9000, loss is: [0.52266514] epoch: 1, batch_id: 9500, loss is: [0.51126254] epoch: 1, batch_id: 10000, loss is: [0.27308795] epoch: 1, batch_id: 10500, loss is: [0.22041513] epoch: 1, batch_id: 11000, loss is: [0.32234907] epoch: 1, batch_id: 11500, loss is: [0.6857507] epoch: 1, batch_id: 12000, loss is: [0.40997463] epoch: 1, batch_id: 12500, loss is: [0.53966033] epoch: 1, batch_id: 13000, loss is: [0.2620927] epoch: 1, batch_id: 13500, loss is: [0.21417136] epoch: 1, batch_id: 14000, loss is: [0.5232475] epoch: 1, batch_id: 14500, loss is: [0.37579858] epoch: 1, batch_id: 15000, loss is: [0.3611152] epoch: 1, batch_id: 15500, loss is: [0.336707] epoch: 1, batch_id: 16000, loss is: [0.2795578] epoch: 1, batch_id: 16500, loss is: [0.54298353] epoch: 1, batch_id: 17000, loss is: [0.26425135] epoch: 1, batch_id: 17500, loss is: [0.52595145] epoch: 1, batch_id: 18000, loss is: [0.24938256] epoch: 1, batch_id: 18500, loss is: [0.30653632] epoch: 1, batch_id: 19000, loss is: [0.58400965] epoch: 1, batch_id: 19500, loss is: [0.18243803] epoch: 1, batch_id: 20000, loss is: [0.28917578] epoch: 1, batch_id: 20500, loss is: [1.0765818] epoch: 1, batch_id: 21000, loss is: [0.32550114] epoch: 1, batch_id: 21500, loss is: [0.16792971] epoch: 1, batch_id: 22000, loss is: [0.65214527] epoch: 1, batch_id: 22500, loss is: [0.58119446] epoch: 1, batch_id: 23000, loss is: [0.43643892] epoch: 1, batch_id: 23500, loss is: [0.47376677] epoch: 1, batch_id: 24000, loss is: [0.3279624] epoch: 1, batch_id: 24500, loss is: [0.50899947] epoch: 1, batch_id: 25000, loss is: [0.61989105] epoch: 1, batch_id: 25500, loss is: [0.42433214] epoch: 1, batch_id: 26000, loss is: [0.26673254] [validation] accuracy: 0.8882260322570801, loss: 0.35311153531074524 /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data
推理数据读取
# 比赛数据读取 set = [] def dataset(datapath): with open(datapath)as f: # 读取文件 for i in f.readlines(): # 逐行读取数据 dataset = np.array(i.split(',')) # 分割数据 set.append(dataset) # 存入数据 return set # 比赛数据扩充 def create_padded_dataset(dataset): padded_sents = [] labels = [] for batch_id, data in enumerate(dataset): # 循环 # print(data) sent = data # 读取数据 padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32') # 拼接填充 # print(padded_sent) padded_sents.append(padded_sent) # 输入数据 # print(padded_sents) return np.array(padded_sents) # 转换成数组并返回 test_data = dataset('新闻文本标签分类/Test_IDs.txt') # 读取数据 # print() # 对train、val数据进行实例化 test_data = create_padded_dataset(test_data) # 数据填充 # 查看数据大小及举例内容 print(test_data)
[[4057 1902 1475 ... 5306 5306 5306] [2805 5242 3593 ... 5306 5306 5306] [1836 3222 4641 ... 5306 5306 5306] ... [4838 1202 1490 ... 5306 5306 5306] [ 805 3757 3757 ... 5306 5306 5306] [2805 5242 3593 ... 5306 5306 5306]]
开始推理
这里可以选择效果好的模型然后进行预测
nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"] # 标签列表 # 导入模型 model_state_dict = paddle.load('0_model_final.pdparams') # 模型读取 model = MyNet() # 读取网络 model.set_state_dict(model_state_dict) model.eval() # print(type(test_data[0])) count = 0 # 初始值 with open('./result.txt', 'w', encoding='utf-8') as f_train: # 生成文件 for batch_id, data in enumerate(test_data): # 循环数据 results = model(paddle.to_tensor(data.reshape(30,1))) # 开始训练 for probs in results: # 映射分类label idx = np.argmax(probs) # 获取结果值 labels = nu[idx] # 通过结果值获取标签 f_train.write(labels+"\n") # 写入数据 count +=1 break if count%500==0: # 查看推理情况 print(count) print(count)
效果不一定好,但是可以跑通 ,如果有其他的需求可以联系我(留言或群里面at我)我也会进一步改进项目
感谢大家的支持!
传说中的飞桨社区最菜代码人,让我们一起努力!
记住:三岁出品必是精品 (不要脸系列)