pytorch的lstm掩码实现

简介: pytorch的lstm掩码实现
importtorchfromtorchimportnnimporttorch.nn.utils.rnnasrnn_utilsfromtorch.utils.dataimportDataLoaderimporttorch.utils.dataasdata_classMyData(data_.Dataset):
def__init__(self, data, label):
self.data=dataself.label=labeldef__len__(self):
returnlen(self.data)
def__getitem__(self, idx):
tuple_= (self.data[idx], self.label[idx])
returntuple_defcollate_fn(data_tuple):  # data_tuple是一个列表,列表中包含batchsize个元组,每个元组中包含数据和标签data_tuple.sort(key=lambdax: len(x[0]), reverse=True)
data= [sq[0] forsqindata_tuple]
label= [sq[1] forsqindata_tuple]
data_length= [len(sq) forsqindata]
data=rnn_utils.pad_sequence(data, batch_first=True, padding_value=0.0)  # 用零补充,使长度对齐label=rnn_utils.pad_sequence(label, batch_first=True, padding_value=0.0)  # 这行代码只是为了把列表变为tensorreturndata.unsqueeze(-1), label, data_lengthif__name__=='__main__':
EPOCH=1batchsize=7hiddensize=4num_layers=2learning_rate=0.001# 训练数据train_x= [torch.FloatTensor([1, 1, 1, 1, 1, 1, 1]),
torch.FloatTensor([2, 2, 2, 2, 2, 2]),
torch.FloatTensor([3, 3, 3, 3, 3]),
torch.FloatTensor([4, 4, 4, 4]),
torch.FloatTensor([5, 5, 5]),
torch.FloatTensor([6, 6]),
torch.FloatTensor([7])]
# 标签train_y= [torch.rand(7, hiddensize),
torch.rand(6, hiddensize),
torch.rand(5, hiddensize),
torch.rand(4, hiddensize),
torch.rand(3, hiddensize),
torch.rand(2, hiddensize),
torch.rand(1, hiddensize)]
data_=MyData(train_x, train_y)
data_loader=DataLoader(data_, batch_size=batchsize, shuffle=True, collate_fn=collate_fn)
net=nn.LSTM(input_size=1, hidden_size=hiddensize, num_layers=num_layers, batch_first=True)
criteria=nn.MSELoss()
optimizer=torch.optim.Adam(net.parameters(), lr=learning_rate)
# 训练forepochinrange(EPOCH):
forbatch_id, (batch_x, batch_y, batch_x_len) inenumerate(data_loader):
print('pack前:', batch_x.shape)
# print('pack前:',batch_x)batch_x_pack=rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True)
batch_y_pack=rnn_utils.pack_padded_sequence(batch_y, batch_x_len, batch_first=True)
print('pack后:', batch_x_pack[0].shape)
# print('pack后:',batch_x_pack)out, (h, c) =net(batch_x_pack)  # out.data's shape (所有序列总长度, hiddensize)print('LSTM输出:', out[0].shape, h.shape, c.shape)
# print('LSTM输出:',out)loss=criteria(out.data, batch_y_pack.data)
out=rnn_utils.pad_packed_sequence(out, batch_first=True)
print('还原pack数据:', out[0].shape)
# print('还原pack数据:',out)optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}'.format(epoch, batch_id, loss))
print('Training done!')
目录
相关文章
|
7月前
|
机器学习/深度学习 自然语言处理 PyTorch
PyTorch应用实战六:利用LSTM实现文本情感分类
PyTorch应用实战六:利用LSTM实现文本情感分类
162 0
|
2天前
|
机器学习/深度学习 算法 PyTorch
在Python中使用LSTM和PyTorch进行时间序列预测
在Python中使用LSTM和PyTorch进行时间序列预测
|
8月前
|
机器学习/深度学习 资源调度 自然语言处理
长短时记忆网络(LSTM)完整实战:从理论到PyTorch实战演示
长短时记忆网络(LSTM)完整实战:从理论到PyTorch实战演示
2232 0
|
2天前
|
机器学习/深度学习 自然语言处理 PyTorch
PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)
PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)
91 1
|
2天前
|
机器学习/深度学习 数据采集 自然语言处理
PyTorch搭建LSTM神经网络实现文本情感分析实战(附源码和数据集)
PyTorch搭建LSTM神经网络实现文本情感分析实战(附源码和数据集)
193 0
|
11月前
|
机器学习/深度学习 数据采集 自然语言处理
【Deep Learning A情感文本分类实战】2023 Pytorch+Bert、Roberta+TextCNN、BiLstm、Lstm等实现IMDB情感文本分类完整项目(项目已开源)
亮点:代码开源+结构清晰+准确率高+保姆级解析 🍊本项目使用Pytorch框架,使用上游语言模型+下游网络模型的结构实现IMDB情感分析 🍊语言模型可选择Bert、Roberta 🍊神经网络模型可选择BiLstm、LSTM、TextCNN、Rnn、Gru、Fnn共6种 🍊语言模型和网络模型扩展性较好,方便读者自己对模型进行修改
419 0
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
时间序列pytorch搭建lstm用电量预测 完整代码数据
时间序列pytorch搭建lstm用电量预测 完整代码数据
141 0
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
141 0
股票预测-基金预测 pytorch搭建LSTM网络 黄金价格预测实战
|
9月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch的LSTM物品移动预测算法
实现了一个多层双向LSTM模型,并用于训练一个时间序列预测任务
66 0

相关实验场景

更多