Pull to refresh

Некоторые аспекты качества обучающих последовательностей

Reading time8 min
Views2.1K
На Хабре появился ряд статей о качестве образования и как процесса и как результата (уровень выпускников).

Тема заинтересовала и руки зачесались проверить, а как это устроено у пчелок роботов искусственного интеллекта, влияет ли качество обучающей последовательности на результат.

Была выбрана простая сеть из примеров Keras в которую добавил одну строку. Нас интересует насколько упорядоченность входной обучающей последовательности mnist влияет на результат обучения MLP.

Результат получился неожиданным и странным, пришлось перепроверять многократно, но перейдем к делу и конкретике.

Идея эксперимента проста и обычна — обучаем MLP из keras на общедоступном mnist и получаем ориентир, после обучаем на последовательностях 01234567890123..7890123. Как студентов учат — немного бейсик, немного ассемблер, немного fortran, и т.д. и сравним с исходным обучением. Результат вполне ожидаем, исходная последовательность учит лучше, но порядок такой же. Вот график 64 испытаний


А теперь будем учить сеть так, подаем все картинки с «0», потом все «1», потом все «2» и так до «9» и результат получается никакой!, сеть просто не учится. Интуитивно ожидаешь результат сопоставимый, хуже или лучше — это уже детали, но вот таблица результатов обучения 64 раза

