Как стать автором
Обновить

«Вспомнить все» или решение проблемы катастрофической забывчивости для чайников

Время на прочтение29 мин
Количество просмотров8.3K

Эта моя статья будет посвящена проблеме катастрофической забывчивости и новейшим методам ее решения. Будут приведены примеры реализации этих методов, которые легко адаптировать под почти любую конфигурацию нейронной сети.

Сначала напомним, что это, собственно, за проблема. Если вдруг так оказалось, что вам нужно обучать нейронную сетку сначала на одном датасете, а затем на другом, то вы обнаружите, что по мере обучения на втором датасете сетка быстро забывает первый датасет, то есть теряет навык, полученный при обучении на нем. Или же если вы используете transfer learning и доучиваете готовую сетку на своих примерах, то будет наблюдаться тот же эффект – сетка успешно доучится на ваших данных, но при этом существенно утеряет предыдущие навыки, то есть то, ради чего весь transfer learning и затевался. Если вдруг датасетов, на которых надо последовательно учиться, не два а, к примеру, пять, то к концу обучения на пятом сетка забудет первый датасет практически полностью.

С этой проблемой мы имеем дело с середины прошлого века, то есть с момента, когда начали проводиться эксперименты по обучению нейронных сетей. За прошедшее время об проблему обломали немало зубов, и она успела впитаться в коллективное бессознательное как нечто трудноразрешимое, с чем приходится мириться или придумывать разные трюки для того, чтобы хоть как-то снизить ее негативное влияние.

Однако четыре года назад группой «британских ученых» (на самом деле это были парни из британского отделения DeepMind) был сделан значительный прорыв – они придумали метод эластичного закрепления весов (EWC). Если объяснять на пальцах, то его суть в следующем. Если навык обученной нейронной сетки заключен в весах ее связей (и не только связей, а вообще параметров), то «не все они одинаково полезны». То есть какие-то связи более важны для выученного навыка, какие-то менее. Идея в том, чтобы, когда сетка уже обучена навыку (датасету), при дальнейшем обучении другим навыкам не давать важным весам уходить далеко от эталонных значений, то есть полученных после обучения первому навыку. Реализуется это добавлением слагаемого-регуляризатора в функцию стоимости, так что при изменении параметра регуляризатор тянет его назад, в сторону, противоположную изменению. И чем «важнее» параметр, тем сильнее его тянет обратно. Получается, что веса сети как будто привязаны к эталонным значениям резинками разной упругости.

Все это выглядит просто и замечательно, однако основная проблема – как понять, насколько каждый вес или параметр сети важен или наоборот – не важен. Тут «британские ученые» тоже не оплошали и предложили использовать в качестве «важности» весов диагональные элементы информационной матрицы Фишера. И это отлично сработало! То есть, если после обучения на первом датасете посчитать таким способом важности весов, а потом, при обучении на втором датасете, добавить регуляризатор в функцию стоимости, то сетка научится второму датасету не потеряв навыка первого. Ну почти не потеряв. И таким образом можно дообучить сетку последовательно хоть на десятке датасетов, добавляя соответствующие регуляризаторы. И все навыки будут сохраняться пока будет хватать емкости сетки.

К счастью, потом еще один математик доказал, что правильнее не регуляризаторы добавлять, а важности весов суммировать, что существенно проще. К сожалению, у метода все же есть пара недостатков: чтоб все работало, выход сети должен быть вероятностным распределением (то есть softmax), и потом нужно строить дополнительный граф вычислений для диагонали информационной матрицы Фишера и считать его для каждого примера датасета. Хорошая новость в том, что tensorflow или torch все сделает за вас, плохая в том, что посчитать придется много.

Далее привожу пример реализации EWC для глубокой сетки из нескольких полносвязных слоев. Последовательное обучение сетки нескольким датасетам проводится следующим образом: берем датасет, открываем урок для сетки, учим ее, закрываем урок, считая важности на текущем датасете и, суммируя с предыдущими значениями важности, повторяем все для следующего датасета.

Пример реализации EWC тут.
# код рассчитан на python 3.7 и tensorflow 1.15.0

import datetime
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def _weight_variable(shape):
    # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
    stddev = 2. / np.sqrt(shape[0])
    initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
    return tf.Variable(initial, dtype=tf.float32)


def _bias_variable(shape):
    # смещения инициализируем нулями
    initial = tf.constant(0.0, shape=shape)
    return tf.Variable(initial, dtype=tf.float32)


