Skip to content

Commit a60e31a

Browse files
Merge pull request #76 from HumanCompatibleAI/tomato_featurization
Tomato featurization
2 parents 3b7a161 + 26f82c6 commit a60e31a

File tree

20 files changed

+1148
-254
lines changed

20 files changed

+1148
-254
lines changed

.github/workflows/pythontests.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ jobs:
3333
run: pip install -e .
3434
- name: Run tests and generate coverage report
3535
run: |
36-
python -m unittest discover -s testing/ -p "*_test.py"
36+
coverage run -m unittest discover -s testing/ -p "*_test.py"
3737
- name: Upload coverage to Codecov
3838
uses: codecov/codecov-action@v1
3939
with:
40-
flags: no-planners
4140
name: codecov-report
4241
fail_ci_if_error: false

codecov.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
codecov:
2+
status:
3+
project:
4+
default:
5+
target: auto
6+
threshold: 1%
27
require_ci_to_pass: yes
38
max_report_age: off
49

src/overcooked_ai_py/agents/benchmarking.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld, Action, OvercookedState
99
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
1010
from overcooked_ai_py.mdp.layout_generator import LayoutGenerator
11+
from overcooked_ai_py.mdp.overcooked_trajectory import DEFAULT_TRAJ_KEYS
1112

1213

