Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions MartinAymeline/FlappyAgent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
from keras.models import load_model
from collections import deque
from utilities import process_screen, greedy_action


stacked = []
calls = 0
DQN = load_model('model_dqn_new_65000.h5')
possible_actions = [119,None]


def FlappyPolicy(state, screen):
global stacked
global calls
global DQN
global action

calls = calls + 1
processed_screen = process_screen(screen)

if (calls == 1) :
# stack the 4 last frames
stacked = deque([processed_screen,processed_screen, \
processed_screen,processed_screen], maxlen=4)
x = np.stack(stacked, axis=-1)

else :
stacked.append(processed_screen)
x = np.stack(stacked, axis=-1)

Q = DQN.predict(np.array([x]))

return possible_actions[np.argmax(Q)]


25 changes: 25 additions & 0 deletions MartinAymeline/constantes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
class constantes : # Fix main constants

# memory buffer constants
replay_memory_size = 200000 # number of previous transitions to remember
mini_batch_size = 32

# Learning constants
gamma = 0.99
total_steps = 200000 # The best network was obtained after 65000 steps
observation = 5000.
explore = 1000000. # frames over which to anneal epsilon
final_eps = 0.001 # final value of epsilon
initial_eps = 0.1 # starting value of epsilon

# Optimizer's constants
alpha = 1e-4 # learning rate
beta_1 = 0.9
beta_2 = 0.999

# Evaluation constants
evaluation_period = 5000 # Ealuation of the deep q network every 5000 steps
nb_epochs = total_steps // evaluation_period
epoch=-1