class Model:
    def __init__(self, shape, session):
        """
        :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                        от входа к выходу справа налево, например, [784, 100, 10]
        :param session: tensorflow-сессия для расчетов сети
        """

        self.session = session
        self._shape = shape
        depth = len(shape) - 1
        if depth < 1:
            raise ValueError("Недопустимая структура сети!")

        # заглушки для входных данных
        self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
        self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])

        # все веса слоев сети будем хранить в списке
        self.var_list = []
        for ins, outs in zip(shape[:-1], shape[1:]):
            self.var_list.append(_weight_variable([ins, outs]))
            self.var_list.append(_bias_variable([outs]))

        # инициализируем веса сети
        for v in self.var_list:
            session.run(v.initializer)

        # список для хранения важностей весов сети
        self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]

        # строим вычислительный граф
        x, y, z = self.x, None, None
        for i in range(depth):
            z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
            y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
            x = y

        # функция стоимости
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))

        # точность (accuracy)
        self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))

        # вычисляем градиенты вероятности правильной метки
        self.prob_grads = tf.gradients(tf.math.log(y[0, tf.argmax(self.labels[0])]), self.var_list)

        self.train_step = None

    def open_lesson(self, learning_rate=1.0, lmbda=0.0):
        """
        Открытие урока обучения сети на отдельном датасете
        :param learning_rate: скорость обучения для SGD
        :param lmbda:         коэффициент влияния важностей - насколько сильно
                              важности тянут веса к эталонным значениям
        """
        loss = self.loss

        if hasattr(self, "star_vars") and lmbda != 0:
            # добавляем к функции стоимости слагаемые-регуляризаторы
            for v in range(len(self._shape)*2-2):
                loss += tf.reduce_sum(tf.multiply(
                    tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                    tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                ))
        # устанавливаем шаг оптимизатора
        self.train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    def close_lesson(self, closing_inputs=None, closing_labels=None):
        """
        Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
        :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
        :return:
        """

        # рассчитываем важности весов на закрывающем датасете
        addendum = self._compute_fisher(closing_inputs, closing_labels)
        # добавляем рассчитанные важности к сохраненным
        for i, a in zip(self.wb_importance, addendum):
            i += a
        # запоминаем текущие эталонные веса сети после обучения
        self._store_weights_and_biases()

    def _store_weights_and_biases(self):
        self.star_vars = [v.eval() for v in self.var_list]

    def _compute_fisher(self, closing_inputs, closing_labels):
        # вычисление диагональных элементов информационной
        # матрицы Фишера для весов сети на заданном датасете
        num_samples = len(closing_inputs)

        # инициализируем значения нулями
        fisher = [np.zeros(self.var_list[v].shape, dtype=np.float32) for v in range(len(self.var_list))]

        for i in range(num_samples):
            # вычисляем первые производные логарифма вероятности
            feed_dict = {self.x: closing_inputs[i:i+1], self.labels: closing_labels[i: i+1]}
            derivatives = self.session.run(self.prob_grads, feed_dict=feed_dict)
            # возводим их в квадрат т. к. это диагональные элементы
            for f, d in zip(fisher, derivatives):
                f += np.square(d)

        # усредняем по количеству примеров в датасете.
        # такой подход применим если мы хотим чтоб каждый урок оказывал одинаковый вклад в важности.
        # если же мы хотим чтоб каждый пример при закрытии урока оказывал одинаковое влияние, то
        # нужно делить на некую подобранную константу, а не на число примеров.
        for f in fisher:
            f /= num_samples

        return fisher


# функция случайным образом переставляет входы одинаково для всех примеров датасета
def permute_mnist(mnist):
    perm_inds = list(range(mnist.train.images.shape[1]))
    np.random.shuffle(perm_inds)
    mnist2 = deepcopy(mnist)
    sets = ["train", "validation", "test"]
    for set_name in sets:
        this_set = getattr(mnist2, set_name)
        this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
    return mnist2


def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
    """
    Обучение модели
    :param model:       обучаемая модель
    :param train_set:   обучающий датасет
    :param test_sets:   список датасетов, на которых будет считаться средняя точность
    :param batch_size:  размер батча
    :param epochs:      количество эпох обучения
    :return:            средняя точность на тестовых датасетах после обучения
    """
    num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
    for idx in range(num_iters):
        train_batch = train_set.train.next_batch(batch_size)
        feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
        model.train_step.run(feed_dict=feed_dict)
        print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')

    print(f'\rTraining  {num_iters}/{num_iters} iterations done.')

    accuracy = 0.
    for t, test_set in enumerate(test_sets):
        feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
        accuracy += model.accuracy.eval(feed_dict=feed_dict)
    accuracy /= len(test_sets)
    print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
    return accuracy


def continual_learning(net_struct, data_sets, session, lr, lmbda):
    """
    Последовательное обучение на нескольких обучающих наборах
    :param net_struct: структура сети
    :param data_sets:  список обучающих датасетов для последовательного обучения
    :param session:    tf-сессия
    :param lr:         скорость обучения
    :param lmbda:      степень влияния важностей на обучение
    :return:           список усредненных по выученным датасетам оценок
    """
    model = Model(net_struct, session)
    test_sets = []
    accuracies = []
    for data_set in data_sets:
        test_sets.append(data_set)
        model.open_lesson(lr, lmbda)
        accuracy = train_model(model, data_set, test_sets, 100, 4)
        accuracies.append(accuracy)
        model.close_lesson(data_set.validation.images, data_set.validation.labels)
    del model
    return accuracies

# считываем данные MNIST
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# создаем tf-сессию
sess = tf.InteractiveSession()

# создаем 10 различных обучающих наборов для последовательного обучения
mnist0 = mnist
mnist1 = permute_mnist(mnist)
mnist2 = permute_mnist(mnist)
mnist3 = permute_mnist(mnist)
mnist4 = permute_mnist(mnist)
mnist5 = permute_mnist(mnist)
mnist6 = permute_mnist(mnist)
mnist7 = permute_mnist(mnist)
mnist8 = permute_mnist(mnist)
mnist9 = permute_mnist(mnist)

start_time = datetime.datetime.now()

# определим параметры обучения
data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
net_struct = [784, 300, 150, 10]
lmbda = 15.
learning_rate = 0.2

accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
print ('Total time spent', datetime.datetime.now() - start_time)

dataset_num = range(1, len(accuracies) + 1)

