Pull to refresh

Как я обучал нейросеть для реализации функции оценки положения на Russian AI Cup CodeBall 2018

Reading time 12 min
Views 4.2K
Имея возможность качественно оценить положение в игре в какой-то момент времени и возможность симулировать игровой мир, при создании бота, для одного из решений, остается лишь стремиться совершать такие действия, которые приводят к улучшению этой оценки в ближайшем будущем.

Функция оценки положения — возвращает вещественное значение где меньшее означает худшее. На вход такой функции я подавал только положение и вектор скорости мяча. Изначально эта функция была реализована довольно простыми формулами и парой if-ов. Однако это дало хорошую основу для накрутки на localrunner-е множества логов для последующего обучения нейросети. Так я прокрутил 300 игр (по 18000 тиков) локально, что в сумме дало около 12ГБ логов и плюс к этому 145 логов игр топов было скачано с сервера (5.7гб).

Далее нужно было выделить из этих логов обучающую и тестовую выборки. Делал я это следующим образом: отталкиваясь от забитого гола смотрел в «прошлое» на 300 тиков (5 секунд) и шагом в 5 тиков каждое положение и скорость мяча + эталонную оценку брал за пример.

Важный момент: эталонная оценка (выход) здесь вычислялась по формуле

$O = S/exp(T/60)$

где S = -1 если мяч залетает в «мои» ворота и 1 в обратном случае, а T это время в тиках оставшееся до гола.

Еще один менее важный, но тоже момент: поле игры симметрично и соответственно эталонная оценка тоже должна быть обратно симметричной если смотреть с точки зрения противника. Т.е. если что-то оценивается с «моей» точки зрения как X то тоже самое положение должно оцениваться с точки зрения противника как -X. Это означает что если «сложить пополам» все пространство входа нейросети по любому параметру то сеть будет обучаться лучше, условно говоря, «в 2 раза», а главное она будет выдавать гарантированно обратно симметричный ответ (что есть, как минимум, просто красиво). Я «складывал» по скорости мяча по оси Z. Проще говоря, если мяч летит от «моих» ворот то смотрю со «своей» точки зрения, иначе — с точки зрения противника. Получается что для нейросети мяч всегда летит в положительную сторону по Z. Точно так же можно поступить и для продольной симметрии (по оси X), правда в этом случае продолжаем смотреть с точки зрения той же команды, но, как бы, в зеркале расположенном в плоскости с нормалью (1, 0, 0).

Итак, вот код подготовки тестовой и обучающей выборки из логов на Python:

import json
from pprint import pprint
import glob
import numpy as np
import random

xtrain = []
ytrain = []

xtest = []
ytest = []

