diciembre 2, 2020

~ 4 MIN

Pytorch Lightning - Hyperparámetros

< Blog RSS

Open In Colab

Pytorch Lightning - Hyperparámetros

En posts anteriores hemos estado aprendiendo a utilizar la librería de Pytorch Lightning, que nos ayuda mucho a la hora de entrenar redes neuronales. Tras ver los conceptos fundamentales para empezar a trabajar con esta librería, y funcionalidad más avanzada como el cálculo de métricas y callbacks, en este post aprenderemos sobre como podemos manejar los diferentes hyperparámetros de nuestro modelo.

import pytorch_lightning as pl

pl.__version__
'1.0.7'

💡 Puedes instalar pythorch lighning con el comando pip install pytorch-lightning.

Por defecto, cualquier variable que le pasemos a nuestro LightningModule en la función __init__ será considerado como un hyperparámetro. Pytorch Lightning guardará estas variables en el objeto self.hparams, que podremos utilizar en cualquier lugar, siempre y cuando llamemos a la función self.save_hyperparameters.

import torch
import torchvision

class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, path = '../data', batch_size = 64):
        super().__init__()
        self.path = path
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.mnist_train = torchvision.datasets.MNIST(
            self.path, train=True, download=True, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
          )
        self.mnist_val = torchvision.datasets.MNIST(
            self.path, train=False, download=True, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
          )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)
from pytorch_lightning.metrics.functional.classification import accuracy
import torch.nn.functional as F

def block(c_in, c_out, k=3, p=1, s=1, pk=2, ps=2):
    return torch.nn.Sequential(
        torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(pk, stride=ps)
    )

class Modelo(pl.LightningModule):

    def __init__(self, n_channels=1, n_outputs=10):
        super().__init__()
        self.save_hyperparameters()
        self.conv1 = block(self.hparams.n_channels, 64)
        self.conv2 = block(64, 128)
        self.fc = torch.nn.Linear(128*7*7, self.hparams.n_outputs)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('loss', loss)
        self.log('acc', accuracy(y_hat, y), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', accuracy(y_hat, y), prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
model = Modelo()

model.hparams
"n_channels": 1
"n_outputs":  10

Si bien podemos trabajar de esta manera, la cosa se complicará cuando tengamos muchos hyperparámetros. Además, existen otros que también es importante guardar pero que no le pasaremos al modelo, como por ejemplo el batch size. Para solventar este problema, es común definir todos los hyperparámetros en un solo dict, que pasaremos al nuestro modelo y guardaremos en la misma función usada anteriormente.

class Modelo(pl.LightningModule):

    def __init__(self, config, n_channels = 1, n_outputs = 10):
        super().__init__()
        self.save_hyperparameters(config)
        self.conv1 = block(n_channels, self.hparams.filters1)
        self.conv2 = block(self.hparams.filters1, self.hparams.filters2)
        self.fc = torch.nn.Linear(self.hparams.filters2*7*7, n_outputs)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('loss', loss)
        self.log('acc', accuracy(y_hat, y), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', accuracy(y_hat, y), prog_bar=True)

    def configure_optimizers(self):
        return getattr(torch.optim, self.hparams.optimizer)(self.parameters(), lr=self.hparams.lr)
config = {
    'lr': 3e-4,
    'optimizer': 'Adam',
    'batch_size': 64,
    'filters1': 32,
    'filters2': 64
}

modelo = Modelo(config)
dm = MNISTDataModule(batch_size=config['batch_size'])

trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dm)
modelo.hparams
"batch_size": 64
"filters1":   32
"filters2":   64
"lr":         0.0003
"optimizer":  Adam

Podemos acceder a estos hyperparámetros incluso después de cargar un modelo a partir de un checkpoint. De esta manera, siempre sabremos el conjunto de parámetros utilizados para entrenar nuestro modelo.

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

config = {
    'lr': 3e-4,
    'optimizer': 'Adam',
    'batch_size': 64,
    'filters1': 32,
    'filters2': 64
}

modelo = Modelo(config)
dm = MNISTDataModule(batch_size=config['batch_size'])

# callbacks

early_stop_callback = EarlyStopping(
   monitor='val_acc',
   patience=3,
   verbose=False,
   mode='max'
)

checkpoint = ModelCheckpoint(
    dirpath='./',
    filename='modelo-{val_acc:.5f}',
    save_top_k=1,
    monitor='val_acc',
    mode='max'
)

# entrenamiento

trainer = pl.Trainer(
    gpus=1,
    callbacks=[
        early_stop_callback,
        checkpoint
    ]
)

trainer.fit(modelo, dm)
modelo = Modelo.load_from_checkpoint(checkpoint_path="modelo-val_acc=0.99120.ckpt")
modelo.hparams
"batch_size": 64
"filters1":   32
"filters2":   64
"lr":         0.0003
"optimizer":  Adam

Por último, Pytorch Lightning también ofrece una manera de interactuar con estos hyperparámetros iterpretando los argumentos pasados al ejectura un script. Puedes aprender más al respecto aquí.

Resumen

En este post hemos visto cómo podemos manejar los diferentes hyperparámetros utilizados para entrenar nuestros modelos: el learning rate, batch size, el optimizador usado, el número de capas convolucionales o filtros... Todos estos parámetros influyen en el resultado obtenido, y es importante guardar todos estos valores junto al modelo para no repetir trabajo o poder comparar modelos entre sí o con otros. Definiendo un dict con todos los hyperparámetros, lo pasaremos al modelo en su inicialización y simplemente llamando a la función self.save_hyperparameters, Pytorch Lightning se encargará de guardar estos valores, hacerlos accesibles a través del objeto self.hparams y guardarlos y cargarlos en los checkpoints.

< Blog RSS