diciembre 2, 2020
~ 4 MIN
Pytorch Lightning - Hyperparámetros
< Blog RSSPytorch Lightning - Hyperparámetros
En posts anteriores hemos estado aprendiendo a utilizar la librería de Pytorch Lightning
, que nos ayuda mucho a la hora de entrenar redes neuronales. Tras ver los conceptos fundamentales para empezar a trabajar con esta librería, y funcionalidad más avanzada como el cálculo de métricas y callbacks, en este post aprenderemos sobre como podemos manejar los diferentes hyperparámetros de nuestro modelo.
import pytorch_lightning as pl
pl.__version__
'1.0.7'
💡 Puedes instalar pythorch lighning con el comando
pip install pytorch-lightning
.
Por defecto, cualquier variable que le pasemos a nuestro LightningModule en la función __init__
será considerado como un hyperparámetro. Pytorch Lightning
guardará estas variables en el objeto self.hparams
, que podremos utilizar en cualquier lugar, siempre y cuando llamemos a la función self.save_hyperparameters
.
import torch
import torchvision
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, path = '../data', batch_size = 64):
super().__init__()
self.path = path
self.batch_size = batch_size
def setup(self, stage=None):
self.mnist_train = torchvision.datasets.MNIST(
self.path, train=True, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
self.mnist_val = torchvision.datasets.MNIST(
self.path, train=False, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)
from pytorch_lightning.metrics.functional.classification import accuracy
import torch.nn.functional as F
def block(c_in, c_out, k=3, p=1, s=1, pk=2, ps=2):
return torch.nn.Sequential(
torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
torch.nn.ReLU(),
torch.nn.MaxPool2d(pk, stride=ps)
)
class Modelo(pl.LightningModule):
def __init__(self, n_channels=1, n_outputs=10):
super().__init__()
self.save_hyperparameters()
self.conv1 = block(self.hparams.n_channels, 64)
self.conv2 = block(64, 128)
self.fc = torch.nn.Linear(128*7*7, self.hparams.n_outputs)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('loss', loss)
self.log('acc', accuracy(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', accuracy(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
model = Modelo()
model.hparams
"n_channels": 1
"n_outputs": 10
Si bien podemos trabajar de esta manera, la cosa se complicará cuando tengamos muchos hyperparámetros. Además, existen otros que también es importante guardar pero que no le pasaremos al modelo, como por ejemplo el batch size. Para solventar este problema, es común definir todos los hyperparámetros en un solo dict
, que pasaremos al nuestro modelo y guardaremos en la misma función usada anteriormente.
class Modelo(pl.LightningModule):
def __init__(self, config, n_channels = 1, n_outputs = 10):
super().__init__()
self.save_hyperparameters(config)
self.conv1 = block(n_channels, self.hparams.filters1)
self.conv2 = block(self.hparams.filters1, self.hparams.filters2)
self.fc = torch.nn.Linear(self.hparams.filters2*7*7, n_outputs)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('loss', loss)
self.log('acc', accuracy(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', accuracy(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return getattr(torch.optim, self.hparams.optimizer)(self.parameters(), lr=self.hparams.lr)
config = {
'lr': 3e-4,
'optimizer': 'Adam',
'batch_size': 64,
'filters1': 32,
'filters2': 64
}
modelo = Modelo(config)
dm = MNISTDataModule(batch_size=config['batch_size'])
trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dm)
modelo.hparams
"batch_size": 64
"filters1": 32
"filters2": 64
"lr": 0.0003
"optimizer": Adam
Podemos acceder a estos hyperparámetros incluso después de cargar un modelo a partir de un checkpoint. De esta manera, siempre sabremos el conjunto de parámetros utilizados para entrenar nuestro modelo.
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
config = {
'lr': 3e-4,
'optimizer': 'Adam',
'batch_size': 64,
'filters1': 32,
'filters2': 64
}
modelo = Modelo(config)
dm = MNISTDataModule(batch_size=config['batch_size'])
# callbacks
early_stop_callback = EarlyStopping(
monitor='val_acc',
patience=3,
verbose=False,
mode='max'
)
checkpoint = ModelCheckpoint(
dirpath='./',
filename='modelo-{val_acc:.5f}',
save_top_k=1,
monitor='val_acc',
mode='max'
)
# entrenamiento
trainer = pl.Trainer(
gpus=1,
callbacks=[
early_stop_callback,
checkpoint
]
)
trainer.fit(modelo, dm)
modelo = Modelo.load_from_checkpoint(checkpoint_path="modelo-val_acc=0.99120.ckpt")
modelo.hparams
"batch_size": 64
"filters1": 32
"filters2": 64
"lr": 0.0003
"optimizer": Adam
Por último, Pytorch Lightning
también ofrece una manera de interactuar con estos hyperparámetros iterpretando los argumentos pasados al ejectura un script. Puedes aprender más al respecto aquí.
Resumen
En este post hemos visto cómo podemos manejar los diferentes hyperparámetros utilizados para entrenar nuestros modelos: el learning rate, batch size, el optimizador usado, el número de capas convolucionales o filtros... Todos estos parámetros influyen en el resultado obtenido, y es importante guardar todos estos valores junto al modelo para no repetir trabajo o poder comparar modelos entre sí o con otros. Definiendo un dict
con todos los hyperparámetros, lo pasaremos al modelo en su inicialización y simplemente llamando a la función self.save_hyperparameters
, Pytorch Lightning
se encargará de guardar estos valores, hacerlos accesibles a través del objeto self.hparams
y guardarlos y cargarlos en los checkpoints.