Skip to content

Commit 56edd6e

Browse files
authored
Merge pull request flatland-association#21 from flatland-association/feature/graph-env-first-steps
Update to breaking changes in flatland-rl #257 (feature/graph env first steps).
2 parents 582aade + 2ee986d commit 56edd6e

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

flatland_baselines/deadlock_avoidance_heuristic/observation/full_env_observation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from flatland.envs.rail_env import RailEnv
33

44

5-
class FullEnvObservation(ObservationBuilder[RailEnv]):
5+
class FullEnvObservation(ObservationBuilder[RailEnv, RailEnv]):
66
"""
77
Returns full env as observation.
88
"""
9+
910
def __init__(self):
1011
pass
1112

@@ -15,5 +16,5 @@ def get(self, handle: AgentHandle = 0) -> ObservationType:
1516
def reset(self):
1617
pass
1718

18-
def set_env(self,env):
19+
def set_env(self, env):
1920
self.env = env

flatland_baselines/deadlock_avoidance_heuristic/policy/deadlock_avoidance_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def callback(self, handle, agent, position, direction, action, possible_transiti
9090
@_enable_flatland_deadlock_avoidance_policy_lru_cache(maxsize=100000)
9191
def _is_no_switch_cell(self, position) -> bool:
9292
for new_dir in range(4):
93-
possible_transitions = self.env.rail.get_transitions(*position, new_dir)
93+
possible_transitions = self.env.rail.get_transitions((position, new_dir))
9494
num_transitions = fast_count_nonzero(possible_transitions)
9595
if num_transitions > 1:
9696
return False

flatland_baselines/deadlock_avoidance_heuristic/utils/flatland/shortest_distance_walker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def walk(self, handle, position, direction):
3939
if self.distance_map is None:
4040
self.distance_map = self.env.distance_map.get()
4141

42-
possible_transitions = self.env.rail.get_transitions(*position, direction)
42+
possible_transitions = self.env.rail.get_transitions((position, direction))
4343
num_transitions = fast_count_nonzero(possible_transitions)
4444
if num_transitions == 1:
4545
new_direction = fast_argmax(possible_transitions)

0 commit comments

Comments
 (0)