# нарисуем график деградации средней точности на всех выученных датасетах
plt.figure(figsize=(7, 3.5))
plt.ylim(0.40, 1.)
plt.xlim(1, len(accuracies))
plt.ylabel('Total accuracy')
plt.xlabel('Number of tasks')
plt.plot(dataset_num, accuracies, marker=".")
#plt.legend()
plt.show()

Для иллюстрации проблемы катастрофической забывчивости в этом коде можно выставить значение лямбда в 0, и увидеть, как сильно начнет падать средняя точность

Примерно через три месяца изобрели другой способ считать важности весов Synaptic Intelligence (SI) – это насколько менялась функция стоимости при изменениях веса в процессе обучения. В качестве достоинства метода авторы указывают, что важности считаются прямо в процессе обучения сетки. Авторы также утверждают, что с вычисленными таким способом важностями навыки сохраняются лучше, чем у EWC. Но строгих доказательств не приводят, кроме эксперимента, проведенного на конкретных значениях гиперпараметров сети.

В этом способе хорошо то, что он тоже работает, а также, что его использование не зависит от типа выхода сети (то есть выход не обязательно должен быть распределением/softmax). Вычислительно метод существенно легче, чем оригинальный EWC. Но также требует построения дополнительного графа вычислений и неотделимо сопровождает процесс обучения, то есть имеет пропорциональную обучению вычислительную стоимость, что не есть хорошо.

Пример реализации SI тут.
# код рассчитан на python 3.7 и tensorflow 1.15.0

import datetime
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


_epsilon = 0.1


def _weight_variable(shape):
    # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
    stddev = 2. / np.sqrt(shape[0])
    initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
    return tf.Variable(initial, dtype=tf.float32)


def _bias_variable(shape):
    # смещения инициализируем нулями
    initial = tf.constant(0.0, shape=shape)
    return tf.Variable(initial, dtype=tf.float32)


class Model:
    def __init__(self, shape, session):
        """
        :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                        от входа к выходу справа налево, например, [784, 100, 10]
        :param session: tensorflow-сессия для расчетов сети
        """

        self.session = session
        self._shape = shape
        depth = len(shape) - 1
        if depth < 1:
            raise ValueError("Недопустимая структура сети!")

        # заглушки для входных данных
        self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
        self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])

        # все веса слоев сети будем хранить в списке
        self.var_list = []
        for ins, outs in zip(shape[:-1], shape[1:]):
            self.var_list.append(_weight_variable([ins, outs]))
            self.var_list.append(_bias_variable([outs]))

        # инициализируем веса сети
        for v in self.var_list:
            session.run(v.initializer)

        # список для накопления важностей весов сети за текущий урок
        self._accums = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]

        # список для хранения важностей весов сети за все завершенные уроки
        self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]

        # строим вычислительный граф
        x, y, z = self.x, None, None
        for i in range(depth):
            z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
            y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
            x = y

        # функция стоимости
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))

        # точность (accuracy)
        self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))

        self.grads = tf.gradients(self.loss, self.var_list)

        self._train_step = None

    def open_lesson(self, learning_rate=1.0, lmbda=0.0):
        """
        Открытие урока обучения сети на отдельном датасете
        :param learning_rate: скорость обучения для SGD
        :param lmbda:         коэффициент влияния важностей - насколько сильно
                              важности тянут веса к эталонным значениям
        """
        loss = self.loss

        if hasattr(self, "star_vars"):
            if lmbda != 0:
                # добавляем к функции стоимости слагаемые-регуляризаторы
                for v in range(len(self._shape)*2-2):
                    loss += tf.reduce_sum(tf.multiply(
                        tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                        tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                    ))
        else:
            # запоминаем текущие веса
            self.star_vars = [v.eval() for v in self.var_list]

        # устанавливаем шаг оптимизатора
        self._train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

        #
        self.cur_vars = [v.eval() for v in self.var_list]

    def train_step_run(self, feed_dict):
        self._train_step.run(feed_dict=feed_dict)

        # считаем обновленные оптимизатором значения параметров сети
        new_vars = [v.eval() for v in self.var_list]

        # рассчитаем градиенты
        grads = self.session.run(self.grads, feed_dict=feed_dict)

        # аккумулируем изменение функции стоимости (производная
        # по параметру, умноженная на изменение параметра)
        for acc, grad, prev_var, new_var in zip(self._accums, grads, self.cur_vars, new_vars):
            acc -= grad * (new_var - prev_var)

        # мастера tensorflow вероятно смогут из train_step.run вытащить
        # уже подсчитанные градиенты по весам и сами изменения весов чтобы
        # не считать их еще раз и сэкономить вычислительные ресурсы

        # сохраним новые значения весов для следующей итерации
        self.cur_vars = new_vars

    def close_lesson(self):
        """
        Закрытие урока обучения сети. Накопление важностей весов.
        :return:
        """

        # рассчитаем квадраты смещений параметров
        deltas = [np.square(v.eval() - prev_v) + _epsilon
                  for v, prev_v in zip(self.var_list, self.star_vars)]

        # добавляем рассчитанные важности к сохраненным
        for i, a, d in zip(self.wb_importance, self._accums, deltas):
            i += a / d

        # запоминаем текущие эталонные веса сети после обучения
        self.star_vars = [v.eval() for v in self.var_list]


# функция случайным образом переставляет входы одинаково для всех примеров датасета
def permute_mnist(mnist):
    perm_inds = list(range(mnist.train.images.shape[1]))
    np.random.shuffle(perm_inds)
    mnist2 = deepcopy(mnist)
    sets = ["train", "validation", "test"]
    for set_name in sets:
        this_set = getattr(mnist2, set_name)
        this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
    return mnist2


