julio 7, 2021

~ 13 MIN

Pytorch Distribuido

< Blog RSS

Open In Colab

Pytorch Distribuido

En este post aprenderemos a utilizar algunas de las estrategias que nos ofrece Pytorch para entrenar redes neuronales de manera distribuída.

Empezamos entrenando un clasificador de imágenes en el dataset EuroSAT.

import os
from sklearn.model_selection import train_test_split

def setup(path='./data', test_size=0.2, random_state=42):

    classes = sorted(os.listdir(path))

    print("Generating images and labels ...")
    images, encoded = [], []
    for ix, label in enumerate(classes):
        _images = os.listdir(f'{path}/{label}')
        images += [f'{path}/{label}/{img}' for img in _images]
        encoded += [ix]*len(_images)
    print(f'Number of images: {len(images)}')

     # train / val split
    print("Generating train / val splits ...")
    train_images, val_images, train_labels, val_labels = train_test_split(
        images,
        encoded,
        stratify=encoded,
        test_size=test_size,
        random_state=random_state
    )

    print("Training samples: ", len(train_labels))
    print("Validation samples: ", len(val_labels))

    return classes, train_images, train_labels, val_images, val_labels

classes, train_images, train_labels, val_images, val_labels = setup('./data')
Generating images and labels ...
Number of images: 27000
Generating train / val splits ...
Training samples:  21600
Validation samples:  5400
import torch
from skimage import io

class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, ix):
        img = io.imread(self.images[ix])[...,(3,2,1)]
        img = torch.tensor(img / 4000, dtype=torch.float).clip(0,1).permute(2,0,1)
        label = torch.tensor(self.labels[ix], dtype=torch.long)
        return img, label

ds = {
    'train': Dataset(train_images, train_labels),
    'val': Dataset(val_images, val_labels)
}

batch_size = 1024
dl = {
    'train': torch.utils.data.DataLoader(ds['train'], batch_size=batch_size, shuffle=True, num_workers=20, pin_memory=True),
    'val': torch.utils.data.DataLoader(ds['val'], batch_size=batch_size, shuffle=False, num_workers=20, pin_memory=True)
}
import timm

