Переобучение - это одна из самых частых проблем при обучении глубоких нейронных сетей и в машинном обучении в целом. При переобучении нейронная сеть адаптируется к особенностям обучающего набора данных, а не находит общие закономерности. В результате сеть хорошо работает на обучающем наборе, но плохо на тех данных, которые она не видела в процессе обучения. Таким образом, у модели снижается обобщающая способность.

Один из вариантов решения этой проблемы – остановить процесс обучения нейросети при появлении первых признаков переобучения. В Keras это можно сделать с помощью EarlyStopping Callback, в этой статье я расскажу, как его использовать.

Создаем демонстрационную нейросеть для MNIST

Давайте рассмотрим, как применить EarlyStopping Callback на примере нейронной сети для распознавания рукописных цифр из набора данных MNIST.

На первом этапе нужно подключить интересующий нас Callback совместно с другими модулями Keras:

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import utils
from tensorflow.keras.callbacks import EarlyStopping

Загружаем данные и создаем нейронную сеть:

# Загружаем данные
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Преобразуем в нужный формат
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
Y_train = utils.to_categorical(y_train, 10)
Y_test = utils.to_categorical(y_test, 10)

# Создаем последовательную модель для нейронной сети
model = Sequential()

# Полносвязная нейронная сеть, состоящая из двух слоев
model.add(Dense(800, input_dim=784, activation="relu"))
model.add(Dense(10, activation="softmax"))

# Компилируем модель
model.compile(loss="categorical_crossentropy", 
              optimizer="adam", 
              metrics=["accuracy"])

Как работает Keras EarlyStopping Callback

Чтобы обнаружить переобучение мы сравниваем качество работы сети на обучающем и проверочном наборах данных. Если качество на обучающем наборе данных растёт, а на проверочном снижается – значит переобучение началось.

EarlyStopping Callback в Keras наблюдает за метрикой качества обучения и прерывает обучение процесс обучения, если эта метрика начинает снижаться. Для нашей сети EarlyStopping Callback создается следующим образом:

early_stopping_callback = EarlyStopping(monitor='val_acc', 
                                        patience=2)

Метрика качества обучения задается в параметре monitor. Здесь мы используем долю правильных ответов на проверочном наборе данных ('val_acc').

При обучении нейросети мы используем стохастический градиентный спуск или аналогичные методы, при которых качество решения на некоторых эпохах может снижаться, но после этого снова возрастать. Параметр patience говорит о том, сколько эпох обучения может ухудшаться метрика качества, прежде чем обучение будет остановлено. В примере ранняя остановка обучения произойдет если доля правильных ответов на проверочном наборе данных снижается две эпохи подряд.

Запускаем обучение с ранней остановкой

Запускаем обучение нейронной сети в течение 25 эпох. Наверняка вы помните из предыдущих примеров, что при таком количестве эпох уже возникает переобучение. При вызове метода model.fit в параметре callbacks мы указываем список, состоящий из одного early_stopping_callback, который мы создали на предыдущем этапе.

history = model.fit(X_train, Y_train, 
                    batch_size=200, 
                    epochs=25, 
                    validation_split=0.2, 
                    verbose=2, 
                    callbacks=[early_stopping_callback])

Диагностический вывод сети в процессе обучения:

Train on 48000 samples, validate on 12000 samples
Epoch 1/25
 - 6s - loss: 0.3135 - acc: 0.9122 - val_loss: 0.1637 - val_acc: 0.9538
Epoch 2/25
 - 1s - loss: 0.1272 - acc: 0.9636 - val_loss: 0.1108 - val_acc: 0.9690
Epoch 3/25
 - 1s - loss: 0.0841 - acc: 0.9758 - val_loss: 0.0916 - val_acc: 0.9734
Epoch 4/25
 - 1s - loss: 0.0585 - acc: 0.9832 - val_loss: 0.0876 - val_acc: 0.9725
Epoch 5/25
 - 1s - loss: 0.0429 - acc: 0.9877 - val_loss: 0.0832 - val_acc: 0.9743
Epoch 6/25
 - 1s - loss: 0.0315 - acc: 0.9916 - val_loss: 0.0725 - val_acc: 0.9778
Epoch 7/25
 - 1s - loss: 0.0235 - acc: 0.9945 - val_loss: 0.0723 - val_acc: 0.9778
Epoch 8/25
 - 1s - loss: 0.0173 - acc: 0.9961 - val_loss: 0.0670 - val_acc: 0.9793
Epoch 9/25
 - 1s - loss: 0.0129 - acc: 0.9974 - val_loss: 0.0728 - val_acc: 0.9783
Epoch 10/25
 - 1s - loss: 0.0093 - acc: 0.9986 - val_loss: 0.0715 - val_acc: 0.9789

Мы видим, что обучение остановилось на десятой эпохе вместо 25 эпох. На восьмой эпохе доля правильных ответов на проверочном наборе данных была 0.9793. На девятой эпохе эта доля снизилась до 0.9783. Десятая эпоха привела к небольшому росту доли правильных ответов до 0.9789, но это все равно меньше, чем показатель на восьмой эпохе. Таком образом, у нас в течение двух эпох доля правильных ответов была ниже, чем на восьмой эпохе, поэтому EarlyStopping Callback остановил обучение.

Если вы обучали сеть без диагностического вывода, узнать на какой эпохе произошла остановка можно с помощью early_stopping_callback.stopped_epoch:

print("Обучение остановлено на эпохе", early_stopping_callback.stopped_epoch)
Обучение остановлено на эпохе 9

Напечатано, что остановка произошла на девятой эпохе, но отсчет здесь ведется с нуля.

Визуализация процесса обучения

Давайте построим графики качества обучения нейросети с помощью объекта history.

plt.plot(history.history['acc'], 
         label='Доля верных ответов на обучающем наборе')
plt.plot(history.history['val_acc'], 
         label='Доля верных ответов на проверочном наборе')
plt.xlabel('Эпоха обучения')
plt.ylabel('Доля верных ответов')
plt.legend()
plt.show()

График выглядит следующим образом:

График качества обучения нейросети

Видно, что доля правильных ответов на обучающем наборе данных постоянно растет. Однако на проверочном наборе данных этот показатель растет примерно до пятой или шестой эпохи, а потом почти не меняется. Снижение на девятой и десятой эпохе на графике почти не заметно. Таким образом, ранняя остановка обучения произошла в подходящий момент.

Итоги

EarlyStopping Callback в Keras позволяет остановить процесс обучения нейросети до начала переобучения. При создании EarlyStopping Callback указываем метрику качества, снижение которой будет признаком начала переобучения, и количество эпох, в течение которых метрика качества может снижаться.

В документация на сайте Keras можно найти другие полезные параметры EarlyStopping Callback. Например, restore_best_weights=True позволит после обучения получить сеть с весами, соответствующим лучшему значению метрики качества обучения. По умолчанию этот флаг установлен в False и возвращаются веса на последней эпохе обучения сети. Если значение patience достаточно высоко, то это может быть не лучшая модель.

Также снизить переобучение помогает слой Dropout и методы регуляризации, о них не стоит забывать.

Полезные ссылки

  1. Полный текст примера кода использования EarlyStopping Callback для распознавания рукописных цифр.
  2. Документация на EarlyStopping Callback.