def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
    """
    Обучение модели
    :param model:       обучаемая модель
    :param train_set:   обучающий датасет
    :param test_sets:   список датасетов, на которых будет считаться средняя точность
    :param batch_size:  размер батча
    :param epochs:      количество эпох обучения
    :return:            средняя точность на тестовых датасетах после обучения
    """
    num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
    for idx in range(num_iters):
        train_batch = train_set.train.next_batch(batch_size)
        feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
        model.train_step_run(feed_dict=feed_dict)
        print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')

    print(f'\rTraining  {num_iters}/{num_iters} iterations done.')

    accuracy = 0.
    for t, test_set in enumerate(test_sets):
        feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
        accuracy += model.accuracy.eval(feed_dict=feed_dict)
    accuracy /= len(test_sets)
    print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
    return accuracy


def continual_learning(net_struct, data_sets, session, lr, lmbda):
    """
    Последовательное обучение на нескольких обучающих наборах
    :param net_struct: структура сети
    :param data_sets:  список обучающих датасетов для последовательного обучения
    :param session:    tf-сессия
    :param lr:         скорость обучения
    :param lmbda:      степень влияния важностей на обучение
    :return:           список усредненных по выученным датасетам оценок
    """
    model = Model(net_struct, session)
    test_sets = []
    accuracies = []
    for data_set in data_sets:
        test_sets.append(data_set)
        model.open_lesson(lr, lmbda)
        accuracy = train_model(model, data_set, test_sets, 100, 4)
        accuracies.append(accuracy)
        model.close_lesson()
    del model
    return accuracies

# считываем данные MNIST
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# создаем tf-сессию
sess = tf.InteractiveSession()

# создаем 10 различных обучающих наборов для последовательного обучения
mnist0 = mnist
mnist1 = permute_mnist(mnist)
mnist2 = permute_mnist(mnist)
mnist3 = permute_mnist(mnist)
mnist4 = permute_mnist(mnist)
mnist5 = permute_mnist(mnist)
mnist6 = permute_mnist(mnist)
mnist7 = permute_mnist(mnist)
mnist8 = permute_mnist(mnist)
mnist9 = permute_mnist(mnist)

start_time = datetime.datetime.now()

# определим параметры обучения
data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
net_struct = [784, 300, 150, 10]
lmbda = 0.1
learning_rate = 0.02

accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
print ('Total time spent', datetime.datetime.now() - start_time)

dataset_num = range(1, len(accuracies) + 1)

# нарисуем график деградации средней точности на всех выученных датасетах
plt.figure(figsize=(7, 3.5))
plt.ylim(0.40, 1.)
plt.xlim(1, len(accuracies))
plt.ylabel('Total accuracy')
plt.xlabel('Number of tasks')
plt.plot(dataset_num, accuracies, marker=".")
#plt.legend()
plt.show()

Еще примерно через полгода важность весов предложили считать по степени зависимости выходных сигналов обученной сетки от весов на каждом примере датасета, что довольно логично. Этот метод назвали Memory Aware Synapses (MAS). Метод показывает довольно хорошие результаты – авторы утверждают, что лучше, чем у обоих предыдущих, но также без строгого доказательства. Метод не требует softmax-выхода сети. Полноценная версия MAS вычислительно сравнима с EWC. Однако, в отличие от SI, отделима от обучения.

Интересно, что при использовании нейросетки с полносвязными слоями и ReLU-активациями в методе MAS важность связи между нейронами равна произведению выхода нейрона-источника на выход нейрона-приемника. И это прямо один в один обучение по Хеббу, только в хеббовом обучении так считается изменение веса связи, а в MAS это важность связи.

Пример реализации полноценного MAS тут.
# код рассчитан на python 3.7 и tensorflow 1.15.0

import datetime
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def _weight_variable(shape):
    # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
    stddev = 2. / np.sqrt(shape[0])
    initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
    return tf.Variable(initial, dtype=tf.float32)


def _bias_variable(shape):
    # смещения инициализируем нулями
    initial = tf.constant(0.0, shape=shape)
    return tf.Variable(initial, dtype=tf.float32)