model_names = timm.list_models('tf_efficientnet_b5*')
model_names
['tf_efficientnet_b5', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b5_ns']
import torch.nn.functional as F
import timm

class Model(torch.nn.Module):

    def __init__(self, n_outputs=10, use_amp=True):
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b5', pretrained=True, num_classes=n_outputs)
        self.use_amp = use_amp

    def forward(self, x, log=False):
        if log:
            print(x.shape)
        with torch.cuda.amp.autocast(enabled=self.use_amp):
            return self.model(x)
from tqdm import tqdm
import numpy as np

def step(model, batch, device):
    x, y = batch
    x, y = x.to(device), y.to(device)
    y_hat = model(x)
    loss = F.cross_entropy(y_hat, y)
    acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.size(0)
    return loss, acc

def train_amp(model, dl, optimizer, epochs=10, device="cpu", use_amp = True, prof=None, end=0):
    model.to(device)
    hist = {'loss': [], 'acc': [], 'val_loss': [], 'val_acc': []}
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    for e in range(1, epochs+1):
        # train
        model.train()
        l, a = [], []
        bar = tqdm(dl['train'])
        stop=False
        for batch_idx, batch in enumerate(bar):
            optimizer.zero_grad()

            # AMP
            with torch.cuda.amp.autocast(enabled=use_amp):
                loss, acc = step(model, batch, device)
            scaler.scale(loss).backward()
            # gradient clipping
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            scaler.step(optimizer)
            scaler.update()

            l.append(loss.item())
            a.append(acc)
            bar.set_description(f"training... loss {np.mean(l):.4f} acc {np.mean(a):.4f}")
            # profiling
            if prof:
                if batch_idx >= end:
                    stop = True
                    break
                prof.step()
        hist['loss'].append(np.mean(l))
        hist['acc'].append(np.mean(a))
        if stop:
            break
        # eval
        model.eval()
        l, a = [], []
        bar = tqdm(dl['val'])
        with torch.no_grad():
            for batch in bar:
                loss, acc = step(model, batch, device)
                l.append(loss.item())
                a.append(acc)
                bar.set_description(f"evluating... loss {np.mean(l):.4f} acc {np.mean(a):.4f}")
        hist['val_loss'].append(np.mean(l))
        hist['val_acc'].append(np.mean(a))
        # log
        log = f'Epoch {e}/{epochs}'
        for k, v in hist.items():
            log += f' {k} {v[-1]:.4f}'
        print(log)

    return hist
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
hist = train_amp(model, dl, optimizer, epochs=3, device="cuda")
training... loss 1.9256 acc 0.6465: 100%|██████████| 22/22 [00:11<00:00,  1.93it/s]
evluating... loss 3.9619 acc 0.4316: 100%|██████████| 6/6 [00:01<00:00,  3.24it/s]
  0%|          | 0/22 [00:00<?, ?it/s]

Epoch 1/3 loss 1.9256 acc 0.6465 val_loss 3.9619 val_acc 0.4316


training... loss 0.1988 acc 0.9366: 100%|██████████| 22/22 [00:10<00:00,  2.04it/s]
evluating... loss 0.4834 acc 0.8799: 100%|██████████| 6/6 [00:01<00:00,  3.19it/s]
  0%|          | 0/22 [00:00<?, ?it/s]

Epoch 2/3 loss 0.1988 acc 0.9366 val_loss 0.4834 val_acc 0.8799


training... loss 0.0702 acc 0.9807: 100%|██████████| 22/22 [00:10<00:00,  2.06it/s]
evluating... loss 0.1988 acc 0.9450: 100%|██████████| 6/6 [00:01<00:00,  3.10it/s]

Epoch 3/3 loss 0.0702 acc 0.9807 val_loss 0.1988 val_acc 0.9450

Usando las técnicas aprendidas en posts anteriores conseguimos entrenar nuestro modelo a unos 12 segundos por epoch. Nada mal para tratarse de una EfficientNetB5. A partir de aquí, las siguientes mejoras van a consistir en entrenar nuestro modelo de manera distribuida, es decir, usando varias GPUs. Vamos a ver algunos ejemplos.

Data Parallel

Esta estrategia consiste en copiar el modelo en cada una de las GPUs disponibles y dividir el batch entre ellas. Si tenemos 2 GPUs cada una verá la mitad de un batch, y por consiguiente deberíamos poder entrenar el doble de rápido (en teoría, a la práctica no es así ya que al introducir el entrenamiento distribuido se requiren de nuevas operaciones, copias entre GPUs y sincornizaciones que necesitas sus recursos).

model = Model()
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  model = torch.nn.DataParallel(model)
Let's use 2 GPUs!
model.cuda()

# cada gpu recibe la mitad del batch !
output = model(torch.randn(32, 3, 32, 32).cuda(), log=True)

output.size()
torch.Size([16, 3, 32, 32])torch.Size([16, 3, 32, 32])






torch.Size([32, 10])

Para aprovechar al máximo esta estrategia, multiplica el tamaño del batch óptimo por el número de GPUs disponibles.

batch_size = 1024 * 2

dl = {
    'train': torch.utils.data.DataLoader(ds['train'], batch_size=batch_size, shuffle=True, num_workers=20, pin_memory=True),
    'val': torch.utils.data.DataLoader(ds['val'], batch_size=batch_size, shuffle=False, num_workers=20, pin_memory=True)
}

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
hist = train_amp(model, dl, optimizer, epochs=3, device="cuda")
training... loss 1.9569 acc 0.6176: 100%|██████████| 11/11 [00:07<00:00,  1.52it/s]
evluating... loss 6.6081 acc 0.2944: 100%|██████████| 3/3 [00:02<00:00,  1.26it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/3 loss 1.9569 acc 0.6176 val_loss 6.6081 val_acc 0.2944


training... loss 0.2553 acc 0.9194: 100%|██████████| 11/11 [00:07<00:00,  1.51it/s]
evluating... loss 7.7305 acc 0.4626: 100%|██████████| 3/3 [00:02<00:00,  1.31it/s]
  0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/3 loss 0.2553 acc 0.9194 val_loss 7.7305 val_acc 0.4626


training... loss 0.0791 acc 0.9732: 100%|██████████| 11/11 [00:07<00:00,  1.41it/s]
evluating... loss 1.7874 acc 0.6995: 100%|██████████| 3/3 [00:02<00:00,  1.27it/s]

Epoch 3/3 loss 0.0791 acc 0.9732 val_loss 1.7874 val_acc 0.6995

Perfecto, hemos conseguido reducir en un 25% el tiempo por epoch. Por contra, al usar un batch size más grande, la convergencia es más lenta (se podría ajustar el learning rate para compensar). Como puedes ver esta estrategia no es del todo óptima ya que no hemos conseguido doblar la velocidad. Esto es debido al funcionamiento de la estrategia Data Parallel, que requiere varias copias de datos y sincronización entre GPUs. Podemos mejorarlo con la siguiente estrategia.

Distributed Data Parallel

En esta estrategia vamos a considerar cada GPU como un proceso independiente, lo cual evitará muchas copias de datos y sincronizaciones pesadas. Igual que antes, cada GPU tendrá una copia del modelo y verá una parte proporcional del dataset. Por ejemplo, si tenemos dos GPUs cada una verá la mitad de los datos. Debido al funcionamiento de esta estrategia no puede ejecutarse en un notebook, por lo que se implementa en el siguiente script.

Model Parallel

Por último vamos a ver otra estrategia para entrenar modelos en varias GPUs. En este caso, será el propio modelo el que esté distribuido, algunas capas estarán en una GPU mientras que otras vivirán en otra GPU. Para ello tendremos que especificar en qué GPU vive cada parte del modelo y comunicar los tensores correspondientes.

class ModelParallel(torch.nn.Module):

    def __init__(self, gpu1, gpu2, n_outputs=10, use_amp=True):
        super().__init__()
        resnet = timm.create_model('resnet50', pretrained=True, num_classes=n_outputs)
        self.backbone1 = torch.nn.Sequential(*list(resnet.children())[:6]).to(gpu1)
        self.backbone2 = torch.nn.Sequential(*list(resnet.children())[6:]).to(gpu2)
        self.use_amp = use_amp
        self.gpu1 = gpu1
        self.gpu2 = gpu2

    def forward(self, x):
        with torch.cuda.amp.autocast(enabled=self.use_amp):
            x = x.to(self.gpu1)
            x = self.backbone1(x)
            x = x.to(self.gpu2)
            x = self.backbone2(x)
            return x
gpu1 = torch.device('cuda:0')
gpu2 = torch.device('cuda:1')

modelParallel = ModelParallel(gpu1, gpu2)

output = modelParallel(torch.randn(32, 3, 32, 32))

output.size()
/home/juan/miniconda3/lib/python3.9/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448255797/work/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)





torch.Size([32, 10])
modelParallel.backbone1
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
  )
  (5): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
  )
)
modelParallel.backbone2
Sequential(
  (0): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (4): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (5): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
  )
  (1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
    )
  )
  (2): SelectAdaptivePool2d (pool_type=avg, flatten=True)
  (3): Linear(in_features=2048, out_features=10, bias=True)
)
def step_mp(model, batch, device):
    x, y = batch
    y = y.to(device)
    y_hat = model(x)
    loss = F.cross_entropy(y_hat, y)
    acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.size(0)
    return loss, acc

