机器翻译和数据集
机器翻译(MT):将一段文本从一种语言自动翻译为另一种语言,用神经网络解决这个问题通常称为神经机器翻译(NMT)。
主要特征:输出是单词序列而不是单个单词。 输出序列的长度可能与源序列的长度不同。
import os os.listdir('/home/kesci/input/')
['fraeng6506', 'd2l9528', 'd2l6239']
import sys sys.path.append('/home/kesci/input/d2l9528/') import collections import d2l import zipfile from d2l.data.base import Vocab import time import torch import torch.nn as nn import torch.nn.functional as F from torch.utils import data from torch import optim
数据预处理
将数据集清洗、转化为神经网络的输入minbatch
with open('/home/kesci/input/fraeng6506/fra.txt', 'r') as f: raw_text = f.read() print(raw_text[0:1000])
Go. Va ! CC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #1158250 (Wittydev) Hi. Salut ! CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #509819 (Aiji) Hi. Salut. CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #4320462 (gillux) Run! Cours ! CC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #906331 (sacredceltic) Run! Courez ! CC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #906332 (sacredceltic) Who? Qui ? CC-BY 2.0 (France) Attribution: tatoeba.org #2083030 (CK) & #4366796 (gillux) Wow! Ça alors ! CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #374631 (zmoo) Fire! Au feu ! CC-BY 2.0 (France) Attribution: tatoeba.org #1829639 (Spamster) & #4627939 (sacredceltic) Help! À l'aide ! CC-BY 2.0 (France) Attribution: tatoeba.org #435084 (lukaszpp) & #128430 (sysko) Jump. Saute. CC-BY 2.0 (France) Attribution: tatoeba.org #631038 (Shishir) & #2416938 (Phoenix) Stop! Ça suffit ! CC-BY 2.0 (France) Attribution: tato
def preprocess_raw(text): text = text.replace('\u202f', ' ').replace('\xa0', ' ') out = '' for i, char in enumerate(text.lower()): if char in (',', '!', '.') and i > 0 and text[i-1] != ' ': out += ' ' out += char return out text = preprocess_raw(raw_text) print(text[0:1000])
go . va ! cc-by 2 .0 (france) attribution: tatoeba .org #2877272 (cm) & #1158250 (wittydev) hi . salut ! cc-by 2 .0 (france) attribution: tatoeba .org #538123 (cm) & #509819 (aiji) hi . salut . cc-by 2 .0 (france) attribution: tatoeba .org #538123 (cm) & #4320462 (gillux) run ! cours ! cc-by 2 .0 (france) attribution: tatoeba .org #906328 (papabear) & #906331 (sacredceltic) run ! courez ! cc-by 2 .0 (france) attribution: tatoeba .org #906328 (papabear) & #906332 (sacredceltic) who? qui ? cc-by 2 .0 (france) attribution: tatoeba .org #2083030 (ck) & #4366796 (gillux) wow ! ça alors ! cc-by 2 .0 (france) attribution: tatoeba .org #52027 (zifre) & #374631 (zmoo) fire ! au feu ! cc-by 2 .0 (france) attribution: tatoeba .org #1829639 (spamster) & #4627939 (sacredceltic) help ! à l'aide ! cc-by 2 .0 (france) attribution: tatoeba .org #435084 (lukaszpp) & #128430 (sysko) jump . saute . cc-by 2 .0 (france) attribution: tatoeba .org #631038 (shishir) & #2416938 (phoenix) stop ! ça suffit ! cc-b
字符在计算机里是以编码的形式存在,我们通常所用的空格是 \x20 ,是在标准ASCII可见字符 0x20~0x7e 范围内。
而 \xa0 属于 latin1 (ISO/IEC_8859-1)中的扩展字符集字符,代表不间断空白符nbsp(non-breaking space),超出gbk编码范围,是需要去除的特殊字符。再数据预处理的过程中,我们首先需要对数据进行清洗。
分词
字符串---单词组成的列表
num_examples = 50000 source, target = [], [] for i, line in enumerate(text.split('\n')): if i > num_examples: break parts = line.split('\t') if len(parts) >= 2: source.append(parts[0].split(' ')) target.append(parts[1].split(' ')) source[0:3], target[0:3]
([['go', '.'], ['hi', '.'], ['hi', '.']], [['va', '!'], ['salut', '!'], ['salut', '.']])
d2l.set_figsize() d2l.plt.hist([[len(l) for l in source], [len(l) for l in target]],label=['source', 'target']) d2l.plt.legend(loc='upper right');
<img src="https://cdn.kesci.com/rt_upload/7589E7D345B3463A8F0F4574ED6EDA9A/q5jefa8ffq.svg">
建立词典
单词组成的列表---单词id组成的列表
def build_vocab(tokens): tokens = [token for line in tokens for token in line] return d2l.data.base.Vocab(tokens, min_freq=3, use_special_tokens=True) src_vocab = build_vocab(source) len(src_vocab)
3789
Image Name
载入数据集
def pad(line, max_len, padding_token): if len(line) > max_len: return line[:max_len] return line + [padding_token] * (max_len - len(line)) pad(src_vocab[source[0]], 10, src_vocab.pad)
[38, 4, 0, 0, 0, 0, 0, 0, 0, 0]
def build_array(lines, vocab, max_len, is_source): lines = [vocab[line] for line in lines] if not is_source: lines = [[vocab.bos] + line + [vocab.eos] for line in lines] array = torch.tensor([pad(line, max_len, vocab.pad) for line in lines]) valid_len = (array != vocab.pad).sum(1) #第一个维度 return array, valid_len
Image Name
def load_data_nmt(batch_size, max_len): # This function is saved in d2l. src_vocab, tgt_vocab = build_vocab(source), build_vocab(target) src_array, src_valid_len = build_array(source, src_vocab, max_len, True) tgt_array, tgt_valid_len = build_array(target, tgt_vocab, max_len, False) train_data = data.TensorDataset(src_array, src_valid_len, tgt_array, tgt_valid_len) train_iter = data.DataLoader(train_data, batch_size, shuffle=True) return src_vocab, tgt_vocab, train_iter
src_vocab, tgt_vocab, train_iter = load_data_nmt(batch_size=2, max_len=8) for X, X_valid_len, Y, Y_valid_len, in train_iter: print('X =', X.type(torch.int32), '\nValid lengths for X =', X_valid_len, '\nY =', Y.type(torch.int32), '\nValid lengths for Y =', Y_valid_len) break
X = tensor([[ 5, 24, 3, 4, 0, 0, 0, 0], [ 12, 1388, 7, 3, 4, 0, 0, 0]], dtype=torch.int32) Valid lengths for X = tensor([4, 5]) Y = tensor([[ 1, 23, 46, 3, 3, 4, 2, 0], [ 1, 15, 137, 27, 4736, 4, 2, 0]], dtype=torch.int32) Valid lengths for Y = tensor([7, 7])
Encoder-Decoder
encoder:输入到隐藏状态
decoder:隐藏状态到输出
Image Name
class Encoder(nn.Module): def __init__(self, **kwargs): super(Encoder, self).__init__(**kwargs) def forward(self, X, *args): raise NotImplementedError
class Decoder(nn.Module): def __init__(self, **kwargs): super(Decoder, self).__init__(**kwargs) def init_state(self, enc_outputs, *args): raise NotImplementedError def forward(self, X, state): raise NotImplementedError
class EncoderDecoder(nn.Module): def __init__(self, encoder, decoder, **kwargs): super(EncoderDecoder, self).__init__(**kwargs) self.encoder = encoder self.decoder = decoder def forward(self, enc_X, dec_X, *args): enc_outputs = self.encoder(enc_X, *args) dec_state = self.decoder.init_state(enc_outputs, *args) return self.decoder(dec_X, dec_state)
可以应用在对话系统、生成式任务中。