class Model:
    def __init__(self, shape, session):
        """
        :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                        от входа к выходу справа налево, например, [784, 100, 10]
        :param session: tensorflow-сессия для расчетов сети
        """

        self.session = session
        self._shape = shape
        depth = len(shape) - 1
        if depth < 1:
            raise ValueError("Недопустимая структура сети!")

        # заглушки для входных данных
        self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
        self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])

        # все веса слоев сети будем хранить в списке
        self.var_list = []
        for ins, outs in zip(shape[:-1], shape[1:]):
            self.var_list.append(_weight_variable([ins, outs]))
            self.var_list.append(_bias_variable([outs]))

        # инициализируем веса сети
        for v in self.var_list:
            session.run(v.initializer)

        # список для хранения важностей весов сети
        self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]

        # строим вычислительный граф
        x, y, z = self.x, None, None
        for i in range(depth):
            z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
            y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
            x = y

        # функция стоимости
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))

        # точность (accuracy)
        self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))

        # сумма квадратов выходов сети
        F = tf.reduce_sum(tf.square(y[0]))

        # вычисляем градиенты значений F по весам и берем от них модуль
        self.cur_importances = [tf.abs(grad) for grad in tf.gradients(F, self.var_list)]

        self.train_step = None

    def open_lesson(self, learning_rate=1.0, lmbda=0.0):
        """
        Открытие урока обучения сети на отдельном датасете
        :param learning_rate: скорость обучения для SGD
        :param lmbda:         коэффициент влияния важностей - насколько сильно
                              важности тянут веса к эталонным значениям
        """
        loss = self.loss

        if hasattr(self, "star_vars") and lmbda != 0:
            # добавляем к функции стоимости слагаемые-регуляризаторы
            for v in range(len(self._shape)*2-2):
                loss += tf.reduce_sum(tf.multiply(
                    tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                    tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                ))

        # устанавливаем шаг оптимизатора
        self.train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    def close_lesson(self, closing_set=None):
        """
        Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
        :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
        :return:
        """

        # рассчитываем важности весов на закрывающем датасете
        addendum = self._compute_mas(closing_set)

        # добавляем рассчитанные важности к сохраненным
        for i, a in zip(self.wb_importance, addendum):
            i += a

        # запоминаем текущие эталонные веса сети после обучения
        self.star_vars = [v.eval() for v in self.var_list]


    def _compute_mas(self, closing_set):
        # вычисление диагональных элементов информационной
        # матрицы Фишера для весов сети на заданном датасете
        num_samples = len(closing_set)

        # инициализируем значения нулями
        mas = [np.zeros(self.var_list[v].shape, dtype=np.float32) for v in range(len(self.var_list))]

        for i in range(num_samples):
            # вычисляем первые производные логарифма вероятности
            feed_dict = {self.x: closing_set[i:i + 1]}
            cur_importances = self.session.run(self.cur_importances, feed_dict=feed_dict)
            # возводим их в квадрат т. к. это диагональные элементы
            for f, d in zip(mas, cur_importances):
                f += d

        # усредняем по количеству примеров в датасете.
        # такой подход применим если мы хотим чтоб каждый урок оказывал одинаковый вклад в важности.
        # если же мы хотим чтоб каждый пример при закрытии урока оказывал одинаковое влияние, то
        # нужно делить на некую подобранную константу, а не на число примеров.
        for f in mas:
            f /= num_samples

        return mas


# функция случайным образом переставляет входы одинаково для всех примеров датасета
def permute_mnist(mnist):
    perm_inds = list(range(mnist.train.images.shape[1]))
    np.random.shuffle(perm_inds)
    mnist2 = deepcopy(mnist)
    sets = ["train", "validation", "test"]
    for set_name in sets:
        this_set = getattr(mnist2, set_name)
        this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
    return mnist2


def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
    """
    Обучение модели
    :param model:       обучаемая модель
    :param train_set:   обучающий датасет
    :param test_sets:   список датасетов, на которых будет считаться средняя точность
    :param batch_size:  размер батча
    :param epochs:      количество эпох обучения
    :return:            средняя точность на тестовых датасетах после обучения
    """
    num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
    for idx in range(num_iters):
        train_batch = train_set.train.next_batch(batch_size)
        feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
        model.train_step.run(feed_dict=feed_dict)
        print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')

    print(f'\rTraining  {num_iters}/{num_iters} iterations done.')

    accuracy = 0.
    for t, test_set in enumerate(test_sets):
        feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
        accuracy += model.accuracy.eval(feed_dict=feed_dict)
    accuracy /= len(test_sets)
    print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
    return accuracy


def continual_learning(net_struct, data_sets, session, lr, lmbda):
    """
    Последовательное обучение на нескольких обучающих наборах
    :param net_struct: структура сети
    :param data_sets:  список обучающих датасетов для последовательного обучения
    :param session:    tf-сессия
    :param lr:         скорость обучения
    :param lmbda:      степень влияния важностей на обучение
    :return:           список усредненных по выученным датасетам оценок
    """
    model = Model(net_struct, session)
    test_sets = []
    accuracies = []
    for data_set in data_sets:
        test_sets.append(data_set)
        model.open_lesson(lr, lmbda)
        accuracy = train_model(model, data_set, test_sets, 100, 4)
        accuracies.append(accuracy)
        model.close_lesson(data_set.validation.images)
    del model
    return accuracies

# считываем данные MNIST
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# создаем tf-сессию
sess = tf.InteractiveSession()

# создаем 10 различных обучающих наборов для последовательного обучения
mnist0 = mnist
mnist1 = permute_mnist(mnist)
mnist2 = permute_mnist(mnist)
mnist3 = permute_mnist(mnist)
mnist4 = permute_mnist(mnist)
mnist5 = permute_mnist(mnist)
mnist6 = permute_mnist(mnist)
mnist7 = permute_mnist(mnist)
mnist8 = permute_mnist(mnist)
mnist9 = permute_mnist(mnist)

start_time = datetime.datetime.now()

# определим параметры обучения
data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
net_struct = [784, 300, 150, 10]
lmbda = 5.
learning_rate = 0.2

accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
print ('Total time spent', datetime.datetime.now() - start_time)

dataset_num = range(1, len(accuracies) + 1)

# нарисуем график деградации средней точности на всех выученных датасетах
plt.figure(figsize=(7, 3.5))
plt.ylim(0.40, 1.)
plt.xlim(1, len(accuracies))
plt.ylabel('Total accuracy')
plt.xlabel('Number of tasks')
plt.plot(dataset_num, accuracies, marker=".")
#plt.legend()
plt.show()