f1 = r"F:\Home\Projects\MailRuAI\Codeball2018\LocalRunner\logs_archive\logs_01/*.txt"
f2 = r"F:\Home\Projects\MailRuAI\Codeball2018\LocalRunner\logs_archive\logs_02/*.txt"
f3 = r"F:\Home\Projects\MailRuAI\Codeball2018\LocalRunner\logs_archive\logs_03/*.txt"
f7 = r"F:\Home\Projects\MailRuAI\Codeball2018\downloaded_games/*.txt"
for file in (glob.glob(f1) + glob.glob(f2) + glob.glob(f3) + glob.glob(f7)):

	with open(file) as f:
		content = f.readlines()

	print(len(content))
	print(file)
	sumofscores = 0
	lastscore0 = 0
	lastscore1 = 0
	ticksbackward = 300
	ticksbackwardinc = 5
	for x in range(0, len(content)):
		data = json.loads(content[x])
		if "scores" in data and sum(data["scores"]) > sumofscores:
			sumofscores = sum(data["scores"])
			value = 0

			if data["scores"][0] > lastscore0:
				lastscore0 = data["scores"][0]
				value = 1
			if data["scores"][1] > lastscore1:
				lastscore1 = data["scores"][1]
				value = -1

			for y in range(ticksbackwardinc, ticksbackward, ticksbackwardinc):
				dataY = json.loads(content[x - y])
				if "scores" in dataY and sum(dataY["scores"]) == sumofscores - 1:
					sign = 1
					if dataY['ball']['velocity']['z'] < 0:
						sign = -1
					signX = 1
					if dataY['ball']['velocity']['x'] * sign < 0:
						signX = -1
					inputs = np.zeros(6)

					inputs[0] = dataY['ball']['velocity']['x'] * sign * signX
					inputs[1] = dataY['ball']['velocity']['y'] 
					inputs[2] = dataY['ball']['velocity']['z'] * sign
					inputs[3] = dataY['ball']['position']['x'] * sign * signX
					inputs[4] = dataY['ball']['position']['y'] 
					inputs[5] = dataY['ball']['position']['z'] * sign

					outputs = np.zeros(2)
					outputs[0] = value*sign
					outputs[1] = y
					if (random.random() > 0.2):
						xtrain.append(inputs)
						ytrain.append(outputs)
					else:
						xtest.append(inputs)
						ytest.append(outputs)
				else:
					print("exceeded")

	print(len(xtrain))
	print(len(xtest))

np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtrain_BR.npy", np.asarray(xtrain))
np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytrain_BR.npy", np.asarray(ytrain))
np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtest_BR.npy", np.asarray(xtest))
np.save("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytest_BR.npy", np.asarray(ytest))


Самые внимательные, наверно, уже заметили, что outputs содержит два выхода и вообще не то что я выше описал, но не переживайте, это рудимент и далее следует преобразование перед самим обучением:

import numpy as np
from keras.datasets import boston_housing
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Concatenate, Add
import random
import datetime
np.set_printoptions(edgeitems=50)

xtrain = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtrain_BR.npy")
ytrain = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytrain_BR.npy")
xtest = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/xtest_BR.npy")
ytest = np.load("F:/Home/Projects/MailRuAI/Codeball2018/nnet/ytest_BR.npy")

ytrain = np.exp(-(ytrain[:,1])/60) * ytrain[:,0]
ytest = np.exp(-(ytest[:,1])/60) * ytest[:,0]

inp = Input(shape=(xtrain.shape[1],))
d1 = Dense(6, activation='relu')(inp)
d2 = Dense(6, activation='linear')(inp)
d3 = Dense(6, activation='sigmoid')(inp)
added = Concatenate()([d1, d2, d3])
d21 = Dense(3, activation='relu')(added)
d22 = Dense(3, activation='linear')(added)
d23 = Dense(3, activation='sigmoid')(added)
added2 = Concatenate()([d21, d22, d23])
d31 = Dense(3, activation='relu')(added2)
d32 = Dense(3, activation='linear')(added2)
d33 = Dense(3, activation='sigmoid')(added2)
added3 = Concatenate()([d31, d32, d33])
out = Dense(1)(added3)
model = Model(inputs=inp, outputs=out)
model.compile(optimizer='adam', loss='mse', metrics=['mae'])

#model.load_weights("F:/Home/Projects/MailRuAI/Codeball2018/nnet/WEXP_B36F.dat")

for x in range(0, 10):
	lostTR, maeTR = model.evaluate(xtrain, ytrain, verbose=0)
	print("Train mae: " + repr(lostTR) + ", " + repr(maeTR))
	lostTS, maeTS = model.evaluate(xtest, ytest, verbose=0)
	print("Test mae:  " + repr(lostTS) + ", " + repr(maeTS))
	while True:

		model.fit(xtrain, ytrain, epochs=1, batch_size=1, verbose=2)

		print("Aim: " + repr(lostTS))
		lostTR2, maeTR2 = model.evaluate(xtrain, ytrain, verbose=0)
		print("Train mae: " + repr(lostTR2) + ", " + repr(maeTR2))
		lostTS2, maeTS2 = model.evaluate(xtest, ytest, verbose=0)
		print("Test mae:  " + repr(lostTS2) + ", " + repr(maeTS2))
		print("Improve number: " + repr(x))
		print(datetime.datetime.now())
		if lostTS > lostTS2:
			print ("imporoved")
			model.save_weights("F:/Home/Projects/MailRuAI/Codeball2018/nnet/WEXP_B36F.dat")
			break

