我的文本数据和文本标签放在csv文件中的,如何将文本数据设置小批次和打乱,再传入RNN或者LSTM网络

简介: 在 TorchText 中,可以使用 TabularDataset 类来读取 CSV 文件中的数据,并使用 BucketIterator 来对数据进行小批次和打乱处理。假设我们的 CSV 文件中有以下内容。

在 TorchText 中,可以使用 TabularDataset 类来读取 CSV 文件中的数据,并使用 BucketIterator 来对数据进行小批次和打乱处理。下面是一个简单示例:


import torch
import torch.nn as nn
from torchtext.legacy.data import TabularDataset, Field, BucketIterator
# 定义 Field 对象
TEXT = Field(tokenize='spacy', include_lengths=True)
LABEL = Field(sequential=False, unk_token=None)
# 读取 CSV 文件
train_data, test_data = TabularDataset.splits(
    path='data', train='train.csv', test='test.csv',
    format='csv', fields=[('text', TEXT), ('label', LABEL)]
)
# 构建词汇表
TEXT.build_vocab(train_data)
# 创建 Iterator 对象
train_iterator, test_iterator = BucketIterator.splits(
    (train_data, test_data), batch_size=32,
    shuffle=True, sort_within_batch=True,
    sort_key=lambda x: len(x.text)
)
# 定义模型
class RNN(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    def forward(self, text, text_lengths):
        embedded = self.embedding(text)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'), batch_first=True, enforce_sorted=False) # 可变长度序列压缩为固定长度
        packed_output, hidden = self.rnn(packed_embedded)
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True) # 固定长度输出还原为可变长度序列
        hidden = torch.sum(hidden, dim=0)
        return self.fc(hidden)
# 初始化模型和优化器
model = RNN(len(TEXT.vocab), 100, 256, len(LABEL.vocab))
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
# 训练
for epoch in range(10):
    for batch in train_iterator:
        optimizer.zero_grad()
        text, text_lengths = batch.text
        label = batch.label
        output = model(text, text_lengths)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
# 测试
with torch.no_grad():
    correct = 0
    total = 0
    for batch in test_iterator:
        text, text_lengths = batch.text
        label = batch.label
        output = model(text, text_lengths)
        predictions = torch.argmax(output, dim=1)
        correct += (predictions == label).sum().item()
        total += len(label)
accuracy = correct / total
print(f"Test accuracy: {accuracy}")


在上面的代码中,我们首先定义了 TEXTLABEL 两个 Field 对象,并使用 TabularDataset.splits() 函数读取 CSV 文件中的数据。然后使用 BucketIterator.splits() 函数创建了训练集和测试集的迭代器对象,并设置了小批次、打乱等参数。接下来定义了一个简单的 RNN 模型,并使用交叉熵损失函数进行训练和测试。

需要注意的是,在创建 BucketIterator 对象时,我们设置了 sort_within_batch=Truesort_key=lambda x: len(x.text),这样可以保证在同一个 batch 中按照文本长度重新排序,避免出现填充过多的情况。在模型的 forward() 函数中,我们使用了 pack_padded_sequence()pad_packed_sequence() 函数来处理可变长度的输入和输出,确保固定长度输出和有效计算。



--------------------上面 CSV 文件中的数据,文件内容结构 是什么样子的?


假设我们的 CSV 文件中有以下内容:


text,label
This is a positive sentence,1
This is a negative sentence,0
Another positive sentence,1
Another negative sentence,0


其中第一行是列名,第一列是文本数据,第二列是标签。每一行代表一个样本。

在 TorchText 中,使用 TabularDataset 类读取 CSV 文件时,需要指定列名和列对应的 Field 对象,示例如下:


from torchtext.legacy.data import TabularDataset, Field
# 定义 Field 对象
TEXT = Field()
LABEL = Field()
# 读取 CSV 文件
train_data, test_data = TabularDataset.splits(
    path='data', train='train.csv', test='test.csv',
    format='csv', fields=[('text', TEXT), ('label', LABEL)]
)


这里将 CSV 文件中的 text 列对应到了 TEXT 对象,将 label 列对应到了 LABEL 对象。可以根据实际需求,设置不同的参数,如何进行分词、是否支持序列等。


--------------------其中的 path='data'是什么含义?