152 changes: 152 additions & 0 deletions MartinAymeline/eval.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
0,0.0,4.69466784611
1,0.0,4.96170904573
2,5.47673034668,3.53810591755
3,0.0,3.62164132165
0,0.0,5.86053436558
1,4.66656112671,3.88620864601
2,7.15516853333,2.93598770647
3,8.48180675507,3.93345115937
4,10.2029304504,3.56317498664
5,11.2960853577,4.7983870147
6,12.9479646683,5.75704039982
7,14.86444664,5.06069977694
8,16.3370666504,4.46964998889
9,17.982843399,6.19302055301
10,18.5910949707,6.50989475939
11,17.6032295227,8.10438881031
12,15.5112314224,9.03463693305
13,15.8485612869,11.2945019369
14,15.7190685272,8.59904841042
15,14.6428871155,9.70675154722
16,13.4239826202,10.5581026454
17,12.485534668,10.801754052
18,11.0798435211,14.4698277258
19,9.79876613617,10.1922431753
20,7.87689256668,15.5221458151
21,7.39563035965,15.4769874784
22,5.57780265808,14.7065343904
23,4.64339876175,16.9289811078
24,4.26625919342,14.2144713002
25,3.76601338387,16.3316570776
26,3.70627355576,14.1546108109
27,5.11464166641,15.4347679658
28,4.40491008759,15.073854647
29,4.24400806427,15.949591948
30,6.30815887451,16.3497812243
31,8.96263790131,15.1154130214
32,11.736158371,16.4718276847
33,12.8968105316,18.8227412404
34,14.7348413467,17.9756738702
35,15.7232618332,18.0161046671
36,17.8804893494,17.6891486946
37,20.2932090759,21.1049183936
38,22.4963302612,20.7092030539
0,24.967710495,19.6716769573
1,38.5622138977,14.2626393644
2,40.7054405212,15.117170123
3,39.2491836548,16.5829738643
4,45.0468063354,11.9271002253
5,45.4236869812,10.9867965833
6,42.0939025879,10.0917820893
7,42.1448974609,9.4055183894
8,39.3510360718,6.49246510544
9,42.3908004761,4.77522978771
10,40.0641403198,4.58435612121
11,34.6684570313,4.10691721313
12,32.4749946594,5.02565120362
13,32.4948348999,5.9498961219
14,38.884727478,4.75203246545
15,43.4196586609,4.60334158404
16,55.7478370667,3.81404406007
17,53.9235687256,4.36071293652
18,69.6009521484,4.48999655189
19,72.2026443481,3.42965011211
20,389.599456787,-11.9877204284
21,2478.59570313,-11.9335284749
22,6405.49511719,-11.9877204284
23,13140.9375,-11.9877204284
24,25132.1035156,-11.9877204284
25,42703.2929688,-12.1962853141
26,61368.8945313,-12.1962853141
27,82755.6875,-12.1962853141
28,104760.78125,-11.2593159839
29,107774.976563,-11.7461346088
30,90664.625,-11.9335284749
31,80297.1796875,-11.2432807355
32,51785.5625,-5.13517710508
33,48012.2539063,-12.043981514
34,41425.4921875,-9.27832739837
35,32831.4921875,-2.11385608523
36,17783.8730469,0.243092634338
37,15080.1923828,1.10150837747
38,16184.1445313,2.04957438321
0,0.0,3.86854620148
1,8.84752559662,4.25661835006
0,0.0,-12.4,-7.0
1,5.01180744171,-14.7,-9.0
2,5.85478305817,-14.0,-9.0
3,5.52176523209,-14.05,-9.0
4,5.50321054459,-13.35,-8.0
5,5.94852733612,-13.95,-9.0
6,6.59453201294,-13.9,-8.0
7,7.23370599747,-11.95,-8.0
8,8.26382637024,-11.55,-2.0
9,8.98322105408,-11.75,-7.0
10,9.68953227997,-11.2,-7.0
11,10.5278759003,-9.3,-2.0
12,12.3057289124,-8.7,-2.0
13,13.1675548553,-9.0,-2.0
14,13.0113019943,-9.1,-2.0
15,14.6724071503,-4.05,4.0
16,16.3890533447,-6.2,-2.0
17,17.4456806183,-5.05,4.0
18,18.2743644714,-2.55,4.0
19,16.9419136047,-1.35,4.0
20,16.3637008667,-4.8,4.0
21,18.7339382172,-4.75,4.0
22,18.3891048431,-3.95,4.0
23,17.445936203,-1.5,4.0
24,14.6130094528,-0.55,4.0
25,12.6413230896,-12.8,-8.0
26,21.5028362274,-16.45,-14.0
27,39.2529563904,-20.0,-20.0
0,24.967710495,1.9,4.0
1,38.5454292297,-0.75,4.0
2,36.9952774048,2.55,4.0
3,28.1553764343,-7.55,-2.0
4,29.2874355316,-10.85,-3.0
0,0.0,-4.9,-4.0
1,7.87420940399,-4.9,-4.0
2,1.58440470695,-5.0,-5.0
3,3.56381487846,-4.8,-4.0
4,2.0504193306,-4.75,-4.0
5,-0.517796576023,-4.8,-4.0
6,1.06802773476,-4.45,-4.0
7,0.590887069702,-3.75,-1.0
8,2.1341612339,-2.35,1.0
9,3.60842895508,-0.1,20.0
10,4.7539973259,9.65,52.0
11,3.8366549015,15.8,57.0
12,10.3156099319,19.45,66.0
13,460.176086426,-4.4,-3.0
14,-864.924438477,-5.0,-5.0
15,62.3476867676,-4.75,-4.0
16,311.228363037,-4.85,-4.0
17,-57.1425170898,-4.65,-3.0
18,-46.8574295044,-5.0,-5.0
19,-28.9334468842,-5.0,-5.0
20,10.1054878235,-4.7,-4.0
21,-0.0738104507327,-4.05,0.0
22,6.38661527634,-4.85,-4.0
23,-9.64006328583,-4.25,-2.0
24,44.112663269,-4.7,-4.0
25,13.0731668472,-4.4,-3.0
26,16.3547077179,-4.85,-4.0
27,-53.566696167,-4.45,-4.0
28,3.15030813217,-4.65,-4.0
29,5.84460353851,-4.35,-2.0
30,-4.33490753174,-4.55,-3.0
31,0.0737409219146,-4.25,-2.0
32,-4.76753520966,-4.35,-3.0
33,-3.23868513107,-4.45,-3.0
0,0.0,-4.85,-4.0
Binary file added MartinAymeline/model_dqn_new_65000.h5
Binary file not shown.
63 changes: 63 additions & 0 deletions MartinAymeline/replay_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from collections import deque
import numpy as np