Натив и эксперименты
step 0
('Test accuracy:', 0.9708)
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.1009)
step 1
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.97219999999999995)
('Test accuracy:', 0.1009)
step 2
('Test accuracy:', 0.97330000000000005)
('Test accuracy:', 0.97609999999999997)
('Test accuracy:', 0.1028)
step 3
('Test accuracy:', 0.97040000000000004)
('Test accuracy:', 0.97160000000000002)
('Test accuracy:', 0.1135)
step 4
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.97050000000000003)
('Test accuracy:', 0.098199999999999996)
step 5
('Test accuracy:', 0.96999999999999997)
('Test accuracy:', 0.96909999999999996)
('Test accuracy:', 0.1009)
step 6
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.97540000000000004)
('Test accuracy:', 0.1028)
step 7
('Test accuracy:', 0.97360000000000002)
('Test accuracy:', 0.97350000000000003)
('Test accuracy:', 0.1135)
step 8
('Test accuracy:', 0.97740000000000005)
('Test accuracy:', 0.97109999999999996)
('Test accuracy:', 0.1135)
step 9
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.97089999999999999)
('Test accuracy:', 0.1135)
step 10
('Test accuracy:', 0.96930000000000005)
('Test accuracy:', 0.9708)
('Test accuracy:', 0.1028)
step 11
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97099999999999997)
('Test accuracy:', 0.1135)
step 12
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97519999999999996)
('Test accuracy:', 0.1009)
step 13
('Test accuracy:', 0.97719999999999996)
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.1135)
step 14
('Test accuracy:', 0.97489999999999999)
('Test accuracy:', 0.97189999999999999)
('Test accuracy:', 0.1135)
step 15
('Test accuracy:', 0.9758)
('Test accuracy:', 0.97219999999999995)
('Test accuracy:', 0.10489999999999999)
step 16
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97529999999999994)
('Test accuracy:', 0.1135)
step 17
('Test accuracy:', 0.97819999999999996)
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.1009)
step 18
('Test accuracy:', 0.97850000000000004)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.1009)
step 19
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.0974)
step 20
('Test accuracy:', 0.97699999999999998)
('Test accuracy:', 0.97319999999999995)
('Test accuracy:', 0.1135)
step 21
('Test accuracy:', 0.97309999999999997)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.1009)
step 22
('Test accuracy:', 0.97560000000000002)
('Test accuracy:', 0.97519999999999996)
('Test accuracy:', 0.1135)
step 23
('Test accuracy:', 0.97619999999999996)
('Test accuracy:', 0.97450000000000003)
('Test accuracy:', 0.1009)
step 24
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.97430000000000005)
('Test accuracy:', 0.1028)
step 25
('Test accuracy:', 0.97609999999999997)
('Test accuracy:', 0.97599999999999998)
('Test accuracy:', 0.1135)
step 26
('Test accuracy:', 0.97840000000000005)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.1028)
step 27
('Test accuracy:', 0.96909999999999996)
('Test accuracy:', 0.97019999999999995)
('Test accuracy:', 0.1135)
step 28
('Test accuracy:', 0.9738)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.1009)
step 29
('Test accuracy:', 0.97460000000000002)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.1135)
step 30
('Test accuracy:', 0.97640000000000005)
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.1042)
step 31
('Test accuracy:', 0.97409999999999997)
('Test accuracy:', 0.95650000000000002)
('Test accuracy:', 0.089200000000000002)
step 32
('Test accuracy:', 0.97689999999999999)
('Test accuracy:', 0.97109999999999996)
('Test accuracy:', 0.1135)
step 33
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.97340000000000004)
('Test accuracy:', 0.1009)
step 34
('Test accuracy:', 0.97699999999999998)
('Test accuracy:', 0.97150000000000003)
('Test accuracy:', 0.1135)
step 35
('Test accuracy:', 0.97250000000000003)
('Test accuracy:', 0.97140000000000004)
('Test accuracy:', 0.1009)
step 36
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.96950000000000003)
('Test accuracy:', 0.1055)
step 37
('Test accuracy:', 0.97519999999999996)
('Test accuracy:', 0.96509999999999996)
('Test accuracy:', 0.1135)
step 38
('Test accuracy:', 0.97299999999999998)
('Test accuracy:', 0.9728)
('Test accuracy:', 0.1028)
step 39
('Test accuracy:', 0.96909999999999996)
('Test accuracy:', 0.97240000000000004)
('Test accuracy:', 0.1009)
step 40
('Test accuracy:', 0.97399999999999998)
('Test accuracy:', 0.96479999999999999)
('Test accuracy:', 0.1135)
step 41
('Test accuracy:', 0.97799999999999998)
('Test accuracy:', 0.97319999999999995)
('Test accuracy:', 0.1135)
step 42
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.96340000000000003)
('Test accuracy:', 0.1009)
step 43
('Test accuracy:', 0.97740000000000005)
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.1009)
step 44
('Test accuracy:', 0.97160000000000002)
('Test accuracy:', 0.97389999999999999)
('Test accuracy:', 0.1135)
step 45
('Test accuracy:', 0.97599999999999998)
('Test accuracy:', 0.97360000000000002)
('Test accuracy:', 0.1033)
step 46
('Test accuracy:', 0.97389999999999999)
('Test accuracy:', 0.97019999999999995)
('Test accuracy:', 0.1135)
step 47
('Test accuracy:', 0.97650000000000003)
('Test accuracy:', 0.97619999999999996)
('Test accuracy:', 0.10290000000000001)
step 48
('Test accuracy:', 0.97409999999999997)
('Test accuracy:', 0.9647)
('Test accuracy:', 0.1009)
step 49
('Test accuracy:', 0.97240000000000004)
('Test accuracy:', 0.97450000000000003)
('Test accuracy:', 0.1135)
step 50
('Test accuracy:', 0.97570000000000001)
('Test accuracy:', 0.97040000000000004)
('Test accuracy:', 0.1135)
step 51
('Test accuracy:', 0.97250000000000003)
('Test accuracy:', 0.97219999999999995)
('Test accuracy:', 0.1135)
step 52
('Test accuracy:', 0.97230000000000005)
('Test accuracy:', 0.97309999999999997)
('Test accuracy:', 0.1135)
step 53
('Test accuracy:', 0.9758)
('Test accuracy:', 0.97230000000000005)
('Test accuracy:', 0.1135)
step 54
('Test accuracy:', 0.97770000000000001)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.089200000000000002)
step 55
('Test accuracy:', 0.97340000000000004)
('Test accuracy:', 0.96919999999999995)
('Test accuracy:', 0.1135)
step 56
('Test accuracy:', 0.97170000000000001)
('Test accuracy:', 0.97070000000000001)
('Test accuracy:', 0.1028)
step 57
('Test accuracy:', 0.97670000000000001)
('Test accuracy:', 0.97330000000000005)
('Test accuracy:', 0.1135)
step 58
('Test accuracy:', 0.97589999999999999)
('Test accuracy:', 0.97370000000000001)
('Test accuracy:', 0.1033)
step 59
('Test accuracy:', 0.9748)
('Test accuracy:', 0.97419999999999995)
('Test accuracy:', 0.10290000000000001)
step 60
('Test accuracy:', 0.97409999999999997)
('Test accuracy:', 0.97099999999999997)
('Test accuracy:', 0.1009)
step 61
('Test accuracy:', 0.9758)
('Test accuracy:', 0.97450000000000003)
('Test accuracy:', 0.1135)
step 62
('Test accuracy:', 0.97529999999999994)
('Test accuracy:', 0.97260000000000002)
('Test accuracy:', 0.1028)
step 63
('Test accuracy:', 0.97240000000000004)
('Test accuracy:', 0.96809999999999996)
('Test accuracy:', 0.1135)

