引入注意力机制的Seq2seq模型
本节中将注意机制添加到sequence to sequence 模型中,以显式地使用权重聚合states。下图展示encoding 和decoding的模型结构,在时间步为t的时候。此刻attention layer保存着encodering看到的所有信息——即encoding的每一步输出。在decoding阶段,解码器的时刻的隐藏状态被当作query,encoder的每个时间步的hidden states作为key和value进行attention聚合. Attetion model的输出当作成上下文信息context vector,并与解码器输入拼接起来一起送到解码器:
Image Name
下图展示了seq2seq机制的所以层的关系,下面展示了encoder和decoder的layer结构
Image Name
import sys sys.path.append('/home/kesci/input/d2len9900') import d2l
解码器
由于带有注意机制的seq2seq的编码器与之前章节中的Seq2SeqEncoder相同,所以在此处我们只关注解码器。我们添加了一个MLP注意层(MLPAttention),它的隐藏大小与解码器中的LSTM层相同。然后我们通过从编码器传递三个参数来初始化解码器的状态:
- the encoder outputs of all timesteps:encoder输出的各个状态,被用于attetion layer的memory部分,有相同的key和values
- the hidden state of the encoder’s final timestep:编码器最后一个时间步的隐藏状态,被用于初始化decoder 的hidden state
- the encoder valid length: 编码器的有效长度,借此,注意层不会考虑编码器输出中的填充标记(Paddings)
在解码的每个时间步,我们使用解码器的最后一个RNN层的输出作为注意层的query。然后,将注意力模型的输出与输入嵌入向量连接起来,输入到RNN层。虽然RNN层隐藏状态也包含来自解码器的历史信息,但是attention model的输出显式地选择了enc_valid_len以内的编码器输出,这样attention机制就会尽可能排除其他不相关的信息。
class Seq2SeqAttentionDecoder(d2l.Decoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention_cell = MLPAttention(num_hiddens,num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.LSTM(embed_size+ num_hiddens,num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens,vocab_size) def init_state(self, enc_outputs, enc_valid_len, *args): outputs, hidden_state = enc_outputs # print("first:",outputs.size(),hidden_state[0].size(),hidden_state[1].size()) # Transpose outputs to (batch_size, seq_len, hidden_size) return (outputs.permute(1,0,-1), hidden_state, enc_valid_len) #outputs.swapaxes(0, 1) def forward(self, X, state): enc_outputs, hidden_state, enc_valid_len = state #("X.size",X.size()) X = self.embedding(X).transpose(0,1) # print("Xembeding.size2",X.size()) outputs = [] for l, x in enumerate(X): # print(f"\n{l}-th token") # print("x.first.size()",x.size()) # query shape: (batch_size, 1, hidden_size) # select hidden state of the last rnn layer as query query = hidden_state[0][-1].unsqueeze(1) # np.expand_dims(hidden_state[0][-1], axis=1) # context has same shape as query # print("query enc_outputs, enc_outputs:\n",query.size(), enc_outputs.size(), enc_outputs.size()) context = self.attention_cell(query, enc_outputs, enc_outputs, enc_valid_len) # Concatenate on the feature dimension # print("context.size:",context.size()) x = torch.cat((context, x.unsqueeze(1)), dim=-1) # Reshape x to (1, batch_size, embed_size+hidden_size) # print("rnn",x.size(), len(hidden_state)) out, hidden_state = self.rnn(x.transpose(0,1), hidden_state) outputs.append(out) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.transpose(0, 1), [enc_outputs, hidden_state, enc_valid_len]
现在我们可以用注意力模型来测试seq2seq。为了与第9.7节中的模型保持一致,我们对vocab_size、embed_size、num_hiddens和num_layers使用相同的超参数。结果,我们得到了相同的解码器输出形状,但是状态结构改变了。
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) # encoder.initialize() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) X = torch.zeros((4, 7),dtype=torch.long) print("batch size=4\nseq_length=7\nhidden dim=16\nnum_layers=2\n") print('encoder output size:', encoder(X)[0].size()) print('encoder hidden size:', encoder(X)[1][0].size()) print('encoder memory size:', encoder(X)[1][1].size()) state = decoder.init_state(encoder(X), None) out, state = decoder(X, state) out.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
batch size=4 seq_length=7 hidden dim=16 num_layers=2 encoder output size: torch.Size([7, 4, 16]) encoder hidden size: torch.Size([2, 4, 16]) encoder memory size: torch.Size([2, 4, 16]) (torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([2, 4, 16]))
训练
与第9.7.4节相似,通过应用相同的训练超参数和相同的训练损失来尝试一个简单的娱乐模型。从结果中我们可以看出,由于训练数据集中的序列相对较短,额外的注意层并没有带来显著的改进。由于编码器和解码器的注意层的计算开销,该模型比没有注意的seq2seq模型慢得多。
import zipfile import torch import requests from io import BytesIO from torch.utils import data import sys import collections class Vocab(object): # This class is saved in d2l. def __init__(self, tokens, min_freq=0, use_special_tokens=False): # sort by frequency and token counter = collections.Counter(tokens) token_freqs = sorted(counter.items(), key=lambda x: x[0]) token_freqs.sort(key=lambda x: x[1], reverse=True) if use_special_tokens: # padding, begin of sentence, end of sentence, unknown self.pad, self.bos, self.eos, self.unk = (0, 1, 2, 3) tokens = ['', '', '', ''] else: self.unk = 0 tokens = [''] tokens += [token for token, freq in token_freqs if freq >= min_freq] self.idx_to_token = [] self.token_to_idx = dict() for token in tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def __len__(self): return len(self.idx_to_token) def __getitem__(self, tokens): if not isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) else: return [self.__getitem__(token) for token in tokens] def to_tokens(self, indices): if not isinstance(indices, (list, tuple)): return self.idx_to_token[indices] else: return [self.idx_to_token[index] for index in indices] def load_data_nmt(batch_size, max_len, num_examples=1000): """Download an NMT dataset, return its vocabulary and data iterator.""" # Download and preprocess def preprocess_raw(text): text = text.replace('\u202f', ' ').replace('\xa0', ' ') out = '' for i, char in enumerate(text.lower()): if char in (',', '!', '.') and text[i-1] != ' ': out += ' ' out += char return out with open('/home/kesci/input/fraeng6506/fra.txt', 'r') as f: raw_text = f.read() text = preprocess_raw(raw_text) # Tokenize source, target = [], [] for i, line in enumerate(text.split('\n')): if i >= num_examples: break parts = line.split('\t') if len(parts) >= 2: source.append(parts[0].split(' ')) target.append(parts[1].split(' ')) # Build vocab def build_vocab(tokens): tokens = [token for line in tokens for token in line] return Vocab(tokens, min_freq=3, use_special_tokens=True) src_vocab, tgt_vocab = build_vocab(source), build_vocab(target) # Convert to index arrays def pad(line, max_len, padding_token): if len(line) > max_len: return line[:max_len] return line + [padding_token] * (max_len - len(line)) def build_array(lines, vocab, max_len, is_source): lines = [vocab[line] for line in lines] if not is_source: lines = [[vocab.bos] + line + [vocab.eos] for line in lines] array = torch.tensor([pad(line, max_len, vocab.pad) for line in lines]) valid_len = (array != vocab.pad).sum(1) return array, valid_len src_vocab, tgt_vocab = build_vocab(source), build_vocab(target) src_array, src_valid_len = build_array(source, src_vocab, max_len, True) tgt_array, tgt_valid_len = build_array(target, tgt_vocab, max_len, False) train_data = data.TensorDataset(src_array, src_valid_len, tgt_array, tgt_valid_len) train_iter = data.DataLoader(train_data, batch_size, shuffle=True) return src_vocab, tgt_vocab, train_iter
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.0 batch_size, num_steps = 64, 10 lr, num_epochs, ctx = 0.005, 500, d2l.try_gpu() src_vocab, tgt_vocab, train_iter = load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) model = d2l.EncoderDecoder(encoder, decoder)
训练和预测
d2l.train_s2s_ch9(model, train_iter, lr, num_epochs, ctx)
epoch 50,loss 0.104, time 54.7 sec epoch 100,loss 0.046, time 54.8 sec epoch 150,loss 0.031, time 54.7 sec epoch 200,loss 0.027, time 54.3 sec epoch 250,loss 0.025, time 54.3 sec epoch 300,loss 0.024, time 54.4 sec epoch 350,loss 0.024, time 54.4 sec epoch 400,loss 0.024, time 54.5 sec epoch 450,loss 0.023, time 54.4 sec epoch 500,loss 0.023, time 54.7 sec
for sentence in ['Go .', 'Good Night !', "I'm OK .", 'I won !']: print(sentence + ' => ' + d2l.predict_s2s_ch9( model, sentence, src_vocab, tgt_vocab, num_steps, ctx))
Go . => va ! Good Night ! => ! I'm OK . => ça va . I won ! => j'ai gagné !