机器翻译和数据集
机器翻译(MT):将一段文本从一种语言自动翻译为另一种语言,用神经网络解决这个问题通常称为神经机器翻译(NMT)。
主要特征:输出是单词序列而不是单个单词。 输出序列的长度可能与源序列的长度不同。
不能用传统RNN实现,机器翻译输出于输入序列长度可能不一样。
1 | import os |
['fraeng6506', 'd2l9528']
1 | import sys |
数据预处理
将数据集清洗、转化为神经网络的输入minbatch
1 | with open('/home/kesci/input/fraeng6506/fra.txt', 'r') as f: |
Go. Va ! CC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #1158250 (Wittydev)
Hi. Salut ! CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #509819 (Aiji)
Hi. Salut. CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #4320462 (gillux)
Run! Cours ! CC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #906331 (sacredceltic)
Run! Courez ! CC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #906332 (sacredceltic)
Who? Qui ? CC-BY 2.0 (France) Attribution: tatoeba.org #2083030 (CK) & #4366796 (gillux)
Wow! Ça alors ! CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #374631 (zmoo)
Fire! Au feu ! CC-BY 2.0 (France) Attribution: tatoeba.org #1829639 (Spamster) & #4627939 (sacredceltic)
Help! À l'aide ! CC-BY 2.0 (France) Attribution: tatoeba.org #435084 (lukaszpp) & #128430 (sysko)
Jump. Saute. CC-BY 2.0 (France) Attribution: tatoeba.org #631038 (Shishir) & #2416938 (Phoenix)
Stop! Ça suffit ! CC-BY 2.0 (France) Attribution: tato
1 | def preprocess_raw(text): |
go . va ! cc-by 2 .0 (france) attribution: tatoeba .org #2877272 (cm) & #1158250 (wittydev)
hi . salut ! cc-by 2 .0 (france) attribution: tatoeba .org #538123 (cm) & #509819 (aiji)
hi . salut . cc-by 2 .0 (france) attribution: tatoeba .org #538123 (cm) & #4320462 (gillux)
run ! cours ! cc-by 2 .0 (france) attribution: tatoeba .org #906328 (papabear) & #906331 (sacredceltic)
run ! courez ! cc-by 2 .0 (france) attribution: tatoeba .org #906328 (papabear) & #906332 (sacredceltic)
who? qui ? cc-by 2 .0 (france) attribution: tatoeba .org #2083030 (ck) & #4366796 (gillux)
wow ! ça alors ! cc-by 2 .0 (france) attribution: tatoeba .org #52027 (zifre) & #374631 (zmoo)
fire ! au feu ! cc-by 2 .0 (france) attribution: tatoeba .org #1829639 (spamster) & #4627939 (sacredceltic)
help ! à l'aide ! cc-by 2 .0 (france) attribution: tatoeba .org #435084 (lukaszpp) & #128430 (sysko)
jump . saute . cc-by 2 .0 (france) attribution: tatoeba .org #631038 (shishir) & #2416938 (phoenix)
stop ! ça suffit ! cc-b
字符在计算机里是以编码的形式存在,我们通常所用的空格是 \x20 ,是在标准ASCII可见字符 0x20~0x7e 范围内。
而 \xa0 属于 latin1 (ISO/IEC_8859-1)中的扩展字符集字符,代表不间断空白符nbsp(non-breaking space),超出gbk编码范围,是需要去除的特殊字符。再数据预处理的过程中,我们首先需要对数据进行清洗。
分词
字符串—单词组成的列表
1 | num_examples = 50000 |
([['go', '.'], ['hi', '.'], ['hi', '.']],
[['va', '!'], ['salut', '!'], ['salut', '.']])
1 | d2l.set_figsize() |
建立词典
单词组成的列表—单词id组成的列表
1 | def build_vocab(tokens): |
3789
载入数据集
1 | def pad(line, max_len, padding_token): # 保证句子长度一致 |
[38, 4, 0, 0, 0, 0, 0, 0, 0, 0]
1 | def build_array(lines, vocab, max_len, is_source): |
1 | def load_data_nmt(batch_size, max_len): # This function is saved in d2l. |
1 | src_vocab, tgt_vocab, train_iter = load_data_nmt(batch_size=2, max_len=8) |
X = tensor([[ 63, 16, 6, 59, 2805, 0, 0, 0],
[ 5, 78, 20, 613, 4, 0, 0, 0]], dtype=torch.int32)
Valid lengths for X = tensor([5, 5])
Y = tensor([[ 1, 66, 98, 75, 895, 6, 2, 0],
[ 1, 5, 100, 22, 10, 35, 810, 4]], dtype=torch.int32)
Valid lengths for Y = tensor([7, 8])
Encoder-Decoder
encoder:输入到隐藏状态
decoder:隐藏状态到输出
1 | class Encoder(nn.Module): |
1 | class Decoder(nn.Module): |
1 | class EncoderDecoder(nn.Module): |
可以应用在对话系统、生成式任务中。
Sequence to Sequence模型
模型:
训练
预测
具体结构:
Encoder
1 | class Seq2SeqEncoder(d2l.Encoder): |
1 | encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8,num_hiddens=16, num_layers=2) |
(torch.Size([7, 4, 16]), 2, torch.Size([2, 4, 16]), torch.Size([2, 4, 16]))
Decoder
1 | class Seq2SeqDecoder(d2l.Decoder): |
1 | decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8,num_hiddens=16, num_layers=2) |
(torch.Size([4, 7, 10]), 2, torch.Size([2, 4, 16]), torch.Size([2, 4, 16]))
损失函数
1 | def SequenceMask(X, X_len,value=0):# X:一个batch的输入(损失),X_len:有效长度 |
1 | X = torch.tensor([[1,2,3], [4,5,6]]) |
tensor([[1, 0, 0],
[4, 5, 0]])
1 | X = torch.ones((2, 3, 4)) |
tensor([[[ 1., 1., 1., 1.],
[-1., -1., -1., -1.],
[-1., -1., -1., -1.]],
[[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[-1., -1., -1., -1.]]])
1 | class MaskedSoftmaxCELoss(nn.CrossEntropyLoss): |
1 | loss = MaskedSoftmaxCELoss() |
tensor([2.3026, 1.7269, 0.0000])
训练
1 | def train_ch7(model, data_iter, lr, num_epochs, device): # Saved in d2l |
1 | embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.0 |
epoch 50,loss 0.096, time 34.0 sec
epoch 100,loss 0.047, time 34.7 sec
epoch 150,loss 0.032, time 33.4 sec
epoch 200,loss 0.027, time 32.9 sec
epoch 250,loss 0.026, time 34.3 sec
epoch 300,loss 0.024, time 33.9 sec
测试
1 | def translate_ch7(model, src_sentence, src_vocab, tgt_vocab, max_len, device): |
1 | for sentence in ['Go .', 'Wow !', "I'm OK .", 'I won !']: |
Go . => va !
Wow ! => <unk> !
I'm OK . => je vais bien .
I won ! => j'ai gagné !
Beam Search
简单greedy search:
维特比算法:选择整体分数最高的句子(搜索空间太大)
集束搜索: