从入门AI到手写Transformer-17.整体代码讲解
整理自视频 老袁不说话 。
17.整体代码讲解
代码
import collections
import math
import torch
from torch import nn
import os
import time
import numpy as np
from matplotlib import pyplot as plt
from matplotlib_inline import backend_inline
import hashlib
import os
import tarfile
import zipfile
import requests
from IPython import display
from torch. utils import data
DATA_HUB = dict ( )
DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/"
DATA_HUB[ "fra-eng" ] = (
DATA_URL + "fra-eng.zip" ,
"94646ad1522d915e7b0f9296181140edcf86a4f5" ,
)
def try_gpu ( i= 0 ) :
"""如果存在,则返回gpu(i),否则返回cpu()"""
if torch. cuda. device_count( ) >= i + 1 :
return torch. device( f"cuda: { i} " )
return torch. device( "cpu" )
def bleu ( pred_seq, label_seq, k) :
"""计算BLEU"""
pred_tokens, label_tokens = pred_seq. split( " " ) , label_seq. split( " " )
len_pred, len_label = len ( pred_tokens) , len ( label_tokens)
score = math. exp( min ( 0 , 1 - len_label / len_pred) )
for n in range ( 1 , k + 1 ) :
num_matches, label_subs = 0 , collections. defaultdict( int )
for i in range ( len_label - n + 1 ) :
label_subs[ " " . join( label_tokens[ i : i + n] ) ] += 1
for i in range ( len_pred - n + 1 ) :
if label_subs[ " " . join( pred_tokens[ i : i + n] ) ] > 0 :
num_matches += 1
label_subs[ " " . join( pred_tokens[ i : i + n] ) ] -= 1
score *= math. pow ( num_matches / ( len_pred - n + 1 ) , math. pow ( 0.5 , n) )
return score
def count_corpus ( tokens) :
"""统计词元的频率"""
if len ( tokens) == 0 or isinstance ( tokens[ 0 ] , list ) :
tokens = [ token for line in tokens for token in line]
return collections. Counter( tokens)
def download ( name, cache_dir= os. path. join( "." , "./data" ) ) :
"""下载一个DATA_HUB中的文件,返回本地文件名"""
assert name in DATA_HUB, f" { name} 不存在于 { DATA_HUB} "
url, sha1_hash = DATA_HUB[ name]
os. makedirs( cache_dir, exist_ok= True )
fname = os. path. join( cache_dir, url. split( "/" ) [ - 1 ] )
if os. path. exists( fname) :
sha1 = hashlib. sha1( )
with open ( fname, "rb" ) as f:
while True :
data = f. read( 1048576 )
if not data:
break
sha1. update( data)
if sha1. hexdigest( ) == sha1_hash:
return fname
print ( f"正在从 { url} 下载 { fname} ..." )
r = requests. get( url, stream= True , verify= True )
with open ( fname, "wb" ) as f:
f. write( r. content)
return fname
def download_extract ( name, folder= None ) :
"""下载并解压zip/tar文件"""
fname = download( name)
base_dir = os. path. dirname( fname)
data_dir, ext = os. path. splitext( fname)
if ext == ".zip" :
fp = zipfile. ZipFile( fname, "r" )
elif ext in ( ".tar" , ".gz" ) :
fp = tarfile. open ( fname, "r" )
else :
assert False , "只有zip/tar文件可以被解压缩"
fp. extractall( base_dir)
return os. path. join( base_dir, folder) if folder else data_dir
def read_data_nmt ( ) :
"""载入“英语-法语”数据集"""
data_dir = download_extract( "fra-eng" )
with open ( os. path. join( data_dir, "fra.txt" ) , "r" , encoding= "utf-8" ) as f:
return f. read( )
def masked_softmax ( X, valid_lens) :
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
if valid_lens is None :
return nn. functional. softmax( X, dim= - 1 )
else :
shape = X. shape
if valid_lens. dim( ) == 1 :
valid_lens = torch. repeat_interleave( valid_lens, shape[ 1 ] )
else :
valid_lens = valid_lens. reshape( - 1 )
X = sequence_mask( X. reshape( - 1 , shape[ - 1 ] ) , valid_lens, value= - 1e6 )
return nn. functional. softmax( X. reshape( shape) , dim= - 1 )
def sequence_mask ( X, valid_len, value= 0 ) :
"""在序列中屏蔽不相关的项"""
maxlen = X. size( 1 )
mask = (
torch. arange( ( maxlen) , dtype= torch. float32, device= X. device) [ None , : ]
< valid_len[ : , None ]
)
X[ ~ mask] = value
return X
def preprocess_nmt ( text) :
"""预处理“英语-法语”数据集"""
def no_space ( char, prev_char) :
return char in set ( ",.!?" ) and prev_char != " "
text = text. replace( "\u202f" , " " ) . replace( "\xa0" , " " ) . lower( )
out = [
" " + char if i > 0 and no_space( char, text[ i - 1 ] ) else char
for i, char in enumerate ( text)
]
return "" . join( out)
def tokenize_nmt ( text, num_examples= None ) :
"""词元化“英语-法语”数据数据集"""
source, target = [ ] , [ ]
for i, line in enumerate ( text. split( "\n" ) ) :
if num_examples and i > num_examples:
break
parts = line. split( "\t" )
if len ( parts) == 2 :
source. append( parts[ 0 ] . split( " " ) )
target. append( parts[ 1 ] . split( " " ) )
return source, target
def grad_clipping ( net, theta) :
"""裁剪梯度"""
if isinstance ( net, nn. Module) :
params = [ p for p in net. parameters( ) if p. requires_grad]
else :
params = net. params
norm = torch. sqrt( sum ( torch. sum ( ( p. grad** 2 ) ) for p in params) )
if norm > theta:
for param in params:
param. grad[ : ] *= theta / norm
def truncate_pad ( line, num_steps, padding_token) :
"""截断或填充文本序列"""
if len ( line) > num_steps:
return line[ : num_steps]
return line + [ padding_token] * ( num_steps - len ( line) )
def build_array_nmt ( lines, vocab, num_steps) :
"""将机器翻译的文本序列转换成小批量"""
lines = [ vocab[ l] for l in lines]
lines = [ l + [ vocab[ "" ] ] for l in lines]
array = torch. tensor( [ truncate_pad( l, num_steps, vocab[ "" ] ) for l in lines] )
valid_len = ( array != vocab[ "" ] ) . type ( torch. int32) . sum ( 1 )
return array, valid_len
def load_array ( data_arrays, batch_size, is_train= True ) :
"""构造一个PyTorch数据迭代器"""
dataset = data. TensorDataset( * data_arrays)
return data. DataLoader( dataset, batch_size, shuffle= is_train)
def load_data_nmt ( batch_size, num_steps, num_examples= 600 ) :
"""返回翻译数据集的迭代器和词表"""
text = preprocess_nmt( read_data_nmt( ) )
source, target = tokenize_nmt( text, num_examples)
src_vocab = Vocab( source, min_freq= 2 , reserved_tokens= [ "" , "" , "" ] )
tgt_vocab = Vocab( target, min_freq= 2 , reserved_tokens= [ "" , "" , "" ] )
src_array, src_valid_len = build_array_nmt( source, src_vocab, num_steps)
tgt_array, tgt_valid_len = build_array_nmt( target, tgt_vocab, num_steps)
data_arrays = ( src_array, src_valid_len, tgt_array, tgt_valid_len)
data_iter = load_array( data_arrays, batch_size)
return data_iter, src_vocab, tgt_vocab
def sequence_mask ( X, valid_len, value= 0 ) :
maxlen = X. size( 1 )
mask = (
torch. arange( ( maxlen) , dtype= torch. float32, device= X. device) [ None , : ]
< valid_len[ : , None ]
)
X[ ~ mask] = value
return X
def transpose_qkv ( X, num_heads) :
X = X. reshape( X. shape[ 0 ] , X. shape[ 1 ] , num_heads, - 1 )
X = X. permute( 0 , 2 , 1 , 3 )
return X. reshape( - 1 , X. shape[ 2 ] , X. shape[ 3 ] )
def train_seq2seq ( net, data_iter, lr, num_epochs, tgt_vocab, device) :
def xavier_init_weights ( m) :
if type ( m) == nn. Linear:
nn. init. xavier_uniform_( m. weight)
if type ( m) == nn. GRU:
for param in m. _flat_weights_names:
if "weight" in param:
nn. init. xavier_uniform_( m. _parameters[ param] )
net. apply ( xavier_init_weights)
net. to( device)
optimizer = torch. optim. Adam( net. parameters( ) , lr= lr)
loss = MaskedSoftmaxCELoss( )
net. train( )
animator = Animator( xlabel= "epoch" , ylabel= "loss" , xlim= [ 10 , num_epochs] )
for epoch in range ( num_epochs) :
timer = Timer( )
metric = Accumulator( 2 )
for batch in data_iter:
optimizer. zero_grad( )
X, X_valid_len, Y, Y_valid_len = [ x. to( device) for x in batch]
bos = torch. tensor(
[ tgt_vocab[ "" ] ] * Y. shape[ 0 ] , device= device
) . reshape( - 1 , 1 )
dec_input = torch. cat( [ bos, Y[ : , : - 1 ] ] , 1 )
Y_hat, _ = net( X, dec_input, X_valid_len)
l = loss( Y_hat, Y, Y_valid_len)
l. sum ( ) . backward( )
grad_clipping( net, 1 )
num_tokens = Y_valid_len. sum ( )
optimizer. step( )
with torch. no_grad( ) :
metric. add( l. sum ( ) , num_tokens)
if ( epoch + 1 ) % 10 == 0 :
animator. add( epoch + 1 , ( metric[ 0 ] / metric[ 1 ] , ) )
print (
f"loss { metric[ 0 ] / metric[ 1 ] : .3f } , { metric[ 1 ] / timer. stop( ) : .1f } "
f"tokens/sec on { str ( device) } "
)
def predict_seq2seq (
net,
src_sentence,
src_vocab,
tgt_vocab,
num_steps,
device,
save_attention_weights= False ,
) :
net. to( device)
net. eval ( )
src_tokens = src_vocab[ src_sentence. lower( ) . split( " " ) ] + [ src_vocab[ "" ] ]
enc_valid_len = torch. tensor( [ len ( src_tokens) ] , device= device)
src_tokens = truncate_pad( src_tokens, num_steps, src_vocab[ "" ] )
enc_X = torch. unsqueeze(
torch. tensor( src_tokens, dtype= torch. long , device= device) , dim= 0
)
enc_outputs = net. encoder( enc_X, enc_valid_len)
dec_state = net. decoder. init_state( enc_outputs, enc_valid_len)
dec_X = torch. unsqueeze(
torch. tensor( [ tgt_vocab[ "" ] ] , dtype= torch. long , device= device) , dim= 0
)
output_seq, attention_weight_seq = [ ] , [ ]
for _ in range ( num_steps) :
Y, dec_state = net. decoder( dec_X, dec_state)
dec_X = Y. argmax( dim= 2 )
pred = dec_X. squeeze( dim= 0 ) . type ( torch. int32) . item( )
if save_attention_weights:
attention_weight_seq. append( net. decoder. attention_weights)
if pred == tgt_vocab[ "" ] :
break
output_seq. append( pred)
return " " . join( tgt_vocab. to_tokens( output_seq) ) , attention_weight_seq
def transpose_output ( X, num_heads) :
X = X. reshape( - 1 , num_heads, X. shape[ 1 ] , X. shape[ 2 ] )
X = X. permute( 0 , 2 , 1 , 3 )
return X. reshape( X. shape[ 0 ] , X. shape[ 1 ] , - 1 )
def use_svg_display ( ) :
"""使用svg格式在Jupyter中显示绘图"""
backend_inline. set_matplotlib_formats( "svg" )
def set_axes ( axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend) :
"""设置matplotlib的轴"""
axes. set_xlabel( xlabel)
axes. set_ylabel( ylabel)
axes. set_xscale( xscale)
axes. set_yscale( yscale)
axes. set_xlim( xlim)
axes. set_ylim( ylim)
if legend:
axes. legend( legend)
axes. grid( )
def set_figsize ( figsize= ( 3.5 , 2.5 ) ) :
"""设置matplotlib的图表大小"""
use_svg_display( )
plt. rcParams[ "figure.figsize" ] = figsize
def dropout_layer ( X, dropout) :
assert 0 <= dropout <= 1
if dropout == 1 :
return torch. zeros_like( X)
if dropout == 0 :
return X
mask = ( torch. rand( X. shape) > dropout) . float ( )
return mask * X / ( 1.0 - dropout)
class Accumulator :
"""在n个变量上累加"""
def __init__ ( self, n) :
self. data = [ 0.0 ] * n
def add ( self, * args) :
self. data = [ a + float ( b) for a, b in zip ( self. data, args) ]
def reset ( self) :
self. data = [ 0.0 ] * len ( self. data)
def __getitem__ ( self, idx) :
return self. data[ idx]
class Timer :
"""记录多次运行时间"""
def __init__ ( self) :
self. times = [ ]
self. start( )
def start ( self) :
"""启动计时器"""
self. tik = time. time( )
def stop ( self) :
"""停止计时器并将时间记录在列表中"""
self. times. append( time. time( ) - self. tik)
return self. times[ - 1 ]
def avg ( self) :
"""返回平均时间"""
return sum ( self. times) / len ( self. times)
def sum ( self) :
"""返回时间总和"""
return sum ( self. times)
def cumsum ( self) :
"""返回累计时间"""
return np. array( self. times) . cumsum( ) . tolist( )
class Animator :
"""在动画中绘制数据"""
def __init__ (
self,
xlabel= None ,
ylabel= None ,
legend= None ,
xlim= None ,
ylim= None ,
xscale= "linear" ,
yscale= "linear" ,
fmts= ( "-" , "m--" , "g-." , "r:" ) ,
nrows= 1 ,
ncols= 1 ,
figsize= ( 3.5 , 2.5 ) ,
) :
if legend is None :
legend = [ ]
use_svg_display( )
self. fig, self. axes = plt. subplots( nrows, ncols, figsize= figsize)
if nrows * ncols == 1 :
self. axes = [
self. axes,
]
self. config_axes = lambda : set_axes(
self. axes[ 0 ] , xlabel, ylabel, xlim, ylim, xscale, yscale, legend
)
self. X, self. Y, self. fmts = None , None , fmts
def add ( self, x, y) :
if not hasattr ( y, "__len__" ) :
y = [ y]
n = len ( y)
if not hasattr ( x, "__len__" ) :
x = [ x] * n
if not self. X:
self. X = [ [ ] for _ in range ( n) ]
if not self. Y:
self. Y = [ [ ] for _ in range ( n) ]
for i, ( a, b) in enumerate ( zip ( x, y) ) :
if a is not None and b is not None :
self. X[ i] . append( a)
self. Y[ i] . append( b)
self. axes[ 0 ] . cla( )
for x, y, fmt in zip ( self. X, self. Y, self. fmts) :
self. axes[ 0 ] . plot( x, y, fmt)
self. config_axes( )
display. display( self. fig)
plt. draw( )
plt. pause( 0.001 )
class Vocab :
"""文本词表"""
def __init__ ( self, tokens= None , min_freq= 0 , reserved_tokens= None ) :
if tokens is None :
tokens = [ ]
if reserved_tokens is None :
reserved_tokens = [ ]
counter = count_corpus( tokens)
self. _token_freqs = sorted ( counter. items( ) , key= lambda x: x[ 1 ] , reverse= True )
self. idx_to_token = [ "" ] + reserved_tokens
self. token_to_idx = { token: idx for idx, token in enumerate ( self. idx_to_token) }
for token, freq in self. _token_freqs:
if freq < min_freq:
break
if token not in self. token_to_idx:
self. idx_to_token. append( token)
self. token_to_idx[ token] = len ( self. idx_to_token) - 1
def __len__ ( self) :
return len ( self. idx_to_token)
def __getitem__ ( self, tokens) :
if not isinstance ( tokens, ( list , tuple ) ) :
return self. token_to_idx. get( tokens, self. unk)
return [ self. __getitem__( token) for token in tokens]
def to_tokens ( self, indices) :
if not isinstance ( indices, ( list , tuple ) ) :
return self. idx_to_token[ indices]
return [ self. idx_to_token[ index] for index in indices]
@property
def unk ( self) :
return 0
@property
def token_freqs ( self) :
return self. _token_freqs
class MaskedSoftmaxCELoss ( nn. CrossEntropyLoss) :
def forward ( self, pred, label, valid_len) :
weights = torch. ones_like( label)
weights = sequence_mask( weights, valid_len)
self. reduction = "none"
unweighted_loss = super ( MaskedSoftmaxCELoss, self) . forward(
pred. permute( 0 , 2 , 1 ) , label
)
weighted_loss = ( unweighted_loss * weights) . mean( dim= 1 )
return weighted_loss
class MultiHeadAttention ( nn. Module) :
def __init__ (
self,
key_size,
query_size,
value_size,
num_hiddens,
num_heads,
dropout,
bias= False ,
** kwargs,
) :
super ( MultiHeadAttention, self) . __init__( ** kwargs)
self. num_heads = num_heads
self. attention = DotProductAttention( dropout)
self. W_q = nn. Linear( query_size, num_hiddens, bias= bias)
self. W_k = nn. Linear( key_size, num_hiddens, bias= bias)
self. W_v = nn. Linear( value_size, num_hiddens, bias= bias)
self. W_o = nn. Linear( num_hiddens, num_hiddens, bias= bias)
def forward ( self, queries, keys, values, valid_lens) :
queries = transpose_qkv( self. W_q( queries) , self. num_heads)
keys = transpose_qkv( self. W_k( keys) , self. num_heads)
values = transpose_qkv( self. W_v( values) , self. num_heads)
if valid_lens is not None :
valid_lens = torch. repeat_interleave(
valid_lens, repeats= self. num_heads, dim= 0
)
output = self. attention( queries, keys, values, valid_lens)
output_concat = transpose_output( output, self. num_heads)
return self. W_o( output_concat)
class PositionalEncoding ( nn. Module) :
def __init__ ( self, num_hiddens, dropout, max_len= 1000 ) :
super ( PositionalEncoding, self) . __init__( )
self. dropout = nn. Dropout( dropout)
self. P = torch. zeros( ( 1 , max_len, num_hiddens) )
X = torch. arange( max_len, dtype= torch. float32) . reshape( - 1 , 1 ) / torch. pow (
10000 , torch. arange( 0 , num_hiddens, 2 , dtype= torch. float32) / num_hiddens
)
self. P[ : , : , 0 : : 2 ] = torch. sin( X)
self. P[ : , : , 1 : : 2 ] = torch. cos( X)
def forward ( self, X) :
X = X + self. P[ : , : X. shape[ 1 ] , : ] . to( X. device)
return self. dropout( X)
class PositionWiseFFN ( nn. Module) :
def __init__ ( self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, ** kwargs) :
super ( PositionWiseFFN, self) . __init__( ** kwargs)
self. dense1 = nn. Linear( ffn_num_input, ffn_num_hiddens)
self. relu = nn. ReLU( )
self. dense2 = nn. Linear( ffn_num_hiddens, ffn_num_outputs)
def forward ( self, X) :
return self. dense2( self. relu( self. dense1( X) ) )
class AddNorm ( nn. Module) :
def __init__ ( self, normalized_shape, dropout, ** kwargs) :
super ( AddNorm, self) . __init__( ** kwargs)
self. dropout = nn. Dropout( dropout)
self. ln = nn. LayerNorm( normalized_shape)
nn. Softmax( )
def forward ( self, X, Y) :
return self. ln( self. dropout( Y) + X)
class Encoder ( nn. Module) :
def __init__ ( self, ** kwargs) :
super ( Encoder, self) . __init__( ** kwargs)
def forward ( self, X, * args) :
raise NotImplementedError
class Decoder ( nn. Module) :
def __init__ ( self, ** kwargs) :
super ( Decoder, self) . __init__( ** kwargs)
def init_state ( self, enc_outputs, * args) :
raise NotImplementedError
def forward ( self, X, state) :
raise NotImplementedError
class EncoderDecoder ( nn. Module) :
def __init__ ( self, encoder, decoder, ** kwargs) :
super ( EncoderDecoder, self) . __init__( ** kwargs)
self. encoder = encoder
self. decoder = decoder
def forward ( self, enc_X, dec_X, * args) :
enc_outputs = self. encoder( enc_X, * args)
dec_state = self. decoder. init_state( enc_outputs, * args)
return self. decoder( dec_X, dec_state)
class DotProductAttention ( nn. Module) :
def __init__ ( self, dropout, ** kwargs) :
super ( DotProductAttention, self) . __init__( ** kwargs)
self. dropout = nn. Dropout( dropout)
def forward ( self, queries, keys, values, valid_lens= None ) :
d = queries. shape[ - 1 ]
scores = torch. bmm( queries, keys. transpose( 1 , 2 ) ) / math. sqrt( d)
self. attention_weights = masked_softmax( scores, valid_lens)
return torch. bmm( self. dropout( self. attention_weights) , values)
class AttentionDecoder ( Decoder) :
def __init__ ( self, ** kwargs) :
super ( AttentionDecoder, self) . __init__( ** kwargs)
@property
def attention_weights ( self) :
raise NotImplementedError
class EncoderBlock ( nn. Module) :
def __init__ (
self,
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
dropout,
use_bias= False ,
** kwargs,
) :
super ( EncoderBlock, self) . __init__( ** kwargs)
self. attention = MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias
)
self. addnorm1 = AddNorm( norm_shape, dropout)
self. ffn = PositionWiseFFN( ffn_num_input, ffn_num_hiddens, num_hiddens)
self. addnorm2 = AddNorm( norm_shape, dropout)
def forward ( self, X, valid_lens) :
Y = self. addnorm1( X, self. attention( X, X, X, valid_lens) )
return self. addnorm2( Y, self. ffn( Y) )
class DecoderBlock ( nn. Module) :
def __init__ (
self,
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
dropout,
i,
** kwargs,
) :
super ( DecoderBlock, self) . __init__( ** kwargs)
self. i = i
self. attention1 = MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout
)
self. addnorm1 = AddNorm( norm_shape, dropout)
self. attention2 = MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout
)
self. addnorm2 = AddNorm( norm_shape, dropout)
self. ffn = PositionWiseFFN( ffn_num_input, ffn_num_hiddens, num_hiddens)
self. addnorm3 = AddNorm( norm_shape, dropout)
def forward ( self, X, state) :
enc_outputs, enc_valid_lens = state[ 0 ] , state[ 1 ]
if state[ 2 ] [ self. i] is None :
key_values = X
else :
key_values = torch. cat( ( state[ 2 ] [ self. i] , X) , axis= 1 )
state[ 2 ] [ self. i] = key_values
if self. training:
batch_size, num_steps, _ = X. shape
dec_valid_lens = torch. arange( 1 , num_steps + 1 , device= X. device) . repeat(
batch_size, 1
)
else :
dec_valid_lens = None
X2 = self. attention1( X, key_values, key_values, dec_valid_lens)
Y = self. addnorm1( X, X2)
Y2 = self. attention2( Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self. addnorm2( Y, Y2)
return self. addnorm3( Z, self. ffn( Z) ) , state
class TransformerEncoder ( Encoder) :
def __init__ (
self,
vocab_size,
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
num_layers,
dropout,
use_bias= False ,
** kwargs,
) :
super ( TransformerEncoder, self) . __init__( ** kwargs)
self. num_hiddens = num_hiddens
self. embedding = nn. Embedding( vocab_size, num_hiddens)
self. pos_encoding = PositionalEncoding( num_hiddens, dropout)
self. blks = nn. Sequential( )
for i in range ( num_layers) :
self. blks. add_module(
"block" + str ( i) ,
EncoderBlock(
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
dropout,
use_bias,
) ,
)
def forward ( self, X, valid_lens, * args) :
X = self. pos_encoding( self. embedding( X) * math. sqrt( self. num_hiddens) )
self. attention_weights = [ None ] * len ( self. blks)
for i, blk in enumerate ( self. blks) :
X = blk( X, valid_lens)
self. attention_weights[ i] = blk. attention. attention. attention_weights
return X
class TransformerDecoder ( AttentionDecoder) :
def __init__ (
self,
vocab_size,
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
num_layers,
dropout,
** kwargs,
) :
super ( TransformerDecoder, self) . __init__( ** kwargs)
self. num_hiddens = num_hiddens
self. num_layers = num_layers
self. embedding = nn. Embedding( vocab_size, num_hiddens)
self. pos_encoding = PositionalEncoding( num_hiddens, dropout)
self. blks = nn. Sequential( )
for i in range ( num_layers) :
self. blks. add_module(
"block" + str ( i) ,
DecoderBlock(
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
dropout,
i,
) ,
)
self. dense = nn. Linear( num_hiddens, vocab_size)
def init_state ( self, enc_outputs, enc_valid_lens, * args) :
return [ enc_outputs, enc_valid_lens, [ None ] * self. num_layers]
def forward ( self, X, state) :
X = self. pos_encoding( self. embedding( X) * math. sqrt( self. num_hiddens) )
self. _attention_weights = [ [ None ] * len ( self. blks) for _ in range ( 2 ) ]
for i, blk in enumerate ( self. blks) :
X, state = blk( X, state)
self. _attention_weights[ 0 ] [ i] = blk. attention1. attention. attention_weights
self. _attention_weights[ 1 ] [ i] = blk. attention2. attention. attention_weights
return self. dense( X) , state
@property
def attention_weights ( self) :
return self. _attention_weights
if __name__ == "__main__" :
num_hiddens, num_layers, dropout, batch_size, num_steps = 32 , 2 , 0.1 , 64 , 10
lr, num_epochs, device = 0.005 , 200 , try_gpu( )
ffn_num_input, ffn_num_hiddens, num_heads = 32 , 64 , 4
key_size, query_size, value_size = 32 , 32 , 32
norm_shape = [ 32 ]
train_iter, src_vocab, tgt_vocab = load_data_nmt( batch_size, num_steps)
encoder = TransformerEncoder(
len ( src_vocab) ,
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
num_layers,
dropout,
)
decoder = TransformerDecoder(
len ( tgt_vocab) ,
key_size,
query_size,
value_size,
num_hiddens,
norm_shape,
ffn_num_input,
ffn_num_hiddens,
num_heads,
num_layers,
dropout,
)
net = EncoderDecoder( encoder, decoder)
train_seq2seq( net, train_iter, lr, num_epochs, tgt_vocab, device)
engs = [ "go ." , "i lost ." , "he's calm ." , "i'm home ." ]
fras = [ "va !" , "j'ai perdu ." , "il est calme ." , "je suis chez moi ." ]
for eng, fra in zip ( engs, fras) :
translation, dec_attention_weight_seq = predict_seq2seq(
net, eng, src_vocab, tgt_vocab, num_steps, device, True
)
print ( f" { eng} => { translation} , " , f"bleu { bleu( translation, fra, k= 2 ) : .3f } " )
输出结果
```python
loss 0.034, 10150.2 tokens/sec on cpu
go . => va !, bleu 1.000
i lost . => je vous en ., bleu 0.000
he's calm . => il est calme ., bleu 1.000
i'm home . => je suis chez moi ., bleu 1.000