Наконец недавно автор (этого опуса) экспериментально обнаружил, что в качестве важности веса связи можно использовать просто суммарный по модулю сигнал, прошедший через связь обученной сетки в процессе пропускания через нее обучающего набора. И такой метод EWC-signal (EWC-S) тоже работает, хоть и немногим хуже, чем предыдущие методы EWC, SI и MAS. В смысле же вычислительной сложности этот метод заслуженно можно назвать китайским – настолько он вычислительно дешев. Он не требует ни выхода сети в виде распределения/softmax, ни расчета каких-либо производных или градиентов. Кроме того, с его помощью важности можно считать прямо в процессе завершающих этапов обучения. Однако, для этого метода tensorflow и torch не построят вычисления для каждой важности за вас, и код придется писать руками.

Пример реализации EWC-S тут.
# код рассчитан на python 3.7 и tensorflow 1.15.0

import datetime
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def _weight_variable(shape):
    # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
    stddev = 2. / np.sqrt(shape[0])
    initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
    return tf.Variable(initial, dtype=tf.float32)


def _bias_variable(shape):
    # смещения инициализируем нулями
    initial = tf.constant(0.0, shape=shape)
    return tf.Variable(initial, dtype=tf.float32)


class Model:
    def __init__(self, shape, session):
        """
        :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                        от входа к выходу справа налево, например, [784, 100, 10]
        :param session: tensorflow-сессия для расчетов сети
        """

        self.session = session
        self._shape = shape
        depth = len(shape) - 1
        if depth < 1:
            raise ValueError("Недопустимая структура сети!")

        # заглушки для входных данных
        self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
        self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])

        # все веса слоев сети будем хранить в списке
        self.var_list = []
        for ins, outs in zip(shape[:-1], shape[1:]):
            self.var_list.append(_weight_variable([ins, outs]))
            self.var_list.append(_bias_variable([outs]))

        # инициализируем веса сети
        for v in self.var_list:
            session.run(v.initializer)

        # список для хранения важностей весов сети
        self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]

        # строим вычислительный граф
        outputs = []
        x, y, z = self.x, None, None
        for i in range(depth):
            z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
            y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
            outputs.append(y)
            x = y

        # функция стоимости
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))

        # точность (accuracy)
        self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))

        # вычисляем суммарный по модулю прошедший сигнал
        self.signals = []

        os = tf.reduce_mean(tf.abs(self.x), axis=0)
        for i in range(depth):
            ws = tf.transpose(tf.multiply(os, tf.transpose(tf.abs(self.var_list[i*2]))))
            self.signals.append(ws)
            os = tf.reduce_mean(tf.abs(outputs[i]), axis=0)
            self.signals.append(os)

        self.train_step = None

    def open_lesson(self, learning_rate=1.0, lmbda=0.0):
        """
        Открытие урока обучения сети на отдельном датасете
        :param learning_rate: скорость обучения для SGD
        :param lmbda:         коэффициент влияния важностей - насколько сильно
                              важности тянут веса к эталонным значениям
        """
        loss = self.loss

        if hasattr(self, "star_vars") and lmbda != 0:
            # добавляем к функции стоимости слагаемые-регуляризаторы
            for v in range(len(self._shape)*2-2):
                loss += tf.reduce_sum(tf.multiply(
                    tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                    tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                ))

        # устанавливаем шаг оптимизатора
        self.train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

    def close_lesson(self, closing_set=None):
        """
        Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
        :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
        :return:
        """

        # рассчитываем важности весов на закрывающем датасете
        addendum = self.session.run(self.signals, feed_dict={self.x: closing_set})

        # добавляем рассчитанные важности к сохраненным
        for i, a in zip(self.wb_importance, addendum):
            i += a

        # запоминаем текущие эталонные веса сети после обучения
        self.star_vars = [v.eval() for v in self.var_list]


# функция случайным образом переставляет входы одинаково для всех примеров датасета
def permute_mnist(mnist):
    perm_inds = list(range(mnist.train.images.shape[1]))
    np.random.shuffle(perm_inds)
    mnist2 = deepcopy(mnist)
    sets = ["train", "validation", "test"]
    for set_name in sets:
        this_set = getattr(mnist2, set_name)
        this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
    return mnist2


def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
    """
    Обучение модели
    :param model:       обучаемая модель
    :param train_set:   обучающий датасет
    :param test_sets:   список датасетов, на которых будет считаться средняя точность
    :param batch_size:  размер батча
    :param epochs:      количество эпох обучения
    :return:            средняя точность на тестовых датасетах после обучения
    """
    num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
    for idx in range(num_iters):
        train_batch = train_set.train.next_batch(batch_size)
        feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
        model.train_step.run(feed_dict=feed_dict)
        print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')

    print(f'\rTraining  {num_iters}/{num_iters} iterations done.')

    accuracy = 0.
    for t, test_set in enumerate(test_sets):
        feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
        accuracy += model.accuracy.eval(feed_dict=feed_dict)
    accuracy /= len(test_sets)
    print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
    return accuracy


def continual_learning(net_struct, data_sets, session, lr, lmbda):
    """
    Последовательное обучение на нескольких обучающих наборах
    :param net_struct: структура сети
    :param data_sets:  список обучающих датасетов для последовательного обучения
    :param session:    tf-сессия
    :param lr:         скорость обучения
    :param lmbda:      степень влияния важностей на обучение
    :return:           список усредненных по выученным датасетам оценок
    """
    model = Model(net_struct, session)
    test_sets = []
    accuracies = []
    for data_set in data_sets:
        test_sets.append(data_set)
        model.open_lesson(lr, lmbda)
        accuracy = train_model(model, data_set, test_sets, 100, 4)
        accuracies.append(accuracy)
        model.close_lesson(data_set.validation.images)
    del model
    return accuracies

