diciembre 2, 2020
~ 6 MIN
Pytorch Lightning - Métericas y Callbacks
< Blog RSSPytorch Lightning - Métricas y Callbacks
En el post anterior aprendimos los conceptos básicos de la librería Pytorch Lightning, la cual nos ofrece mucha funcionalidad a la hora de entrenar nuestras redes neuronales con Pytorch
. En este post exploraremos dos características muy interesantes: métricas
y callbacks
.
import pytorch_lightning as pl
pl.__version__
'1.0.7'
💡 Puedes instalar pythorch lighning con el comando
pip install pytorch-lightning
.
Métricas
Durante el entrenamiento de nuestros modelos es común calcular y trackerar diferentes métricas que nos ayuden a evaluar lo bueno que es un modelo comparado con otros. De esta manera podemos saber si nuestras decisiones de diseño son acertadas. Por ejemplo, en el caso de un clasificador de imágenes nos puede interesar conocer la precisión del modelo (cuántas imágenes clasifica bien) sobre el conjunto de validación. Para ello Pytorch lightning
nos ofrece un conjunto de métricas comunes listas para utilizar.
Para calcular estas métricas es tan sencillo como importarlas, añadirlas a nuestro modelo y logearlas. Vamos a ver un ejemplo siguiendo el caso del post anterior, en el que traducimos nuestro clasificador de imágenes con el dataset MNIST
de Pytorch
a Pytorch Lightning
.
import torch
import torchvision
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, path = '../data', batch_size = 64):
super().__init__()
self.path = path
self.batch_size = batch_size
def setup(self, stage=None):
self.mnist_train = torchvision.datasets.MNIST(
self.path, train=True, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
self.mnist_val = torchvision.datasets.MNIST(
self.path, train=False, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)
from pytorch_lightning.metrics.functional.classification import accuracy
import torch.nn.functional as F
def block(c_in, c_out, k=3, p=1, s=1, pk=2, ps=2):
return torch.nn.Sequential(
torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
torch.nn.ReLU(),
torch.nn.MaxPool2d(pk, stride=ps)
)
class Modelo(pl.LightningModule):
def __init__(self, n_channels=1, n_outputs=10):
super().__init__()
self.conv1 = block(n_channels, 64)
self.conv2 = block(64, 128)
self.fc = torch.nn.Linear(128*7*7, n_outputs)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('loss', loss)
self.log('acc', accuracy(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', accuracy(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
Ahora, cuando entrenemos nuestro modelo, podremos ver el valor de la precisión en los datos de entrenamiento y validación.
modelo = Modelo()
dm = MNISTDataModule()
trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dm)
Para que nuestras métricas funcionen en entornos distribuídos, sin embargo, es más recomendable trabajar con la API de métricas que nos ofrece Pytorch Lightning
. De esta manera, cualquier comunicación o sincronización entre GPUs será manejada por la librería.
class Modelo(pl.LightningModule):
def __init__(self, n_channels=1, n_outputs=10):
super().__init__()
self.conv1 = block(n_channels, 64)
self.conv2 = block(64, 128)
self.fc = torch.nn.Linear(128*7*7, n_outputs)
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('loss', loss)
self.log('acc', self.train_acc(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', self.val_acc(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
modelo = Modelo()
dm = MNISTDataModule()
trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dm)
De esta manera podemos añadir tantas métricas como queramos.
class Modelo(pl.LightningModule):
def __init__(self, n_channels=1, n_outputs=10):
super().__init__()
self.conv1 = block(n_channels, 64)
self.conv2 = block(64, 128)
self.fc = torch.nn.Linear(128*7*7, n_outputs)
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()
self.train_precision = pl.metrics.Precision(num_classes=n_outputs)
self.val_precision = pl.metrics.Precision(num_classes=n_outputs)
self.train_recall = pl.metrics.Recall(num_classes=n_outputs)
self.val_recall = pl.metrics.Recall(num_classes=n_outputs)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('loss', loss)
self.log('acc', self.train_acc(y_hat, y), prog_bar=True)
self.log('precision', self.train_precision(y_hat, y), prog_bar=True)
self.log('recall', self.train_recall(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', self.val_acc(y_hat, y), prog_bar=True)
self.log('val_precision', self.val_precision(y_hat, y), prog_bar=True)
self.log('val_recall', self.val_recall(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
modelo = Modelo()
dm = MNISTDataModule()
trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dm)
Incluso podemos definir nuestras propias métricas, puedes encontrar más información al respecto aquí.
Callbacks
Otra funcionalidad muy interesante que nos aporta la librería Pytorch Lightning
son las callbacks, funciones que podemos ejectura durante el entrenamiento para modificar su comportamiento de alguna manera. Una callback muy útil es la de early stopping, que parará el entrenamiento cuando se cumplan unas ciertas condiciones. Por ejemplo, si queremos detener el proceso tras 3 epochs seguidas sin mejorar la precisión del modelo en los datos de evaluación.
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
early_stop_callback = EarlyStopping(
monitor='val_acc',
patience=3,
verbose=False,
mode='max'
)
modelo = Modelo()
dm = MNISTDataModule()
trainer = pl.Trainer(
gpus=1,
callbacks=[early_stop_callback]
)
trainer.fit(modelo, dm)
Podemos combinar esta callback con otra que nos guarde el mejor modelo encontrado durante el entrenamiento, de manera que ya no hará falta definir un número de epochs sino que podemos dejar el entrenamiento transcurrir y confiar en las callbacks para obtener el mejor resultado.
from pytorch_lightning.callbacks import ModelCheckpoint
modelo = Modelo()
dm = MNISTDataModule()
# callbacks
early_stop_callback = EarlyStopping(
monitor='val_acc',
patience=3,
verbose=False,
mode='max'
)
checkpoint = ModelCheckpoint(
dirpath='./',
filename='modelo-{val_acc:.5f}',
save_top_k=1,
monitor='val_acc',
mode='max'
)
# entrenamiento
trainer = pl.Trainer(
gpus=1,
callbacks=[
early_stop_callback,
checkpoint
]
)
trainer.fit(modelo, dm)
Una vez terminado el entrenamiento podemos cargar el mejor modelo de la siguiente manera
modelo = Modelo.load_from_checkpoint(checkpoint_path="modelo-val_acc=0.99060.ckpt")
Existen otras callbacks, e incluso puedes crearte las tuyas propias (encuentra toda la información al respecto aquí).
Resumen
En este post hemos hablado de dos de las funcionalidades más interesantes que nos ofrece la librería Pytorch Lightning
. Por una lado, podemos utilizar las métricas implementadas para evaluar nuestros modelos de manera eficiente sin la necesidad de tener que implementar nuestras propias funciones, evitando así posibles fuentes de error. Por otro lado, gracias a las callbacks, podremos modificar el comportamiento por defecto del bucle de entrenamiento. Esto nos permite, entre muchas otras cosas, implementar la técnica de early stopping o guardar diferentes checkpoints durante el entrenamiento. En ambos casos podremos definir nuestras propias métricas y callbacks, si así lo deseamos, que funcionarán también en entornos distribuidos.