julio 7, 2021

~ 6 MIN

Pytorch Lightning - Optimizaciones

< Blog RSS

Open In Colab

Pytorch Lightning - Optimización

Seguimos hablando sobre optimizar nuestro código en Pytorch. Hemos visto ya muchas técnicas que podemos usar, por suerte la mayoría de ellas ya están implementadas en Pytorch Lightning, por lo que no tenemos que comernos mucho la cabeza.

import os
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
import torch
from skimage import io
from torch.utils.data import DataLoader

class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, ix):
        img = io.imread(self.images[ix])[...,(3,2,1)]
        img = torch.tensor(img / 4000, dtype=torch.float).clip(0,1).permute(2,0,1)
        label = torch.tensor(self.labels[ix], dtype=torch.long)
        return img, label

class DataModule(pl.LightningDataModule):

    def __init__(self, path='./data', batch_size=1024, num_workers=20, test_size=0.2, random_state=42):
        super().__init__()
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.test_size = test_size
        self.random_state = random_state


    def setup(self, stage=None):

        self.classes = sorted(os.listdir(self.path))

        print("Generating images and labels ...")
        images, encoded = [], []
        for ix, label in enumerate(self.classes):
            _images = os.listdir(f'{self.path}/{label}')
            images += [f'{self.path}/{label}/{img}' for img in _images]
            encoded += [ix]*len(_images)
        print(f'Number of images: {len(images)}')

         # train / val split
        print("Generating train / val splits ...")
        train_images, val_images, train_labels, val_labels = train_test_split(
            images,
            encoded,
            stratify=encoded,
            test_size=self.test_size,
            random_state=self.random_state
        )

        print("Training samples: ", len(train_labels))
        print("Validation samples: ", len(val_labels))

        self.train_ds = Dataset(train_images, train_labels)
        self.val_ds = Dataset(val_images, val_labels)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True
        )
dm = DataModule()
dm.setup()

imgs, labels = next(iter(dm.train_dataloader()))
imgs.shape, labels.shape
Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400





(torch.Size([1024, 3, 64, 64]), torch.Size([1024]))
import torch.nn.functional as F
import timm

class Model(pl.LightningModule):

    def __init__(self, n_outputs=10, prof=None):
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b5', pretrained=True, num_classes=n_outputs)
        self.prof = prof

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log('loss', loss)
        self.log('acc', acc, prog_bar=True)
        if self.prof is not None:
            self.prof.step()
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def shared_step(self, batch):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.size(0)
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
model = Model()
dm = DataModule()

trainer = pl.Trainer(gpus=1, precision=16, max_epochs=3)
trainer.fit(model, dm)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type         | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.445   Total estimated model params size (MB)


Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400

Podemos usar una estrategia distribuida a través del parámetro accelerator. En este caso usaremos el valor dp para una estrategia Data Parallel. Puedes ver el resto de estrategias aquí.

model = Model()
dm = DataModule(batch_size=2048)
trainer = pl.Trainer(gpus=2, accelerator='dp', precision=16, max_epochs=3)
trainer.fit(model, dm)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type         | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.445   Total estimated model params size (MB)


Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400

Si bien el entrenamiento es ligeramente más lento que comparado con el código en Pytorch puro, la flexibilidad y funcionalidad que nos aporta Pytorch Lightning puede valer la pena en la mayoría de casos. Puedes ver un ejemplo usando Distributed Data Parallelaquí.

Profiling

Pytorch Lightning también nos ofrece alternativas a la hora de tracker nuestro código en la búsqueda de cuellos de botella.

model = Model()
dm = DataModule()

trainer = pl.Trainer(gpus=1, precision=16, max_epochs=1, profiler='simple')
trainer.fit(model, dm)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type         | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M    Trainable params
0         Non-trainable params
28.4 M    Total params
113.445   Total estimated model params size (MB)


Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400

FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  15.719         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  13.826         	|1              	|  13.826         	|  87.956         	|
run_training_batch                 	|  0.41584        	|22             	|  9.1485         	|  58.2           	|
optimizer_step_and_closure_0       	|  0.41555        	|22             	|  9.1421         	|  58.16          	|
training_step_and_backward         	|  0.24678        	|22             	|  5.4291         	|  34.538         	|
backward                           	|  0.14637        	|22             	|  3.2202         	|  20.486         	|
model_forward                      	|  0.097416       	|22             	|  2.1432         	|  13.634         	|
training_step                      	|  0.097244       	|22             	|  2.1394         	|  13.61          	|
get_train_batch                    	|  0.087274       	|22             	|  1.92           	|  12.215         	|
evaluation_step_and_end            	|  0.12911        	|8              	|  1.0329         	|  6.5708         	|
validation_step                    	|  0.12895        	|8              	|  1.0316         	|  6.563          	|
on_validation_end                  	|  0.26859        	|2              	|  0.53718        	|  3.4174         	|
on_train_start                     	|  0.038434       	|1              	|  0.038434       	|  0.2445         	|
on_train_batch_end                 	|  0.0016715      	|22             	|  0.036773       	|  0.23394        	|
on_validation_start                	|  0.0087554      	|2              	|  0.017511       	|  0.1114         	|
on_validation_batch_end            	|  0.0012855      	|8              	|  0.010284       	|  0.065422       	|
cache_result                       	|  1.3535e-05     	|135            	|  0.0018272      	|  0.011624       	|
on_train_epoch_start               	|  0.001372       	|1              	|  0.001372       	|  0.008728       	|
on_train_end                       	|  0.0011929      	|1              	|  0.0011929      	|  0.0075887      	|
on_batch_start                     	|  2.2384e-05     	|22             	|  0.00049245     	|  0.0031328      	|
on_after_backward                  	|  1.7712e-05     	|22             	|  0.00038966     	|  0.0024789      	|
on_validation_batch_start          	|  3.5063e-05     	|8              	|  0.00028051     	|  0.0017845      	|
on_train_batch_start               	|  1.1261e-05     	|22             	|  0.00024775     	|  0.0015761      	|
on_before_zero_grad                	|  1.0351e-05     	|22             	|  0.00022771     	|  0.0014486      	|
on_batch_end                       	|  1.0082e-05     	|22             	|  0.00022181     	|  0.0014111      	|
training_step_end                  	|  7.1055e-06     	|22             	|  0.00015632     	|  0.00099447     	|
on_train_epoch_end                 	|  0.00014509     	|1              	|  0.00014509     	|  0.000923       	|
validation_step_end                	|  9.4492e-06     	|8              	|  7.5594e-05     	|  0.00048091     	|
on_epoch_end                       	|  1.1085e-05     	|3              	|  3.3254e-05     	|  0.00021155     	|
on_validation_epoch_end            	|  1.5438e-05     	|2              	|  3.0875e-05     	|  0.00019642     	|
on_epoch_start                     	|  9.9017e-06     	|3              	|  2.9705e-05     	|  0.00018897     	|
on_fit_start                       	|  1.8213e-05     	|1              	|  1.8213e-05     	|  0.00011587     	|
on_validation_epoch_start          	|  6.951e-06      	|2              	|  1.3902e-05     	|  8.8441e-05     	|
on_train_dataloader                	|  1.0995e-05     	|1              	|  1.0995e-05     	|  6.9947e-05     	|
on_before_accelerator_backend_setup	|  7.469e-06      	|1              	|  7.469e-06      	|  4.7516e-05     	|
on_val_dataloader                  	|  7.282e-06      	|1              	|  7.282e-06      	|  4.6326e-05     	|

Puedes ver más opciones aquí.

Resumen

Optimizar nuestro código en Pytorch es muy importante, y para ello tenemos muchas herramientas y técnicas a nuestro alcance para exprimir al máximo nuestras redes. Pytorch Lightning nos facilita mucho la vida a la hora de utilizar estas técnicas de manera transparente sin necesidad de hacer grandes cambios en nuestro código, mientras que en Pytorch tendremos que bucear en la documentación y ejemplos para poder aprovechar todo lo que hemos ido viendo en los últimos posts (dando como resultado un código muy largo y difícil de entender). A través del objeto Trainer podermos definir diferentes estrategias de entrenamiento distribuido de manera sencilla, y las opciones de profiling nos ayudarán a encontrar los puntos débiles de nuestro código para poder corregirlos.

< Blog RSS