diciembre 2, 2020
~ 9 MIN
Pytorch Lightning - Introducción
< Blog RSSPytorch Lightning
Si has ido siguiendo los diferentes posts de este blog, es posible que hayas entrenado varias redes neuronales utilizando la librería Pytorch
. De ser así, quizás has tenido la sensación de estar repitiendo el mismo código una y otra vez, sobre todo en lo referente al bucle de entrenamiento. Además, es posible que hayas tenido problemas intentando implementar funcionalidad más avanzada, asegurándote que todo funciona como debería sin errores. ¿No sería estupendo tener una librería que implementase por nosotros todo este códigp boilerplate sin perder la flexibilidad que nos ofrece Pytorch
? Pues estás de enhorabuena, porque tal librería existe, y se llama Pytorch Lightning ⚡️.
import pytorch_lightning as pl
pl.__version__
'1.0.7'
💡 Puedes instalar pythorch lighning con el comando
pip install pytorch-lightning
.
En este post aprenderemos los conceptos básicos de esta librería, entrenando un modelo simple para clasificación de imágenes con el dataset MNIST
como ya hemos hecho en varios posts anteriores, así podremos compara directamente ambas opciones.
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np
# preparamos los datos
dataloader = {
'train': torch.utils.data.DataLoader(torchvision.datasets.MNIST('../data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
), batch_size=2048, shuffle=True, pin_memory=True),
'test': torch.utils.data.DataLoader(torchvision.datasets.MNIST('../data', train=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
), batch_size=2048, shuffle=False, pin_memory=True)
}
Pytorch
Para entrenar nuestro modelo usando puro Pytorch
, primero definimos nuestra red neuronal creando una clase que derive de torch.nn.Module
en la que tenemos que definir la función __init__
, con las diferentes capas del modelo, y forward
, con todas las operaciones necesarias para calcular las salidas de la red a partir de las entradas.
# definimos el modelo
def block(c_in, c_out, k=3, p=1, s=1, pk=2, ps=2):
return torch.nn.Sequential(
torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
torch.nn.ReLU(),
torch.nn.MaxPool2d(pk, stride=ps)
)
class CNN(torch.nn.Module):
def __init__(self, n_channels=1, n_outputs=10):
super().__init__()
self.conv1 = block(n_channels, 64)
self.conv2 = block(64, 128)
self.fc = torch.nn.Linear(128*7*7, n_outputs)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
Una vez tenemos nuestro modelo, necesitamos definir la lógica de entrenamiento. En este paso, Pytorch
nos da total libertad para hacerlo de la manera en la que queramos. Una opción que hemos utilizado en posts anteriores es definir una función fit
, a la cual le pasaremos nuestro modelo y datos, y que se encargará de todo.
# entrenamos el modelo
device = "cuda" if torch.cuda.is_available() else "cpu"
def fit(model, dataloader, epochs=5):
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, epochs+1):
model.train()
train_loss, train_acc = [], []
bar = tqdm(dataloader['train'])
for batch in bar:
X, y = batch
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
y_hat = model(X)
loss = criterion(y_hat, y)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
train_acc.append(acc)
bar.set_description(f"loss {np.mean(train_loss):.5f} acc {np.mean(train_acc):.5f}")
bar = tqdm(dataloader['test'])
val_loss, val_acc = [], []
model.eval()
with torch.no_grad():
for batch in bar:
X, y = batch
X, y = X.to(device), y.to(device)
y_hat = model(X)
loss = criterion(y_hat, y)
val_loss.append(loss.item())
acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
val_acc.append(acc)
bar.set_description(f"val_loss {np.mean(val_loss):.5f} val_acc {np.mean(val_acc):.5f}")
print(f"Epoch {epoch}/{epochs} loss {np.mean(train_loss):.5f} val_loss {np.mean(val_loss):.5f} acc {np.mean(train_acc):.5f} val_acc {np.mean(val_acc):.5f}")
En nuestra función fit
hemos implementado la lógica de entrenamiento y evaluación del modelo, el cálculo de su precisión a la vez que imprimimos por pantalla la información más relevante durante el entrenamiento en una barra de progreso.
model = CNN()
fit(model, dataloader)
loss 0.61840 acc 0.83090: 100%|██████████| 30/30 [00:07<00:00, 3.79it/s]
val_loss 0.20638 val_acc 0.94002: 100%|██████████| 5/5 [00:01<00:00, 4.56it/s]
0%| | 0/30 [00:00<?, ?it/s]
Epoch 1/5 loss 0.61840 val_loss 0.20638 acc 0.83090 val_acc 0.94002
loss 0.14674 acc 0.95707: 100%|██████████| 30/30 [00:07<00:00, 4.08it/s]
val_loss 0.08969 val_acc 0.97255: 100%|██████████| 5/5 [00:01<00:00, 4.50it/s]
0%| | 0/30 [00:00<?, ?it/s]
Epoch 2/5 loss 0.14674 val_loss 0.08969 acc 0.95707 val_acc 0.97255
loss 0.08559 acc 0.97516: 100%|██████████| 30/30 [00:07<00:00, 4.07it/s]
val_loss 0.05989 val_acc 0.98230: 100%|██████████| 5/5 [00:01<00:00, 4.57it/s]
0%| | 0/30 [00:00<?, ?it/s]
Epoch 3/5 loss 0.08559 val_loss 0.05989 acc 0.97516 val_acc 0.98230
loss 0.06402 acc 0.98114: 100%|██████████| 30/30 [00:07<00:00, 4.09it/s]
val_loss 0.05274 val_acc 0.98333: 100%|██████████| 5/5 [00:01<00:00, 4.56it/s]
0%| | 0/30 [00:00<?, ?it/s]
Epoch 4/5 loss 0.06402 val_loss 0.05274 acc 0.98114 val_acc 0.98333
loss 0.05296 acc 0.98434: 100%|██████████| 30/30 [00:07<00:00, 4.08it/s]
val_loss 0.04791 val_acc 0.98432: 100%|██████████| 5/5 [00:01<00:00, 4.54it/s]
Epoch 5/5 loss 0.05296 val_loss 0.04791 acc 0.98434 val_acc 0.98432
Pero, ¿que pasaría si quisiésemos implementar técnicas como early stopping o guardar el mejor modelo durante el entrenamiento ?, ¿o calcular otras métricas más allá de la precisión?, ¿o entrenar otros modelos más complicados que requieran de bucles más elaborados? En todos estos casos tendríamos que modificar nuestra función fit
, corriendo el riesgo de introducir bugs, y resultando en una implementación distinta para cada aplicación en la que trabajemos.
El LightningModule
Para solucionar estos problemas, Pytorch Lightning
nos ofrece la clase LightningModule
, la cual podemos utilizar para definir nuestros modelos de la siguiente manera:
class Modelo(pl.LightningModule):
# igual que antes
def __init__(self, n_channels=1, n_outputs=10):
super().__init__()
self.conv1 = block(n_channels, 64)
self.conv2 = block(64, 128)
self.fc = torch.nn.Linear(128*7*7, n_outputs)
# igual que antes
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
# lógica de entrenamiento
def training_step(self, batch, batch_idx):
# no hace falta enviar nada a la gpu
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('loss', loss)
return loss
# no necesitamos llamar a loss.backward() ni optimier.step()
# pytorch lightning se encarga por nosotros
# optimizador
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
Como puedes observar, las clases __init__
y forward
son exactamente iguales que en la implementación original. Sin embargo, hemos añadido dos nuevas funciones: training_step
, en la que calculamos la salida de la red y devolvemos la función de pérdida, y configure_optimizers
, en la que devolvemos el optimizador. Pytorch Lightning
se encarga de mover los datos a la GPU si es necesario, así como de llamar a las funciones loss.backward
, optimizer.zero_grad
y optimizer.step
.
El Lightning Trainer
Gracias a la implementación aterior, ahora podemos entrenar nuestro modelo de manera sencilla con el lightning trainer. Primero, instanciaremos un trainer
al cual le podemos pasar parámetros tales como el número de epochs. Una vez definido el trainer
, podemos entrenar nuestro modelo simplemente llamando a su función fit
, pasándole como parámteros nuestro modelo y el dataloader de entrenamiento.
modelo = Modelo()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(modelo, dataloader['train'])
Pytorch Lightning
nos da información interesante al principio del entrenamiento, como el hardware disponible y usado así como un resumen de nuestro modelo y su número de parámetros. Cuando empiece el entrenamiento, veremos una barra de progreso indicándonos la epoch en la que nos encontramos y el valor de la función de pérdida.
Entrenando en GPUs
Una de las principales ventajas de Pytorch Lightning
es lo sencillo que es entrenar en diferente hardware. Para entrenar nuestro modelo en una GPU, simplemente le pasamo la variabale gpu
al trainer
.
modelo = Modelo()
trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dataloader['train'])
Como puedes ver, indicando gpus=1
entrenaremos nuestro modelo en una GPU. Y es que si disponemos de más de una GPU, podremos indicar el número y Pytorch Lightning
se encargará de distribuir el entrenamiento entre todas ellas 🔥. También podremos indicar un número de nodos, GPUs por nodo e incluso TPUs. Todo ello, sin cambiar ni una línea de código.
Usando datos de validación
Para evaluar nuestro modelo a la vez que lo entrenamos, simplemente tenemos que definir la función validation_step
en el LightningModule. Podemos ver cualquier información en la barra de progreso con la función self.log
con la variable prog_bar=True
.
class Modelo(pl.LightningModule):
def __init__(self, n_channels=1, n_outputs=10):
super().__init__()
self.conv1 = block(n_channels, 64)
self.conv2 = block(64, 128)
self.fc = torch.nn.Linear(128*7*7, n_outputs)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
Ahora, pasaremos nuestro dataloader de validación también en el trainer
.
modelo = Modelo()
trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dataloader['train'], dataloader['test'])
El LightningDataModule
De la misma forma que podemos encapsular nuestro modelo y la lógica de entrenamiento directamente en una sola clase, podemos hacer algo similar para nuestro dataset.
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, path = '../data', batch_size = 64):
super().__init__()
self.path = path
self.batch_size = batch_size
def setup(self, stage=None):
self.mnist_train = torchvision.datasets.MNIST(
self.path, train=True, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
self.mnist_val = torchvision.datasets.MNIST(
self.path, train=False, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)
La función setup
se encargará de descargar los datos y procesarlos como sea necesario para generar los datasets. Después, utilizaremos las funciones train_dataloader
y val_dataloader
para generar los dataloaders. Una vez definido el LightningDataModule, podemos entrenar nuestro modelo de manera simple.
modelo = Modelo()
dm = MNISTDataModule()
trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dm)
Resumen
En este post hemos aprendido los conceptos básicos de la librería Pytorch Lightning
, la cual nos ayudará a ser más eficientes a la hora de entrenar modelos en Pytorch
. En primer lugar, el objeto LightningModule
reemplaza al objeto torch.nn.Module
a la hora de definir nuestros modelos. Además, podremos indicarle la lógica para entrenar y validar nuestro modelo de manera simple. En segundo lugar, el objeto LightningDataModule
nos permite encapsular la lógica de descarga y preparación de los datos. Estas dos clases es lo único que necesitaremos para llevar todo el proceso a cabo, por lo que nuestro código quedará mucho más ordenado y también será más reproducible. Gracias a estas definiciones, podremos usar el Lightning Trainer
para entrenar de manera simple nuestro modelo, ya sea en la CPU o GPU sin tener que hacer ningún cambio en el código.
Si bien hemos presentado los fundamentos, existe mucha más funcionalidad interesante en la librería. En próximos posts veremos algunos ejemplos, así como otras librerías que encajan muy bien con Pytorch Lightning
para optimizar nuestro proceso de trabajo.