Почему именно 3 внутренних слоя и именно такой конфигурации не спрашивайте — сам не знаю. Однако интуиция и дни опытов привели именно к ней.

И наконец последний вопрос, как использовать уже обученную на Python нейросеть в C# не имея никаких готовых классов? Создать класс! При такой простой конфигурации нейросети и учитывая что нам требуется реализация лишь функции «predict» (т.е. просто прогонка от входа к выходу) это довольно просто. Вот она:

public enum Activation { relu, linear, sigmoid };
    public class layer
    {
        public int Count = 0;
        public List<List<double>> weights = new List<List<double>>();
        public List<double> Ps = new List<double>();
        public List<Activation> funcs = new List<Activation>();
        public List<double> Values = new List<double>();
        public void Add(Activation aact)
        {
            Count++;
            weights.Add(new List<double>());
            Ps.Add(0);
            funcs.Add(aact);
            Values.Add(0);
        }

        public void Add(Activation aact, int acnt)
        {
            for (int i = 0; i < acnt; i++)
                Add(aact);
        }

        public void Calculate(List<double> ainps)
        {
            for (int i = 0; i < Count; i++)
            {
                Values[i] = Ps[i];
                for (int j = 0; j < ainps.Count; j++)
                    Values[i] += weights[i][j] * ainps[j];

                switch (funcs[i])
                {
                    case Activation.linear:
                        break;
                    case Activation.relu:
                        Values[i] = System.Math.Max(0, Values[i]);
                        break;
                    case Activation.sigmoid:
                        Values[i] = (double)(1.0 / (1.0 + System.Math.Exp(-Values[i])));
                        break;
                }
            }
        }
    }

    public class nnet
    {
        public int inputCount = 0;
        public List<layer> layers = new List<layer>();
        public layer outputLayer = null;
        public nnet(int ainputcount, int aoutputcount)
        {
            inputCount = ainputcount;
            outputLayer = new layer();
            outputLayer.Add(Activation.linear, aoutputcount);
        }

        public List<double> predict(List<double> ainput)
        {
            for (int i = 0; i < layers.Count + 1; i++)
            {
                List<double> inps = ainput;
                if (i > 0)
                    inps = layers[i - 1].Values;
                layer lr = outputLayer;
                if (i < layers.Count)
                    lr = layers[i];

                lr.Calculate(inps);
            }
            return outputLayer.Values;
        }
    }