在 TorchText 中,TabularDatasetField 对象需要读取数据文件并根据其内容处理数据。path='data' 表示数据文件所在的目录为当前工作目录下的 data 目录,即数据文件的完整路径为 './data/train.csv''./data/test.csv'

如果不指定 path 参数,则默认使用当前工作目录,即数据文件的完整路径应该与程序的运行位置有关。可以通过 os.getcwd() 函数获取当前工作目录。


--------------------上面的 TEXT.build_vocab(train_data) 含义和参数解释

TEXT.build_vocab(train_data) 的作用是根据训练集 train_data 中的文本数据构建词汇表。

具体来说,该函数会遍历 train_data 中所有的文本数据,将其中出现的单词按照一定的规则转化为数字表示,并将其存储到 TEXT.vocab 对象中。这个过程包含以下几个步骤:

  1. 统计训练集中每个单词的出现频率。
  2. 根据指定的参数(如最小出现次数、最大词汇量等)过滤词汇表中的单词。
  3. 把剩余的单词按照出现频率排序,给它们分配一个唯一的数字 ID。
  4. 将词汇表保存到 TEXT.vocab 对象中。


在上面的代码中,train_data 是一个 TabularDataset 对象,其中包含了所有的训练数据。TEXT 是一个 Field 对象,其包含了用于处理文本数据的相关参数和方法。在调用 build_vocab 函数时,没有指定其他参数,因此默认使用 min_freq=1 这个参数来筛选词汇表中的单词。

min_freq 指定了词汇表中单词的最小出现频率。如果一个单词在训练集中的出现次数小于 min_freq,则会被过滤掉,不会出现在词汇表中。例如,如果设置 min_freq=5,则只有在训练集中至少出现 5 次的单词才会被包含在词汇表中。

除了 min_freqbuild_vocab 函数还可以接受其他参数,如 max_size(限制词汇表的最大大小)、vectors(指定预训练词向量文件路径)等,具体可参考 TorchText 文档。


--------------------TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')的含义和用途


TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d') 是根据训练集 train_data 中的文本数据构建词汇表,并加载预训练的 GloVe 词向量。


其中,max_size=10000 指定了词汇表的最大大小为 10000。如果词汇表中单词的数量超过了这个限制,出现频率低的单词会被舍弃。如果没有设置此参数,则默认不会限制词汇表的大小。

vectors='glove.6B.100d' 指定了需要加载的预训练词向量文件路径。在这里,我们使用的是 GloVe 词向量模型的预训练文件,文件名为 glove.6B.100d.txt。该文件中包含了大约 40 万个单词的词向量,每个词向量的维度为 100。

通过加载预训练的词向量,在模型训练时可以将这些词向量作为权重初始化,从而加快模型收敛速度并提高模型的性能。同时,由于预训练词向量已经包含了一定的语义信息,使得模型可以更好地理解和推理自然语言文本。



--------------------上面vectors='glove.6B.100d' 除了用'glove.6B.100d' 还可以有哪些选项?


vectors 参数中,除了使用 glove.6B.100d 这个预训练词向量的文件名之外,还可以使用其它的预训练词向量模型或者将该参数设置为 None。下面是几种常见的预训练词向量模型及其文件名:


  • GloVe 词向量:
  • glove.6B.50d.txt
  • glove.6B.100d.txt
  • glove.6B.200d.txt
  • glove.6B.300d.txt
  • FastText 词向量:
  • cc.en.300.vec.gz
  • wiki-news-300d-1M.vec
  • crawl-300d-2M.vec
  • Word2Vec 词向量:
  • GoogleNews-vectors-negative300.bin


需要注意的是,不同的预训练词向量模型可能包含的单词和词向量维度都不同,因此在使用时需要根据具体情况选择合适的模型和参数。


--------------------如何下载FastText 词向量,以及怎么传给上面说得vectors,给出案例


可以在官网上下载 FastText 词向量的预训练文件。链接为:https://fasttext.cc/docs/en/crawl-vectors.html

下载完成后,可以将其放置在指定的路径下,并将该路径传递给 vectors 参数。

以下是一个示例代码片段,展示了如何使用 FastText 词向量:


import torchtext.vocab as vocab
import torchtext.data as data
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField()
train_data, test_data = data.TabularDataset.splits(
        path='./data/',
        train='train.csv',
        test='test.csv',
        format='csv',
        skip_header=True,
        fields=[('text', TEXT), ('label', LABEL)]
)
TEXT.build_vocab(train_data, max_size=10000, vectors='cc.en.300.vec.gz')
print(TEXT.vocab.vectors.shape)


