diff --git a/RandomBird/FlappyAgent.py b/RandomBird/FlappyAgent.py deleted file mode 100644 index 9f3ec84..0000000 --- a/RandomBird/FlappyAgent.py +++ /dev/null @@ -1,9 +0,0 @@ -import numpy as np - -def FlappyPolicy(state, screen): - action=None - if(np.random.randint(0,2)<1): - action=119 - return action - - diff --git a/Rocamora-Ardevol/FlappyAgent.py b/Rocamora-Ardevol/FlappyAgent.py new file mode 100644 index 0000000..b4fe2c3 --- /dev/null +++ b/Rocamora-Ardevol/FlappyAgent.py @@ -0,0 +1,60 @@ +import numpy as np + +Qlearning = dict() +Qsarsa = dict() + +def FlappyPolicy(state, screen): + """ + Returns an action for each timestep depending on the game state: + 'None' for doing nothing; + '119' for jumping + """ + action = actTDLambda(state) + + return( 119 if action else 0 ) + + + +def actQlearning(state): + + global Qlearning + + if not bool(Qlearning): + Qlearning = np.load('Qlearning.npy') + + s1, s2, s3 = toDiscreteRef(state) + key = str(s1)+'|'+str(s2)+'|'+str(s3) + + if Qlearning[()].get(key) == None: + return 0 + + return Qlearning[()][key][0] < Qlearning[()][key][1] + + +def actTDLambda(state): + + global Qsarsa + + if not bool(Qsarsa): + Qsarsa = np.load('Qsarsa.npy') + + s1, s2, s3 = toDiscreteRef(state) + key = str(s1)+'|'+str(s2)+'|'+str(s3) + + if Qsarsa[()].get(key) == None: + return 0 + + return Qsarsa[()][key][0] < Qsarsa[()][key][1] + + +def toDiscreteRef(state): + """ + Converts the game state variables into the custom discrete variable state for the + Q-learning approach. + """ + + s1 = state['next_pipe_bottom_y'] - state['player_y'] + s2 = state['next_pipe_dist_to_player'] + s3 = state['player_vel'] + + return int(s1-s1%10), int(s2-s2%20), int(s3-s3%2) diff --git a/Rocamora-Ardevol/Qlearning.npy b/Rocamora-Ardevol/Qlearning.npy new file mode 100644 index 0000000..b2db723 Binary files /dev/null and b/Rocamora-Ardevol/Qlearning.npy differ diff --git a/Rocamora-Ardevol/Qsarsa.npy b/Rocamora-Ardevol/Qsarsa.npy new file mode 100644 index 0000000..b2db723 Binary files /dev/null and b/Rocamora-Ardevol/Qsarsa.npy differ diff --git a/Rocamora-Ardevol/README.md b/Rocamora-Ardevol/README.md new file mode 100644 index 0000000..24b6ea8 --- /dev/null +++ b/Rocamora-Ardevol/README.md @@ -0,0 +1,19 @@ +# Implementations +## Q-learning approach with temporal differences TD(0) +This algorithm approximates the Q matrix associated to the optimal policy by choosing actions greedily and updating it using the best next Q value independently of the taken policy (off-policy). + +## SARSA approach with temporal differences TD($\lambda$) +The SARSA algorithm inferes the value of the problem's optimal policy's Q matrix by chosing actions greedily and updating the value of the Q on the evaluated policy. + +This allows for using the TD($\lambda$) value estimator, which allows for a much faster propagation of the information and thus a faster convergence. + +## Q-learning using a neural network as an approximating function +It inferes the Q matrix values through a neural network. This implementation uses a memory replay. + +**This case does not work at the moment** due to a lack of an appropriate hyperparameter tunning. Once properly working it will be naturably extensible to use all the PLE's state variables. + +# Acknowledgements +The theoretical foundations for this work are based on Emmanuel Rachelson's course on machine learning + +Some implementation details and hyperparameters based on the work of: +https://github.com/chncyhn/flappybird-qlearning-bot \ No newline at end of file diff --git a/Rocamora-Ardevol/Train/Neural network training.ipynb b/Rocamora-Ardevol/Train/Neural network training.ipynb new file mode 100644 index 0000000..32f9f17 --- /dev/null +++ b/Rocamora-Ardevol/Train/Neural network training.ipynb @@ -0,0 +1,287 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ple.games.flappybird import FlappyBird\n", + "from ple import PLE\n", + "\n", + "import numpy as np\n", + "#from FlappyAgent import FlappyPolicy\n", + "\n", + "import matplotlib.pyplot as plt\n", + "#%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from keras.models import Sequential\n", + "from keras.layers.core import Dense, Dropout, Activation\n", + "from keras.optimizers import RMSprop, sgd, Adam\n", + "from keras.layers.recurrent import LSTM\n", + "import numpy as np\n", + "import random\n", + "import h5py\n", + "from IPython.display import clear_output\n", + "from collections import deque" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "game = FlappyBird(graphics=\"fixed\") # use \"fancy\" for full background, random bird color and random pipe color, use \"fixed\" (default) for black background and constant bird and pipe colors.\n", + "p = PLE(game, fps=30, frame_skip=1, num_steps=1, force_fps=True, display_screen=True)\n", + "# Note: if you want to see you agent act in real time, set force_fps to False. But don't use this setting for learning, just for display purposes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Declare functions\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def convstate(state):\n", + " \"\"\"\n", + " Calculate new state variables from game state\n", + " \"\"\"\n", + " s = np.zeros((3))\n", + " s[0] = state['next_pipe_bottom_y'] - state['player_y']\n", + " s[1] = state['next_pipe_dist_to_player']\n", + " s[2] = state['player_vel']\n", + " \n", + " s[0] = (s[0] - (210 - 40)/2) / ((210 + 40)/2)\n", + " s[1] = (s[1] - (420 - 420)/2) / ((420 + 420)/2) \n", + " s[2] = (s[2] - (10 - 10)/2) / ((10 + 10)/2)\n", + " \n", + " return s.reshape((1,3))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def epsilon_greedy(s):\n", + " \n", + " if(np.random.rand()<=epsilon): # random action\n", + " return np.random.choice([0,1], p=[0.9,.1])\n", + " \n", + " else: \n", + " qval = model.predict(s)\n", + " return np.argmax(qval)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ReplayMemory:\n", + " \"\"\"\n", + " self.memory contains the old state, the action, the reward, the new state and wether it is a final status, \n", + " concatenated in an array.\n", + " \"\"\"\n", + " def __init__ (self, size):\n", + " self.size = size\n", + " self.index = 0\n", + " self.currentsize = 0\n", + " self.memory = np.zeros((size,9))\n", + " \n", + " def insert (self, state):\n", + " if self.currentsize < self.size:\n", + " self.currentsize += 1\n", + " self.memory[self.index,:] = state[:]\n", + " self.index += 1\n", + " self.index = self.index % self.size\n", + " \n", + " def sample (self, batchSize):\n", + " batchSize = min(self.currentsize, batchSize)\n", + " ind = np.random.choice(self.currentsize, size=batchSize, replace=False)\n", + " return self.memory[ind,:]\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Declare model\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = Sequential()\n", + "\n", + "model.add(Dense(100, kernel_initializer='lecun_uniform', input_shape=(3,)))\n", + "model.add(Activation('relu'))\n", + "#model.add(Dropout(0.5)) \n", + "model.add(Dense(100, kernel_initializer='lecun_uniform'))\n", + "model.add(Activation('relu'))\n", + "#model.add(Dropout(0.5))\n", + "model.add(Dense(2, kernel_initializer='lecun_uniform'))\n", + "model.add(Activation('linear'))\n", + "#model.compile(loss='mse', optimizer=\"rmsprop\")\n", + "adam = Adam(lr=1e-2)\n", + "model.compile(loss='mse', optimizer=adam)\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Hyperparameters\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nb_games = 1000\n", + "gamma = .99 # discount factor\n", + "epsilon = .1 # epsilon-greddy\n", + "batchSize = 32\n", + "replay = ReplayMemory(10000)\n", + "replay_pos = ReplayMemory(10000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Train network\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Some control variables\n", + "cumulated = np.zeros((nb_games))\n", + "\n", + "# Start the game\n", + "p.init()\n", + "r = 0\n", + "step = 0\n", + "\n", + "for i in range(nb_games):\n", + " p.reset_game()\n", + " \n", + " # Control print\n", + " if i%100 == 0:\n", + " print(i, epsilon, np.mean(cumulated[i-50:i]))\n", + " \n", + " # Decrease exploration ratio\n", + " epsilon *= 0.98\n", + " \n", + " # 0) Retrieve initial state\n", + " \n", + " s = convstate(game.getGameState())\n", + " \n", + " while(not p.game_over()):\n", + " \n", + " # 1) Choose action greedily\n", + " a = epsilon_greedy(s)\n", + " action = 119 if a else None\n", + " \n", + " # Execute \n", + " r = p.act(action)\n", + " cumulated[i] += r\n", + " \n", + " clipped_r = max( min( r, 1 ), -1 ) # Clip the reward values\n", + " \n", + " ss = convstate(game.getGameState())\n", + "\n", + " replay.insert(np.concatenate((s,[[a]],[[r]],ss,[[p.game_over()]]),axis=1))\n", + " \n", + " # 2) Update Q \n", + " \n", + " if step > 1000: # and step % 100 == 99:\n", + " \n", + " train_x = np.zeros((batchSize,3))\n", + " train_y = np.zeros((batchSize,2))\n", + " for idx,entry in enumerate(replay.sample(batchSize)):\n", + " currentS = entry[0:3].copy().reshape(1,3)\n", + " nextS = entry[5:8].copy().reshape(1,3)\n", + " act = entry[3]\n", + " rew = entry[4]\n", + " ending = entry[8]\n", + "\n", + " currentQ = model.predict(currentS)\n", + " nextQmax = np.max(model.predict(nextS))\n", + " currentQ[0][a] = rew + gamma * nextQmax * (1-ending)\n", + "\n", + " train_x[idx,:] = currentS[0,:]\n", + " train_y[idx,:] = currentQ[0,:]\n", + "\n", + " model.fit(train_x, train_y, batch_size=1, nb_epoch=1, verbose=0)\n", + " \n", + " \n", + " # 3) Redeclare state\n", + " s = ss\n", + " \n", + " step += 1\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Rocamora-Ardevol/Train/Q - learning training.ipynb b/Rocamora-Ardevol/Train/Q - learning training.ipynb new file mode 100644 index 0000000..90b70c2 --- /dev/null +++ b/Rocamora-Ardevol/Train/Q - learning training.ipynb @@ -0,0 +1,509 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from ple.games.flappybird import FlappyBird\n", + "from ple import PLE\n", + "\n", + "import numpy as np\n", + "#from FlappyAgent import FlappyPolicy\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "game = FlappyBird(graphics=\"fixed\") # use \"fancy\" for full background, random bird color and random pipe color, use \"fixed\" (default) for black background and constant bird and pipe colors.\n", + "p = PLE(game, fps=30, frame_skip=1, num_steps=1, force_fps=True, display_screen=True)\n", + "# Note: if you want to see you agent act in real time, set force_fps to False. But don't use this setting for learning, just for display purposes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Declare functions\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "def convstate(state):\n", + " \"\"\"\n", + " Calculate new state variables from game state\n", + " \"\"\"\n", + " s1 = state['next_pipe_bottom_y'] - state['player_y']\n", + " s2 = state['next_pipe_dist_to_player']\n", + " s3 = state['player_vel']\n", + " \n", + " return int(s1-s1%10), int(s2-s2%20), int(s3-s3%2)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def epsilon_greedy(key):\n", + " if(np.random.rand()<=epsilon): # random action\n", + " return np.random.choice([0,1], p =[0.8,0.2])\n", + " \n", + " else: \n", + " return np.argmax(Q.get(key, [0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Reinit variables\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "Q = dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# Metaparameters\n", + "nb_games = 20000\n", + "alpha = 0.1 #0.7\n", + "epsilon = 0.1 #0.4\n", + "gamma = 0.9" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Run training\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/sergi/anaconda3/envs/flappy/lib/python3.6/site-packages/numpy/core/fromnumeric.py:2957: RuntimeWarning: Mean of empty slice.\n", + " out=out, **kwargs)\n", + "/home/sergi/anaconda3/envs/flappy/lib/python3.6/site-packages/numpy/core/_methods.py:80: RuntimeWarning: invalid value encountered in double_scalars\n", + " ret = ret.dtype.type(ret / rcount)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0.1 0.1 nan\n", + "100 0.08000000000000002 0.0995 -5.0\n", + "200 0.06400000000000002 0.09900250000000001 -4.98\n", + "300 0.051200000000000016 0.0985074875 -4.98\n", + "400 0.04096000000000002 0.09801495006250001 -4.74\n", + "500 0.03276800000000001 0.09752487531218751 -4.74\n", + "600 0.026214400000000013 0.09703725093562657 -4.38\n", + "700 0.02097152000000001 0.09655206468094843 -3.96\n", + "800 0.016777216000000008 0.09606930435754368 -3.68\n", + "900 0.013421772800000007 0.09558895783575597 -3.56\n", + "1000 0.010737418240000006 0.09511101304657718 -3.96\n", + "1100 0.01 0.09463545798134429 -3.18\n", + "1200 0.01 0.09416228069143756 -1.8\n", + "1300 0.01 0.09369146928798038 -2.34\n", + "1400 0.01 0.09322301194154048 -2.46\n", + "1500 0.01 0.09275689688183278 -3.0\n", + "1600 0.01 0.09229311239742362 -2.14\n", + "1700 0.01 0.0918316468354365 -1.06\n", + "1800 0.01 0.09137248860125932 -1.2\n", + "1900 0.01 0.09091562615825302 -0.32\n", + "2000 0.01 0.09046104802746176 0.32\n", + "2100 0.01 0.09000874278732444 -1.48\n", + "2200 0.01 0.08955869907338782 -1.02\n", + "2300 0.01 0.08911090557802087 0.08\n", + "2400 0.01 0.08866535105013076 -1.58\n", + "2500 0.01 0.08822202429488012 -1.26\n", + "2600 0.01 0.08778091417340572 -1.78\n", + "2700 0.01 0.0873420096025387 -0.24\n", + "2800 0.01 0.086905299554526 -1.12\n", + "2900 0.01 0.08647077305675337 3.28\n", + "3000 0.01 0.0860384191914696 -0.2\n", + "3100 0.01 0.08560822709551225 -1.34\n", + "3200 0.01 0.08518018596003468 -1.22\n", + "3300 0.01 0.08475428503023451 0.66\n", + "3400 0.01 0.08433051360508334 -1.0\n", + "3500 0.01 0.08390886103705793 1.14\n", + "3600 0.01 0.08348931673187264 1.78\n", + "3700 0.01 0.08307187014821328 1.62\n", + "3800 0.01 0.08265651079747222 -1.7\n", + "3900 0.01 0.08224322824348486 2.52\n", + "4000 0.01 0.08183201210226744 1.9\n", + "4100 0.01 0.0814228520417561 2.92\n", + "4200 0.01 0.08101573778154732 0.36\n", + "4300 0.01 0.08061065909263958 0.86\n", + "4400 0.01 0.08020760579717638 1.24\n", + "4500 0.01 0.0798065677681905 2.14\n", + "4600 0.01 0.07940753492934956 1.8\n", + "4700 0.01 0.07901049725470281 0.48\n", + "4800 0.01 0.0786154447684293 1.98\n", + "4900 0.01 0.07822236754458715 4.54\n", + "5000 0.01 0.07783125570686422 4.3\n", + "5100 0.01 0.0774420994283299 6.76\n", + "5200 0.01 0.07705488893118825 6.36\n", + "5300 0.01 0.07666961448653231 11.02\n", + "5400 0.01 0.07628626641409965 4.94\n", + "5500 0.01 0.07590483508202915 1.24\n", + "5600 0.01 0.075525310906619 2.76\n", + "5700 0.01 0.07514768435208591 12.38\n", + "5800 0.01 0.07477194593032548 3.54\n", + "5900 0.01 0.07439808620067385 2.68\n", + "6000 0.01 0.07402609576967048 1.32\n", + "6100 0.01 0.07365596529082213 6.16\n", + "6200 0.01 0.07328768546436802 8.14\n", + "6300 0.01 0.07292124703704618 7.34\n", + "6400 0.01 0.07255664080186094 13.44\n", + "6500 0.01 0.07219385759785164 8.78\n", + "6600 0.01 0.07183288830986238 13.3\n", + "6700 0.01 0.07147372386831308 11.4\n", + "6800 0.01 0.07111635524897152 4.04\n", + "6900 0.01 0.07076077347272666 11.98\n", + "7000 0.01 0.07040696960536302 4.18\n", + "7100 0.01 0.0700549347573362 8.18\n", + "7200 0.01 0.06970466008354953 5.26\n", + "7300 0.01 0.06935613678313178 2.72\n", + "7400 0.01 0.06900935609921612 4.12\n", + "7500 0.01 0.06866430931872003 4.38\n", + "7600 0.01 0.06832098777212643 3.74\n", + "7700 0.01 0.0679793828332658 2.12\n", + "7800 0.01 0.06763948591909946 4.58\n", + "7900 0.01 0.06730128848950397 -0.16\n", + "8000 0.01 0.06696478204705644 1.24\n", + "8100 0.01 0.06662995813682115 0.3\n", + "8200 0.01 0.06629680834613705 0.44\n", + "8300 0.01 0.06596532430440637 4.76\n", + "8400 0.01 0.06563549768288433 4.16\n", + "8500 0.01 0.06530732019446991 1.76\n", + "8600 0.01 0.06498078359349757 0.92\n", + "8700 0.01 0.06465587967553008 1.72\n", + "8800 0.01 0.06433260027715243 0.74\n", + "8900 0.01 0.06401093727576666 -1.52\n", + "9000 0.01 0.06369088258938783 -0.48\n", + "9100 0.01 0.0633724281764409 0.08\n", + "9200 0.01 0.06305556603555869 -0.16\n", + "9300 0.01 0.0627402882053809 0.06\n", + "9400 0.01 0.062426586764353996 0.68\n", + "9500 0.01 0.062114453830532226 0.14\n", + "9600 0.01 0.061803881561379566 -0.58\n", + "9700 0.01 0.061494862153572666 -1.6\n", + "9800 0.01 0.0611873878428048 -2.52\n", + "9900 0.01 0.060881450903590775 -2.3\n", + "10000 0.01 0.06057704364907282 -2.28\n", + "10100 0.01 0.06027415843082746 -1.84\n", + "10200 0.01 0.05997278763867332 -2.0\n", + "10300 0.01 0.059672923700479955 2.76\n", + "10400 0.01 0.05937455908197756 6.82\n", + "10500 0.01 0.05907768628656767 7.08\n", + "10600 0.01 0.05878229785513483 6.36\n", + "10700 0.01 0.05848838636585915 5.0\n", + "10800 0.01 0.05819594443402985 2.8\n", + "10900 0.01 0.057904964711859706 3.84\n", + "11000 0.01 0.05761543988830041 4.72\n", + "11100 0.01 0.05732736268885891 2.08\n", + "11200 0.01 0.05704072587541461 3.4\n", + "11300 0.01 0.05675552224603753 3.5\n", + "11400 0.01 0.05647174463480734 2.12\n", + "11500 0.01 0.056189385911633305 3.66\n", + "11600 0.01 0.05590843898207514 4.8\n", + "11700 0.01 0.05562889678716477 0.42\n", + "11800 0.01 0.055350752303228945 0.44\n", + "11900 0.01 0.0550739985417128 -0.38\n", + "12000 0.01 0.05479862854900423 0.88\n", + "12100 0.01 0.05452463540625921 0.96\n", + "12200 0.01 0.05425201222922791 -2.0\n", + "12300 0.01 0.05398075216808177 -1.76\n", + "12400 0.01 0.053710848407241364 -1.12\n", + "12500 0.01 0.053442294165205156 -1.22\n", + "12600 0.01 0.05317508269437913 -1.18\n", + "12700 0.01 0.052909207280907235 -1.36\n", + "12800 0.01 0.0526446612445027 -2.22\n", + "12900 0.01 0.052381437938280186 -0.08\n", + "13000 0.01 0.052119530748588785 -1.62\n", + "13100 0.01 0.05185893309484584 -1.6\n", + "13200 0.01 0.05159963842937161 0.5\n", + "13300 0.01 0.051341640237224755 -2.12\n", + "13400 0.01 0.05108493203603863 -0.76\n", + "13500 0.01 0.05082950737585844 1.64\n", + "13600 0.01 0.05057535983897914 -0.1\n", + "13700 0.01 0.050322483039784247 -0.36\n", + "13800 0.01 0.050070870624585324 0.6\n", + "13900 0.01 0.049820516271462396 1.56\n", + "14000 0.01 0.049571413690105086 -0.06\n", + "14100 0.01 0.04932355662165456 1.68\n", + "14200 0.01 0.04907693883854629 1.76\n", + "14300 0.01 0.048831554144353556 3.94\n", + "14400 0.01 0.04858739637363179 2.64\n", + "14500 0.01 0.04834445939176363 2.0\n", + "14600 0.01 0.04810273709480481 2.92\n", + "14700 0.01 0.04786222340933079 1.08\n", + "14800 0.01 0.04762291229228413 1.3\n", + "14900 0.01 0.04738479773082271 -1.24\n", + "15000 0.01 0.04714787374216859 -0.7\n", + "15100 0.01 0.046912134373457745 -0.92\n", + "15200 0.01 0.04667757370159046 0.3\n", + "15300 0.01 0.046444185833082505 2.16\n", + "15400 0.01 0.046211964903917095 3.62\n", + "15500 0.01 0.04598090507939751 0.12\n", + "15600 0.01 0.045751000554000526 3.7\n", + "15700 0.01 0.04552224555123052 1.66\n", + "15800 0.01 0.04529463432347437 5.44\n", + "15900 0.01 0.045068161151857 5.68\n", + "16000 0.01 0.04484282034609772 5.32\n", + "16100 0.01 0.04461860624436723 2.0\n", + "16200 0.01 0.044395513213145395 1.68\n", + "16300 0.01 0.04417353564707967 1.94\n", + "16400 0.01 0.04395266796884427 2.84\n", + "16500 0.01 0.04373290462900005 4.04\n", + "16600 0.01 0.04351424010585505 3.96\n", + "16700 0.01 0.043296668905325776 6.14\n", + "16800 0.01 0.04308018556079915 1.28\n", + "16900 0.01 0.04286478463299515 3.04\n", + "17000 0.01 0.042650460709830175 1.1\n", + "17100 0.01 0.04243720840628103 0.42\n", + "17200 0.01 0.042225022364249624 3.8\n", + "17300 0.01 0.04201389725242838 3.84\n", + "17400 0.01 0.041803827766166236 3.26\n", + "17500 0.01 0.0415948086273354 2.54\n", + "17600 0.01 0.04138683458419873 0.32\n", + "17700 0.01 0.04117990041127773 0.4\n", + "17800 0.01 0.04097400090922134 1.52\n", + "17900 0.01 0.04076913090467523 0.88\n", + "18000 0.01 0.04056528525015186 5.06\n", + "18100 0.01 0.0403624588239011 7.94\n", + "18200 0.01 0.04016064652978159 6.56\n", + "18300 0.01 0.03995984329713268 3.56\n", + "18400 0.01 0.03976004408064702 1.78\n", + "18500 0.01 0.03956124386024378 3.94\n", + "18600 0.01 0.039363437640942564 -1.24\n", + "18700 0.01 0.03916662045273785 0.94\n", + "18800 0.01 0.03897078735047416 1.18\n", + "18900 0.01 0.03877593341372179 1.64\n", + "19000 0.01 0.03858205374665318 2.14\n", + "19100 0.01 0.03838914347791991 3.66\n", + "19200 0.01 0.038197197760530315 3.16\n", + "19300 0.01 0.038006211771727666 4.28\n", + "19400 0.01 0.037816180712869026 2.98\n", + "19500 0.01 0.03762709980930468 6.58\n", + "19600 0.01 0.03743896431025816 2.66\n", + "19700 0.01 0.037251769488706864 2.82\n", + "19800 0.01 0.03706551064126333 0.46\n", + "19900 0.01 0.03688018308805701 4.18\n" + ] + } + ], + "source": [ + "# Some control variables\n", + "cumulated = np.zeros((nb_games))\n", + "\n", + "# Start the game\n", + "p.init()\n", + "reward = 0\n", + "\n", + "for i in range(nb_games):\n", + " p.reset_game()\n", + " \n", + " # Control print\n", + " if i%100 == 0:\n", + " print(i, epsilon, alpha, np.mean(cumulated[i-50:i]))\n", + " \n", + " # Decrease exploration ratio\n", + " epsilon = max(epsilon*0.8, 0.01)\n", + " alpha *= 0.995\n", + " \n", + " if i%1000 == 999:\n", + " np.save('Q_%d' % i ,Q)\n", + " np.save('cumulated_%d' % i, cumulated)\n", + " \n", + " # 0) Retrieve initial state\n", + " s1, s2, s3 = convstate(game.getGameState())\n", + " current_key = str(s1)+'|'+str(s2)+'|'+str(s3)\n", + " \n", + " while(not p.game_over()):\n", + " \n", + " # 1) Choose action greedily\n", + " a = epsilon_greedy(current_key)\n", + " \n", + " action = None\n", + " if a==1:\n", + " action = 119\n", + " \n", + " \n", + " # Execute\n", + " reward = p.act(action)\n", + " cumulated[i] += reward\n", + " \n", + " ss1, ss2, ss3 = convstate(game.getGameState())\n", + " next_key = str(ss1)+'|'+str(ss2)+'|'+str(ss3)\n", + " \n", + "\n", + " \n", + " # 2) Update Q value\n", + " if Q.get(current_key) == None:\n", + " Q[current_key] = [0,0]\n", + " \n", + " maxQ = max(Q.get(next_key, [0]))\n", + " \n", + " \n", + " Q[current_key][a] = (1-alpha)*Q[current_key][a] + alpha*( reward + gamma*maxQ )\n", + " \n", + " # Update values and map key\n", + " s1 = ss1\n", + " s2 = ss2\n", + " s3 = ss3\n", + " current_key = next_key\n", + " \n", + " \n", + "\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Q" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Postprocess\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mean evolution\n", + "vallist = list()\n", + "\n", + "for idx, val in enumerate(cumulated):\n", + " if idx < 50:\n", + " pass\n", + " else: \n", + " vallist.append(np.mean( cumulated[idx-50:idx] ))\n", + "\n", + "plt.plot(vallist)\n", + "\n", + "vallist = list()\n", + "\n", + "for idx, val in enumerate(cumulated):\n", + " if idx < 500:\n", + " pass\n", + " else: \n", + " vallist.append(np.mean( cumulated[idx-500:idx] ))\n", + "\n", + "plt.plot(vallist)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#c1 = cumulated\n", + "#c2 = cumulated\n", + "c3 = cumulated" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mean evolution\n", + "vallist = list()\n", + "\n", + "result = np.concatenate((c1,c2,c3))\n", + "\n", + "for idx, val in enumerate(result):\n", + " if idx < 50:\n", + " pass\n", + " else: \n", + " vallist.append(np.mean( result[idx-50:idx] ))\n", + "\n", + "plt.plot(vallist)\n", + "\n", + "vallist = list()\n", + "\n", + "for idx, val in enumerate(result):\n", + " if idx < 500:\n", + " pass\n", + " else: \n", + " vallist.append(np.mean( result[idx-500:idx] ))\n", + "\n", + "plt.plot(vallist)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Rocamora-Ardevol/Train/SARSA - TD lambda training.ipynb b/Rocamora-Ardevol/Train/SARSA - TD lambda training.ipynb new file mode 100644 index 0000000..e31e941 --- /dev/null +++ b/Rocamora-Ardevol/Train/SARSA - TD lambda training.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ple.games.flappybird import FlappyBird\n", + "from ple import PLE\n", + "\n", + "import numpy as np\n", + "#from FlappyAgent import FlappyPolicy\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "game = FlappyBird(graphics=\"fixed\") # use \"fancy\" for full background, random bird color and random pipe color, use \"fixed\" (default) for black background and constant bird and pipe colors.\n", + "p = PLE(game, fps=30, frame_skip=1, num_steps=1, force_fps=True, display_screen=True)\n", + "# Note: if you want to see you agent act in real time, set force_fps to False. But don't use this setting for learning, just for display purposes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Declare functions\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def convstate(state):\n", + " \"\"\"\n", + " Calculate new state variables from game state\n", + " \"\"\"\n", + " s1 = state['next_pipe_bottom_y'] - state['player_y']\n", + " s2 = state['next_pipe_dist_to_player']\n", + " s3 = state['player_vel']\n", + " \n", + " return int(s1-s1%10), int(s2-s2%20), int(s3-s3%2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def epsilon_greedy(key):\n", + " if(np.random.rand()<=epsilon): # random action\n", + " return np.random.choice([0,1])\n", + " \n", + " else: \n", + " return np.argmax(Q.get(key, [0]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def update_trace(key,action):\n", + " # Update the trace\n", + " global epsTrace\n", + " epsTrace = { k: list(map(lambda x: x*gamma*lamb, v)) for k,v in epsTrace.items() }\n", + " \n", + " if epsTrace.get(key) == None:\n", + " epsTrace[key] = [0,0]\n", + " \n", + " # Remember the current state\n", + " epsTrace[key][action] = 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def propagate(delta):\n", + " for k,v in epsTrace.items():\n", + " Q[k][0] = Q[k][0] + alpha*epsTrace[k][0]*delta\n", + " Q[k][1] = Q[k][1] + alpha*epsTrace[k][1]*delta" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Reinit variables\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Q = dict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Metaparameters\n", + "nb_games = 18000\n", + "alpha = 0.1 #0.7\n", + "epsilon = 0.1 #0.4\n", + "gamma = 0.9\n", + "lamb = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Run training\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "nbgames = 18000\n", + "\n", + "# Some control variables\n", + "cumulated = np.zeros((nb_games))\n", + "\n", + "# Start the game\n", + "p.init()\n", + "reward = 0\n", + "\n", + "for i in range(nb_games):\n", + " p.reset_game()\n", + " epsTrace = dict()\n", + " \n", + " # Control print\n", + " if i%100 == 0:\n", + " print(i, epsilon, alpha, np.mean(cumulated[i-50:i]))\n", + " \n", + " # Decrease exploration ratio\n", + " epsilon = max(epsilon * 0.95, 0.01)\n", + " alpha *= 0.995\n", + " \n", + " if i%1000 == 999:\n", + " np.save('Qsarsa_more_%d' % i ,Q)\n", + " np.save('cumulated_sarsa_more_%d' % i, cumulated)\n", + " \n", + " # 0) Retrieve initial state \n", + " s1, s2, s3 = convstate(game.getGameState())\n", + " current_key = str(s1)+'|'+str(s2)+'|'+str(s3)\n", + " \n", + " if Q.get(current_key) == None:\n", + " Q[current_key] = [0,0]\n", + " \n", + " # Choose action greedily\n", + " a = epsilon_greedy(current_key)\n", + " \n", + " while(not p.game_over()):\n", + " \n", + " # Translate action\n", + " action = None\n", + " if a==1:\n", + " action = 119\n", + " \n", + " # 1) Execute\n", + " reward = p.act(action)\n", + " cumulated[i] += reward\n", + " \n", + " ss1, ss2, ss3 = convstate(game.getGameState())\n", + " next_key = str(ss1)+'|'+str(ss2)+'|'+str(ss3)\n", + "\n", + " # 2) Choose new action greedily\n", + " aa = epsilon_greedy(next_key)\n", + " \n", + " # 3) Update Q value\n", + " # Update trace\n", + " update_trace(current_key, a)\n", + " \n", + " # Update Q\n", + " if Q.get(next_key) == None:\n", + " Q[next_key] = [0,0]\n", + " \n", + " delta = reward + gamma*Q[next_key][aa] - Q[current_key][a]\n", + " propagate(delta)\n", + " \n", + " # Update values and map key\n", + " s1 = ss1\n", + " s2 = ss2\n", + " s3 = ss3\n", + " a = aa\n", + " current_key = next_key \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "Postprocess\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mean evolution\n", + "vallist = list()\n", + "\n", + "for idx, val in enumerate(cumulated):\n", + " if idx < 50:\n", + " pass\n", + " else: \n", + " vallist.append(np.mean( cumulated[idx-50:idx] ))\n", + "\n", + "plt.plot(vallist)\n", + "\n", + "vallist = list()\n", + "\n", + "for idx, val in enumerate(cumulated):\n", + " if idx < 500:\n", + " pass\n", + " else: \n", + " vallist.append(np.mean( cumulated[idx-500:idx] ))\n", + "\n", + "plt.plot(vallist)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/RandomBird/run.py b/Rocamora-Ardevol/run.py similarity index 100% rename from RandomBird/run.py rename to Rocamora-Ardevol/run.py