julio 7, 2021
~ 6 MIN
Pytorch Lightning - Optimizaciones
< Blog RSSPytorch Lightning - Optimización
Seguimos hablando sobre optimizar nuestro código en Pytorch. Hemos visto ya muchas técnicas que podemos usar, por suerte la mayoría de ellas ya están implementadas en Pytorch Lightning
, por lo que no tenemos que comernos mucho la cabeza.
import os
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
import torch
from skimage import io
from torch.utils.data import DataLoader
class Dataset(torch.utils.data.Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.images)
def __getitem__(self, ix):
img = io.imread(self.images[ix])[...,(3,2,1)]
img = torch.tensor(img / 4000, dtype=torch.float).clip(0,1).permute(2,0,1)
label = torch.tensor(self.labels[ix], dtype=torch.long)
return img, label
class DataModule(pl.LightningDataModule):
def __init__(self, path='./data', batch_size=1024, num_workers=20, test_size=0.2, random_state=42):
super().__init__()
self.path = path
self.batch_size = batch_size
self.num_workers = num_workers
self.test_size = test_size
self.random_state = random_state
def setup(self, stage=None):
self.classes = sorted(os.listdir(self.path))
print("Generating images and labels ...")
images, encoded = [], []
for ix, label in enumerate(self.classes):
_images = os.listdir(f'{self.path}/{label}')
images += [f'{self.path}/{label}/{img}' for img in _images]
encoded += [ix]*len(_images)
print(f'Number of images: {len(images)}')
# train / val split
print("Generating train / val splits ...")
train_images, val_images, train_labels, val_labels = train_test_split(
images,
encoded,
stratify=encoded,
test_size=self.test_size,
random_state=self.random_state
)
print("Training samples: ", len(train_labels))
print("Validation samples: ", len(val_labels))
self.train_ds = Dataset(train_images, train_labels)
self.val_ds = Dataset(val_images, val_labels)
def train_dataloader(self):
return DataLoader(
self.train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=True
)
def val_dataloader(self):
return DataLoader(
self.val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True
)
dm = DataModule()
dm.setup()
imgs, labels = next(iter(dm.train_dataloader()))
imgs.shape, labels.shape
Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples: 21600
Validation samples: 5400
(torch.Size([1024, 3, 64, 64]), torch.Size([1024]))
import torch.nn.functional as F
import timm
class Model(pl.LightningModule):
def __init__(self, n_outputs=10, prof=None):
super().__init__()
self.model = timm.create_model('tf_efficientnet_b5', pretrained=True, num_classes=n_outputs)
self.prof = prof
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
self.log('loss', loss)
self.log('acc', acc, prog_bar=True)
if self.prof is not None:
self.prof.step()
return loss
def validation_step(self, batch, batch_idx):
loss, acc = self.shared_step(batch)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
def shared_step(self, batch):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.size(0)
return loss, acc
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
model = Model()
dm = DataModule()
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=3)
trainer.fit(model, dm)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
| Name | Type | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M Trainable params
0 Non-trainable params
28.4 M Total params
113.445 Total estimated model params size (MB)
Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples: 21600
Validation samples: 5400
Podemos usar una estrategia distribuida a través del parámetro accelerator
. En este caso usaremos el valor dp
para una estrategia Data Parallel
. Puedes ver el resto de estrategias aquí.
model = Model()
dm = DataModule(batch_size=2048)
trainer = pl.Trainer(gpus=2, accelerator='dp', precision=16, max_epochs=3)
trainer.fit(model, dm)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
| Name | Type | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M Trainable params
0 Non-trainable params
28.4 M Total params
113.445 Total estimated model params size (MB)
Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples: 21600
Validation samples: 5400
Si bien el entrenamiento es ligeramente más lento que comparado con el código en Pytorch
puro, la flexibilidad y funcionalidad que nos aporta Pytorch Lightning
puede valer la pena en la mayoría de casos. Puedes ver un ejemplo usando Distributed Data Parallel
aquí.
Profiling
Pytorch Lightning
también nos ofrece alternativas a la hora de tracker nuestro código en la búsqueda de cuellos de botella.
model = Model()
dm = DataModule()
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=1, profiler='simple')
trainer.fit(model, dm)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
| Name | Type | Params
---------------------------------------
0 | model | EfficientNet | 28.4 M
---------------------------------------
28.4 M Trainable params
0 Non-trainable params
28.4 M Total params
113.445 Total estimated model params size (MB)
Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples: 21600
Validation samples: 5400
FIT Profiler Report
Action | Mean duration (s) |Num calls | Total time (s) | Percentage % |
--------------------------------------------------------------------------------------------------------------------------------------
Total | - |_ | 15.719 | 100 % |
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch | 13.826 |1 | 13.826 | 87.956 |
run_training_batch | 0.41584 |22 | 9.1485 | 58.2 |
optimizer_step_and_closure_0 | 0.41555 |22 | 9.1421 | 58.16 |
training_step_and_backward | 0.24678 |22 | 5.4291 | 34.538 |
backward | 0.14637 |22 | 3.2202 | 20.486 |
model_forward | 0.097416 |22 | 2.1432 | 13.634 |
training_step | 0.097244 |22 | 2.1394 | 13.61 |
get_train_batch | 0.087274 |22 | 1.92 | 12.215 |
evaluation_step_and_end | 0.12911 |8 | 1.0329 | 6.5708 |
validation_step | 0.12895 |8 | 1.0316 | 6.563 |
on_validation_end | 0.26859 |2 | 0.53718 | 3.4174 |
on_train_start | 0.038434 |1 | 0.038434 | 0.2445 |
on_train_batch_end | 0.0016715 |22 | 0.036773 | 0.23394 |
on_validation_start | 0.0087554 |2 | 0.017511 | 0.1114 |
on_validation_batch_end | 0.0012855 |8 | 0.010284 | 0.065422 |
cache_result | 1.3535e-05 |135 | 0.0018272 | 0.011624 |
on_train_epoch_start | 0.001372 |1 | 0.001372 | 0.008728 |
on_train_end | 0.0011929 |1 | 0.0011929 | 0.0075887 |
on_batch_start | 2.2384e-05 |22 | 0.00049245 | 0.0031328 |
on_after_backward | 1.7712e-05 |22 | 0.00038966 | 0.0024789 |
on_validation_batch_start | 3.5063e-05 |8 | 0.00028051 | 0.0017845 |
on_train_batch_start | 1.1261e-05 |22 | 0.00024775 | 0.0015761 |
on_before_zero_grad | 1.0351e-05 |22 | 0.00022771 | 0.0014486 |
on_batch_end | 1.0082e-05 |22 | 0.00022181 | 0.0014111 |
training_step_end | 7.1055e-06 |22 | 0.00015632 | 0.00099447 |
on_train_epoch_end | 0.00014509 |1 | 0.00014509 | 0.000923 |
validation_step_end | 9.4492e-06 |8 | 7.5594e-05 | 0.00048091 |
on_epoch_end | 1.1085e-05 |3 | 3.3254e-05 | 0.00021155 |
on_validation_epoch_end | 1.5438e-05 |2 | 3.0875e-05 | 0.00019642 |
on_epoch_start | 9.9017e-06 |3 | 2.9705e-05 | 0.00018897 |
on_fit_start | 1.8213e-05 |1 | 1.8213e-05 | 0.00011587 |
on_validation_epoch_start | 6.951e-06 |2 | 1.3902e-05 | 8.8441e-05 |
on_train_dataloader | 1.0995e-05 |1 | 1.0995e-05 | 6.9947e-05 |
on_before_accelerator_backend_setup | 7.469e-06 |1 | 7.469e-06 | 4.7516e-05 |
on_val_dataloader | 7.282e-06 |1 | 7.282e-06 | 4.6326e-05 |
Puedes ver más opciones aquí.
Resumen
Optimizar nuestro código en Pytorch
es muy importante, y para ello tenemos muchas herramientas y técnicas a nuestro alcance para exprimir al máximo nuestras redes. Pytorch Lightning
nos facilita mucho la vida a la hora de utilizar estas técnicas de manera transparente sin necesidad de hacer grandes cambios en nuestro código, mientras que en Pytorch
tendremos que bucear en la documentación y ejemplos para poder aprovechar todo lo que hemos ido viendo en los últimos posts (dando como resultado un código muy largo y difícil de entender). A través del objeto Trainer
podermos definir diferentes estrategias de entrenamiento distribuido de manera sencilla, y las opciones de profiling
nos ayudarán a encontrar los puntos débiles de nuestro código para poder corregirlos.