def train_mp(model, dl, optimizer, device, epochs=10, use_amp = True, prof=None, end=0):
    hist = {'loss': [], 'acc': [], 'val_loss': [], 'val_acc': []}
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    for e in range(1, epochs+1):
        # train
        model.train()
        l, a = [], []
        bar = tqdm(dl['train'])
        stop=False
        for batch_idx, batch in enumerate(bar):
            optimizer.zero_grad()

            # AMP
            with torch.cuda.amp.autocast(enabled=use_amp):
                loss, acc = step(model, batch, device)
            scaler.scale(loss).backward()
            # gradient clipping
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            scaler.step(optimizer)
            scaler.update()

            l.append(loss.item())
            a.append(acc)
            bar.set_description(f"training... loss {np.mean(l):.4f} acc {np.mean(a):.4f}")
            # profiling
            if prof:
                if batch_idx >= end:
                    stop = True
                    break
                prof.step()
        hist['loss'].append(np.mean(l))
        hist['acc'].append(np.mean(a))
        if stop:
            break
        # eval
        model.eval()
        l, a = [], []
        bar = tqdm(dl['val'])
        with torch.no_grad():
            for batch in bar:
                loss, acc = step(model, batch, device)
                l.append(loss.item())
                a.append(acc)
                bar.set_description(f"evluating... loss {np.mean(l):.4f} acc {np.mean(a):.4f}")
        hist['val_loss'].append(np.mean(l))
        hist['val_acc'].append(np.mean(a))
        # log
        log = f'Epoch {e}/{epochs}'
        for k, v in hist.items():
            log += f' {k} {v[-1]:.4f}'
        print(log)

    return hist
batch_size = 1024
dl = {
    'train': torch.utils.data.DataLoader(ds['train'], batch_size=batch_size, shuffle=True, num_workers=20, pin_memory=True),
    'val': torch.utils.data.DataLoader(ds['val'], batch_size=batch_size, shuffle=False, num_workers=20, pin_memory=True)
}
modelParallel = ModelParallel(gpu1, gpu2)
optimizer = torch.optim.Adam(modelParallel.parameters(), lr=1e-3)
hist = train_mp(modelParallel, dl, optimizer, gpu2, epochs=3)
training... loss 0.4347 acc 0.8791: 100%|██████████| 22/22 [00:05<00:00,  3.99it/s]
evluating... loss 4.1624 acc 0.4927: 100%|██████████| 6/6 [00:01<00:00,  3.63it/s]
  0%|          | 0/22 [00:00<?, ?it/s]

Epoch 1/3 loss 0.4347 acc 0.8791 val_loss 4.1624 val_acc 0.4927


training... loss 0.1147 acc 0.9665: 100%|██████████| 22/22 [00:05<00:00,  3.92it/s]
evluating... loss 0.7813 acc 0.8133: 100%|██████████| 6/6 [00:01<00:00,  3.53it/s]
  0%|          | 0/22 [00:00<?, ?it/s]

Epoch 2/3 loss 0.1147 acc 0.9665 val_loss 0.7813 val_acc 0.8133


training... loss 0.0705 acc 0.9786: 100%|██████████| 22/22 [00:05<00:00,  3.97it/s]
evluating... loss 0.3387 acc 0.9078: 100%|██████████| 6/6 [00:01<00:00,  3.50it/s]

Epoch 3/3 loss 0.0705 acc 0.9786 val_loss 0.3387 val_acc 0.9078

Se puede combinar con DDP, pero en este caso ya necesitaremos varias máquinas con varias GPUs cada una, dificultando también su correcta implementación. Por suerte, la librería Pytorch Lightning implementa todas estas estrategias de manera transparente para que podamos usarlas sin apenas cambios en nuestro código, pero eso lo veremos en el siguiente post.

< Blog RSS