自然语言处理实战第二版(MEAP)(四)(1)https://developer.aliyun.com/article/1517990
7.3 莫尔斯电码
在 ASCII 文本和计算机甚至电话出现之前,还有另一种交流自然语言的方式:莫尔斯电码。莫尔斯电码是一种将点和短划替代自然语言字母和单词的文本编码。这些点和短划在电报线上或无线电波上变成长音和短音的蜂鸣声。莫尔斯电码听起来就像一个非常缓慢的拨号上网连接中的蜂鸣声。在本节后面的 Python 示例中播放音频文件,亲自听一下吧。业余无线电操作员通过敲击单个键向世界各地发送消息。你能想象在计算机键盘上输入文本,而键盘上只有一个键,就像图 7.7 中的 Framework 笔记本的空格键一样吗?
图 7.7 单个关键的笔记本键盘
图 7.8 显示了一个实际的莫尔斯电码键的样子。就像计算机键盘上的键或游戏控制器上的开火按钮一样,莫尔斯电码键只在按下按钮时关闭电气接触。
图 7.8 一把古董莫尔斯电码键
莫尔斯电码是一种设计成只需按下一个键敲出的语言,就像这样。它在电报时代被广泛使用,在电话使得通过电线发送语音和数据成为可能之前。为了在纸上可视化莫尔斯电码,人们用点和线来代表按键的短敲和长敲。按下键时,你短暂地向外发出一个点,而稍微按住键则会发出一个破折号。当你根本不按下该键时则是完全的沉默。所以它和输入文本不太一样。更像是把你的键盘当作游戏手柄上的开火按钮。你可以把莫尔斯电码键想象成视频游戏激光或以按下键的时候才发送能量的任何东西。你甚至可以通过在多人游戏中将武器当作电报机来发送秘密信息。
要在计算机键盘上仅使用一个键进行通信几乎是不可能的,如果不是萨缪尔·莫尔斯创造新的自然语言的工作,就不会有这种可能。莫尔斯在设计莫尔斯电码的语言方面做得非常好,即使像我这样拙笨的业余无线电操作员也可以在紧急情况下使用它。接下来,你将学习这种语言中最重要的两个字母,以便在紧急情况下也能使用它。不用担心,你只需要学习这个语言的两个字母就足够了。这应该足以让你更清楚地理解卷积以及它在自然语言上的工作原理。
图 7.9 莫尔斯电码字典
莫尔斯电码至今仍然在无线电波嘈杂的情况下使用,以便使别人能够理解你的语音。当你真的,真的,真的需要传达信息时,它尤其有用。被困在沉没的潜艇或船内的水下空腔的水手使用莫尔斯电码在金属船体上敲出来与营救者进行交流。在地震或矿井事故后被埋在瓦砾下的人们会用金属管道和钢梁敲击来与救援人员进行通信。如果你懂一点莫尔斯电码,你也许可以通过用莫尔斯电码敲出你的话与别人进行双向对话。
这是一个以莫尔斯电码进行广播的秘密消息的音频数据示例。在接下来的部分中,你将使用手工制作的卷积核处理这个数据。现在,你可能只想播放音频轨道,以便听到莫尔斯电码的声音是什么样子。
代码清单 7.10 下载秘密
>>> from nlpia2.init import maybe_download >>> url = 'https://upload.wikimedia.org/wikipedia/' \ 'commons/7/78/1210secretmorzecode.wav' >>> filepath = maybe_download(url) # #1 >>> filepath '/home/hobs/.nlpia2-data/1210secretmorzecode.wav'
当然,你的.nlpia2-data
目录将位于你的$HOME
目录下,而不是我的。这里是这些示例中使用的所有数据。现在,你可以加载 wav 文件,以创建一个包含音频信号的数值数组,稍后可以用卷积进行处理。
7.3.1 使用卷积解码莫尔斯电码
如果您了解一点 Python,您可以构建一个能够为您解释摩尔斯电码的机器,这样您就不必记住图 7.9 摩尔斯电码字典中所有的点和划线了。在僵尸启示录或“大事件”(加州地震)期间可能会派上用场。只需确保保留能够运行 Python 的计算机或手机。
第 7.11 列 加载秘密摩尔斯电码 WAV 文件
>>> from scipy.io import wavfile >>> sample_rate, audio = wavfile.read(filepath) >>> print(f'sample_rate: {sample_rate}') >>> print(f'audio:\n{audio}')
sample_rate: 4000 audio: [255 0 255 ... 0 255 0]
这个 WAV 文件中的音频信号在哔哔声时在 255 和 0 之间振荡(最大和最小的 uint8
值)。因此,您需要使用 abs()
对信号进行矫正,然后将其标准化,使信号在播放音调时为 1,在没有音调时为 0。您还希望将采样数转换为毫秒,并对信号进行降采样,以便更容易地检查单个值并查看发生了什么。第 7.12 列 居中、标准化和降采样音频数据,并提取此音频数据的前两秒。
第 7.12 列 标准化和降采样音频信号
>>> pd.options.display.max_rows = 7 >>> audio = audio[:sample_rate * 2] # #1 >>> audio = np.abs(audio - audio.max() / 2) - .5 # #2 >>> audio = audio / audio.max() # #3 >>> audio = audio[::sample_rate // 400] # #4 >>> audio = pd.Series(audio, name='audio') >>> audio.index = 1000 * audio.index / sample_rate # #5 >>> audio.index.name = 'time (ms)' >>> print(f'audio:\n{audio}')
现在,您可以使用 audio.plot()
绘制闪亮的新摩尔斯电码点和划线。
第 7.10 图 方波摩尔斯电码秘密消息
您能在图 7.10 中看到点在哪里吗?点是 60 毫秒的静音(信号值为 0),然后是 60 毫秒的音调(信号值为 1),然后再次是 60 秒的静音(信号值为 0)。
要通过卷积检测点,您需要设计一个与低、高、低的模式匹配的核心。唯一的区别是对于低信号,您需要使用负一而不是零,这样数学就会加起来。您希望卷积的输出在检测到点符号时为 1。
第 7.12 列 展示了如何构建点检测核心。
第 7.13 列 点检测核
>>> kernel = [-1] * 24 + [1] * 24 + [-1] * 24 # #1 >>> kernel = pd.Series(kernel, index=2.5 * np.arange(len(kernel))) >>> kernel.index.name = 'Time (ms)' >>> ax = kernel.plot(linewidth=3, ylabel='Kernel weight')
第 7.11 图 摩尔斯电码点检测核心
您可以通过将其与音频信号进行卷积来尝试您手工制作的核心,以查看它是否能够检测到点。目标是使卷积信号在点符号出现时高、接近于 1,在音频中的短脉冲。您还希望您的点检测卷积在点之前或之后的任何短划线或静音处返回低值(接近于零)。
第 7.14 列 点检测器与秘密消息卷积
>>> kernel = np.array(kernel) / sum(np.abs(kernel)) # #1 >>> pad = [0] * (len(kernel) // 2) # #2 >>> isdot = convolve(audio.values, kernel) >>> isdot = np.array(pad[:-1] + list(isdot) + pad) # #3 >>> df = pd.DataFrame() >>> df['audio'] = audio >>> df['isdot'] = isdot - isdot.min() >>> ax = df.plot()
第 7.12 图 手工制作的点检测卷积
看起来手工制作的核心做得不错!卷积输出仅在点符号的中间接近于 1。
现在您了解了卷积的工作原理,可以随意使用 np.convolve()
函数。它运行更快,并为您提供了更多关于填充处理的 mode
选项。
第 7.15 列 NumPy 卷积
>>> isdot = np.convolve(audio.values, kernel, mode='same') # #1 >>> df['isdot'] = isdot - isdot.min() >>> ax = df.plot()
第 7.13 图 NumPy 卷积
Numpy 卷积有三种可能的模式可用于进行卷积,按输出长度递增的顺序依次为:
- valid: 以纯 Python 为例,只输出
len(kernel)-1
个卷积值。 - same: 通过在数组的开始和结尾之外推算信号,输出与输入长度相同的信号。
- full: 输出信号将比输入信号更长。
Numpy 卷积设置为“same”模式似乎在我们的莫尔斯电码音频信号中运作得更好。因此,当在神经网络中进行卷积时,你需要检查你的神经网络库是否使用类似的模式。
建造一个卷积滤波器以在莫尔斯电码音频文件中检测一个单一符号真是一项艰苦的工程。而且这还不是一个自然语言文本的单个字符, 只是 S
字母的三分之一!幸运的是,你辛勤手工制作的日子已经结束了。你可以在神经网络的反向传播中使用它所拥有的强大力量来学习正确的内核以检测解决问题所需的所有不同信号。
7.4 使用 PyTorch 构建 CNN
图 7.14 展示了您如何将文本流入 CNN 网络,然后输出嵌入。与以前的 NLP 流水线一样,需要首先对文本进行标记化。然后您会识别出文本中使用的所有令牌集。您将忽略不想计数的令牌,并为词汇表中的每个单词分配一个整数索引。输入语句有 4 个令牌,因此我们从一个由 4 个整数索引组成的序列开始,每个令牌对应一个索引。
CNN 通常使用单词嵌入来代替单热编码来表示每个单词。您将初始化一个单词嵌入矩阵,该矩阵的行数与词汇表中的单词数量相同,并且如果要使用 300-D 嵌入,则有 300 个列。可以将所有初始单词嵌入设置为零或某些小的随机值。如果要进行知识转移并使用预训练的单词嵌入,则可以在 GloVE、Word2vec、fastText 或任何喜欢的单词嵌入中查找您的令牌。并将这些向量插入到与词汇表索引匹配的行中的嵌入矩阵中。
对于这个四令牌句子,然后可以查找适当的单词嵌入,一旦在单词嵌入矩阵中查找每个嵌入,就会得到一个 4 个嵌入向量的序列。你也会得到额外的填充标记嵌入,它们通常被设置为零,所以它们不会干扰卷积。如果您使用最小的 GloVe 嵌入,那么您的单词嵌入是 50 维的,因此您会得到一个 50 x 4 的数值矩阵,用于这个短句子。
你的卷积层可以使用 1-D 卷积内核处理这 50 个维度中的每一个,稍微挤压一下关于你的句子的这个矩阵的信息。如果你使用了长度为 2 的内核和步幅为 2,你将得到一个大小为 50 x 3 的矩阵来表示四个 50-D 单词向量的序列。
通常使用池化层,通常是最大池化,来进一步减小输出的大小。带有 1-D 内核的最大池化层将把你的三个 50-D 向量的序列压缩成一个单一的 50-D 向量。顾名思义,最大池化将为向量序列中每个通道(维度)的最大和最有影响的输出。最大池化通常相当有效,因为它允许你的卷积为原始文本中每个 n-gram 找到最重要的意义维度。通过多个内核,它们可以分别专门化文本的不同方面,这些方面会影响你的目标变量。
注意
你应该将卷积层的输出称为“编码”,而不是“嵌入”。这两个词都用来描述高维向量,但是“编码”一词暗示着在时间上或序列中的处理。卷积数学在你的单词向量序列中的时间内发生,而“嵌入”向量是单个不变令牌的处理结果。嵌入不编码任何有关单词顺序或序列的信息。编码是对文本含义的更完整的表示,因为它们考虑了单词顺序,就像你的大脑一样。
由 CNN 层输出的编码向量是一个具有你指定的任意大小(长度)的向量。你的编码向量的长度(维度数)与输入文本的长度无关。
图 7.14 CNN 处理层 ^([22])
你将需要利用前几章的所有技能来整理文本,以便将其输入到你的神经网络中。图 7.14 中你的管道的前几个阶段是你在前几章中做的标记和大小写转换。你将利用前面示例中的经验来决定忽略哪些单词,比如停用词、标点符号、专有名词或非常罕见的单词。
根据你手工制作的任意停用词列表过滤和忽略单词通常是一个不好的主意,特别是对于像 CNN 这样的神经网络。词形还原和词干提取通常也不是一个好主意。模型将比你用直觉猜测的更了解你的令牌的统计信息。你在 Kaggle、DataCamp 和其他数据科学网站上看到的大多数示例都会鼓励你手工制作管道的这些部分。你现在知道得更清楚了。
你也不会手工制作卷积内核。你会让反向传播的魔力来处理这些事情。神经网络可以学习模型的大部分参数,例如哪些词要忽略,哪些词应该被合并在一起,因为它们具有相似的含义。实际上,在第六章中,您已经学会了用嵌入向量来表示单词的含义,这些嵌入向量精确地捕捉了它们与其他单词的相似程度。只要有足够的数据来创建这些嵌入向量,您就不再需要处理词形还原和词干提取。
7.4.1 裁剪和填充
CNN 模型需要一致长度的输入文本,以便编码中的所有输出值在向量中处于一致的位置。这确保了你的 CNN 输出的编码向量始终具有相同的维度,无论你的文本是多长或多短。你的目标是创建一个字符串和一个整页文本的向量表示。不幸的是,CNN 不能处理可变长度的文本,所以如果你的文本对于 CNN 来说太长,就会将许多单词和字符在字符串末尾进行 “裁剪”。而且你需要插入填充令牌,称为 padding,来填补那些对于您的 CNN 来说太短的字符串中的空白部分。
请记住,卷积操作始终会减少输入序列的长度,无论其长度多长。卷积操作始终会将输入序列的长度减少一个比内核大小少的数。而任何池化操作,如最大池化,也会一致地减少输入序列的长度。因此,如果您没有进行任何填充或裁剪,长句子会产生比短句子更长的编码向量。而这对于需要具有大小不变性的编码是不起作用的。无论输入的大小如何,你希望你的编码向量始终具有相同的长度。
这是向量的基本属性,即它们在整个你正在处理的向量空间中具有相同数量的维度。你希望你的 NLP 流水线能够在相同的位置或向量维度上找到特定的含义,无论这种情感在文本的哪个位置发生。填充和裁剪可以确保你的 CNN 在位置(时间)和大小(持续时间)上是不变的。基本上,只要这些模式在您的 CNN 可处理的最大长度范围内的任何位置,您的 CNN 就可以在文本的含义中找到这些模式,无论这些模式在文本中的位置如何。
你可以选择任何你喜欢的符号来表示填充标记。许多人使用标记 “”,因为它在任何自然语言字典中都不存在。大多数说英语的自然语言处理工程师都能猜到 “” 的含义。而且你的自然语言处理管道会注意到这些标记在许多字符串的末尾重复出现。这将帮助它在嵌入层中创建适当的 “填充” 情感。如果你对填充情感的样子感到好奇,加载你的嵌入向量,比较 “” 的嵌入和 “blah”(如 “blah blah blah”)的嵌入。你只需要确保使用一致的标记,并告诉你的嵌入层你用于填充标记的令牌是什么。通常将其作为你的 id2token
或 vocab
序列中的第一个标记,以便它具有索引和 id 值 0
。
一旦你告诉大家你的填充标记是什么,你现在需要决定一个一致的填充方法。就像在计算机视觉中一样,你可以在你的令牌序列的任意一侧填充,即开头或结尾。你甚至可以拆分填充,将一半放在开头,另一半放在结尾。只是不要把它们插在单词之间。那会干扰卷积计算。并确保你添加的填充标记的总数能够创建正确长度的序列用于你的 CNN。
在清单 7.16 中,您将加载由 Kaggle 贡献者标记了其新闻价值的 “birdsite”(微博)帖子。稍后您将使用您的 CNN 模型来预测 CNN(有线电视新闻网)是否会在 “miasma.” 中的新闻在自己传播之前 “采取”。
重要提示
我们有意使用能引导您朝着亲社会、真实、注意力集中的行为的词语。弥漫在互联网上的黑暗模式已经引导了科技界的创意中坚力量创建了一个替代的、更真实的宇宙,拥有它自己的词汇。
“Birdsite”:“fedies” 称之为 Twitter
“Fedies”:使用保护您健康和隐私的联合社交媒体应用的用户
“Fediverse” 联合社交媒体应用的替代宇宙(Mastodon,PeerTube)
“Nitter” 是 Twitter 的一个不那么操纵的前端。
“Miasma” 是尼尔·斯蒂芬森对一个爱情的互联网的称呼
清单 7.16 加载新闻帖子
>>> df = pd.read_csv(HOME_DATA_DIR / 'news.csv') >>> df = df[['text', 'target']] # #1 >>> print(df)
text target 0 Our Deeds are the Reason of this #earthquake M... 1 1 Forest fire near La Ronge Sask. Canada 1 2 All residents asked to 'shelter in place' are ... 1 ... ... ... 7610 M1.94 [01:04 UTC]?5km S of Volcano Hawaii. htt... 1 7611 Police investigating after an e-bike collided ... 1 7612 The Latest: More Homes Razed by Northern Calif... 1 [7613 rows x 2 columns]
您可以在上面的例子中看到,一些微博帖子几乎达到了 birdsite 的字符限制。其他则通过较少的词语表达了观点。因此,您需要对这些较短的文本进行填充,以便数据集中的所有示例具有相同数量的令牌。如果您计划在管道的后期过滤掉非常频繁的词或非常罕见的词,您的填充函数也需要填补这些差距。因此,清单 7.17 对这些文本进行了标记化,并过滤掉了其中的一些最常见的标记。
清单 7.17 词汇表中最常见的单词
import re from collections import Counter from itertools import chain HOME_DATA_DIR = Path.home() / '.nlpia2-data' counts = Counter(chain(*[ re.findall(r'\w+', t.lower()) for t in df['text']])) # #1 vocab = [tok for tok, count in counts.most_common(4000)[3:]] # #2 print(counts.most_common(10))
[('t', 5199), ('co', 4740), ('http', 4309), ('the', 3277), ('a', 2200), ('in', 1986)]
你可以看到,令牌 “t” 出现的次数几乎和帖子数(7613)一样多(5199)。这看起来像是由 url 缩短器创建的部分 url,通常用于跟踪这个应用程序上的微博主。如果你希望你的 CNN 专注于人类可能会阅读的内容中的单词的含义,你应该忽略前三个类似 url 的令牌。如果你的目标是构建一个像人类一样阅读和理解语言的 CNN,那么你将创建一个更复杂的分词器和令牌过滤器,以去除人类不关注的任何文本,例如 URL 和地理空间坐标。
一旦你调整好了你的词汇表和分词器,你就可以构建一个填充函数,以便在需要时重复使用。如果你的 pad()
函数足够通用,就像清单 7.18 中一样,你可以将它用于字符串令牌和整数索引。
清单 7.18 多功能填充函数
def pad(sequence, pad_value, seq_len): padded = list(sequence)[:seq_len] padded = padded + [pad_value] * (seq_len - len(padded)) return padded
我们还需要为 CNN 的良好工作进行最后一个预处理步骤。你想要包含你在第六章学到的令牌嵌入。
7.4.2 用单词嵌入进行更好的表示
想象一下,你正在将一小段文本通过你的管道运行。图 7.15 展示了在你将单词序列转换为数字(或向量,提示提示)进行卷积操作之前的样子。
图 7.15 卷积步幅
现在你已经组装了一个令牌序列,你需要很好地表示它们的含义,以便你的卷积能够压缩和编码所有这些含义。在第 5 和 6 章中我们使用的全连接神经网络中,你可以使用 one-hot 编码。但是 one-hot 编码会创建极其庞大、稀疏的矩阵,而现在你可以做得更好。你在第六章学到了一种非常强大的单词表示方式:单词嵌入。嵌入是你的单词的更加信息丰富和密集的向量表示。当你用嵌入来表示单词时,CNN 和几乎任何其他深度学习或 NLP 模型都会表现得更好。图 7.11 展示了如何做到这一点。
图 7.16 用于卷积的单词嵌入
图 7.16 展示了 PyTorch 中 nn.Embedding
层在幕后执行的操作。为了让你了解 1-D 卷积如何在你的数据上滑动,该图显示了一个两个长度的核在你的数据上移动的 3 个步骤。但是一个 1-D 卷积如何在一个 300-D GloVe 单词嵌入序列上工作呢?你只需要为你想要查找模式的每个维度创建一个卷积核(滤波器)。这意味着你的单词向量的每个维度都是卷积层中的一个通道。
不幸的是,许多博客文章和教程可能会误导您关于卷积层的正确尺寸。 许多 PyTorch 初学者认为 Embedding 层的输出可以直接流入卷积层而不需要任何调整大小。 不幸的是,这将创建一个沿着单词嵌入维度而不是单词序列的 1-D 卷积。 因此,您需要转置您的嵌入层输出,以使通道(单词嵌入维度)与卷积通道对齐。
PyTorch 有一个 nn.Embedding
层,您可以在所有深度学习流水线中使用。 如果您希望模型从头开始学习嵌入,您只需要告诉 PyTorch 您需要多少嵌入,这与您的词汇量大小相同。 嵌入层还需要您告诉它为每个嵌入向量分配多少维度。 可选地,您可以定义填充令牌索引 id 号。
代码清单 7.19 从头开始学习嵌入
from torch import nn embedding = nn.Embedding( num_embeddings=2000, # #1 embedding_dim=64, # #2 padding_idx=0)
嵌入层将是您的 CNN 中的第一层。这将把您的令牌 ID 转换成它们自己独特的 64-D 单词向量。在训练期间的反向传播将调整每个单词在每个维度上的权重,以匹配单词可用于谈论新闻灾害的 64 种不同方式。这些嵌入不会像第六章中的 FastText 和 GloVe 向量一样代表单词的完整含义。这些嵌入只有一个好处,那就是确定一条 Tweet 是否包含新闻灾害信息。
最后,您可以训练您的 CNN,看看它在像 Kaggle 灾难推文数据集这样的极窄数据集上的表现如何。 那些花费时间打造 CNN 的小时将以极快的训练时间和令人印象深刻的准确性得到回报。
代码清单 7.20 从头开始学习嵌入
from nlpia2.ch07.cnn.train79 import Pipeline # #1 pipeline = Pipeline( vocab_size=2000, embeddings=(2000, 64), epochs=7, torch_random_state=433994, # #2 split_random_state=1460940, ) pipeline = pipeline.train()
Epoch: 1, loss: 0.66147, Train accuracy: 0.61392, Test accuracy: 0.63648 Epoch: 2, loss: 0.64491, Train accuracy: 0.69712, Test accuracy: 0.70735 Epoch: 3, loss: 0.55865, Train accuracy: 0.73391, Test accuracy: 0.74278 Epoch: 4, loss: 0.38538, Train accuracy: 0.76558, Test accuracy: 0.77165 Epoch: 5, loss: 0.27227, Train accuracy: 0.79288, Test accuracy: 0.77690 Epoch: 6, loss: 0.29682, Train accuracy: 0.82119, Test accuracy: 0.78609 Epoch: 7, loss: 0.23429, Train accuracy: 0.82951, Test accuracy: 0.79003
仅仅经过 7 次通过训练数据集,您就在测试集上实现了 79% 的准确率。 在现代笔记本电脑 CPU 上,这应该不到一分钟。 并且通过最小化模型中的总参数,您将过拟合保持到最低。 与嵌入层相比,CNN 使用的参数非常少。
如果您继续训练一段时间会发生什么?
代码清单 7.21 继续训练
pipeline.epochs = 13 # #1 pipeline = pipeline.train()
Epoch: 1, loss: 0.24797, Train accuracy: 0.84528, Test accuracy: 0.78740 Epoch: 2, loss: 0.16067, Train accuracy: 0.86528, Test accuracy: 0.78871 ... Epoch: 12, loss: 0.04796, Train accuracy: 0.93578, Test accuracy: 0.77690 Epoch: 13, loss: 0.13394, Train accuracy: 0.94132, Test accuracy: 0.77690
哦,这看起来很可疑。 过拟合太严重了 - 在训练集上达到了 94%,在测试集上达到了 78%。 训练集准确率不断上升,最终超过了 90%。 到了第 20 个 epoch,模型在训练集上的准确率达到了 94%。 它甚至比专家人类还要好。 自己阅读几个示例,不看标签,你能得到其中的 94% 吗? 这是前四个示例,经过令牌化后,忽略了词汇表外的词汇,并添加了填充。
pipeline.indexes_to_texts(pipeline.x_test[:4])
['getting in the poor girl <PAD> <PAD> ...', 'Spot Flood Combo Cree LED Work Light Bar Offroad Lamp Full ...', 'ice the meltdown <PAD> <PAD> <PAD> <PAD> ...', 'and burn for bush fires in St http t co <PAD> <PAD> ...']
如果你的答案是[“disaster”, “not”, “not”, “disaster”],那你全部答对了。但继续努力吧。你能做到十九对二十吗?这就是你需要在训练集准确率上击败这个卷积神经网络所需要做到的。这不是什么意外,因为机器人一直在推特上发布听起来像是灾难的推文。有时甚至真实的人类也会对世界事件感到讽刺或煽动性。
是什么导致了这种过拟合?是参数太多了吗?神经网络的"容量"太大了吗?以下是一个好的函数,用于显示 PyTorch 神经网络每层的参数。
>>> def describe_model(model): # #1 ... state = model.state_dict() ... names = state.keys() ... weights = state.values() ... params = model.parameters() >>> df = pd.DataFrame() >>> df['name'] = list(state.keys()) >>> df['all'] = p.numel(), ... df['learned'] = [ ... p.requires_grad # #2 ... for p in params], # #3 ... size=p.size(), ... ) for name, w, p in zip(names, weights, params) ] ) df = df.set_index('name') return df describe_model(pipeline.model) # #4
learned_params all_params size name embedding.weight 128064 128064 (2001, 64) # #1 linear_layer.weight 1856 1856 (1, 1856) linear_layer.bias 1 1 (1,)
当你遇到过拟合问题时,你可以在管道中使用预训练模型来改善其泛化能力。
自然语言处理实战第二版(MEAP)(四)(3)https://developer.aliyun.com/article/1517995