在上面的代码中,我们首先通过 torchtext.vocabtorchtext.data 模块导入了所需的类和函数。然后,我们定义了 TEXTLABEL 两个字段,分别用于存储文本数据和标签数据。

接着,我们使用 TabularDataset 类加载了训练集和测试集,并指定了数据格式、文件路径等参数。

在调用 TEXT.build_vocab() 方法时,我们将 vectors 参数设置为 FastText 词向量的文件名 cc.en.300.vec.gz。此外,我们还指定了最大词汇表大小为 10000。

最后,我们打印出词向量矩阵的形状,以验证是否成功加载了预训练词向量。



--------------------如何下载Word2Vec 词向量,同上,给出vectors赋值方法


可以在官网上下载 Word2Vec 词向量的预训练文件。链接为:https://code.google.com/archive/p/word2vec/


下载完成后,可以将其放置在指定的路径下,并将该路径传递给 vectors 参数。

以下是一个示例代码片段,展示了如何使用 Word2Vec 词向量:

import torchtext.vocab as vocab
import torchtext.data as data
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField()
train_data, test_data = data.TabularDataset.splits(
        path='./data/',
        train='train.csv',
        test='test.csv',
        format='csv',
        skip_header=True,
        fields=[('text', TEXT), ('label', LABEL)]
)
TEXT.build_vocab(train_data, max_size=10000, vectors='GoogleNews-vectors-negative300.bin', vectors_format='word2vec')
print(TEXT.vocab.vectors.shape)


在上面的代码中,我们首先通过 torchtext.vocabtorchtext.data 模块导入了所需的类和函数。然后,我们定义了 TEXTLABEL 两个字段,分别用于存储文本数据和标签数据。

接着,我们使用 TabularDataset 类加载了训练集和测试集,并指定了数据格式、文件路径等参数。

在调用 TEXT.build_vocab() 方法时,我们将 vectors 参数设置为 Word2Vec 词向量的文件名 GoogleNews-vectors-negative300.bin,并将 vectors_format 参数设置为 'word2vec'。此外,我们还指定了最大词汇表大小为 10000。

最后,我们打印出词向量矩阵的形状,以验证是否成功加载了预训练词向量。


--------------------同上,GloVe 词向量


可以在官网上下载 GloVe 词向量的预训练文件。链接为:http://nlp.stanford.edu/data/glove.6B.zip

下载完成后,可以将其放置在指定的路径下,并将该路径传递给 vectors 参数。

以下是一个示例代码片段,展示了如何使用 GloVe 词向量:


import torchtext.vocab as vocab
import torchtext.data as data
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField()
train_data, test_data = data.TabularDataset.splits(
        path='./data/',
        train='train.csv',
        test='test.csv',
        format='csv',
        skip_header=True,
        fields=[('text', TEXT), ('label', LABEL)]
)
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.300d.txt', vectors_cache='./vector_cache/', unk_init=torch.Tensor.normal_)
print(TEXT.vocab.vectors.shape)


在上面的代码中,我们首先通过 torchtext.vocabtorchtext.data 模块导入了所需的类和函数。然后,我们定义了 TEXTLABEL 两个字段,分别用于存储文本数据和标签数据。

接着,我们使用 TabularDataset 类加载了训练集和测试集,并指定了数据格式、文件路径等参数。

在调用 TEXT.build_vocab() 方法时,我们将 vectors 参数设置为 GloVe 词向量的文件名 glove.6B.300d.txt,并将 vectors_cache 参数设置为缓存路径 './vector_cache/'。此外,      我们还指定了最大词汇表大小为 10000,并使用 torch.Tensor.normal_ 函数进行未登录词的初始化。最后,我们打印出词向量矩阵的形状,以验证是否成功加载了预训练词向量。

