Skip to content

Commit c280a92

Browse files
authored
RL Zoo as a package (#290)
* RL Zoo as a package * Fixes for the CI * Add basic cli * Add hyperparams to package data * Update pytype command * Tmp fix import errors for pytype * Fix test requirements * Add dependencies and remove unused wrapper
1 parent 8600d80 commit c280a92

40 files changed

+725
-656
lines changed

.coveragerc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
branch = False
33
omit =
44
tests/*
5-
utils/plot.py
5+
rl_zoo/utils/plot.py
66

77
[report]
88
exclude_lines =

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
pip install opencv-python-headless
3737
# install parking-env to test HER (pinned so it works with gym 0.21)
3838
pip install highway-env==1.5.0
39+
pip install -e .
3940
- name: Type check
4041
run: |
4142
make type

.github/workflows/trained_agents.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jobs:
3838
pip install highway-env==1.5.0
3939
# Add support for pickle5 protocol
4040
pip install pickle5
41+
pip install -e .
4142
- name: Check trained agents
4243
run: |
4344
make check-trained-agents

.gitignore

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ runs
1818
hub
1919
*.mp4
2020
*.json
21+
22+
# Setuptools distribution and build folders.
23+
/dist/
24+
/build
25+
keys/
26+
*.egg-info
27+
.cache
28+
*.lprof
29+
*.prof

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
## Release 1.6.2 (2022-10-02)
2+
3+
### Breaking Changes
4+
- RL Zoo is now a python package
5+
- low pass filter was removed
6+
7+
### New Features
8+
- RL Zoo cli: `rl_zoo train` and `rl_zoo enjoy`
9+
10+
### Bug fixes
11+
12+
### Documentation
13+
14+
### Other
15+
116
## Release 1.6.1 (2022-09-30)
217

318
**Progress bar and custom yaml file**

Makefile

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
LINT_PATHS = *.py tests/ scripts/ utils/
1+
LINT_PATHS = *.py tests/ scripts/ rl_zoo/
22

33
# Run pytest and coverage report
44
pytest:
@@ -10,7 +10,7 @@ check-trained-agents:
1010

1111
# Type check
1212
type:
13-
pytype -j auto ${LINT_PATHS}
13+
pytype -j auto rl_zoo/ tests/ scripts/ -d import-error
1414

1515
lint:
1616
# stop the build if there are Python syntax errors or undefined names
@@ -42,4 +42,16 @@ docker-cpu:
4242
docker-gpu:
4343
USE_GPU=True ./scripts/build_docker.sh
4444

45-
.PHONY: docker lint type pytest
45+
# PyPi package release
46+
release:
47+
python setup.py sdist
48+
python setup.py bdist_wheel
49+
twine upload dist/*
50+
51+
# Test PyPi package release
52+
test-release:
53+
python setup.py sdist
54+
python setup.py bdist_wheel
55+
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
56+
57+
.PHONY: lint format check-codestyle commit-checks doc spelling docker type pytest

enjoy.py

Lines changed: 2 additions & 272 deletions
Original file line numberDiff line numberDiff line change
@@ -1,274 +1,4 @@
1-
import argparse
2-
import importlib
3-
import os
4-
import sys
5-
6-
import numpy as np
7-
import torch as th
8-
import yaml
9-
from huggingface_sb3 import EnvironmentName
10-
from stable_baselines3.common.utils import set_random_seed
11-
12-
import utils.import_envs # noqa: F401 pylint: disable=unused-import
13-
from utils import ALGOS, create_test_env, get_saved_hyperparams
14-
from utils.callbacks import tqdm
15-
from utils.exp_manager import ExperimentManager
16-
from utils.load_from_hub import download_from_hub
17-
from utils.utils import StoreDict, get_model_path
18-
19-
20-
def main(): # noqa: C901
21-
parser = argparse.ArgumentParser()
22-
parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1")
23-
parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents")
24-
parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys()))
25-
parser.add_argument("-n", "--n-timesteps", help="number of timesteps", default=1000, type=int)
26-
parser.add_argument("--num-threads", help="Number of threads for PyTorch (-1 to use default)", default=-1, type=int)
27-
parser.add_argument("--n-envs", help="number of environments", default=1, type=int)
28-
parser.add_argument("--exp-id", help="Experiment ID (default: 0: latest, -1: no exp folder)", default=0, type=int)
29-
parser.add_argument("--verbose", help="Verbose mode (0: no output, 1: INFO)", default=1, type=int)
30-
parser.add_argument(
31-
"--no-render", action="store_true", default=False, help="Do not render the environment (useful for tests)"
32-
)
33-
parser.add_argument("--deterministic", action="store_true", default=False, help="Use deterministic actions")
34-
parser.add_argument("--device", help="PyTorch device to be use (ex: cpu, cuda...)", default="auto", type=str)
35-
parser.add_argument(
36-
"--load-best", action="store_true", default=False, help="Load best model instead of last model if available"
37-
)
38-
parser.add_argument(
39-
"--load-checkpoint",
40-
type=int,
41-
help="Load checkpoint instead of last model if available, "
42-
"you must pass the number of timesteps corresponding to it",
43-
)
44-
parser.add_argument(
45-
"--load-last-checkpoint",
46-
action="store_true",
47-
default=False,
48-
help="Load last checkpoint instead of last model if available",
49-
)
50-
parser.add_argument("--stochastic", action="store_true", default=False, help="Use stochastic actions")
51-
parser.add_argument(
52-
"--norm-reward", action="store_true", default=False, help="Normalize reward if applicable (trained with VecNormalize)"
53-
)
54-
parser.add_argument("--seed", help="Random generator seed", type=int, default=0)
55-
parser.add_argument("--reward-log", help="Where to log reward", default="", type=str)
56-
parser.add_argument(
57-
"--gym-packages",
58-
type=str,
59-
nargs="+",
60-
default=[],
61-
help="Additional external Gym environment package modules to import (e.g. gym_minigrid)",
62-
)
63-
parser.add_argument(
64-
"--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor"
65-
)
66-
parser.add_argument(
67-
"--custom-objects", action="store_true", default=False, help="Use custom objects to solve loading issues"
68-
)
69-
parser.add_argument(
70-
"-P",
71-
"--progress",
72-
action="store_true",
73-
default=False,
74-
help="if toggled, display a progress bar using tqdm and rich",
75-
)
76-
args = parser.parse_args()
77-
78-
# Going through custom gym packages to let them register in the global registory
79-
for env_module in args.gym_packages:
80-
importlib.import_module(env_module)
81-
82-
env_name: EnvironmentName = args.env
83-
algo = args.algo
84-
folder = args.folder
85-
86-
try:
87-
_, model_path, log_path = get_model_path(
88-
args.exp_id,
89-
folder,
90-
algo,
91-
env_name,
92-
args.load_best,
93-
args.load_checkpoint,
94-
args.load_last_checkpoint,
95-
)
96-
except (AssertionError, ValueError) as e:
97-
# Special case for rl-trained agents
98-
# auto-download from the hub
99-
if "rl-trained-agents" not in folder:
100-
raise e
101-
else:
102-
print("Pretrained model not found, trying to download it from sb3 Huggingface hub: https://huggingface.co/sb3")
103-
# Auto-download
104-
download_from_hub(
105-
algo=algo,
106-
env_name=env_name,
107-
exp_id=args.exp_id,
108-
folder=folder,
109-
organization="sb3",
110-
repo_name=None,
111-
force=False,
112-
)
113-
# Try again
114-
_, model_path, log_path = get_model_path(
115-
args.exp_id,
116-
folder,
117-
algo,
118-
env_name,
119-
args.load_best,
120-
args.load_checkpoint,
121-
args.load_last_checkpoint,
122-
)
123-
124-
print(f"Loading {model_path}")
125-
126-
# Off-policy algorithm only support one env for now
127-
off_policy_algos = ["qrdqn", "dqn", "ddpg", "sac", "her", "td3", "tqc"]
128-
129-
if algo in off_policy_algos:
130-
args.n_envs = 1
131-
132-
set_random_seed(args.seed)
133-
134-
if args.num_threads > 0:
135-
if args.verbose > 1:
136-
print(f"Setting torch.num_threads to {args.num_threads}")
137-
th.set_num_threads(args.num_threads)
138-
139-
is_atari = ExperimentManager.is_atari(env_name.gym_id)
140-
141-
stats_path = os.path.join(log_path, env_name)
142-
hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True)
143-
144-
# load env_kwargs if existing
145-
env_kwargs = {}
146-
args_path = os.path.join(log_path, env_name, "args.yml")
147-
if os.path.isfile(args_path):
148-
with open(args_path) as f:
149-
loaded_args = yaml.load(f, Loader=yaml.UnsafeLoader) # pytype: disable=module-attr
150-
if loaded_args["env_kwargs"] is not None:
151-
env_kwargs = loaded_args["env_kwargs"]
152-
# overwrite with command line arguments
153-
if args.env_kwargs is not None:
154-
env_kwargs.update(args.env_kwargs)
155-
156-
log_dir = args.reward_log if args.reward_log != "" else None
157-
158-
env = create_test_env(
159-
env_name.gym_id,
160-
n_envs=args.n_envs,
161-
stats_path=stats_path,
162-
seed=args.seed,
163-
log_dir=log_dir,
164-
should_render=not args.no_render,
165-
hyperparams=hyperparams,
166-
env_kwargs=env_kwargs,
167-
)
168-
169-
kwargs = dict(seed=args.seed)
170-
if algo in off_policy_algos:
171-
# Dummy buffer size as we don't need memory to enjoy the trained agent
172-
kwargs.update(dict(buffer_size=1))
173-
# Hack due to breaking change in v1.6
174-
# handle_timeout_termination cannot be at the same time
175-
# with optimize_memory_usage
176-
if "optimize_memory_usage" in hyperparams:
177-
kwargs.update(optimize_memory_usage=False)
178-
179-
# Check if we are running python 3.8+
180-
# we need to patch saved model under python 3.6/3.7 to load them
181-
newer_python_version = sys.version_info.major == 3 and sys.version_info.minor >= 8
182-
183-
custom_objects = {}
184-
if newer_python_version or args.custom_objects:
185-
custom_objects = {
186-
"learning_rate": 0.0,
187-
"lr_schedule": lambda _: 0.0,
188-
"clip_range": lambda _: 0.0,
189-
}
190-
191-
model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, device=args.device, **kwargs)
192-
193-
obs = env.reset()
194-
195-
# Deterministic by default except for atari games
196-
stochastic = args.stochastic or is_atari and not args.deterministic
197-
deterministic = not stochastic
198-
199-
episode_reward = 0.0
200-
episode_rewards, episode_lengths = [], []
201-
ep_len = 0
202-
# For HER, monitor success rate
203-
successes = []
204-
lstm_states = None
205-
episode_start = np.ones((env.num_envs,), dtype=bool)
206-
207-
generator = range(args.n_timesteps)
208-
if args.progress:
209-
generator = tqdm(generator)
210-
211-
try:
212-
for _ in generator:
213-
action, lstm_states = model.predict(
214-
obs,
215-
state=lstm_states,
216-
episode_start=episode_start,
217-
deterministic=deterministic,
218-
)
219-
obs, reward, done, infos = env.step(action)
220-
221-
episode_start = done
222-
223-
if not args.no_render:
224-
env.render("human")
225-
226-
episode_reward += reward[0]
227-
ep_len += 1
228-
229-
if args.n_envs == 1:
230-
# For atari the return reward is not the atari score
231-
# so we have to get it from the infos dict
232-
if is_atari and infos is not None and args.verbose >= 1:
233-
episode_infos = infos[0].get("episode")
234-
if episode_infos is not None:
235-
print(f"Atari Episode Score: {episode_infos['r']:.2f}")
236-
print("Atari Episode Length", episode_infos["l"])
237-
238-
if done and not is_atari and args.verbose > 0:
239-
# NOTE: for env using VecNormalize, the mean reward
240-
# is a normalized reward when `--norm_reward` flag is passed
241-
print(f"Episode Reward: {episode_reward:.2f}")
242-
print("Episode Length", ep_len)
243-
episode_rewards.append(episode_reward)
244-
episode_lengths.append(ep_len)
245-
episode_reward = 0.0
246-
ep_len = 0
247-
248-
# Reset also when the goal is achieved when using HER
249-
if done and infos[0].get("is_success") is not None:
250-
if args.verbose > 1:
251-
print("Success?", infos[0].get("is_success", False))
252-
253-
if infos[0].get("is_success") is not None:
254-
successes.append(infos[0].get("is_success", False))
255-
episode_reward, ep_len = 0.0, 0
256-
257-
except KeyboardInterrupt:
258-
pass
259-
260-
if args.verbose > 0 and len(successes) > 0:
261-
print(f"Success rate: {100 * np.mean(successes):.2f}%")
262-
263-
if args.verbose > 0 and len(episode_rewards) > 0:
264-
print(f"{len(episode_rewards)} Episodes")
265-
print(f"Mean reward: {np.mean(episode_rewards):.2f} +/- {np.std(episode_rewards):.2f}")
266-
267-
if args.verbose > 0 and len(episode_lengths) > 0:
268-
print(f"Mean episode length: {np.mean(episode_lengths):.2f} +/- {np.std(episode_lengths):.2f}")
269-
270-
env.close()
271-
1+
from rl_zoo.enjoy import enjoy
2722

2733
if __name__ == "__main__":
274-
main()
4+
enjoy()

hyperparams/her.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ FetchSlide-v1:
5959
FetchPickAndPlace-v1:
6060
env_wrapper:
6161
- sb3_contrib.common.wrappers.TimeFeatureWrapper
62-
# - utils.wrappers.DoneOnSuccessWrapper:
62+
# - rl_zoo.wrappers.DoneOnSuccessWrapper:
6363
# reward_offset: 0
6464
# n_successes: 4
6565
# - stable_baselines3.common.monitor.Monitor
@@ -96,7 +96,7 @@ FetchReach-v1:
9696
NeckGoalEnvRelativeSparse-v2:
9797
model_class: 'sac'
9898
# env_wrapper:
99-
# - utils.wrappers.HistoryWrapper:
99+
# - rl_zoo.wrappers.HistoryWrapper:
100100
# horizon: 2
101101
# - sb3_contrib.common.wrappers.TimeFeatureWrapper
102102
n_timesteps: !!float 1e6
@@ -122,7 +122,7 @@ NeckGoalEnvRelativeSparse-v2:
122122
NeckGoalEnvRelativeDense-v2:
123123
model_class: 'sac'
124124
env_wrapper:
125-
- utils.wrappers.HistoryWrapperObsDict:
125+
- rl_zoo.wrappers.HistoryWrapperObsDict:
126126
horizon: 2
127127
# - sb3_contrib.common.wrappers.TimeFeatureWrapper
128128
n_timesteps: !!float 1e6

0 commit comments

Comments
 (0)