importtorchfromtorchimportnnimporttorch.nn.utils.rnnasrnn_utilsfromtorch.utils.dataimportDataLoaderimporttorch.utils.dataasdata_classMyData(data_.Dataset):
def__init__(self, data, label):
self.data=dataself.label=labeldef__len__(self):
returnlen(self.data)
def__getitem__(self, idx):
tuple_= (self.data[idx], self.label[idx])
returntuple_defcollate_fn(data_tuple): data_tuple.sort(key=lambdax: len(x[0]), reverse=True)
data= [sq[0] forsqindata_tuple]
label= [sq[1] forsqindata_tuple]
data_length= [len(sq) forsqindata]
data=rnn_utils.pad_sequence(data, batch_first=True, padding_value=0.0) label=rnn_utils.pad_sequence(label, batch_first=True, padding_value=0.0) returndata.unsqueeze(-1), label, data_lengthif__name__=='__main__':
EPOCH=1batchsize=7hiddensize=4num_layers=2learning_rate=0.001train_x= [torch.FloatTensor([1, 1, 1, 1, 1, 1, 1]),
torch.FloatTensor([2, 2, 2, 2, 2, 2]),
torch.FloatTensor([3, 3, 3, 3, 3]),
torch.FloatTensor([4, 4, 4, 4]),
torch.FloatTensor([5, 5, 5]),
torch.FloatTensor([6, 6]),
torch.FloatTensor([7])]
train_y= [torch.rand(7, hiddensize),
torch.rand(6, hiddensize),
torch.rand(5, hiddensize),
torch.rand(4, hiddensize),
torch.rand(3, hiddensize),
torch.rand(2, hiddensize),
torch.rand(1, hiddensize)]
data_=MyData(train_x, train_y)
data_loader=DataLoader(data_, batch_size=batchsize, shuffle=True, collate_fn=collate_fn)
net=nn.LSTM(input_size=1, hidden_size=hiddensize, num_layers=num_layers, batch_first=True)
criteria=nn.MSELoss()
optimizer=torch.optim.Adam(net.parameters(), lr=learning_rate)
forepochinrange(EPOCH):
forbatch_id, (batch_x, batch_y, batch_x_len) inenumerate(data_loader):
print('pack前:', batch_x.shape)
batch_x_pack=rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True)
batch_y_pack=rnn_utils.pack_padded_sequence(batch_y, batch_x_len, batch_first=True)
print('pack后:', batch_x_pack[0].shape)
out, (h, c) =net(batch_x_pack) print('LSTM输出:', out[0].shape, h.shape, c.shape)
loss=criteria(out.data, batch_y_pack.data)
out=rnn_utils.pad_packed_sequence(out, batch_first=True)
print('还原pack数据:', out[0].shape)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}'.format(epoch, batch_id, loss))
print('Training done!')