diciembre 2, 2020

~ 9 MIN

Pytorch Lightning - Introducción

< Blog RSS

Open In Colab

Pytorch Lightning

Si has ido siguiendo los diferentes posts de este blog, es posible que hayas entrenado varias redes neuronales utilizando la librería Pytorch. De ser así, quizás has tenido la sensación de estar repitiendo el mismo código una y otra vez, sobre todo en lo referente al bucle de entrenamiento. Además, es posible que hayas tenido problemas intentando implementar funcionalidad más avanzada, asegurándote que todo funciona como debería sin errores. ¿No sería estupendo tener una librería que implementase por nosotros todo este códigp boilerplate sin perder la flexibilidad que nos ofrece Pytorch? Pues estás de enhorabuena, porque tal librería existe, y se llama Pytorch Lightning ⚡️.

import pytorch_lightning as pl

pl.__version__
'1.0.7'

💡 Puedes instalar pythorch lighning con el comando pip install pytorch-lightning.

En este post aprenderemos los conceptos básicos de esta librería, entrenando un modelo simple para clasificación de imágenes con el dataset MNIST como ya hemos hecho en varios posts anteriores, así podremos compara directamente ambas opciones.

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np
# preparamos los datos

dataloader = {
    'train': torch.utils.data.DataLoader(torchvision.datasets.MNIST('../data', train=True, download=True,
                       transform=torchvision.transforms.Compose([
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize((0.1307,), (0.3081,))
                            ])
                      ), batch_size=2048, shuffle=True, pin_memory=True),
    'test': torch.utils.data.DataLoader(torchvision.datasets.MNIST('../data', train=False,
                   transform=torchvision.transforms.Compose([
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize((0.1307,), (0.3081,))
                        ])
                     ), batch_size=2048, shuffle=False, pin_memory=True)
}

Pytorch

Para entrenar nuestro modelo usando puro Pytorch, primero definimos nuestra red neuronal creando una clase que derive de torch.nn.Module en la que tenemos que definir la función __init__, con las diferentes capas del modelo, y forward, con todas las operaciones necesarias para calcular las salidas de la red a partir de las entradas.

# definimos el modelo

def block(c_in, c_out, k=3, p=1, s=1, pk=2, ps=2):
    return torch.nn.Sequential(
        torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(pk, stride=ps)
    )

class CNN(torch.nn.Module):
  def __init__(self, n_channels=1, n_outputs=10):
    super().__init__()
    self.conv1 = block(n_channels, 64)
    self.conv2 = block(64, 128)
    self.fc = torch.nn.Linear(128*7*7, n_outputs)

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.shape[0], -1)
    x = self.fc(x)
    return x

Una vez tenemos nuestro modelo, necesitamos definir la lógica de entrenamiento. En este paso, Pytorch nos da total libertad para hacerlo de la manera en la que queramos. Una opción que hemos utilizado en posts anteriores es definir una función fit, a la cual le pasaremos nuestro modelo y datos, y que se encargará de todo.

# entrenamos el modelo

device = "cuda" if torch.cuda.is_available() else "cpu"

def fit(model, dataloader, epochs=5):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(1, epochs+1):
        model.train()
        train_loss, train_acc = [], []
        bar = tqdm(dataloader['train'])
        for batch in bar:
            X, y = batch
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_hat = model(X)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
            train_acc.append(acc)
            bar.set_description(f"loss {np.mean(train_loss):.5f} acc {np.mean(train_acc):.5f}")
        bar = tqdm(dataloader['test'])
        val_loss, val_acc = [], []
        model.eval()
        with torch.no_grad():
            for batch in bar:
                X, y = batch
                X, y = X.to(device), y.to(device)
                y_hat = model(X)
                loss = criterion(y_hat, y)
                val_loss.append(loss.item())
                acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
                val_acc.append(acc)
                bar.set_description(f"val_loss {np.mean(val_loss):.5f} val_acc {np.mean(val_acc):.5f}")
        print(f"Epoch {epoch}/{epochs} loss {np.mean(train_loss):.5f} val_loss {np.mean(val_loss):.5f} acc {np.mean(train_acc):.5f} val_acc {np.mean(val_acc):.5f}")

