septiembre 12, 2020

~ 14 MIN

Transfer Learning en Redes Convolucionales

< Blog RSS

Open In Colab

Transfer Learning en Redes Convolucionales

En posts anteriores hemos introducido la arquitectura de red neuronal convolucional y también hemos presentado varias arquitecturas famosas que han demostrado buenas prestaciones en multitud de tareas. Estas redes están formadas muchas capas convolucionales, algunas con más de 100 capas, lo cual significa que tienen muchos parámetros y entrenarlas desde cero puedes ser costoso. Sin embargo, existe una técnica que nos permite obtener buenos modelos con menores requisitos: el transfer learning. Ya hemos hablado anteriormente de esta técnica, en el contexto de modelos de lenguaje, pero la idea es la misma: utilizaremos el máximo número de capas de una red ya entrenada en otro dataset, y simplemente entrenaremos las nuevas capas que necesitemos para nuestra tarea concreta.

En este post vamos a ver cómo podemos utilizar una red neuronal pre-entrada en Imagenet, y adaptarla para una nueva tarea de clasificación con un pequeño dataset.

El dataset

Nuestro objetivo será el de entrenar un clasificador de flores. Podemos descargar las imágenes de la siguiente url.

import wget 

wget.download('https://mymldatasets.s3.eu-de.cloud-object-storage.appdomain.cloud/flowers.zip')
100% [..........................................] 235962103 / 235962103




'flowers (2).zip'
import zipfile

with zipfile.ZipFile('flowers.zip', 'r') as zip_ref:
    zip_ref.extractall('.')

Una vez extraído el dataset, podemos ver que tenemos 5 clases de flores diferentes, distribuidas en 5 carpetas diferentes. Cada carpeta contiene varios ejemplos de flores de la categoría en cuestión.

import os 

PATH = 'flowers'