# считываем данные MNIST
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# создаем tf-сессию
sess = tf.InteractiveSession()

# создаем 10 различных обучающих наборов для последовательного обучения
mnist0 = mnist
mnist1 = permute_mnist(mnist)
mnist2 = permute_mnist(mnist)
mnist3 = permute_mnist(mnist)
mnist4 = permute_mnist(mnist)
mnist5 = permute_mnist(mnist)
mnist6 = permute_mnist(mnist)
mnist7 = permute_mnist(mnist)
mnist8 = permute_mnist(mnist)
mnist9 = permute_mnist(mnist)

start_time = datetime.datetime.now()

# определим параметры обучения
data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
net_struct = [784, 300, 150, 10]
lmbda = 0.75
learning_rate = 0.2

accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
print ('Total time spent', datetime.datetime.now() - start_time)

dataset_num = range(1, len(accuracies) + 1)

# нарисуем график деградации средней точности на всех выученных датасетах
plt.figure(figsize=(7, 3.5))
plt.ylim(0.40, 1.)
plt.xlim(1, len(accuracies))
plt.ylabel('Total accuracy')
plt.xlabel('Number of tasks')
plt.plot(dataset_num, accuracies, marker=".")
#plt.legend()
plt.show()

Также было экспериментально обнаружено, что, для сохранения выученных навыков нейросетки, можно не только привязывать веса к эталонным на резиночках разной упругости (где упругость пропорциональна важности связи), но и просто замедлять скорость изменения (то есть градиенты) весов, как будто для них увеличивается сила трения при сдвиге пропорционально важности веса связи. Это опять слегка ухудшает способность метода сохранять навыки при последовательном обучении, но сильно экономит память, потому что не надо хранить эталонные веса. Такой метод называется Weight Velocity Attenuation (WVA).

Пример реализации WVA на базе суммарного по модулю сигнала тут.
# код рассчитан на python 3.7 и tensorflow 1.15.0

import datetime
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.training.optimizer import Optimizer


def _weight_variable(shape):
    # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
    stddev = 2. / np.sqrt(shape[0])
    initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
    return tf.Variable(initial, dtype=tf.float32)


def _bias_variable(shape):
    # смещения инициализируем нулями
    initial = tf.constant(0.0, shape=shape)
    return tf.Variable(initial, dtype=tf.float32)


class _WVA_SGD(tf.train.GradientDescentOptimizer):

    def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
        super(_WVA_SGD, self).__init__(learning_rate, use_locking, name)

    def minimize(self, loss, global_step=None, var_list=None,
                 gate_gradients=Optimizer.GATE_OP, aggregation_method=None,
                 colocate_gradients_with_ops=False, name=None,
                 grad_loss=None, impacts=None):
        """ comments """
        grads_and_vars = self.compute_gradients(
            loss, var_list=var_list, gate_gradients=gate_gradients,
            aggregation_method=aggregation_method,
            colocate_gradients_with_ops=colocate_gradients_with_ops,
            grad_loss=grad_loss)

        vars_with_grad = [v for g, v in grads_and_vars if g is not None]
        if not vars_with_grad:
            raise ValueError(
                "No gradients provided for any variable, check your graph for ops"
                " that do not support gradients, between variables %s and loss %s." %
                ([str(v) for _, v in grads_and_vars], loss))

        if impacts is None:
            processed_grads_and_vars = grads_and_vars
        else:
            impact = iter(impacts)
            processed_grads_and_vars = []
            for g, v in grads_and_vars:
                if g is None:
                    processed_grads_and_vars.append((g, v))
                else:
                    processed_grads_and_vars.append((tf.multiply(g, next(impact)), v))

        return self.apply_gradients(processed_grads_and_vars,
                                    global_step=global_step,
                                    name=name)


class Model:
    def __init__(self, shape, session):
        """
        :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                        от входа к выходу справа налево, например, [784, 100, 10]
        :param session: tensorflow-сессия для расчетов сети
        """

        self.session = session
        self._shape = shape
        depth = len(shape) - 1
        if depth < 1:
            raise ValueError("Недопустимая структура сети!")

        # заглушки для входных данных
        self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
        self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])

        # все веса слоев сети будем хранить в списке
        self.var_list = []
        for ins, outs in zip(shape[:-1], shape[1:]):
            self.var_list.append(_weight_variable([ins, outs]))
            self.var_list.append(_bias_variable([outs]))

        # инициализируем веса сети
        for v in self.var_list:
            session.run(v.initializer)

        # список для хранения важностей весов сети
        self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]

        # строим вычислительный граф
        outputs = []
        x, y, z = self.x, None, None
        for i in range(depth):
            z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
            y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
            outputs.append(y)
            x = y

        # функция стоимости
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))

        # точность (accuracy)
        self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))

        # вычисляем суммарный по модулю прошедший сигнал
        self.signals = []

        os = tf.reduce_mean(tf.abs(self.x), axis=0)
        for i in range(depth):
            ws = tf.transpose(tf.multiply(os, tf.transpose(tf.abs(self.var_list[i*2]))))
            self.signals.append(ws)
            os = tf.reduce_mean(tf.abs(outputs[i]), axis=0)
            self.signals.append(os)

        self.train_step = None

    def open_lesson(self, learning_rate=1.0, lmbda=0.0):
        """
        Открытие урока обучения сети на отдельном датасете
        :param learning_rate: скорость обучения для SGD
        :param lmbda:         коэффициент влияния важностей - насколько сильно
                              важности тянут веса к эталонным значениям
        """
        impacts = [tf.constant(1. / (1. + lmbda * v)) for v in self.wb_importance]

        # устанавливаем шаг оптимизатора
        self.train_step = _WVA_SGD(learning_rate).minimize(self.loss, impacts=impacts)

    def close_lesson(self, closing_set=None):
        """
        Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
        :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
        :return:
        """

        # рассчитываем важности весов на закрывающем датасете
        addendum = self.session.run(self.signals, feed_dict={self.x: closing_set})

        # добавляем рассчитанные важности к сохраненным
        for i, a in zip(self.wb_importance, addendum):
            i += a