En nuestra función fit hemos implementado la lógica de entrenamiento y evaluación del modelo, el cálculo de su precisión a la vez que imprimimos por pantalla la información más relevante durante el entrenamiento en una barra de progreso.

model = CNN()
fit(model, dataloader)
loss 0.61840 acc 0.83090: 100%|██████████| 30/30 [00:07<00:00,  3.79it/s]
val_loss 0.20638 val_acc 0.94002: 100%|██████████| 5/5 [00:01<00:00,  4.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1/5 loss 0.61840 val_loss 0.20638 acc 0.83090 val_acc 0.94002


loss 0.14674 acc 0.95707: 100%|██████████| 30/30 [00:07<00:00,  4.08it/s]
val_loss 0.08969 val_acc 0.97255: 100%|██████████| 5/5 [00:01<00:00,  4.50it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 2/5 loss 0.14674 val_loss 0.08969 acc 0.95707 val_acc 0.97255


loss 0.08559 acc 0.97516: 100%|██████████| 30/30 [00:07<00:00,  4.07it/s]
val_loss 0.05989 val_acc 0.98230: 100%|██████████| 5/5 [00:01<00:00,  4.57it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 3/5 loss 0.08559 val_loss 0.05989 acc 0.97516 val_acc 0.98230


loss 0.06402 acc 0.98114: 100%|██████████| 30/30 [00:07<00:00,  4.09it/s]
val_loss 0.05274 val_acc 0.98333: 100%|██████████| 5/5 [00:01<00:00,  4.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 4/5 loss 0.06402 val_loss 0.05274 acc 0.98114 val_acc 0.98333


loss 0.05296 acc 0.98434: 100%|██████████| 30/30 [00:07<00:00,  4.08it/s]
val_loss 0.04791 val_acc 0.98432: 100%|██████████| 5/5 [00:01<00:00,  4.54it/s]

Epoch 5/5 loss 0.05296 val_loss 0.04791 acc 0.98434 val_acc 0.98432

Pero, ¿que pasaría si quisiésemos implementar técnicas como early stopping o guardar el mejor modelo durante el entrenamiento ?, ¿o calcular otras métricas más allá de la precisión?, ¿o entrenar otros modelos más complicados que requieran de bucles más elaborados? En todos estos casos tendríamos que modificar nuestra función fit, corriendo el riesgo de introducir bugs, y resultando en una implementación distinta para cada aplicación en la que trabajemos.

El LightningModule

Para solucionar estos problemas, Pytorch Lightning nos ofrece la clase LightningModule, la cual podemos utilizar para definir nuestros modelos de la siguiente manera:

class Modelo(pl.LightningModule):

    # igual que antes
    def __init__(self, n_channels=1, n_outputs=10):
        super().__init__()
        self.conv1 = block(n_channels, 64)
        self.conv2 = block(64, 128)
        self.fc = torch.nn.Linear(128*7*7, n_outputs)

    # igual que antes
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

    # lógica de entrenamiento
    def training_step(self, batch, batch_idx):
        # no hace falta enviar nada a la gpu
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('loss', loss)
        return loss
        # no necesitamos llamar a loss.backward() ni optimier.step()
        # pytorch lightning se encarga por nosotros

    # optimizador
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Como puedes observar, las clases __init__ y forward son exactamente iguales que en la implementación original. Sin embargo, hemos añadido dos nuevas funciones: training_step, en la que calculamos la salida de la red y devolvemos la función de pérdida, y configure_optimizers, en la que devolvemos el optimizador. Pytorch Lightning se encarga de mover los datos a la GPU si es necesario, así como de llamar a las funciones loss.backward, optimizer.zero_grad y optimizer.step.

El Lightning Trainer

Gracias a la implementación aterior, ahora podemos entrenar nuestro modelo de manera sencilla con el lightning trainer. Primero, instanciaremos un trainer al cual le podemos pasar parámetros tales como el número de epochs. Una vez definido el trainer, podemos entrenar nuestro modelo simplemente llamando a su función fit, pasándole como parámteros nuestro modelo y el dataloader de entrenamiento.

modelo = Modelo()

trainer = pl.Trainer(max_epochs=5)
trainer.fit(modelo, dataloader['train'])

Pytorch Lightning nos da información interesante al principio del entrenamiento, como el hardware disponible y usado así como un resumen de nuestro modelo y su número de parámetros. Cuando empiece el entrenamiento, veremos una barra de progreso indicándonos la epoch en la que nos encontramos y el valor de la función de pérdida.

Entrenando en GPUs

Una de las principales ventajas de Pytorch Lightning es lo sencillo que es entrenar en diferente hardware. Para entrenar nuestro modelo en una GPU, simplemente le pasamo la variabale gpu al trainer.

modelo = Modelo()

trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dataloader['train'])

Como puedes ver, indicando gpus=1 entrenaremos nuestro modelo en una GPU. Y es que si disponemos de más de una GPU, podremos indicar el número y Pytorch Lightning se encargará de distribuir el entrenamiento entre todas ellas 🔥. También podremos indicar un número de nodos, GPUs por nodo e incluso TPUs. Todo ello, sin cambiar ni una línea de código.

Usando datos de validación

Para evaluar nuestro modelo a la vez que lo entrenamos, simplemente tenemos que definir la función validation_step en el LightningModule. Podemos ver cualquier información en la barra de progreso con la función self.log con la variable prog_bar=True.

class Modelo(pl.LightningModule):

    def __init__(self, n_channels=1, n_outputs=10):
        super().__init__()
        self.conv1 = block(n_channels, 64)
        self.conv2 = block(64, 128)
        self.fc = torch.nn.Linear(128*7*7, n_outputs)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Ahora, pasaremos nuestro dataloader de validación también en el trainer.

modelo = Modelo()

trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dataloader['train'], dataloader['test'])

El LightningDataModule

De la misma forma que podemos encapsular nuestro modelo y la lógica de entrenamiento directamente en una sola clase, podemos hacer algo similar para nuestro dataset.

class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, path = '../data', batch_size = 64):
        super().__init__()
        self.path = path
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.mnist_train = torchvision.datasets.MNIST(
            self.path, train=True, download=True, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
          )
        self.mnist_val = torchvision.datasets.MNIST(
            self.path, train=False, download=True, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
          )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)