相关文章
|
28天前
|
机器学习/深度学习 数据采集 存储
时间序列预测新突破:深入解析循环神经网络(RNN)在金融数据分析中的应用
【10月更文挑战第7天】时间序列预测是数据科学领域的一个重要课题,特别是在金融行业中。准确的时间序列预测能够帮助投资者做出更明智的决策,比如股票价格预测、汇率变动预测等。近年来,随着深度学习技术的发展,尤其是循环神经网络(Recurrent Neural Networks, RNNs)及其变体如长短期记忆网络(LSTM)和门控循环单元(GRU),在处理时间序列数据方面展现出了巨大的潜力。本文将探讨RNN的基本概念,并通过具体的代码示例展示如何使用这些模型来进行金融数据分析。
178 2
|
1月前
|
监控 安全 网络安全
云计算与网络安全:保护数据的关键策略
【9月更文挑战第34天】在数字化时代,云计算已成为企业和个人存储、处理数据的优选方式。然而,随着云服务的普及,网络安全问题也日益凸显。本文将探讨云计算环境中的网络安全挑战,并提供一系列策略来加强信息安全。从基础的数据加密到复杂的访问控制机制,我们将一探究竟如何在享受云服务便利的同时,确保数据的安全性和隐私性不被侵犯。
63 10
|
2月前
|
存储 安全 网络安全
云计算与网络安全:守护数据,构筑未来
在当今的信息化时代,云计算已成为推动技术革新的重要力量。然而,随之而来的网络安全问题也日益凸显。本文从云服务、网络安全和信息安全等技术领域展开,探讨了云计算在为生活带来便捷的同时,如何通过技术创新和策略实施来确保网络环境的安全性和数据的保密性。
|
3天前
|
机器学习/深度学习 自然语言处理 前端开发
前端神经网络入门:Brain.js - 详细介绍和对比不同的实现 - CNN、RNN、DNN、FFNN -无需准备环境打开浏览器即可测试运行-支持WebGPU加速
本文介绍了如何使用 JavaScript 神经网络库 **Brain.js** 实现不同类型的神经网络,包括前馈神经网络(FFNN)、深度神经网络(DNN)和循环神经网络(RNN)。通过简单的示例和代码,帮助前端开发者快速入门并理解神经网络的基本概念。文章还对比了各类神经网络的特点和适用场景,并简要介绍了卷积神经网络(CNN)的替代方案。
|
6天前
|
存储 安全 网络安全
云计算与网络安全:保护数据的新策略
【10月更文挑战第28天】随着云计算的广泛应用,网络安全问题日益突出。本文将深入探讨云计算环境下的网络安全挑战,并提出有效的安全策略和措施。我们将分析云服务中的安全风险,探讨如何通过技术和管理措施来提升信息安全水平,包括加密技术、访问控制、安全审计等。此外,文章还将分享一些实用的代码示例,帮助读者更好地理解和应用这些安全策略。
|
11天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:从漏洞到加密,保护数据的关键步骤
【10月更文挑战第24天】在数字化时代,网络安全和信息安全是维护个人隐私和企业资产的前线防线。本文将探讨网络安全中的常见漏洞、加密技术的重要性以及如何通过提高安全意识来防范潜在的网络威胁。我们将深入理解网络安全的基本概念,学习如何识别和应对安全威胁,并掌握保护信息不被非法访问的策略。无论你是IT专业人士还是日常互联网用户,这篇文章都将为你提供宝贵的知识和技能,帮助你在网络世界中更安全地航行。
|
14天前
|
存储 安全 网络安全
云计算与网络安全:如何保护您的数据
【10月更文挑战第21天】在这篇文章中,我们将探讨云计算和网络安全的关系。随着云计算的普及,网络安全问题日益突出。我们将介绍云服务的基本概念,以及如何通过网络安全措施来保护您的数据。最后,我们将提供一些代码示例,帮助您更好地理解这些概念。
|
14天前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
|
1月前
|
Ubuntu 网络安全 数据安全/隐私保护
阿里云国际版如何设置网络控制面板
阿里云国际版如何设置网络控制面板
|
1月前
|
SQL 安全 测试技术
网络安全与信息安全:保护数据的艺术
【9月更文挑战第36天】在数字化时代,网络安全和信息安全已成为维护个人隐私和企业资产的基石。本文深入探讨了网络安全漏洞、加密技术以及安全意识的重要性,旨在为读者提供一份知识宝典,帮助他们在网络世界中航行而不触礁。我们将从网络安全的基本概念出发,逐步深入到复杂的加密算法,最后强调培养安全意识的必要性。无论你是IT专业人士还是日常互联网用户,这篇文章都将为你打开一扇了解和实践网络安全的大门。
34 2

热门文章

最新文章