Автор: Александр Гончаренко

Редактура: Марк Страхов

Stohastic Weight Averaging

Давайте для начала разберёмся с методом FGE (**Fast Geometric Ensembling)**. Его суть заключается в использовании циклического learning rate scheduler на последних эпохах и сохранении веса каждый раз, когда learning rate достигает минимума.

Untitled

Затем мы создаем ансамбль из моделей с сохраненными весами. Все эти модели будут иметь различные веса, но примерно один loss. Такая методика позволяет нам быстро построить ансамбль, метрики у которого лучше, чем у отдельной модели (как показывает статья). У данного подхода есть и недостаток: время предикта у ансамбля кратно больше, чем у отдельной нейросети.

Более быстрым методом являетcя SWA (Stoсhastic Weight Averaging). Он предлагает сделать то же самое, что и FGE, но вместо создания ансамбля — использует модель, в которой веса будут усреднены. Важно отметить: мы должны также, как и в FGE, использовать либо циклический learning rate scheduler, либо любой другой, который не просто двигается в сторону локального минимума, а позволяет найти разные точки около него. Этот метод дает примерно такое же качество, как и FGE, только предсказания происходят кратно быстрее (так как у нас будет не ансамбль из $n$ моделей, а одна модель, соответственно и быстрее в $n$ раз). На рисунке ниже изображены веса, полученные при помощи FGE($w_1,w_2,w_3$) и SWA ($w_{swa}$).

Untitled

Как использовать?

Этот алгоритм очень легко добавить к вашему циклу обучения, если вы используете pytorch lightning. Вам нужно лишь передать callback в trainer.

Важные параметры:

swa_lrs — какой learning rate использовать. Если поставить None, то будет использоваться learning rate оптимизатора;

swa_epoch_start — с какой эпохи начинать сохранять веса.

swa_callback = StochasticWeightAveraging(swa_lrs=5e-4, swa_epoch_start=1)
trainer = pl.Trainer(callbacks=[swa_callback, ...]

На обычном pytorch использовать SWA тоже не сложно, он поддерживается с версии 1.6:

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

Exponential Weight Averaging

Для начала давайте вспомним, что из себя представляет Exponentially Weighted Moving Average:

*$EWMA_t=\alpha*x_t+(1-\alpha)EWMA_{t-1},$