# A class for the replay memory. We use the one which was in the RL4 Notebook

class MemoryBuffer:
"An experience replay buffer using numpy arrays"
# Initialize the class
def __init__(self, length, screen_shape, action_shape):
self.length = length
self.screen_shape = screen_shape
self.action_shape = action_shape
shape = (length,) + screen_shape
self.screens_x = np.zeros(shape, dtype=np.uint8) # starting states
self.screens_y = np.zeros(shape, dtype=np.uint8) # resulting states
shape = (length,) + action_shape
self.actions = np.zeros(shape, dtype=np.uint8) # actions
self.rewards = np.zeros((length,1), dtype=np.int8) # rewards
self.terminals = np.zeros((length,1), dtype=np.bool) # true if resulting state is terminal
self.terminals[-1] = True
self.index = 0 # points one position past the last inserted element
self.size = 0 # current size of the buffer

# Add state x, action a, reward r and new state y
def append(self, screenx, a, r, screeny, d):
self.screens_x[self.index] = screenx
self.actions[self.index] = a
self.rewards[self.index] = r
self.screens_y[self.index] = screeny
self.terminals[self.index] = d
self.index = (self.index+1) % self.length
self.size = np.min([self.size+1,self.length])

def stacked_frames_x(self, index):
im_deque = deque(maxlen=4)
pos = index % self.length
for i in range(4):
im = self.screens_x[pos]
im_deque.appendleft(im)
test_pos = (pos-1) % self.length
if self.terminals[test_pos] == False:
pos = test_pos
return np.stack(im_deque, axis=-1)

def stacked_frames_y(self, index):
im_deque = deque(maxlen=4)
pos = index % self.length
for i in range(4):
im = self.screens_y[pos]
im_deque.appendleft(im)
test_pos = (pos-1) % self.length
if self.terminals[test_pos] == False:
pos = test_pos
return np.stack(im_deque, axis=-1)

def minibatch(self, size):
indices = np.random.choice(self.size, size=size, replace=False)
x = np.zeros((size,)+self.screen_shape+(4,))
y = np.zeros((size,)+self.screen_shape+(4,))
for i in range(size):
x[i] = self.stacked_frames_x(indices[i])
y[i] = self.stacked_frames_y(indices[i])
return x, self.actions[indices], self.rewards[indices], y, self.terminals[indices]
29 changes: 29 additions & 0 deletions MartinAymeline/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# You're not allowed to change this file
from ple.games.flappybird import FlappyBird
from ple import PLE
import numpy as np
from FlappyAgent import FlappyPolicy

game = FlappyBird(graphics="fixed") # use "fancy" for full background, random bird color and random pipe color, use "fixed" (default) for black background and constant bird and pipe colors.
p = PLE(game, fps=30, frame_skip=1, num_steps=1, force_fps=False, display_screen=True)
# Note: if you want to see you agent act in real time, set force_fps to False. But don't use this setting for learning, just for display purposes.

p.init()
reward = 0.0

nb_games = 100
cumulated = np.zeros((nb_games))

for i in range(nb_games):
p.reset_game()

while(not p.game_over()):
state = game.getGameState()
screen = p.getScreenRGB()
action=FlappyPolicy(state, screen) ### Your job is to define this function.

reward = p.act(action)
cumulated[i] = cumulated[i] + reward

average_score = np.mean(cumulated)
max_score = np.max(cumulated)
Loading