# функция случайным образом переставляет входы одинаково для всех примеров датасета
def permute_mnist(mnist):
    perm_inds = list(range(mnist.train.images.shape[1]))
    np.random.shuffle(perm_inds)
    mnist2 = deepcopy(mnist)
    sets = ["train", "validation", "test"]
    for set_name in sets:
        this_set = getattr(mnist2, set_name)
        this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
    return mnist2


def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
    """
    Обучение модели
    :param model:       обучаемая модель
    :param train_set:   обучающий датасет
    :param test_sets:   список датасетов, на которых будет считаться средняя точность
    :param batch_size:  размер батча
    :param epochs:      количество эпох обучения
    :return:            средняя точность на тестовых датасетах после обучения
    """
    num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
    for idx in range(num_iters):
        train_batch = train_set.train.next_batch(batch_size)
        feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
        model.train_step.run(feed_dict=feed_dict)
        print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')

    print(f'\rTraining  {num_iters}/{num_iters} iterations done.')

    accuracy = 0.
    for t, test_set in enumerate(test_sets):
        feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
        accuracy += model.accuracy.eval(feed_dict=feed_dict)
    accuracy /= len(test_sets)
    print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
    return accuracy


def continual_learning(net_struct, data_sets, session, lr, lmbda):
    """
    Последовательное обучение на нескольких обучающих наборах
    :param net_struct: структура сети
    :param data_sets:  список обучающих датасетов для последовательного обучения
    :param session:    tf-сессия
    :param lr:         скорость обучения
    :param lmbda:      степень влияния важностей на обучение
    :return:           список усредненных по выученным датасетам оценок
    """
    model = Model(net_struct, session)
    test_sets = []
    accuracies = []
    for data_set in data_sets:
        test_sets.append(data_set)
        model.open_lesson(lr, lmbda)
        accuracy = train_model(model, data_set, test_sets, 100, 4)
        accuracies.append(accuracy)
        model.close_lesson(data_set.validation.images)
    del model
    return accuracies

# считываем данные MNIST
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# создаем tf-сессию
sess = tf.InteractiveSession()

# создаем 10 различных обучающих наборов для последовательного обучения
mnist0 = mnist
mnist1 = permute_mnist(mnist)
mnist2 = permute_mnist(mnist)
mnist3 = permute_mnist(mnist)
mnist4 = permute_mnist(mnist)
mnist5 = permute_mnist(mnist)
mnist6 = permute_mnist(mnist)
mnist7 = permute_mnist(mnist)
mnist8 = permute_mnist(mnist)
mnist9 = permute_mnist(mnist)

start_time = datetime.datetime.now()

# определим параметры обучения
data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
net_struct = [784, 300, 150, 10]
lmbda = 250.
learning_rate = 0.2

accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
print ('Total time spent', datetime.datetime.now() - start_time)

dataset_num = range(1, len(accuracies) + 1)

# нарисуем график деградации средней точности на всех выученных датасетах
plt.figure(figsize=(7, 3.5))
plt.ylim(0.40, 1.)
plt.xlim(1, len(accuracies))
plt.ylabel('Total accuracy')
plt.xlabel('Number of tasks')
plt.plot(dataset_num, accuracies, marker=".")
#plt.legend()
plt.show()

Резюмируем. Если вы планируете учить вашу нейросетку последовательно и разному, и у вас мало времени и памяти, то стоит использовать WVA. Если времени мало, но памяти завались, то стоит посмотреть на EWC-S. Если времени вагон, а памяти мало, то стоит важности весов рассчитывать как в MAS, а использовать как в WVA, то есть сделать гибрид WVA-MAS. Если есть и время, и память, и требуется наилучшее сохранение навыков без компромиссов, и код писать ну очень лениво, то стоит использовать полноценный MAS.

P.S. Подозреваю, что MAS будет выбираться чаще всего именно по последней причине…

P.P.S. У всех перечисленных методов есть одна тонкость – они работают только если для каждого выхода сети в каждом датасете для последовательного обучения есть примеры его (выход) активирующие. Если датасет содержит примеры, активирующие только часть выходов, нужно применять специальные трюки (считать функцию потерь только на активируемых в датасете выходах). Подробности можно посмотреть в статье про Synaptic Intelligence – см. Split MNIST.

Теги:
Хабы:
+14
Комментарии8

Публикации

Истории

Работа

Data Scientist
63 вакансии
Python разработчик
142 вакансии

Ближайшие события