Автор: Александр Гончаренко Редактура: Шамиль Мамедов, Иван
Давайте для начала посмотрим на минимумы лосс функции:
Минимумы находятся на одном уровне, но левый лежит в “резкой” яме.
Давайте представим, что мы обучили нейросеть. Хорошая генерализующая способность будет достигнута при таких весах $w$, когда на тренировочной выборке ошибка $L_{train}$ не сильно отличается от ошибки на тестовой $L_{test}$.
Мы знаем, что $L_{train}(w)$ и $L_{test}(w)$ будут отличаться. Но мы ожидаем, что локальные минимумы этих функций будут находиться рядом, а функции станут похожими.
Картинка ниже изображает вышесказанное. Если мы на нее посмотрим, то мы поймем, что в случае с “резкой” ямой $L_{train}$ и $L_{test}$ будут отличаться гораздо сильнее, чем в случае с “нерезкой” ямой, а следовательно, и обобщающая способность будет ниже.
Авторы статьи Sharpness-Aware Minimization предлагают не просто искать минимум функции, а искать его еще и в “нерезкой” яме. Давайте разбираться, что это все значит.
Введем термин “неровность” (в оригинале — sharpness). Интуитивно кажется, что функцию можно считать идеально “ровной” в точке, при которой в ее окрестности все значения функции равны. Соответственно, чем сильнее различие значений функции в окрестности, тем “неровнее” эта функция. Давайте формализуем:
То есть мы находим разницу между максимальным значением функции в окрестности (окружности с радиусом $p$) и значением функции в данной точке.
Теперь мы хотим минимизировать не только значение лосс функции, но и неровность. Если мы их сложим, то получим новый лосс, который и будем минимизировать:
Можно доказать, что для $p>0$ с большой долей вероятности выполняется следующее неравенство (достаточно объемное его доказательство можно найти в приложении к оригинальной статье):
где $L_S$ — лосс на тренировочной выборке , а $L_D$ — лосс на генеральной совокупности.
Из этой формулы видно, что минимизация $L_{sam}$ ведет к уменьшению лосс функции на генеральной совокупности (конечно, если мы не забудем про регуляризацию).
Основная проблема в том, что $w+\epsilon$ — веса с наибольшим лоссом в окрестности, а они нам неизвестны.