septiembre 4, 2020
~ 14 MIN
Traducción de texto con Atención
< Blog RSSMecanismos de Atención
En el post anterior aprendimos a implementar una arquitectura de red neuronal conocida como seq2seq
, que utiliza dos redes neuronales (el encoder
y el decoder
) para poder trabajar con secuencias de longitud arbitraria tanto a sus entradas como en las salidas. Este modelo nos permite llevar a cabo tareas tales como la traducción de texto entre dos idiomas, resumir un texto, responder preguntas, etc.
Si bien este modelo nos dio buenos resultados, podemos mejorarlo. Si prestamos atención a la arquitectura que desarrollamos, el decoder
(encargado de generar la secuencia de salida) es inicializado con el último estado oculto del encoder
, el cual tiene la responsabilidad de codificar el significado de toda la frase original. Esto puede ser complicado, sobre todo al trabajar con secuencias muy largas, y para solventar este problema podemos utilizar un mecanismo de atención
que no solo reciba el último estado oculto si no también tenga acceso a todas las salidas del encoder
de manera que el decoder
sea capaz de "focalizar su atención" en aquellas partes más importantes. Por ejemplo, para traducir la primera palabra es lógico pensar que lo más importante será la primera palabra y sus adyacentes en la frase original, pero usar el último estado oculto del encoder
puede no ser suficiente para mantener estas relaciones a largo plazo. Permitir al decoder
acceder a esta información puede resultar en mejores prestaciones.
💡 En la práctica, los mecanismos de atención dan muy buenos resultados en tareas que envuelvan datos secuenciales (como aplicaciones de lenguaje). De hecho, los mejores modelos a día de hoy para tareas de
NLP
no están basados en redes recurrentes sino en arquitecturas que únicamente implementan mecanismos de atención en varias capas. Estas redes neuronales son conocidas comoTransformers
.
El dataset
Vamos a resolver exactamente el mismo caso que en el post anterior, así que todo lo que hace referencia al procesado de datos lo dejaremos igual.
import unicodedata
import re
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s
def read_file(file, reverse=False):
# Read the file and split into lines
lines = open(file, encoding='utf-8').read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in l.split('\t')[:2]] for l in lines]
return pairs
pairs = read_file('spa.txt')
import random
random.choice(pairs)
['graham greene is my favorite author .',
'graham greene es mi escritor favorito .']
SOS_token = 0
EOS_token = 1
PAD_token = 2
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {"SOS": 0, "EOS": 1, "PAD": 2}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS", 2: "PAD"}
self.n_words = 3 # Count SOS, EOS and PAD
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
def indexesFromSentence(self, sentence):
return [self.word2index[word] for word in sentence.split(' ')]
def sentenceFromIndex(self, index):
return [self.index2word[ix] for ix in index]
Para poder aplicar la capa de attention
necesitamos que nuestras frases tengan una longitud máxima definida.
MAX_LENGTH = 10
eng_prefixes = (
"i am ", "i m ",
"he is", "he s ",
"she is", "she s ",
"you are", "you re ",
"we are", "we re ",
"they are", "they re "
)
def filterPairs(pairs, filters, lang=0):
return [p for p in pairs if p[lang].startswith(filters)]
def trimPairs(pairs):
return [p for p in pairs if len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH]
def prepareData(file, filters=None, reverse=False):
pairs = read_file(file, reverse)
print(f"Tenemos {len(pairs)} pares de frases")
if filters is not None:
pairs = filterPairs(pairs, filters, int(reverse))
print(f"Filtramos a {len(pairs)} pares de frases")
pairs = trimPairs(pairs)
print(f"Tenemos {len(pairs)} pares de frases con longitud menor de {MAX_LENGTH}")
# Reverse pairs, make Lang instances
if reverse:
pairs = [list(reversed(p)) for p in pairs]
input_lang = Lang('eng')
output_lang = Lang('spa')
else:
input_lang = Lang('spa')
output_lang = Lang('eng')
for pair in pairs:
input_lang.addSentence(pair[0])
output_lang.addSentence(pair[1])
# add <eos> token
pair[0] += " EOS"
pair[1] += " EOS"
print("Longitud vocabularios:")
print(input_lang.name, input_lang.n_words)
print(output_lang.name, output_lang.n_words)
return input_lang, output_lang, pairs
input_lang, output_lang, pairs = prepareData('spa.txt')
# descomentar para usar el dataset filtrado
#input_lang, output_lang, pairs = prepareData('spa.txt', filters=eng_prefixes)
random.choice(pairs)
Tenemos 124547 pares de frases
Tenemos 95071 pares de frases con longitud menor de 10
Longitud vocabularios:
spa 10881
eng 20659
['she still loved him . EOS', 'ella aun lo amaba . EOS']
output_lang.indexesFromSentence('tengo mucha sed .')
[68, 5028, 135, 4]
output_lang.sentenceFromIndex([3, 1028, 647, 5])
['ve', 'cd', 'mio', 'vete']
En el Dataset
nos aseguraremos de añadir el padding necesario para que todas las frases tengan la misma longitud, lo cual no hace necesario utilizar la función collate
que implementamos en el post anterior.
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
class Dataset(torch.utils.data.Dataset):
def __init__(self, input_lang, output_lang, pairs, max_length):
self.input_lang = input_lang
self.output_lang = output_lang
self.pairs = pairs
self.max_length = max_length
def __len__(self):
return len(self.pairs)
def __getitem__(self, ix):
inputs = torch.tensor(self.input_lang.indexesFromSentence(self.pairs[ix][0]), device=device, dtype=torch.long)
outputs = torch.tensor(self.output_lang.indexesFromSentence(self.pairs[ix][1]), device=device, dtype=torch.long)
# metemos padding a todas las frases hast a la longitud máxima
return torch.nn.functional.pad(inputs, (0, self.max_length - len(inputs)), 'constant', self.input_lang.word2index['PAD']), \
torch.nn.functional.pad(outputs, (0, self.max_length - len(outputs)), 'constant', self.output_lang.word2index['PAD'])
# separamos datos en train-test
train_size = len(pairs) * 80 // 100
train = pairs[:train_size]
test = pairs[train_size:]
dataset = {
'train': Dataset(input_lang, output_lang, train, max_length=MAX_LENGTH),
'test': Dataset(input_lang, output_lang, test, max_length=MAX_LENGTH)
}
len(dataset['train']), len(dataset['test'])
(76056, 19015)
input_sentence, output_sentence = dataset['train'][1]
input_sentence, output_sentence
(tensor([3, 4, 1, 2, 2, 2, 2, 2, 2, 2], device='cuda:0'),
tensor([5, 4, 1, 2, 2, 2, 2, 2, 2, 2], device='cuda:0'))
input_lang.sentenceFromIndex(input_sentence.tolist()), output_lang.sentenceFromIndex(output_sentence.tolist())
(['go', '.', 'EOS', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD'],
['vete', '.', 'EOS', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD'])
dataloader = {
'train': torch.utils.data.DataLoader(dataset['train'], batch_size=64, shuffle=True),
'test': torch.utils.data.DataLoader(dataset['test'], batch_size=256, shuffle=False),
}
inputs, outputs = next(iter(dataloader['train']))
inputs.shape, outputs.shape
(torch.Size([64, 10]), torch.Size([64, 10]))
El modelo
En lo que se refiere al encoder
, seguimos usando exactamente la misma arquitectura. La única diferencia es que, además del último estado oculto, necesitaremos todas sus salidas para que el decoder
pueda usarlas.
class Encoder(torch.nn.Module):
def __init__(self, input_size, embedding_size=100, hidden_size=100, n_layers=2):
super().__init__()
self.hidden_size = hidden_size
self.embedding = torch.nn.Embedding(input_size, embedding_size)
self.gru = torch.nn.GRU(embedding_size, hidden_size, num_layers=n_layers, batch_first=True)
def forward(self, input_sentences):
embedded = self.embedding(input_sentences)
outputs, hidden = self.gru(embedded)
return outputs, hidden
encoder = Encoder(input_size=input_lang.n_words)
encoder_outputs, encoder_hidden = encoder(torch.randint(0, input_lang.n_words, (64, 10)))
# [batch size, seq len, hidden size]
encoder_outputs.shape
torch.Size([64, 10, 100])
# [num layers, batch size, hidden size]
encoder_hidden.shape
torch.Size([2, 64, 100])
El decoder con attention
Vamos a ver un ejemplo de implementación de una capa de atención para nuestro decoder
. En primer lugar tendremos una capa lineal que recibirá como entradas los embeddings
y el estado oculto anterior (concatenados). Esta capa lineal nos dará a la salida tantos valores como elementos tengamos en nuestras secuencias de entrada (recuerda que las hemos forzado a tener una longitud determinada). Después, aplicaremos una función softmax
sobre estos valores obteniendo así una distribución de probabilidad que, seguidamente, multiplicaremos por los outputs del encoder (que también tienen la misma longitud). En esta función de probabilidad, cada elemento tiene un valor entre 0 y 1. Así pues, esta operación dará más importancia a aquellos outputs del encoder
más importantes mientras que al resto les asignará unos valores cercanos a 0. A continuación, concatenaremos estos valores con los embeddings
, de nuevo, y se lo daremos a una nueva capa lineal que combinará estos embeddings
con los outputs del encoder
re-escalados para obtener así los inputs finales de la capa recurrente.
En resumen, usaremos las entradas y estado oculto del decoder
para encontrar unos pesos que re-escalarán las salidas del encoder
, los cuales combinaremos de nuevo con las entradas del decoder
para obtener las representaciones finales de nuestras frases que alimentan la capa recurrente.
class AttnDecoder(torch.nn.Module):
def __init__(self, input_size, embedding_size=100, hidden_size=100, n_layers=2, max_length=MAX_LENGTH):
super().__init__()
self.embedding = torch.nn.Embedding(input_size, embedding_size)
self.gru = torch.nn.GRU(embedding_size, hidden_size, num_layers=n_layers, batch_first=True)
self.out = torch.nn.Linear(hidden_size, input_size)
# attention
self.attn = torch.nn.Linear(hidden_size + embedding_size, max_length)
self.attn_combine = torch.nn.Linear(hidden_size * 2, hidden_size)
def forward(self, input_words, hidden, encoder_outputs):
# sacamos los embeddings
embedded = self.embedding(input_words)
# calculamos los pesos de la capa de atención
attn_weights = torch.nn.functional.softmax(self.attn(torch.cat((embedded.squeeze(1), hidden[0]), dim=1)))
# re-escalamos los outputs del encoder con estos pesos
attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
output = torch.cat((embedded.squeeze(1), attn_applied.squeeze(1)), 1)
# aplicamos la capa de atención
output = self.attn_combine(output)
output = torch.nn.functional.relu(output)
# a partir de aquí, como siempre. La diferencia es que la entrada a la RNN
# no es directmanete el embedding sino una combinación del embedding
# y las salidas del encoder re-escaladas
output, hidden = self.gru(output.unsqueeze(1), hidden)
output = self.out(output.squeeze(1))
return output, hidden, attn_weights
decoder = AttnDecoder(input_size=output_lang.n_words)
decoder_output, decoder_hidden, attn_weights = decoder(torch.randint(0, output_lang.n_words, (64, 1)), encoder_hidden, encoder_outputs)
# [batch size, vocab size]
decoder_output.shape
C:\Users\sensio\miniconda3\lib\site-packages\ipykernel_launcher.py:16: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
app.launch_new_instance()
torch.Size([64, 20659])
# [num layers, batch size, hidden size]
decoder_hidden.shape
torch.Size([2, 64, 100])
# [num layers, batch size, hidden size]
decoder_hidden.shape
torch.Size([2, 64, 100])
# [batch size, max_length]
attn_weights.shape
torch.Size([64, 10])
Entrenamiento
Vamos a implementar el bucle de entrenamiento. En primer lugar, al tener ahora dos redes neuronales, necesitaremos dos optimizadores (uno para el encoder
y otro para el decoder
). Al encoder
le pasaremos la frase en el idioma original, y obtendremos el estado oculto final. Este estado oculto lo usaremos para inicializar el decoder
que, junto al token <sos>
, generará la primera palabra de la frase traducida. Repetiremos el proceso, utilizando como entrada la anterior salida del decoder, hasta obtener el token <eos>
.
from tqdm import tqdm
import numpy as np
def fit(encoder, decoder, dataloader, epochs=10):
encoder.to(device)
decoder.to(device)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, epochs+1):
encoder.train()
decoder.train()
train_loss = []
bar = tqdm(dataloader['train'])
for batch in bar:
input_sentences, output_sentences = batch
bs = input_sentences.shape[0]
loss = 0
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
# obtenemos el último estado oculto del encoder
encoder_outputs, hidden = encoder(input_sentences)
# calculamos las salidas del decoder de manera recurrente
decoder_input = torch.tensor([[output_lang.word2index['SOS']] for b in range(bs)], device=device)
for i in range(output_sentences.shape[1]):
output, hidden, attn_weights = decoder(decoder_input, hidden, encoder_outputs)
loss += criterion(output, output_sentences[:, i].view(bs))
# el siguiente input será la palabra predicha
decoder_input = torch.argmax(output, axis=1).view(bs, 1)
# optimización
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
train_loss.append(loss.item())
bar.set_description(f"Epoch {epoch}/{epochs} loss {np.mean(train_loss):.5f}")
val_loss = []
encoder.eval()
decoder.eval()
with torch.no_grad():
bar = tqdm(dataloader['test'])
for batch in bar:
input_sentences, output_sentences = batch
bs = input_sentences.shape[0]
loss = 0
# obtenemos el último estado oculto del encoder
encoder_outputs, hidden = encoder(input_sentences)
# calculamos las salidas del decoder de manera recurrente
decoder_input = torch.tensor([[output_lang.word2index['SOS']] for b in range(bs)], device=device)
for i in range(output_sentences.shape[1]):
output, hidden, attn_weights = decoder(decoder_input, hidden, encoder_outputs)
loss += criterion(output, output_sentences[:, i].view(bs))
# el siguiente input será la palabra predicha
decoder_input = torch.argmax(output, axis=1).view(bs, 1)
val_loss.append(loss.item())
bar.set_description(f"Epoch {epoch}/{epochs} val_loss {np.mean(val_loss):.5f}")
fit(encoder, decoder, dataloader, epochs=30)
0%| | 0/1189 [00:00<?, ?it/s]C:\Users\sensio\miniconda3\lib\site-packages\ipykernel_launcher.py:16: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
app.launch_new_instance()
Epoch 1/30 loss 34.73953: 100%|██████████████| 1189/1189 [00:47<00:00, 25.01it/s]
Epoch 1/30 val_loss 46.46139: 100%|██████████████| 75/75 [00:04<00:00, 18.29it/s]
Epoch 2/30 loss 27.59262: 100%|██████████████| 1189/1189 [00:47<00:00, 25.13it/s]
Epoch 2/30 val_loss 42.88374: 100%|██████████████| 75/75 [00:04<00:00, 18.24it/s]
Epoch 3/30 loss 24.26190: 100%|██████████████| 1189/1189 [00:47<00:00, 25.17it/s]
Epoch 3/30 val_loss 40.64831: 100%|██████████████| 75/75 [00:04<00:00, 18.23it/s]
Epoch 4/30 loss 21.97039: 100%|██████████████| 1189/1189 [00:47<00:00, 25.00it/s]
Epoch 4/30 val_loss 39.81519: 100%|██████████████| 75/75 [00:04<00:00, 18.11it/s]
Epoch 5/30 loss 20.13848: 100%|██████████████| 1189/1189 [00:47<00:00, 24.90it/s]
Epoch 5/30 val_loss 38.86808: 100%|██████████████| 75/75 [00:04<00:00, 17.94it/s]
Epoch 6/30 loss 18.60935: 100%|██████████████| 1189/1189 [00:47<00:00, 24.91it/s]
Epoch 6/30 val_loss 38.34785: 100%|██████████████| 75/75 [00:04<00:00, 17.98it/s]
Epoch 7/30 loss 17.34123: 100%|██████████████| 1189/1189 [00:47<00:00, 24.96it/s]
Epoch 7/30 val_loss 38.11022: 100%|██████████████| 75/75 [00:04<00:00, 17.86it/s]
Epoch 8/30 loss 16.24881: 100%|██████████████| 1189/1189 [00:47<00:00, 24.97it/s]
Epoch 8/30 val_loss 37.97426: 100%|██████████████| 75/75 [00:04<00:00, 17.94it/s]
Epoch 9/30 loss 15.29087: 100%|██████████████| 1189/1189 [00:47<00:00, 24.95it/s]
Epoch 9/30 val_loss 38.15921: 100%|██████████████| 75/75 [00:04<00:00, 18.03it/s]
Epoch 10/30 loss 14.46075: 100%|█████████████| 1189/1189 [00:47<00:00, 24.95it/s]
Epoch 10/30 val_loss 38.56259: 100%|█████████████| 75/75 [00:04<00:00, 17.95it/s]
Epoch 11/30 loss 13.72210: 100%|█████████████| 1189/1189 [00:48<00:00, 24.76it/s]
Epoch 11/30 val_loss 38.83263: 100%|█████████████| 75/75 [00:04<00:00, 17.89it/s]
Epoch 12/30 loss 13.08170: 100%|█████████████| 1189/1189 [00:47<00:00, 24.82it/s]
Epoch 12/30 val_loss 38.89201: 100%|█████████████| 75/75 [00:04<00:00, 17.93it/s]
Epoch 13/30 loss 12.50561: 100%|█████████████| 1189/1189 [00:47<00:00, 24.89it/s]
Epoch 13/30 val_loss 39.27608: 100%|█████████████| 75/75 [00:04<00:00, 17.98it/s]
Epoch 14/30 loss 11.96744: 100%|█████████████| 1189/1189 [00:47<00:00, 24.87it/s]
Epoch 14/30 val_loss 39.73010: 100%|█████████████| 75/75 [00:04<00:00, 17.99it/s]
Epoch 15/30 loss 11.50962: 100%|█████████████| 1189/1189 [00:47<00:00, 24.85it/s]
Epoch 15/30 val_loss 40.08352: 100%|█████████████| 75/75 [00:04<00:00, 17.72it/s]
Epoch 16/30 loss 11.11470: 100%|█████████████| 1189/1189 [00:47<00:00, 24.87it/s]
Epoch 16/30 val_loss 40.64367: 100%|█████████████| 75/75 [00:04<00:00, 17.66it/s]
Epoch 17/30 loss 10.72270: 100%|█████████████| 1189/1189 [00:48<00:00, 24.72it/s]
Epoch 17/30 val_loss 40.76052: 100%|█████████████| 75/75 [00:04<00:00, 17.77it/s]
Epoch 18/30 loss 10.36965: 100%|█████████████| 1189/1189 [00:47<00:00, 24.82it/s]
Epoch 18/30 val_loss 40.93134: 100%|█████████████| 75/75 [00:04<00:00, 17.86it/s]
Epoch 19/30 loss 10.05808: 100%|█████████████| 1189/1189 [00:47<00:00, 24.83it/s]
Epoch 19/30 val_loss 41.70704: 100%|█████████████| 75/75 [00:04<00:00, 17.78it/s]
Epoch 20/30 loss 9.77267: 100%|██████████████| 1189/1189 [00:47<00:00, 24.80it/s]
Epoch 20/30 val_loss 41.96183: 100%|█████████████| 75/75 [00:04<00:00, 17.85it/s]
Epoch 21/30 loss 9.52629: 100%|██████████████| 1189/1189 [00:48<00:00, 24.75it/s]
Epoch 21/30 val_loss 42.15135: 100%|█████████████| 75/75 [00:04<00:00, 17.96it/s]
Epoch 22/30 loss 9.28994: 100%|██████████████| 1189/1189 [00:47<00:00, 24.78it/s]
Epoch 22/30 val_loss 42.74523: 100%|█████████████| 75/75 [00:04<00:00, 17.66it/s]
Epoch 23/30 loss 9.06947: 100%|██████████████| 1189/1189 [00:48<00:00, 24.72it/s]
Epoch 23/30 val_loss 43.25123: 100%|█████████████| 75/75 [00:04<00:00, 17.78it/s]
Epoch 24/30 loss 8.85993: 100%|██████████████| 1189/1189 [00:48<00:00, 24.77it/s]
Epoch 24/30 val_loss 43.17854: 100%|█████████████| 75/75 [00:04<00:00, 17.83it/s]
Epoch 25/30 loss 8.68826: 100%|██████████████| 1189/1189 [00:47<00:00, 24.88it/s]
Epoch 25/30 val_loss 43.50802: 100%|█████████████| 75/75 [00:04<00:00, 18.03it/s]
Epoch 26/30 loss 8.48245: 100%|██████████████| 1189/1189 [00:47<00:00, 24.78it/s]
Epoch 26/30 val_loss 44.17614: 100%|█████████████| 75/75 [00:04<00:00, 17.86it/s]
Epoch 27/30 loss 8.32397: 100%|██████████████| 1189/1189 [00:48<00:00, 24.74it/s]
Epoch 27/30 val_loss 44.76594: 100%|█████████████| 75/75 [00:04<00:00, 17.82it/s]
Epoch 28/30 loss 8.18032: 100%|██████████████| 1189/1189 [00:48<00:00, 24.76it/s]
Epoch 28/30 val_loss 44.82165: 100%|█████████████| 75/75 [00:04<00:00, 17.96it/s]
Epoch 29/30 loss 8.02375: 100%|██████████████| 1189/1189 [00:47<00:00, 24.89it/s]
Epoch 29/30 val_loss 44.95020: 100%|█████████████| 75/75 [00:04<00:00, 18.27it/s]
Epoch 30/30 loss 7.88312: 100%|██████████████| 1189/1189 [00:48<00:00, 24.76it/s]
Epoch 30/30 val_loss 45.47675: 100%|█████████████| 75/75 [00:04<00:00, 17.77it/s]
Generando traducciones
Una vez tenemos nuestro modelo entrenado, podemos utilizarlo para traducir frases del inglés al castellano de la siguiente manera.
input_sentence, output_sentence = dataset['train'][100]
input_lang.sentenceFromIndex(input_sentence.tolist()), output_lang.sentenceFromIndex(output_sentence.tolist())
(['really', '?', 'EOS', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD'],
['', 'en', 'serio', '?', 'EOS', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD'])
def predict(input_sentence):
# obtenemos el último estado oculto del encoder
encoder_outputs, hidden = encoder(input_sentence.unsqueeze(0))
# calculamos las salidas del decoder de manera recurrente
decoder_input = torch.tensor([[output_lang.word2index['SOS']]], device=device)
# iteramos hasta que el decoder nos de el token <eos>
outputs = []
decoder_attentions = torch.zeros(MAX_LENGTH, MAX_LENGTH)
i = 0
while True:
output, hidden, attn_weights = decoder(decoder_input, hidden, encoder_outputs)
decoder_attentions[i] = attn_weights.data
i += 1
decoder_input = torch.argmax(output, axis=1).view(1, 1)
outputs.append(decoder_input.cpu().item())
if decoder_input.item() == output_lang.word2index['EOS']:
break
return output_lang.sentenceFromIndex(outputs), decoder_attentions
output_words, attn = predict(input_sentence)
output_words
C:\Users\sensio\miniconda3\lib\site-packages\ipykernel_launcher.py:16: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
app.launch_new_instance()
['', 'de', 'verdad', '?', 'EOS']
Visualización de atención
Una de las ventajas que nos da la capa de atención es que nos permite visualizar en qué partes de los inputs se fija el modelo para generar cada una de las palabras en el output, dando un grado de explicabilidad a nuestro modelo (una propiedad siempre deseada en nuestro modelos de Machine Learning
).
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def showAttention(input_sentence, output_words, attentions):
lim1, lim2 = input_sentence.index('EOS')+1, output_words.index('EOS')+1
fig = plt.figure(dpi=100)
ax = fig.add_subplot(111)
cax = ax.matshow(attentions[:lim2, :lim1].numpy(), cmap='bone')
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([' '] + input_sentence[:lim1], rotation=90)
ax.set_yticklabels([' '] + output_words)
# Show label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
showAttention(input_lang.sentenceFromIndex(input_sentence.tolist()), output_words, attn)
Resumen
En este post hemos visto como introducir mecanismos de atención en nuestra arquitectura encoder-decoder
, los cuales permiten a nuestra red neuronal focalizarse en partes concretas de los inputs a la hora de generar los outputs. Esta nueva capa no solo puede mejorar nuestros modelos sino que además también es interpretable, dándonos una idea del razonamiento detrás de las predicciones de nuestro modelo. Las redes neuronales con mejores prestaciones a día de hoy en tareas de NLP
, los transformers
, están basados enteramente en este tipo de capas de atención.