1314
class AgentEvaluator(object):
@@ -151,20 +152,20 @@ def get_agent_pair_trajs(self, a0, a1=None, num_games=100, game_length=None, sta
151152
return trajs_0, trajs_1
152153

153154
@staticmethod
154-
def check_trajectories(trajectories, from_json=False):
155+
def check_trajectories(trajectories, from_json=False, **kwargs):
155156
"""
156157
Checks that of trajectories are in standard format and are consistent with dynamics of mdp.
157158
If the trajectories were saves as json, do not check that they have standard traj keys.
158159
"""
159160
if not from_json:
160161
AgentEvaluator._check_standard_traj_keys(set(trajectories.keys()))
161162
AgentEvaluator._check_right_types(trajectories)
162-
AgentEvaluator._check_trajectories_dynamics(trajectories)
163+
AgentEvaluator._check_trajectories_dynamics(trajectories, **kwargs)
163164
# TODO: Check shapes?
164165

165166
@staticmethod
166167
def _check_standard_traj_keys(traj_keys_set):
167-
default_traj_keys = OvercookedEnv.DEFAULT_TRAJ_KEYS
168+
default_traj_keys = DEFAULT_TRAJ_KEYS
168169
assert traj_keys_set == set(default_traj_keys), "Keys of traj dict did not match standard form.\nMissing keys: {}\nAdditional keys: {}".format(
169170
[k for k in default_traj_keys if k not in traj_keys_set], [k for k in traj_keys_set if k not in default_traj_keys]
170171
)
@@ -181,10 +182,11 @@ def _check_right_types(trajectories):
181182
# TODO: check that are all lists
182183

183184
@staticmethod
184-
def _check_trajectories_dynamics(trajectories):
185+
def _check_trajectories_dynamics(trajectories, verbose=True):
185186
if any(env_params["_variable_mdp"] for env_params in trajectories["env_params"]):
186-
print("Skipping trajectory consistency checking because MDP was recognized as variable. "
187-
"Trajectory consistency checking is not yet supported for variable MDPs.")
187+
if verbose:
188+
print("Skipping trajectory consistency checking because MDP was recognized as variable. "
189+
"Trajectory consistency checking is not yet supported for variable MDPs.")
188190
return
189191

190192
_, envs = AgentEvaluator.get_mdps_and_envs_from_trajectories(trajectories)
@@ -241,7 +243,7 @@ def load_trajectories(filename):
241243
@staticmethod
242244
def save_traj_as_json(trajectory, filename):
243245
"""Saves the `idx`th trajectory as a list of state action pairs"""
244-
assert set(OvercookedEnv.DEFAULT_TRAJ_KEYS) == set(trajectory.keys()), "{} vs\n{}".format(OvercookedEnv.DEFAULT_TRAJ_KEYS, trajectory.keys())
246+
assert set(DEFAULT_TRAJ_KEYS) == set(trajectory.keys()), "{} vs\n{}".format(DEFAULT_TRAJ_KEYS, trajectory.keys())
245247
AgentEvaluator.check_trajectories(trajectory)
246248
trajectory = AgentEvaluator.make_trajectories_json_serializable(trajectory)
247249
save_as_json(trajectory, filename)
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
X X P X X
2+
3+
O ↑1 O
4+
5+
X ↑0 X
6+
7+
X D X S X
8+
9+
10+
Timestep: 1
11+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
12+
Action probs by index: [None, None]
13+
X X P X X
14+
15+
O ↑1 O
16+
17+
X ↑0 X
18+
19+
X D X S X
20+
21+
22+
23+
Timestep: 2
24+
Joint action taken: ('←', '→') Reward: 0 + shaping_factor * [0, 0]
25+
Action probs by index: [None, None]
26+
X X P X X
27+
28+
O →1 O
29+
30+
X ←0 X
31+
32+
X D X S X
33+
34+
35+
36+
Timestep: 3
37+
Joint action taken: ('←', '→') Reward: 0 + shaping_factor * [0, 0]
38+
Action probs by index: [None, None]
39+
X X P X X
40+
41+
O →1 O
42+
43+
X ←0 X
44+
45+
X D X S X
46+
47+
48+
49+
Timestep: 4
50+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
51+
Action probs by index: [None, None]
52+
X X P X X
53+
54+
O →1 O
55+
56+
X ←0 X
57+
58+
X D X S X
59+
60+
61+
62+
Timestep: 5
63+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
64+
Action probs by index: [None, None]
65+
X X P X X
66+
67+
O →1 O
68+
69+
X ←0 X
70+
71+
X D X S X
72+
73+
74+
75+
Timestep: 6
76+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
77+
Action probs by index: [None, None]
78+
X X P X X
79+
80+
O →1 O
81+
82+
X ←0 X
83+
84+
X D X S X
85+
86+
87+
88+
Timestep: 7
89+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
90+
Action probs by index: [None, None]
91+
X X P X X
92+
93+
O →1 O
94+
95+
X ←0 X
96+
97+
X D X S X
98+
99+
100+
101+
Timestep: 8
102+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
103+
Action probs by index: [None, None]
104+
X X P X X
105+
106+
O →1 O
107+
108+
X ←0 X
109+
110+
X D X S X
111+
112+
113+
114+
Timestep: 9
115+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
116+
Action probs by index: [None, None]
117+
X X P X X
118+
119+
O →1 O
120+
121+
X ←0 X
122+
123+
X D X S X
124+
125+
126+
127+
Timestep: 10
128+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
129+
Action probs by index: [None, None]
130+
X X P X X
131+
132+
O →1 O
133+
134+
X ←0 X
135+
136+
X D X S X
137+
138+
139+
140+
Timestep: 11
141+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
142+
Action probs by index: [None, None]
143+
X X P X X
144+
145+
O →1 O
146+
147+
X ←0 X
148+
149+
X D X S X
150+
151+
152+
153+
Timestep: 12
154+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
155+
Action probs by index: [None, None]
156+
X X P X X
157+
158+
O →1 O
159+
160+
X ←0 X
161+
162+
X D X S X
163+
164+
165+
166+
Timestep: 13
167+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
168+
Action probs by index: [None, None]
169+
X X P X X
170+
171+
O →1 O
172+
173+
X ←0 X
174+
175+
X D X S X
176+
177+
178+
179+
Timestep: 14
180+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
181+
Action probs by index: [None, None]
182+
X X P X X
183+
184+
O →1 O
185+
186+
X ←0 X
187+
188+
X D X S X
189+
190+
191+
192+
Timestep: 15
193+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
194+
Action probs by index: [None, None]
195+
X X P X X
196+
197+
O →1 O
198+
199+
X ←0 X
200+
201+
X D X S X
202+
203+
204+
205+
Timestep: 16
206+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
207+
Action probs by index: [None, None]
208+
X X P X X
209+
210+
O →1 O
211+
212+
X ←0 X
213+
214+
X D X S X
215+
216+
217+
218+
Timestep: 17
219+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
220+
Action probs by index: [None, None]
221+
X X P X X
222+
223+
O →1 O
224+
225+
X ←0 X
226+
227+
X D X S X
228+
229+
230+
231+
Timestep: 18
232+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
233+
Action probs by index: [None, None]
234+
X X P X X
235+
236+
O →1 O
237+
238+
X ←0 X
239+
240+
X D X S X
241+
242+
243+
244+
Timestep: 19
245+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
246+
Action probs by index: [None, None]
247+
X X P X X
248+
249+
O →1 O
250+
251+
X ←0 X
252+
253+
X D X S X
254+
255+
256+
257+
Timestep: 20
258+
Joint action taken: ('stay', 'stay') Reward: 0 + shaping_factor * [0, 0]
259+
Action probs by index: [None, None]
260+
X X P X X
261+
262+
O →1 O
263+
264+
X ←0 X
265+
266+
X D X S X
267+
268+
269+

0 commit comments

Comments
 (0)