Не знаю как у людей, но тут чисто только MLP и получается, что обучать её можно не абы как, не на всех последовательностях.

Обучать ИИ так, как обучают в некоторых местах: сначала только бейсик, потом только фортран, потом только ассемблер и т.д. не приведет к успеху. / :-) / Если выявленная особенность присуща всем процессам обучения, как людям так и роботам, то все программы вузов нужно тщательно исследовать.

Исходный текст
from keras import backend as K_B
from keras.datasets import mnist
from keras.layers import Input, Dense, Dropout
from keras.models import Sequential
from keras.optimizers import RMSprop
from keras.utils import np_utils
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

def MLP(ind): 

    model = Sequential()
    model.add(Dense(512, activation='relu', input_shape=(width * height,)))
    model.add(Dropout(0.2))
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(num_classes, activation='softmax'))

    model.compile(loss='categorical_crossentropy',
                  optimizer=RMSprop(),
                  metrics=['accuracy'])

    if (ind == 0):  # начальные веса в каждой тройке испытаний должны быть одинаковыми
        model.save_weights('weights.h5') # натив, сохраняем веса
    else:
        model.load_weights('weights.h5', by_name = False) # эксперимент, восстанавливаем веса

    history = model.fit(X_train, Y_train,
                        shuffle = False,     # добавлена эта строка что бы запретить keras перемешивать батч
                        batch_size=batch_size,
                        epochs=epochs,
                        verbose=0,
                        validation_data=(X_test, Y_test))
    score = model.evaluate(X_test, Y_test, verbose=0)
    print('Test accuracy:', score[1])
    K_B.clear_session() 
# сессию уничтожаем в возрождаем преднамеренно, что бы начальные веса отличались 

    return(score[1])

batch_size = 12
epochs = 12
hidden_size = 512

(X_train, y_train), (X_test, y_test) = mnist.load_data()

num_train, width, height = X_train.shape
num_test = X_test.shape[0]
num_classes = np.unique(y_train).shape[0]

X_train = X_train.astype('float32') 
X_test = X_test.astype('float32')
X_train /= 255.
X_test /= 255.
X_train = X_train.reshape(num_train, height * width)
X_test = X_test.reshape(num_test, height * width)

XX_train = np.copy(X_train)
yy_train = np.copy(y_train)

Y_train = np_utils.to_categorical(y_train, num_classes)
Y_test = np_utils.to_categorical(y_test, num_classes)

steps = 64
st = np.arange(steps, dtype='int')
res_N = np.arange((steps), dtype='float')
res_1 = np.arange((steps), dtype='float')
res_2 = np.arange((steps), dtype='float')

for n in xrange(steps):    
# __  натив

    X_train = np.copy(XX_train)
    y_train = np.copy(yy_train)
    Y_train = np_utils.to_categorical(y_train, num_classes)
    print ' step ', n
    res_N[n] = MLP(0)

# __ 00..0011..1122..2233.. .. 8899..99
    perm = np.arange(num_train, dtype='int')
    cl = np.zeros(num_classes, dtype='int')

    for k in xrange(num_train):
        if (cl[yy_train[k]] * num_classes + yy_train[k] < num_train):
            perm[ cl[yy_train[k]] * num_classes + yy_train[k] ] = k
            cl[yy_train[k]] += 1

    for k in xrange(num_train):
        X_train[k,...] = XX_train[perm[k],...]
    for k in xrange(num_train):
        y_train[k] = yy_train[perm[k]]
    Y_train = np_utils.to_categorical(y_train, num_classes)
    res_2[n] = MLP(2)
# __ 0123..78901..7890123..789
    perm = np.arange(num_train, dtype='int')
    j = 0
    for k in xrange(num_classes):
        for i in xrange(num_train):
            if (yy_train[i] == k):
                perm[j] = i
                j += 1
    for k in xrange(num_train):
        X_train[k,...] = XX_train[perm[k],...]
        y_train[k] = yy_train[perm[k]]
    Y_train = np_utils.to_categorical(y_train, num_classes) 
    res_1[n] = MLP(1)


Другие типы сетей не проверял и эта проверка занимает на моей не очень теслистой тесле много часов.
Tags:
Hubs:
+9
Comments11

Articles

Change theme settings