Остается только подтянуть веса из обученной сети (кстати, веса привожу здесь реально работающие в моей последней версии):

    public class trained_nnet : nnet
    {
        void FillLayer(layer al, double[] atp, double[,] atw)
        {
            al.Ps.Clear();
            al.Ps.AddRange(atp);

            al.weights.Clear();
            for (int i = 0; i < atw.GetLength(0); i++)
            {
                al.weights.Add(new List<double>());
                for (int j = 0; j < atw.GetLength(1); j++)
                {
                    al.weights[i].Add(atw[i, j]);
                }
            }
        }

        public trained_nnet() : base(6, 1)
        {
            layer lr1 = new layer();
            lr1.Add(Activation.relu, 6);
            lr1.Add(Activation.linear, 6);
            lr1.Add(Activation.sigmoid, 6);
            base.layers.Add(lr1);

            layer lr2 = new layer();
            lr2.Add(Activation.relu, 3);
            lr2.Add(Activation.linear, 3);
            lr2.Add(Activation.sigmoid, 3);
            base.layers.Add(lr2);

            layer lr3 = new layer();
            lr3.Add(Activation.relu, 3);
            lr3.Add(Activation.linear, 3);
            lr3.Add(Activation.sigmoid, 3);
            base.layers.Add(lr3);

            double[] t = { 3.6843767166137695, -9.454026222229004, -5.089229106903076, -2.850287437438965, -6.96286153793335, -9.751116752624512, 10.384811401367188, -4.214056968688965, 1.2072025537490845, 1.4019242525100708, -0.13174889981746674, -13.1264066696167, -4.265004634857178, 1.8926845788955688, -0.0813497006893158, -1.4616785049438477, -5.361510753631592, -1.1896661520004272 };
            double[,] t2 = { { 0.1477939784526825, 0.03613739833235741, -0.09796690940856934, 1.942456841468811, -0.3508949875831604, -0.5551134347915649 }, { -0.25495094060897827, 0.049018844962120056, -0.15976546704769135, -1.881699562072754, -1.3928385972976685, 0.017490295693278313 }, { 0.314727246761322, -0.7985705733299255, -0.16902890801429749, 0.7290273308753967, -3.3613057136535645, -0.501738965511322 }, { -0.14706645905971527, 0.013889106921851635, -8.41325855255127, 0.08269797265529633, -0.8194255232810974, 0.054869525134563446 }, { -0.11769858002662659, 0.024719441309571266, -32.9736213684082, -0.06565750390291214, -0.38925793766975403, -0.30816638469696045 }, { -0.09536012262105942, -0.4411015212535858, -0.3092011511325836, 0.061532989144325256, -1.3718899488449097, -0.9904148578643799 }, { 0.03862301632761955, -0.2239271104335785, -0.3054073452949524, 0.013336590491235256, -0.0404842384159565, -0.09027290344238281 }, { -0.317527711391449, -0.14433158934116364, 0.06079907342791557, -0.4572157561779022, 0.2782846987247467, 0.17747753858566284 }, { 0.01980031281709671, 0.015361669473350048, -0.03606397658586502, 0.013219496235251427, -0.03483833745121956, -0.01729537360370159 }, { -0.003958317916840315, 0.09587077051401138, -0.08213665336370468, -0.027169639244675636, 0.032037656754255295, -0.030492693185806274 }, { -0.04885690286755562, -0.06349656730890274, 0.013905149884521961, 0.018028201535344124, 0.012719585560262203, 0.002531017642468214 }, { 0.016520477831363678, -0.00018591046682558954, -0.003657651599496603, 0.06888063997030258, -0.2127065807580948, 0.6427022218704224 }, { -0.5308891534805298, 0.13539844751358032, 0.03864796832203865, 1.5582681894302368, -1.929693341255188, -3.2511842250823975 }, { 0.032178860157728195, 1.1472656726837158, -2.020042896270752, -0.05141841620206833, -0.4635908901691437, 0.2636871039867401 }, { 0.01480827759951353, 0.33971744775772095, -0.15343432128429413, 0.03558071702718735, 3.364596366882324, -0.7852638959884644 }, { 0.0028303645085543394, 1.2297841310501099, -0.4412313997745514, 0.3644706606864929, 2.2155861854553223, -0.43303439021110535 }, { -0.3666411340236664, 0.0464097335934639, 5.143652439117432, -2.2230076789855957, 0.3511424660682678, 1.0514445304870605 }, { 0.014482858590781689, -0.4740144610404968, -1.6240901947021484, 1.7327706813812256, -1.5116417407989502, -1.6811648607254028 } };
            double[] t3 = { -3.09689998626709, -1.2031112909317017, -7.121585369110107, 2.0653932094573975, -2.8601508140563965, -1.6219528913497925, 0.16301754117012024, -6.890131950378418, 3.8225107192993164 };
            double[,] t4 = { { -0.6246452927589417, -0.3575346767902374, 0.6897052526473999, -2.2513232231140137, -0.23217444121837616, 0.17847181856632233, -0.3863859176635742, -0.01201619766652584, 0.050539981573820114, 0.028343766927719116, 0.0034856200218200684, 0.5547005534172058, -0.4277774691581726, -1.0249099731445312, -8.995088577270508, -3.4937169551849365, 0.7673622369766235, -1.6504380702972412 }, { -1.0006977319717407, -0.8660659790039062, -0.0415676049888134, -0.5476861000061035, -0.7828258872032166, -0.05350146442651749, 0.005586389917880297, -0.052493464201688766, 0.07955628633499146, -0.08084911853075027, 0.09794406592845917, -0.031214063987135887, -0.7785998582839966, -0.27977627515792847, -0.4096711277961731, -0.24633635580539703, -1.5932326316833496, -0.5430923104286194 }, { -0.2330777496099472, -0.07477551698684692, -1.0634428262710571, -1.772096872329712, -1.4657013416290283, 0.6256936192512512, -0.1179097518324852, 0.07645376771688461, 0.008837736211717129, 0.030952733010053635, -0.013960030861198902, 1.0339184999465942, 0.20350944995880127, -0.047291483730077744, -4.043337345123291, -0.7629795670509338, -5.41167688369751, -3.7755305767059326 }, { 0.00979659240692854, 0.11435728520154953, -0.4749748706817627, 1.5166815519332886, -5.3047380447387695, 0.9597445130348206, 0.08123911172151566, 0.039479970932006836, -0.01649349369108677, -0.04941410943865776, 0.020120851695537567, -0.16329358518123627, 0.36106961965560913, 0.5348165035247803, 0.11825983971357346, 0.2075480818748474, -1.8661850690841675, 1.4093444347381592 }, { -0.35534173250198364, 0.3471201956272125, -0.2657061517238617, -2.4178225994110107, -3.890836238861084, 0.5999298691749573, -0.10068143904209137, 0.530009388923645, 0.023632165044546127, -0.006245455238968134, 0.031124670058488846, 0.016797777265310287, 1.720144510269165, -0.3200121223926544, 0.17827671766281128, -1.0847045183181763, 0.7679504156112671, 1.1521148681640625 }, { 0.047243088483810425, -0.07313758134841919, -0.13496115803718567, -1.0498348474502563, -2.083388328552246, 0.3018227815628052, 0.019016921520233154, 0.00780009850859642, -0.02416112646460533, -0.012299800291657448, 0.019720694050192833, 0.019809948280453682, -1.637327790260315, 0.09307140856981277, 2.963168144226074, 0.515803337097168, 0.02399904653429985, -3.9851980209350586 }, { -0.6250298023223877, -0.4796958863735199, 0.4311320185661316, -1.4590528011322021, -4.861763000488281, -1.1894060373306274, 0.31154727935791016, -0.028901753947138786, 0.07241783291101456, 0.0573900043964386, -0.16387903690338135, -0.7621306777000427, 2.864539623260498, 1.126343011856079, -0.729159414768219, 15.2516450881958, -0.5845442414283752, -0.2593745291233063 }, { -0.4520488679409027, -0.37348034977912903, -0.22873088717460632, 2.816544532775879, 0.635391891002655, 1.7192658185958862, -0.042334891855716705, -0.012391769327223301, -0.00944773480296135, -0.047271229326725006, 0.045244403183460236, 1.1044175624847412, -2.682516098022461, -1.797003984451294, -5.227936744689941, 0.3994572162628174, -3.361297130584717, -0.16535422205924988 }, { 1.3437395095825195, 0.05596136301755905, -0.6534030437469482, -3.2173333168029785, -3.256056785583496, 3.164973020553589, -0.6149216294288635, 0.3425371050834656, -0.13111716508865356, -0.42127469182014465, -0.0668950155377388, 0.19484268128871918, 2.005012273788452, -3.41219425201416, -0.3146309554576874, -2.1181774139404297, 2.2965285778045654, 5.287317276000977 } };
            double[] t5 = { -1.173705816268921, -1.8888208866119385, -2.566594123840332, 0.1278465986251831, 0.05948356166481972, -0.021375492215156555, -1.554726243019104, -2.2256762981414795, 1.3142614364624023 };
            double[,] t6 = { { -0.023421021178364754, 0.17735084891319275, -0.1922600418329239, -0.11634820699691772, 0.05003879591822624, 0.07409390062093735, -0.131203755736351, -0.11743484437465668, -1.1311017274856567 }, { -0.6256148219108582, -0.08678799867630005, 0.08910120278596878, -0.06354714930057526, 0.05225379019975662, 0.028936704620718956, -2.069547176361084, 0.16652414202690125, 0.4840211570262909 }, { -0.9266191720962524, 0.1542767435312271, -1.511458396911621, -2.2593629360198975, 0.32768234610557556, 0.728438138961792, 1.4113644361495972, -2.9423279762268066, -1.1225157976150513 }, { -0.31864309310913086, -0.06739992648363113, 1.8643943071365356, 0.12609687447547913, 0.003282073885202408, -0.08565603941679001, 0.22951357066631317, -3.9096572399139404, -0.5148558020591736 }, { 0.0030701414216309786, 0.22653144598007202, -0.1772366166114807, 0.01472154725342989, 0.006688127294182777, 0.029435427859425545, -0.049562305212020874, -0.01126908790320158, -0.09357477724552155 }, { -0.003160204039886594, 0.004133348818868399, 0.003914407920092344, 0.013578329235315323, 0.0036796496715396643, 0.028364477679133415, 0.025828130543231964, -0.030584659427404404, -0.0449080727994442 }, { -0.15649960935115814, 0.7045242786407471, 4.971825122833252, 0.26150253415107727, 0.25615766644477844, -0.007457265630364418, 0.4002840220928192, -4.386100769042969, -0.14405106008052826 }, { -1.283564805984497, -1.0451316833496094, -9.010445594787598, -0.23629669845104218, 0.8792487978935242, 0.12951965630054474, 2.7414908409118652, -10.04093074798584, 0.08805646747350693 }, { 0.5142691731452942, 0.27933982014656067, 17.242839813232422, -0.14753387868404388, 0.35601550340652466, -0.03304799646139145, -0.3745580017566681, 3.6696081161499023, 0.18306805193424225 } };
            double[] t7 = { 0.057645831257104874 };
            double[,] t8 = { { 0.02502649463713169, 0.030625218525528908, -0.04921620339155197, -0.06382419914007187, -0.0018273837631568313, -0.002946096006780863, -0.3073849678039551, -0.0770358145236969, 0.44145819544792175 } };

            FillLayer(lr1, t, t2);
            FillLayer(lr2, t3, t4);
            FillLayer(lr3, t5, t6);
            FillLayer(outputLayer, t7, t8);

        }
    }

Вызов нейросети:

        public double StateRatingByNNet()
        {
            double result = 0;
            List<double> xdata = new List<double>();
            double sign = 1;
            if (ball.velocity.Z < 0)
                sign = -1;

            double signX = 1;
            if (ball.velocity.X * sign < 0)
                signX = -1;

            xdata.Add(ball.velocity.X * sign * signX);
            xdata.Add(ball.velocity.Y);
            xdata.Add(ball.velocity.Z * sign);
            xdata.Add(ball.position.X * sign * signX);
            xdata.Add(ball.position.Y);
            xdata.Add(ball.position.Z * sign);
            
            List<double> o = nnet.predict(xdata);
            return result + o[0] * sign;
        }

Спасибо за интерес!
Tags:
Hubs:
+23
Comments 17
Comments Comments 17

Articles