diff --git a/RandomBird/FlappyAgent.py b/RandomBird/FlappyAgent.py
deleted file mode 100644
index 9f3ec84..0000000
--- a/RandomBird/FlappyAgent.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import numpy as np
-
-def FlappyPolicy(state, screen):
- action=None
- if(np.random.randint(0,2)<1):
- action=119
- return action
-
-
diff --git a/Rocamora-Ardevol/FlappyAgent.py b/Rocamora-Ardevol/FlappyAgent.py
new file mode 100644
index 0000000..b4fe2c3
--- /dev/null
+++ b/Rocamora-Ardevol/FlappyAgent.py
@@ -0,0 +1,60 @@
+import numpy as np
+
+Qlearning = dict()
+Qsarsa = dict()
+
+def FlappyPolicy(state, screen):
+ """
+ Returns an action for each timestep depending on the game state:
+ 'None' for doing nothing;
+ '119' for jumping
+ """
+ action = actTDLambda(state)
+
+ return( 119 if action else 0 )
+
+
+
+def actQlearning(state):
+
+ global Qlearning
+
+ if not bool(Qlearning):
+ Qlearning = np.load('Qlearning.npy')
+
+ s1, s2, s3 = toDiscreteRef(state)
+ key = str(s1)+'|'+str(s2)+'|'+str(s3)
+
+ if Qlearning[()].get(key) == None:
+ return 0
+
+ return Qlearning[()][key][0] < Qlearning[()][key][1]
+
+
+def actTDLambda(state):
+
+ global Qsarsa
+
+ if not bool(Qsarsa):
+ Qsarsa = np.load('Qsarsa.npy')
+
+ s1, s2, s3 = toDiscreteRef(state)
+ key = str(s1)+'|'+str(s2)+'|'+str(s3)
+
+ if Qsarsa[()].get(key) == None:
+ return 0
+
+ return Qsarsa[()][key][0] < Qsarsa[()][key][1]
+
+
+def toDiscreteRef(state):
+ """
+ Converts the game state variables into the custom discrete variable state for the
+ Q-learning approach.
+ """
+
+ s1 = state['next_pipe_bottom_y'] - state['player_y']
+ s2 = state['next_pipe_dist_to_player']
+ s3 = state['player_vel']
+
+ return int(s1-s1%10), int(s2-s2%20), int(s3-s3%2)
diff --git a/Rocamora-Ardevol/Qlearning.npy b/Rocamora-Ardevol/Qlearning.npy
new file mode 100644
index 0000000..b2db723
Binary files /dev/null and b/Rocamora-Ardevol/Qlearning.npy differ
diff --git a/Rocamora-Ardevol/Qsarsa.npy b/Rocamora-Ardevol/Qsarsa.npy
new file mode 100644
index 0000000..b2db723
Binary files /dev/null and b/Rocamora-Ardevol/Qsarsa.npy differ
diff --git a/Rocamora-Ardevol/README.md b/Rocamora-Ardevol/README.md
new file mode 100644
index 0000000..24b6ea8
--- /dev/null
+++ b/Rocamora-Ardevol/README.md
@@ -0,0 +1,19 @@
+# Implementations
+## Q-learning approach with temporal differences TD(0)
+This algorithm approximates the Q matrix associated to the optimal policy by choosing actions greedily and updating it using the best next Q value independently of the taken policy (off-policy).
+
+## SARSA approach with temporal differences TD($\lambda$)
+The SARSA algorithm inferes the value of the problem's optimal policy's Q matrix by chosing actions greedily and updating the value of the Q on the evaluated policy.
+
+This allows for using the TD($\lambda$) value estimator, which allows for a much faster propagation of the information and thus a faster convergence.
+
+## Q-learning using a neural network as an approximating function
+It inferes the Q matrix values through a neural network. This implementation uses a memory replay.
+
+**This case does not work at the moment** due to a lack of an appropriate hyperparameter tunning. Once properly working it will be naturably extensible to use all the PLE's state variables.
+
+# Acknowledgements
+The theoretical foundations for this work are based on Emmanuel Rachelson's course on machine learning
+
+Some implementation details and hyperparameters based on the work of:
+https://github.com/chncyhn/flappybird-qlearning-bot
\ No newline at end of file
diff --git a/Rocamora-Ardevol/Train/Neural network training.ipynb b/Rocamora-Ardevol/Train/Neural network training.ipynb
new file mode 100644
index 0000000..32f9f17
--- /dev/null
+++ b/Rocamora-Ardevol/Train/Neural network training.ipynb
@@ -0,0 +1,287 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ple.games.flappybird import FlappyBird\n",
+ "from ple import PLE\n",
+ "\n",
+ "import numpy as np\n",
+ "#from FlappyAgent import FlappyPolicy\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "#%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from keras.models import Sequential\n",
+ "from keras.layers.core import Dense, Dropout, Activation\n",
+ "from keras.optimizers import RMSprop, sgd, Adam\n",
+ "from keras.layers.recurrent import LSTM\n",
+ "import numpy as np\n",
+ "import random\n",
+ "import h5py\n",
+ "from IPython.display import clear_output\n",
+ "from collections import deque"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "game = FlappyBird(graphics=\"fixed\") # use \"fancy\" for full background, random bird color and random pipe color, use \"fixed\" (default) for black background and constant bird and pipe colors.\n",
+ "p = PLE(game, fps=30, frame_skip=1, num_steps=1, force_fps=True, display_screen=True)\n",
+ "# Note: if you want to see you agent act in real time, set force_fps to False. But don't use this setting for learning, just for display purposes."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "Declare functions\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def convstate(state):\n",
+ " \"\"\"\n",
+ " Calculate new state variables from game state\n",
+ " \"\"\"\n",
+ " s = np.zeros((3))\n",
+ " s[0] = state['next_pipe_bottom_y'] - state['player_y']\n",
+ " s[1] = state['next_pipe_dist_to_player']\n",
+ " s[2] = state['player_vel']\n",
+ " \n",
+ " s[0] = (s[0] - (210 - 40)/2) / ((210 + 40)/2)\n",
+ " s[1] = (s[1] - (420 - 420)/2) / ((420 + 420)/2) \n",
+ " s[2] = (s[2] - (10 - 10)/2) / ((10 + 10)/2)\n",
+ " \n",
+ " return s.reshape((1,3))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def epsilon_greedy(s):\n",
+ " \n",
+ " if(np.random.rand()<=epsilon): # random action\n",
+ " return np.random.choice([0,1], p=[0.9,.1])\n",
+ " \n",
+ " else: \n",
+ " qval = model.predict(s)\n",
+ " return np.argmax(qval)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ReplayMemory:\n",
+ " \"\"\"\n",
+ " self.memory contains the old state, the action, the reward, the new state and wether it is a final status, \n",
+ " concatenated in an array.\n",
+ " \"\"\"\n",
+ " def __init__ (self, size):\n",
+ " self.size = size\n",
+ " self.index = 0\n",
+ " self.currentsize = 0\n",
+ " self.memory = np.zeros((size,9))\n",
+ " \n",
+ " def insert (self, state):\n",
+ " if self.currentsize < self.size:\n",
+ " self.currentsize += 1\n",
+ " self.memory[self.index,:] = state[:]\n",
+ " self.index += 1\n",
+ " self.index = self.index % self.size\n",
+ " \n",
+ " def sample (self, batchSize):\n",
+ " batchSize = min(self.currentsize, batchSize)\n",
+ " ind = np.random.choice(self.currentsize, size=batchSize, replace=False)\n",
+ " return self.memory[ind,:]\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Declare model\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = Sequential()\n",
+ "\n",
+ "model.add(Dense(100, kernel_initializer='lecun_uniform', input_shape=(3,)))\n",
+ "model.add(Activation('relu'))\n",
+ "#model.add(Dropout(0.5)) \n",
+ "model.add(Dense(100, kernel_initializer='lecun_uniform'))\n",
+ "model.add(Activation('relu'))\n",
+ "#model.add(Dropout(0.5))\n",
+ "model.add(Dense(2, kernel_initializer='lecun_uniform'))\n",
+ "model.add(Activation('linear'))\n",
+ "#model.compile(loss='mse', optimizer=\"rmsprop\")\n",
+ "adam = Adam(lr=1e-2)\n",
+ "model.compile(loss='mse', optimizer=adam)\n",
+ "\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Hyperparameters\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nb_games = 1000\n",
+ "gamma = .99 # discount factor\n",
+ "epsilon = .1 # epsilon-greddy\n",
+ "batchSize = 32\n",
+ "replay = ReplayMemory(10000)\n",
+ "replay_pos = ReplayMemory(10000)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Train network\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "# Some control variables\n",
+ "cumulated = np.zeros((nb_games))\n",
+ "\n",
+ "# Start the game\n",
+ "p.init()\n",
+ "r = 0\n",
+ "step = 0\n",
+ "\n",
+ "for i in range(nb_games):\n",
+ " p.reset_game()\n",
+ " \n",
+ " # Control print\n",
+ " if i%100 == 0:\n",
+ " print(i, epsilon, np.mean(cumulated[i-50:i]))\n",
+ " \n",
+ " # Decrease exploration ratio\n",
+ " epsilon *= 0.98\n",
+ " \n",
+ " # 0) Retrieve initial state\n",
+ " \n",
+ " s = convstate(game.getGameState())\n",
+ " \n",
+ " while(not p.game_over()):\n",
+ " \n",
+ " # 1) Choose action greedily\n",
+ " a = epsilon_greedy(s)\n",
+ " action = 119 if a else None\n",
+ " \n",
+ " # Execute \n",
+ " r = p.act(action)\n",
+ " cumulated[i] += r\n",
+ " \n",
+ " clipped_r = max( min( r, 1 ), -1 ) # Clip the reward values\n",
+ " \n",
+ " ss = convstate(game.getGameState())\n",
+ "\n",
+ " replay.insert(np.concatenate((s,[[a]],[[r]],ss,[[p.game_over()]]),axis=1))\n",
+ " \n",
+ " # 2) Update Q \n",
+ " \n",
+ " if step > 1000: # and step % 100 == 99:\n",
+ " \n",
+ " train_x = np.zeros((batchSize,3))\n",
+ " train_y = np.zeros((batchSize,2))\n",
+ " for idx,entry in enumerate(replay.sample(batchSize)):\n",
+ " currentS = entry[0:3].copy().reshape(1,3)\n",
+ " nextS = entry[5:8].copy().reshape(1,3)\n",
+ " act = entry[3]\n",
+ " rew = entry[4]\n",
+ " ending = entry[8]\n",
+ "\n",
+ " currentQ = model.predict(currentS)\n",
+ " nextQmax = np.max(model.predict(nextS))\n",
+ " currentQ[0][a] = rew + gamma * nextQmax * (1-ending)\n",
+ "\n",
+ " train_x[idx,:] = currentS[0,:]\n",
+ " train_y[idx,:] = currentQ[0,:]\n",
+ "\n",
+ " model.fit(train_x, train_y, batch_size=1, nb_epoch=1, verbose=0)\n",
+ " \n",
+ " \n",
+ " # 3) Redeclare state\n",
+ " s = ss\n",
+ " \n",
+ " step += 1\n",
+ " "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/Rocamora-Ardevol/Train/Q - learning training.ipynb b/Rocamora-Ardevol/Train/Q - learning training.ipynb
new file mode 100644
index 0000000..90b70c2
--- /dev/null
+++ b/Rocamora-Ardevol/Train/Q - learning training.ipynb
@@ -0,0 +1,509 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ple.games.flappybird import FlappyBird\n",
+ "from ple import PLE\n",
+ "\n",
+ "import numpy as np\n",
+ "#from FlappyAgent import FlappyPolicy\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "game = FlappyBird(graphics=\"fixed\") # use \"fancy\" for full background, random bird color and random pipe color, use \"fixed\" (default) for black background and constant bird and pipe colors.\n",
+ "p = PLE(game, fps=30, frame_skip=1, num_steps=1, force_fps=True, display_screen=True)\n",
+ "# Note: if you want to see you agent act in real time, set force_fps to False. But don't use this setting for learning, just for display purposes."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Declare functions\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def convstate(state):\n",
+ " \"\"\"\n",
+ " Calculate new state variables from game state\n",
+ " \"\"\"\n",
+ " s1 = state['next_pipe_bottom_y'] - state['player_y']\n",
+ " s2 = state['next_pipe_dist_to_player']\n",
+ " s3 = state['player_vel']\n",
+ " \n",
+ " return int(s1-s1%10), int(s2-s2%20), int(s3-s3%2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def epsilon_greedy(key):\n",
+ " if(np.random.rand()<=epsilon): # random action\n",
+ " return np.random.choice([0,1], p =[0.8,0.2])\n",
+ " \n",
+ " else: \n",
+ " return np.argmax(Q.get(key, [0]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Reinit variables\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Q = dict()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Metaparameters\n",
+ "nb_games = 20000\n",
+ "alpha = 0.1 #0.7\n",
+ "epsilon = 0.1 #0.4\n",
+ "gamma = 0.9"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Run training\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/sergi/anaconda3/envs/flappy/lib/python3.6/site-packages/numpy/core/fromnumeric.py:2957: RuntimeWarning: Mean of empty slice.\n",
+ " out=out, **kwargs)\n",
+ "/home/sergi/anaconda3/envs/flappy/lib/python3.6/site-packages/numpy/core/_methods.py:80: RuntimeWarning: invalid value encountered in double_scalars\n",
+ " ret = ret.dtype.type(ret / rcount)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0 0.1 0.1 nan\n",
+ "100 0.08000000000000002 0.0995 -5.0\n",
+ "200 0.06400000000000002 0.09900250000000001 -4.98\n",
+ "300 0.051200000000000016 0.0985074875 -4.98\n",
+ "400 0.04096000000000002 0.09801495006250001 -4.74\n",
+ "500 0.03276800000000001 0.09752487531218751 -4.74\n",
+ "600 0.026214400000000013 0.09703725093562657 -4.38\n",
+ "700 0.02097152000000001 0.09655206468094843 -3.96\n",
+ "800 0.016777216000000008 0.09606930435754368 -3.68\n",
+ "900 0.013421772800000007 0.09558895783575597 -3.56\n",
+ "1000 0.010737418240000006 0.09511101304657718 -3.96\n",
+ "1100 0.01 0.09463545798134429 -3.18\n",
+ "1200 0.01 0.09416228069143756 -1.8\n",
+ "1300 0.01 0.09369146928798038 -2.34\n",
+ "1400 0.01 0.09322301194154048 -2.46\n",
+ "1500 0.01 0.09275689688183278 -3.0\n",
+ "1600 0.01 0.09229311239742362 -2.14\n",
+ "1700 0.01 0.0918316468354365 -1.06\n",
+ "1800 0.01 0.09137248860125932 -1.2\n",
+ "1900 0.01 0.09091562615825302 -0.32\n",
+ "2000 0.01 0.09046104802746176 0.32\n",
+ "2100 0.01 0.09000874278732444 -1.48\n",
+ "2200 0.01 0.08955869907338782 -1.02\n",
+ "2300 0.01 0.08911090557802087 0.08\n",
+ "2400 0.01 0.08866535105013076 -1.58\n",
+ "2500 0.01 0.08822202429488012 -1.26\n",
+ "2600 0.01 0.08778091417340572 -1.78\n",
+ "2700 0.01 0.0873420096025387 -0.24\n",
+ "2800 0.01 0.086905299554526 -1.12\n",
+ "2900 0.01 0.08647077305675337 3.28\n",
+ "3000 0.01 0.0860384191914696 -0.2\n",
+ "3100 0.01 0.08560822709551225 -1.34\n",
+ "3200 0.01 0.08518018596003468 -1.22\n",
+ "3300 0.01 0.08475428503023451 0.66\n",
+ "3400 0.01 0.08433051360508334 -1.0\n",
+ "3500 0.01 0.08390886103705793 1.14\n",
+ "3600 0.01 0.08348931673187264 1.78\n",
+ "3700 0.01 0.08307187014821328 1.62\n",
+ "3800 0.01 0.08265651079747222 -1.7\n",
+ "3900 0.01 0.08224322824348486 2.52\n",
+ "4000 0.01 0.08183201210226744 1.9\n",
+ "4100 0.01 0.0814228520417561 2.92\n",
+ "4200 0.01 0.08101573778154732 0.36\n",
+ "4300 0.01 0.08061065909263958 0.86\n",
+ "4400 0.01 0.08020760579717638 1.24\n",
+ "4500 0.01 0.0798065677681905 2.14\n",
+ "4600 0.01 0.07940753492934956 1.8\n",
+ "4700 0.01 0.07901049725470281 0.48\n",
+ "4800 0.01 0.0786154447684293 1.98\n",
+ "4900 0.01 0.07822236754458715 4.54\n",
+ "5000 0.01 0.07783125570686422 4.3\n",
+ "5100 0.01 0.0774420994283299 6.76\n",
+ "5200 0.01 0.07705488893118825 6.36\n",
+ "5300 0.01 0.07666961448653231 11.02\n",
+ "5400 0.01 0.07628626641409965 4.94\n",
+ "5500 0.01 0.07590483508202915 1.24\n",
+ "5600 0.01 0.075525310906619 2.76\n",
+ "5700 0.01 0.07514768435208591 12.38\n",
+ "5800 0.01 0.07477194593032548 3.54\n",
+ "5900 0.01 0.07439808620067385 2.68\n",
+ "6000 0.01 0.07402609576967048 1.32\n",
+ "6100 0.01 0.07365596529082213 6.16\n",
+ "6200 0.01 0.07328768546436802 8.14\n",
+ "6300 0.01 0.07292124703704618 7.34\n",
+ "6400 0.01 0.07255664080186094 13.44\n",
+ "6500 0.01 0.07219385759785164 8.78\n",
+ "6600 0.01 0.07183288830986238 13.3\n",
+ "6700 0.01 0.07147372386831308 11.4\n",
+ "6800 0.01 0.07111635524897152 4.04\n",
+ "6900 0.01 0.07076077347272666 11.98\n",
+ "7000 0.01 0.07040696960536302 4.18\n",
+ "7100 0.01 0.0700549347573362 8.18\n",
+ "7200 0.01 0.06970466008354953 5.26\n",
+ "7300 0.01 0.06935613678313178 2.72\n",
+ "7400 0.01 0.06900935609921612 4.12\n",
+ "7500 0.01 0.06866430931872003 4.38\n",
+ "7600 0.01 0.06832098777212643 3.74\n",
+ "7700 0.01 0.0679793828332658 2.12\n",
+ "7800 0.01 0.06763948591909946 4.58\n",
+ "7900 0.01 0.06730128848950397 -0.16\n",
+ "8000 0.01 0.06696478204705644 1.24\n",
+ "8100 0.01 0.06662995813682115 0.3\n",
+ "8200 0.01 0.06629680834613705 0.44\n",
+ "8300 0.01 0.06596532430440637 4.76\n",
+ "8400 0.01 0.06563549768288433 4.16\n",
+ "8500 0.01 0.06530732019446991 1.76\n",
+ "8600 0.01 0.06498078359349757 0.92\n",
+ "8700 0.01 0.06465587967553008 1.72\n",
+ "8800 0.01 0.06433260027715243 0.74\n",
+ "8900 0.01 0.06401093727576666 -1.52\n",
+ "9000 0.01 0.06369088258938783 -0.48\n",
+ "9100 0.01 0.0633724281764409 0.08\n",
+ "9200 0.01 0.06305556603555869 -0.16\n",
+ "9300 0.01 0.0627402882053809 0.06\n",
+ "9400 0.01 0.062426586764353996 0.68\n",
+ "9500 0.01 0.062114453830532226 0.14\n",
+ "9600 0.01 0.061803881561379566 -0.58\n",
+ "9700 0.01 0.061494862153572666 -1.6\n",
+ "9800 0.01 0.0611873878428048 -2.52\n",
+ "9900 0.01 0.060881450903590775 -2.3\n",
+ "10000 0.01 0.06057704364907282 -2.28\n",
+ "10100 0.01 0.06027415843082746 -1.84\n",
+ "10200 0.01 0.05997278763867332 -2.0\n",
+ "10300 0.01 0.059672923700479955 2.76\n",
+ "10400 0.01 0.05937455908197756 6.82\n",
+ "10500 0.01 0.05907768628656767 7.08\n",
+ "10600 0.01 0.05878229785513483 6.36\n",
+ "10700 0.01 0.05848838636585915 5.0\n",
+ "10800 0.01 0.05819594443402985 2.8\n",
+ "10900 0.01 0.057904964711859706 3.84\n",
+ "11000 0.01 0.05761543988830041 4.72\n",
+ "11100 0.01 0.05732736268885891 2.08\n",
+ "11200 0.01 0.05704072587541461 3.4\n",
+ "11300 0.01 0.05675552224603753 3.5\n",
+ "11400 0.01 0.05647174463480734 2.12\n",
+ "11500 0.01 0.056189385911633305 3.66\n",
+ "11600 0.01 0.05590843898207514 4.8\n",
+ "11700 0.01 0.05562889678716477 0.42\n",
+ "11800 0.01 0.055350752303228945 0.44\n",
+ "11900 0.01 0.0550739985417128 -0.38\n",
+ "12000 0.01 0.05479862854900423 0.88\n",
+ "12100 0.01 0.05452463540625921 0.96\n",
+ "12200 0.01 0.05425201222922791 -2.0\n",
+ "12300 0.01 0.05398075216808177 -1.76\n",
+ "12400 0.01 0.053710848407241364 -1.12\n",
+ "12500 0.01 0.053442294165205156 -1.22\n",
+ "12600 0.01 0.05317508269437913 -1.18\n",
+ "12700 0.01 0.052909207280907235 -1.36\n",
+ "12800 0.01 0.0526446612445027 -2.22\n",
+ "12900 0.01 0.052381437938280186 -0.08\n",
+ "13000 0.01 0.052119530748588785 -1.62\n",
+ "13100 0.01 0.05185893309484584 -1.6\n",
+ "13200 0.01 0.05159963842937161 0.5\n",
+ "13300 0.01 0.051341640237224755 -2.12\n",
+ "13400 0.01 0.05108493203603863 -0.76\n",
+ "13500 0.01 0.05082950737585844 1.64\n",
+ "13600 0.01 0.05057535983897914 -0.1\n",
+ "13700 0.01 0.050322483039784247 -0.36\n",
+ "13800 0.01 0.050070870624585324 0.6\n",
+ "13900 0.01 0.049820516271462396 1.56\n",
+ "14000 0.01 0.049571413690105086 -0.06\n",
+ "14100 0.01 0.04932355662165456 1.68\n",
+ "14200 0.01 0.04907693883854629 1.76\n",
+ "14300 0.01 0.048831554144353556 3.94\n",
+ "14400 0.01 0.04858739637363179 2.64\n",
+ "14500 0.01 0.04834445939176363 2.0\n",
+ "14600 0.01 0.04810273709480481 2.92\n",
+ "14700 0.01 0.04786222340933079 1.08\n",
+ "14800 0.01 0.04762291229228413 1.3\n",
+ "14900 0.01 0.04738479773082271 -1.24\n",
+ "15000 0.01 0.04714787374216859 -0.7\n",
+ "15100 0.01 0.046912134373457745 -0.92\n",
+ "15200 0.01 0.04667757370159046 0.3\n",
+ "15300 0.01 0.046444185833082505 2.16\n",
+ "15400 0.01 0.046211964903917095 3.62\n",
+ "15500 0.01 0.04598090507939751 0.12\n",
+ "15600 0.01 0.045751000554000526 3.7\n",
+ "15700 0.01 0.04552224555123052 1.66\n",
+ "15800 0.01 0.04529463432347437 5.44\n",
+ "15900 0.01 0.045068161151857 5.68\n",
+ "16000 0.01 0.04484282034609772 5.32\n",
+ "16100 0.01 0.04461860624436723 2.0\n",
+ "16200 0.01 0.044395513213145395 1.68\n",
+ "16300 0.01 0.04417353564707967 1.94\n",
+ "16400 0.01 0.04395266796884427 2.84\n",
+ "16500 0.01 0.04373290462900005 4.04\n",
+ "16600 0.01 0.04351424010585505 3.96\n",
+ "16700 0.01 0.043296668905325776 6.14\n",
+ "16800 0.01 0.04308018556079915 1.28\n",
+ "16900 0.01 0.04286478463299515 3.04\n",
+ "17000 0.01 0.042650460709830175 1.1\n",
+ "17100 0.01 0.04243720840628103 0.42\n",
+ "17200 0.01 0.042225022364249624 3.8\n",
+ "17300 0.01 0.04201389725242838 3.84\n",
+ "17400 0.01 0.041803827766166236 3.26\n",
+ "17500 0.01 0.0415948086273354 2.54\n",
+ "17600 0.01 0.04138683458419873 0.32\n",
+ "17700 0.01 0.04117990041127773 0.4\n",
+ "17800 0.01 0.04097400090922134 1.52\n",
+ "17900 0.01 0.04076913090467523 0.88\n",
+ "18000 0.01 0.04056528525015186 5.06\n",
+ "18100 0.01 0.0403624588239011 7.94\n",
+ "18200 0.01 0.04016064652978159 6.56\n",
+ "18300 0.01 0.03995984329713268 3.56\n",
+ "18400 0.01 0.03976004408064702 1.78\n",
+ "18500 0.01 0.03956124386024378 3.94\n",
+ "18600 0.01 0.039363437640942564 -1.24\n",
+ "18700 0.01 0.03916662045273785 0.94\n",
+ "18800 0.01 0.03897078735047416 1.18\n",
+ "18900 0.01 0.03877593341372179 1.64\n",
+ "19000 0.01 0.03858205374665318 2.14\n",
+ "19100 0.01 0.03838914347791991 3.66\n",
+ "19200 0.01 0.038197197760530315 3.16\n",
+ "19300 0.01 0.038006211771727666 4.28\n",
+ "19400 0.01 0.037816180712869026 2.98\n",
+ "19500 0.01 0.03762709980930468 6.58\n",
+ "19600 0.01 0.03743896431025816 2.66\n",
+ "19700 0.01 0.037251769488706864 2.82\n",
+ "19800 0.01 0.03706551064126333 0.46\n",
+ "19900 0.01 0.03688018308805701 4.18\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Some control variables\n",
+ "cumulated = np.zeros((nb_games))\n",
+ "\n",
+ "# Start the game\n",
+ "p.init()\n",
+ "reward = 0\n",
+ "\n",
+ "for i in range(nb_games):\n",
+ " p.reset_game()\n",
+ " \n",
+ " # Control print\n",
+ " if i%100 == 0:\n",
+ " print(i, epsilon, alpha, np.mean(cumulated[i-50:i]))\n",
+ " \n",
+ " # Decrease exploration ratio\n",
+ " epsilon = max(epsilon*0.8, 0.01)\n",
+ " alpha *= 0.995\n",
+ " \n",
+ " if i%1000 == 999:\n",
+ " np.save('Q_%d' % i ,Q)\n",
+ " np.save('cumulated_%d' % i, cumulated)\n",
+ " \n",
+ " # 0) Retrieve initial state\n",
+ " s1, s2, s3 = convstate(game.getGameState())\n",
+ " current_key = str(s1)+'|'+str(s2)+'|'+str(s3)\n",
+ " \n",
+ " while(not p.game_over()):\n",
+ " \n",
+ " # 1) Choose action greedily\n",
+ " a = epsilon_greedy(current_key)\n",
+ " \n",
+ " action = None\n",
+ " if a==1:\n",
+ " action = 119\n",
+ " \n",
+ " \n",
+ " # Execute\n",
+ " reward = p.act(action)\n",
+ " cumulated[i] += reward\n",
+ " \n",
+ " ss1, ss2, ss3 = convstate(game.getGameState())\n",
+ " next_key = str(ss1)+'|'+str(ss2)+'|'+str(ss3)\n",
+ " \n",
+ "\n",
+ " \n",
+ " # 2) Update Q value\n",
+ " if Q.get(current_key) == None:\n",
+ " Q[current_key] = [0,0]\n",
+ " \n",
+ " maxQ = max(Q.get(next_key, [0]))\n",
+ " \n",
+ " \n",
+ " Q[current_key][a] = (1-alpha)*Q[current_key][a] + alpha*( reward + gamma*maxQ )\n",
+ " \n",
+ " # Update values and map key\n",
+ " s1 = ss1\n",
+ " s2 = ss2\n",
+ " s3 = ss3\n",
+ " current_key = next_key\n",
+ " \n",
+ " \n",
+ "\n",
+ " \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Q"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Postprocess\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Mean evolution\n",
+ "vallist = list()\n",
+ "\n",
+ "for idx, val in enumerate(cumulated):\n",
+ " if idx < 50:\n",
+ " pass\n",
+ " else: \n",
+ " vallist.append(np.mean( cumulated[idx-50:idx] ))\n",
+ "\n",
+ "plt.plot(vallist)\n",
+ "\n",
+ "vallist = list()\n",
+ "\n",
+ "for idx, val in enumerate(cumulated):\n",
+ " if idx < 500:\n",
+ " pass\n",
+ " else: \n",
+ " vallist.append(np.mean( cumulated[idx-500:idx] ))\n",
+ "\n",
+ "plt.plot(vallist)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#c1 = cumulated\n",
+ "#c2 = cumulated\n",
+ "c3 = cumulated"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Mean evolution\n",
+ "vallist = list()\n",
+ "\n",
+ "result = np.concatenate((c1,c2,c3))\n",
+ "\n",
+ "for idx, val in enumerate(result):\n",
+ " if idx < 50:\n",
+ " pass\n",
+ " else: \n",
+ " vallist.append(np.mean( result[idx-50:idx] ))\n",
+ "\n",
+ "plt.plot(vallist)\n",
+ "\n",
+ "vallist = list()\n",
+ "\n",
+ "for idx, val in enumerate(result):\n",
+ " if idx < 500:\n",
+ " pass\n",
+ " else: \n",
+ " vallist.append(np.mean( result[idx-500:idx] ))\n",
+ "\n",
+ "plt.plot(vallist)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/Rocamora-Ardevol/Train/SARSA - TD lambda training.ipynb b/Rocamora-Ardevol/Train/SARSA - TD lambda training.ipynb
new file mode 100644
index 0000000..e31e941
--- /dev/null
+++ b/Rocamora-Ardevol/Train/SARSA - TD lambda training.ipynb
@@ -0,0 +1,279 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ple.games.flappybird import FlappyBird\n",
+ "from ple import PLE\n",
+ "\n",
+ "import numpy as np\n",
+ "#from FlappyAgent import FlappyPolicy\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "game = FlappyBird(graphics=\"fixed\") # use \"fancy\" for full background, random bird color and random pipe color, use \"fixed\" (default) for black background and constant bird and pipe colors.\n",
+ "p = PLE(game, fps=30, frame_skip=1, num_steps=1, force_fps=True, display_screen=True)\n",
+ "# Note: if you want to see you agent act in real time, set force_fps to False. But don't use this setting for learning, just for display purposes."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Declare functions\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def convstate(state):\n",
+ " \"\"\"\n",
+ " Calculate new state variables from game state\n",
+ " \"\"\"\n",
+ " s1 = state['next_pipe_bottom_y'] - state['player_y']\n",
+ " s2 = state['next_pipe_dist_to_player']\n",
+ " s3 = state['player_vel']\n",
+ " \n",
+ " return int(s1-s1%10), int(s2-s2%20), int(s3-s3%2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def epsilon_greedy(key):\n",
+ " if(np.random.rand()<=epsilon): # random action\n",
+ " return np.random.choice([0,1])\n",
+ " \n",
+ " else: \n",
+ " return np.argmax(Q.get(key, [0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def update_trace(key,action):\n",
+ " # Update the trace\n",
+ " global epsTrace\n",
+ " epsTrace = { k: list(map(lambda x: x*gamma*lamb, v)) for k,v in epsTrace.items() }\n",
+ " \n",
+ " if epsTrace.get(key) == None:\n",
+ " epsTrace[key] = [0,0]\n",
+ " \n",
+ " # Remember the current state\n",
+ " epsTrace[key][action] = 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def propagate(delta):\n",
+ " for k,v in epsTrace.items():\n",
+ " Q[k][0] = Q[k][0] + alpha*epsTrace[k][0]*delta\n",
+ " Q[k][1] = Q[k][1] + alpha*epsTrace[k][1]*delta"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Reinit variables\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Q = dict()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Metaparameters\n",
+ "nb_games = 18000\n",
+ "alpha = 0.1 #0.7\n",
+ "epsilon = 0.1 #0.4\n",
+ "gamma = 0.9\n",
+ "lamb = 1"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Run training\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "nbgames = 18000\n",
+ "\n",
+ "# Some control variables\n",
+ "cumulated = np.zeros((nb_games))\n",
+ "\n",
+ "# Start the game\n",
+ "p.init()\n",
+ "reward = 0\n",
+ "\n",
+ "for i in range(nb_games):\n",
+ " p.reset_game()\n",
+ " epsTrace = dict()\n",
+ " \n",
+ " # Control print\n",
+ " if i%100 == 0:\n",
+ " print(i, epsilon, alpha, np.mean(cumulated[i-50:i]))\n",
+ " \n",
+ " # Decrease exploration ratio\n",
+ " epsilon = max(epsilon * 0.95, 0.01)\n",
+ " alpha *= 0.995\n",
+ " \n",
+ " if i%1000 == 999:\n",
+ " np.save('Qsarsa_more_%d' % i ,Q)\n",
+ " np.save('cumulated_sarsa_more_%d' % i, cumulated)\n",
+ " \n",
+ " # 0) Retrieve initial state \n",
+ " s1, s2, s3 = convstate(game.getGameState())\n",
+ " current_key = str(s1)+'|'+str(s2)+'|'+str(s3)\n",
+ " \n",
+ " if Q.get(current_key) == None:\n",
+ " Q[current_key] = [0,0]\n",
+ " \n",
+ " # Choose action greedily\n",
+ " a = epsilon_greedy(current_key)\n",
+ " \n",
+ " while(not p.game_over()):\n",
+ " \n",
+ " # Translate action\n",
+ " action = None\n",
+ " if a==1:\n",
+ " action = 119\n",
+ " \n",
+ " # 1) Execute\n",
+ " reward = p.act(action)\n",
+ " cumulated[i] += reward\n",
+ " \n",
+ " ss1, ss2, ss3 = convstate(game.getGameState())\n",
+ " next_key = str(ss1)+'|'+str(ss2)+'|'+str(ss3)\n",
+ "\n",
+ " # 2) Choose new action greedily\n",
+ " aa = epsilon_greedy(next_key)\n",
+ " \n",
+ " # 3) Update Q value\n",
+ " # Update trace\n",
+ " update_trace(current_key, a)\n",
+ " \n",
+ " # Update Q\n",
+ " if Q.get(next_key) == None:\n",
+ " Q[next_key] = [0,0]\n",
+ " \n",
+ " delta = reward + gamma*Q[next_key][aa] - Q[current_key][a]\n",
+ " propagate(delta)\n",
+ " \n",
+ " # Update values and map key\n",
+ " s1 = ss1\n",
+ " s2 = ss2\n",
+ " s3 = ss3\n",
+ " a = aa\n",
+ " current_key = next_key \n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Postprocess\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Mean evolution\n",
+ "vallist = list()\n",
+ "\n",
+ "for idx, val in enumerate(cumulated):\n",
+ " if idx < 50:\n",
+ " pass\n",
+ " else: \n",
+ " vallist.append(np.mean( cumulated[idx-50:idx] ))\n",
+ "\n",
+ "plt.plot(vallist)\n",
+ "\n",
+ "vallist = list()\n",
+ "\n",
+ "for idx, val in enumerate(cumulated):\n",
+ " if idx < 500:\n",
+ " pass\n",
+ " else: \n",
+ " vallist.append(np.mean( cumulated[idx-500:idx] ))\n",
+ "\n",
+ "plt.plot(vallist)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/RandomBird/run.py b/Rocamora-Ardevol/run.py
similarity index 100%
rename from RandomBird/run.py
rename to Rocamora-Ardevol/run.py