La función setup se encargará de descargar los datos y procesarlos como sea necesario para generar los datasets. Después, utilizaremos las funciones train_dataloader y val_dataloader para generar los dataloaders. Una vez definido el LightningDataModule, podemos entrenar nuestro modelo de manera simple.

modelo = Modelo()
dm = MNISTDataModule()

trainer = pl.Trainer(max_epochs=5, gpus=1)
trainer.fit(modelo, dm)

Resumen

En este post hemos aprendido los conceptos básicos de la librería Pytorch Lightning, la cual nos ayudará a ser más eficientes a la hora de entrenar modelos en Pytorch. En primer lugar, el objeto LightningModule reemplaza al objeto torch.nn.Module a la hora de definir nuestros modelos. Además, podremos indicarle la lógica para entrenar y validar nuestro modelo de manera simple. En segundo lugar, el objeto LightningDataModule nos permite encapsular la lógica de descarga y preparación de los datos. Estas dos clases es lo único que necesitaremos para llevar todo el proceso a cabo, por lo que nuestro código quedará mucho más ordenado y también será más reproducible. Gracias a estas definiciones, podremos usar el Lightning Trainer para entrenar de manera simple nuestro modelo, ya sea en la CPU o GPU sin tener que hacer ningún cambio en el código.

Si bien hemos presentado los fundamentos, existe mucha más funcionalidad interesante en la librería. En próximos posts veremos algunos ejemplos, así como otras librerías que encajan muy bien con Pytorch Lightning para optimizar nuestro proceso de trabajo.

< Blog RSS