classes = os.listdir(PATH)
classes
['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
imgs, labels = [], []

for i, lab in enumerate(classes):
  paths = os.listdir(f'{PATH}/{lab}')
  print(f'Categoría: {lab}. Imágenes: {len(paths)}')
  paths = [p for p in paths if p[-3:] == "jpg"]
  imgs += [f'{PATH}/{lab}/{img}' for img in paths]
  labels += [i]*len(paths)
Categoría: daisy. Imágenes: 769
Categoría: dandelion. Imágenes: 1055
Categoría: rose. Imágenes: 784
Categoría: sunflower. Imágenes: 734
Categoría: tulip. Imágenes: 984

Podemos visualizar algunas imágenes en el dataset.

import random 
from skimage import io
import matplotlib.pyplot as plt

fig, axs = plt.subplots(3,5, figsize=(10,6))
for _ax in axs:
  for ax in _ax:
    ix = random.randint(0, len(imgs)-1)
    img = io.imread(imgs[ix])
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(classes[labels[ix]])
plt.show()

png

Vamos a crear también un subconjunto de test para poder comparar varios modelos.

from sklearn.model_selection import train_test_split

train_imgs, test_imgs, train_labels, test_labels = train_test_split(imgs, labels, test_size=0.2, stratify=labels)

len(train_imgs), len(test_imgs)
(3458, 865)

Y por último creamos nuestros objetos Dataset y DataLoader para poder darle las imágenes a nuestros modelos.

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

class Dataset(torch.utils.data.Dataset):
  def __init__(self, X, y, trans, device):
    self.X = X
    self.y = y
    self.trans = trans
    self.device = device

  def __len__(self):
    return len(self.X)

  def __getitem__(self, ix):
    # cargar la imágen
    img = io.imread(self.X[ix])
    # aplicar transformaciones
    if self.trans:
      img = self.trans(image=img)["image"]
    return torch.from_numpy(img / 255.).float().permute(2,0,1), torch.tensor(self.y[ix])

Nos aseguraremos que todas las imágenes del dataset tengan las mismas dimensiones: 224x224 píxeles.

import albumentations as A

trans = A.Compose([
    A.Resize(224, 224)
])

dataset = {
    'train': Dataset(train_imgs, train_labels, trans, device), 
    'test': Dataset(test_imgs, test_labels, trans, device)
}

len(dataset['train']), len(dataset['test'])
(3458, 865)
fig, axs = plt.subplots(3,5, figsize=(10,6))
for _ax in axs:
  for ax in _ax:
    ix = random.randint(0, len(dataset['train'])-1)
    img, lab = dataset['train'][ix]
    ax.imshow(img.permute(1,2,0))
    ax.axis('off')
    ax.set_title(classes[lab])
plt.show()

png

dataloader = {
    'train': torch.utils.data.DataLoader(dataset['train'], batch_size=64, shuffle=True, pin_memory=True), 
    'test': torch.utils.data.DataLoader(dataset['test'], batch_size=256, shuffle=False)
}

imgs, labels = next(iter(dataloader['train']))
imgs.shape
torch.Size([64, 3, 224, 224])

El Modelo

Vamos a escoger la arquitectura resnet, de la que ya hablamos en el post anterior, para hacer nuestro clasificador. De este modelo usarmos todas las capas excepto la última, la cual sustituiremos por una nueva capa lineal para llevar a cabo la clasificación en 5 clases.

import torchvision

resnet = torchvision.models.resnet18()
resnet
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
class Model(torch.nn.Module):
  def __init__(self, n_outputs=5, pretrained=False, freeze=False):
    super().__init__()
    # descargamos resnet
    resnet = torchvision.models.resnet18(pretrained=pretrained)
    # nos quedamos con todas las capas menos la última
    self.resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
    if freeze:
      for param in self.resnet.parameters():
        param.requires_grad=False
    # añadimos una nueva capa lineal para llevar a cabo la clasificación
    self.fc = torch.nn.Linear(512, 5)

  def forward(self, x):
    x = self.resnet(x)
    x = x.view(x.shape[0], -1)
    x = self.fc(x)
    return x

  def unfreeze(self):
    for param in self.resnet.parameters():
        param.requires_grad=True
model = Model()
outputs = model(torch.randn(64, 3, 224, 224))
outputs.shape
torch.Size([64, 5])
from tqdm import tqdm
import numpy as np

def fit(model, dataloader, epochs=5, lr=1e-2):
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(1, epochs+1):
        model.train()
        train_loss, train_acc = [], []
        bar = tqdm(dataloader['train'])
        for batch in bar:
            X, y = batch
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_hat = model(X)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
            train_acc.append(acc)
            bar.set_description(f"loss {np.mean(train_loss):.5f} acc {np.mean(train_acc):.5f}")
        bar = tqdm(dataloader['test'])
        val_loss, val_acc = [], []
        model.eval()
        with torch.no_grad():
            for batch in bar:
                X, y = batch
                X, y = X.to(device), y.to(device)
                y_hat = model(X)
                loss = criterion(y_hat, y)
                val_loss.append(loss.item())
                acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
                val_acc.append(acc)
                bar.set_description(f"val_loss {np.mean(val_loss):.5f} val_acc {np.mean(val_acc):.5f}")
        print(f"Epoch {epoch}/{epochs} loss {np.mean(train_loss):.5f} val_loss {np.mean(val_loss):.5f} acc {np.mean(train_acc):.5f} val_acc {np.mean(val_acc):.5f}")

Entrenando desde cero

En primer lugar vamos a entrenar nuestro modelo desde cero para ver qué métricas podemos obtener.

model = Model()
fit(model, dataloader, epochs=15)
loss 1.38035 acc 0.41136: 100%|████████| 55/55 [00:20<00:00,  2.72it/s]
val_loss 6.62556 val_acc 0.18881: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 1/15 loss 1.38035 val_loss 6.62556 acc 0.41136 val_acc 0.18881


loss 1.14277 acc 0.53977: 100%|████████| 55/55 [00:20<00:00,  2.73it/s]
val_loss 2.67887 val_acc 0.32263: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 2/15 loss 1.14277 val_loss 2.67887 acc 0.53977 val_acc 0.32263


loss 1.03443 acc 0.57926: 100%|████████| 55/55 [00:20<00:00,  2.70it/s]
val_loss 2.11420 val_acc 0.35868: 100%|██| 4/4 [00:03<00:00,  1.16it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 3/15 loss 1.03443 val_loss 2.11420 acc 0.57926 val_acc 0.35868


loss 0.96419 acc 0.61534: 100%|████████| 55/55 [00:20<00:00,  2.71it/s]
val_loss 3.34853 val_acc 0.36685: 100%|██| 4/4 [00:03<00:00,  1.14it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 4/15 loss 0.96419 val_loss 3.34853 acc 0.61534 val_acc 0.36685


loss 0.89560 acc 0.65142: 100%|████████| 55/55 [00:20<00:00,  2.65it/s]
val_loss 2.53937 val_acc 0.35537: 100%|██| 4/4 [00:03<00:00,  1.15it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 5/15 loss 0.89560 val_loss 2.53937 acc 0.65142 val_acc 0.35537


loss 0.85376 acc 0.67159: 100%|████████| 55/55 [00:20<00:00,  2.65it/s]
val_loss 7.21106 val_acc 0.23745: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 6/15 loss 0.85376 val_loss 7.21106 acc 0.67159 val_acc 0.23745


loss 0.82979 acc 0.67812: 100%|████████| 55/55 [00:20<00:00,  2.71it/s]
val_loss 2.87065 val_acc 0.28967: 100%|██| 4/4 [00:03<00:00,  1.13it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 7/15 loss 0.82979 val_loss 2.87065 acc 0.67812 val_acc 0.28967


loss 0.84172 acc 0.66989: 100%|████████| 55/55 [00:20<00:00,  2.65it/s]
val_loss 10.45654 val_acc 0.23889: 100%|█| 4/4 [00:03<00:00,  1.15it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 8/15 loss 0.84172 val_loss 10.45654 acc 0.66989 val_acc 0.23889


loss 0.77135 acc 0.68920: 100%|████████| 55/55 [00:20<00:00,  2.65it/s]
val_loss 3.13403 val_acc 0.40119: 100%|██| 4/4 [00:03<00:00,  1.15it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 9/15 loss 0.77135 val_loss 3.13403 acc 0.68920 val_acc 0.40119


loss 0.76101 acc 0.70426: 100%|████████| 55/55 [00:20<00:00,  2.65it/s]
val_loss 4.79418 val_acc 0.28006: 100%|██| 4/4 [00:03<00:00,  1.15it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 10/15 loss 0.76101 val_loss 4.79418 acc 0.70426 val_acc 0.28006


loss 0.69798 acc 0.73153: 100%|████████| 55/55 [00:20<00:00,  2.65it/s]
val_loss 1.26967 val_acc 0.52445: 100%|██| 4/4 [00:03<00:00,  1.15it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 11/15 loss 0.69798 val_loss 1.26967 acc 0.73153 val_acc 0.52445


loss 0.66299 acc 0.74205: 100%|████████| 55/55 [00:20<00:00,  2.65it/s]
val_loss 3.41976 val_acc 0.29100: 100%|██| 4/4 [00:03<00:00,  1.15it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 12/15 loss 0.66299 val_loss 3.41976 acc 0.74205 val_acc 0.29100


loss 0.65538 acc 0.75767: 100%|████████| 55/55 [00:20<00:00,  2.69it/s]
val_loss 2.28475 val_acc 0.42954: 100%|██| 4/4 [00:03<00:00,  1.16it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 13/15 loss 0.65538 val_loss 2.28475 acc 0.75767 val_acc 0.42954


loss 0.60125 acc 0.77926: 100%|████████| 55/55 [00:20<00:00,  2.69it/s]
val_loss 3.14427 val_acc 0.25994: 100%|██| 4/4 [00:03<00:00,  1.16it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 14/15 loss 0.60125 val_loss 3.14427 acc 0.77926 val_acc 0.25994


loss 0.60031 acc 0.79432: 100%|████████| 55/55 [00:20<00:00,  2.71it/s]
val_loss 6.78454 val_acc 0.27369: 100%|██| 4/4 [00:03<00:00,  1.15it/s]

Epoch 15/15 loss 0.60031 val_loss 6.78454 acc 0.79432 val_acc 0.27369

Como puedes ver es complicado conseguir buenas métricas ya que nuestro dataset es muy pequeño.

Transfer Learning

Ahora vamos a entrenar el mismo caso pero, en este caso, utilizando los pesos pre-entrenados de resnet.

model = Model(pretrained=True, freeze=True)
fit(model, dataloader)
loss 1.05350 acc 0.63040: 100%|████████| 55/55 [00:13<00:00,  4.09it/s]
val_loss 1.09239 val_acc 0.58929: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 1/5 loss 1.05350 val_loss 1.09239 acc 0.63040 val_acc 0.58929


loss 0.64907 acc 0.81222: 100%|████████| 55/55 [00:13<00:00,  4.10it/s]
val_loss 0.80763 val_acc 0.69506: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 2/5 loss 0.64907 val_loss 0.80763 acc 0.81222 val_acc 0.69506


loss 0.54364 acc 0.83210: 100%|████████| 55/55 [00:13<00:00,  4.10it/s]
val_loss 0.73235 val_acc 0.73022: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 3/5 loss 0.54364 val_loss 0.73235 acc 0.83210 val_acc 0.73022


loss 0.47409 acc 0.85625: 100%|████████| 55/55 [00:13<00:00,  4.09it/s]
val_loss 0.53516 val_acc 0.82462: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 4/5 loss 0.47409 val_loss 0.53516 acc 0.85625 val_acc 0.82462


loss 0.47282 acc 0.84574: 100%|████████| 55/55 [00:13<00:00,  4.10it/s]
val_loss 0.58925 val_acc 0.80056: 100%|██| 4/4 [00:03<00:00,  1.15it/s]

Epoch 5/5 loss 0.47282 val_loss 0.58925 acc 0.84574 val_acc 0.80056

Como puedes ver no sólo obtenemos un mejor modelo en menos epochs sino que además cada epoch tarda menos en completarse. Esto es debido a que, al no estar entrenando gran parte de la red, los requisitos computacionales se reducen considerablemente. Mejores modelos y entrenados más rápido.

Fine Tuning

Todavía podemos mejorar un poco más si, además de utilizar los pesos descargados de Imagenet en resnet, entrenamos también la red completa.

model = Model(pretrained=True, freeze=False)
fit(model, dataloader)
loss 0.78486 acc 0.75256: 100%|████████| 55/55 [00:20<00:00,  2.69it/s]
val_loss 0.66191 val_acc 0.73296: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 1/5 loss 0.78486 val_loss 0.66191 acc 0.75256 val_acc 0.73296


loss 0.35873 acc 0.88409: 100%|████████| 55/55 [00:20<00:00,  2.72it/s]
val_loss 0.47476 val_acc 0.82364: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 2/5 loss 0.35873 val_loss 0.47476 acc 0.88409 val_acc 0.82364


loss 0.24850 acc 0.92500: 100%|████████| 55/55 [00:20<00:00,  2.72it/s]
val_loss 0.39330 val_acc 0.84806: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 3/5 loss 0.24850 val_loss 0.39330 acc 0.92500 val_acc 0.84806


loss 0.20048 acc 0.94517: 100%|████████| 55/55 [00:20<00:00,  2.72it/s]
val_loss 0.51556 val_acc 0.82603: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 4/5 loss 0.20048 val_loss 0.51556 acc 0.94517 val_acc 0.82603


loss 0.20027 acc 0.94205: 100%|████████| 55/55 [00:20<00:00,  2.72it/s]
val_loss 0.67680 val_acc 0.77118: 100%|██| 4/4 [00:03<00:00,  1.18it/s]

Epoch 5/5 loss 0.20027 val_loss 0.67680 acc 0.94205 val_acc 0.77118

Es común entrenar primero el modelo sin entrenar la red pre-entrenada durante varias epochs y después seguir entrenando, pero permitiendo ahora la actualización de pesos también en la red pre-entrenada (usualmente con un learning rate más pequeño).

model = Model(pretrained=True, freeze=True)
fit(model, dataloader)
model.unfreeze()
fit(model, dataloader, lr=1e-4)
loss 1.04522 acc 0.64773: 100%|████████| 55/55 [00:13<00:00,  4.11it/s]
val_loss 1.00881 val_acc 0.57071: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 1/5 loss 1.04522 val_loss 1.00881 acc 0.64773 val_acc 0.57071


loss 0.64234 acc 0.80511: 100%|████████| 55/55 [00:13<00:00,  4.11it/s]
val_loss 0.85559 val_acc 0.66237: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 2/5 loss 0.64234 val_loss 0.85559 acc 0.80511 val_acc 0.66237


loss 0.54557 acc 0.82699: 100%|████████| 55/55 [00:13<00:00,  4.10it/s]
val_loss 0.74277 val_acc 0.71299: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 3/5 loss 0.54557 val_loss 0.74277 acc 0.82699 val_acc 0.71299


loss 0.49066 acc 0.84574: 100%|████████| 55/55 [00:13<00:00,  4.09it/s]
val_loss 0.63627 val_acc 0.75607: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 4/5 loss 0.49066 val_loss 0.63627 acc 0.84574 val_acc 0.75607


loss 0.44511 acc 0.85682: 100%|████████| 55/55 [00:13<00:00,  4.09it/s]
val_loss 0.56260 val_acc 0.79451: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 5/5 loss 0.44511 val_loss 0.56260 acc 0.85682 val_acc 0.79451


loss 0.49753 acc 0.81023: 100%|████████| 55/55 [00:20<00:00,  2.71it/s]
val_loss 0.52436 val_acc 0.81814: 100%|██| 4/4 [00:03<00:00,  1.18it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 1/5 loss 0.49753 val_loss 0.52436 acc 0.81023 val_acc 0.81814


loss 0.48127 acc 0.83068: 100%|████████| 55/55 [00:20<00:00,  2.71it/s]
val_loss 0.48222 val_acc 0.82790: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 2/5 loss 0.48127 val_loss 0.48222 acc 0.83068 val_acc 0.82790


loss 0.44676 acc 0.84176: 100%|████████| 55/55 [00:20<00:00,  2.71it/s]
val_loss 0.46810 val_acc 0.84220: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 3/5 loss 0.44676 val_loss 0.46810 acc 0.84176 val_acc 0.84220


loss 0.43173 acc 0.86335: 100%|████████| 55/55 [00:20<00:00,  2.72it/s]
val_loss 0.45196 val_acc 0.84149: 100%|██| 4/4 [00:03<00:00,  1.17it/s]
  0%|                                           | 0/55 [00:00<?, ?it/s]

Epoch 4/5 loss 0.43173 val_loss 0.45196 acc 0.86335 val_acc 0.84149


loss 0.41360 acc 0.86250: 100%|████████| 55/55 [00:20<00:00,  2.72it/s]
val_loss 0.45011 val_acc 0.84540: 100%|██| 4/4 [00:03<00:00,  1.15it/s]

Epoch 5/5 loss 0.41360 val_loss 0.45011 acc 0.86250 val_acc 0.84540

Otra alternativa de fine tuning es la de entrenar el modelo con diferentes learning rates, uno para la red pre-entrenada y otro para las capas nuevas.

optimizer = torch.optim.Adam([
    {'params': model.resnet.parameters(), 'lr': 1e-4},
    {'params': model.fc.parameters(), 'lr': 1e-3}
])

Resumen

En este post hemos visto como podemos llevar a cabo transfer learning con redes convolucionales. Aplicar esta técnica nos permitirá obtener mejores modelos con menos requisitos computacionales y con datasets reducidos. Podemos descargar una red pre-entrenada con otro dataset (idealmente, un dataset similar al nuestro) y aprovechar el máximo número de capas. Podemos congelar la red pre-entrenada, de manera que no se actualicen sus pesos durante el entrenamiento, y utilizarla solo como extractor de características que las nuevas capas (las cuales si entrenamos) pueden aprovechar. Aún así, hacer fine tuning (seguir entrenando la red pre-entrenada) puede dar como resultado un mejor modelo. El transfer learning es una técnica muy potente que siempre que podamos podemos aprovechar para reducir los requisitos computacionales de nuestros modelos.

< Blog RSS