Skip to content

Commit c6acd1e

Browse files
authoredFeb 3, 2020
Fix check_env, Monitor.close and add Makefile (hill-a#673)
* Fix `check_env` and add Makefile * Fixed doc build * Fixed and typed Monitor
1 parent 34d2bee commit c6acd1e

File tree

8 files changed

+99
-37
lines changed

8 files changed

+99
-37
lines changed
 

‎.github/PULL_REQUEST_TEMPLATE.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@
2424
- [ ] My change requires a change to the documentation.
2525
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
2626
- [ ] I have updated the documentation accordingly.
27-
- [ ] I have ensured `pytest` and `pytype` both pass.
27+
- [ ] I have ensured `pytest` and `pytype` both pass (by running `make pytest` and `make type`).
2828

2929
<!--- This Template is an edited version of the one from https://github.com/evilsocket/pwnagotchi/ -->

‎CONTRIBUTING.md

+28-9
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,17 @@ from stable_baselines import PPO2
5757

5858
In general, we recommend using pycharm to format everything in an efficient way.
5959

60-
Please documentation each function/method using the following template:
60+
Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template:
6161

6262
```python
6363

64-
def my_function(arg1, arg2):
64+
def my_function(arg1: type1, arg2: type2) -> returntype:
6565
"""
6666
Short description of the function.
6767
68-
:param arg1: (arg1 type) describe what is arg1
69-
:param arg2: (arg2 type) describe what is arg2
70-
:return: (return type) describe what is returned
68+
:param arg1: (type1) describe what is arg1
69+
:param arg2: (type2) describe what is arg2
70+
:return: (returntype) describe what is returned
7171
"""
7272
...
7373
return my_variable
@@ -77,7 +77,7 @@ def my_function(arg1, arg2):
7777

7878
Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.
7979

80-
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a , @araffin or @erniejunior ).
80+
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli).
8181
A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch.
8282

8383
Note: in rare cases, we can create exception for codacy failure.
@@ -88,15 +88,34 @@ All new features must add tests in the `tests/` folder ensuring that everything
8888
We use [pytest](https://pytest.org/).
8989
Also, when a bug fix is proposed, tests should be added to avoid regression.
9090

91-
To run tests with `pytest` and type checking with `pytype`:
91+
To run tests with `pytest`:
9292

9393
```
94-
./scripts/run_tests.sh
94+
make pytest
9595
```
9696

97+
Type checking with `pytype`:
98+
99+
```
100+
make type
101+
```
102+
103+
Build the documentation:
104+
105+
```
106+
make doc
107+
```
108+
109+
Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that):
110+
111+
```
112+
make spelling
113+
```
114+
115+
97116
## Changelog and Documentation
98117

99-
Please do not forget to update the changelog and add documentation if needed.
118+
Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed.
100119
A README is present in the `docs/` folder for instructions on how to build the documentation.
101120

102121

‎Makefile

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Run pytest and coverage report
2+
pytest:
3+
./scripts/run_tests.sh
4+
5+
# Type check
6+
type:
7+
pytype
8+
9+
# Build the doc
10+
doc:
11+
cd docs && make html
12+
13+
# Check the spelling in the doc
14+
spelling:
15+
cd docs && make spelling
16+
17+
# Clean the doc build folder
18+
clean:
19+
cd docs && make clean

‎README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ Some of the baselines examples use [MuJoCo](http://www.mujoco.org) (multi-joint
190190
All unit tests in baselines can be run using pytest runner:
191191
```
192192
pip install pytest pytest-cov
193-
pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=.
193+
make pytest
194194
```
195195

196196
## Projects Using Stable-Baselines

‎docs/misc/changelog.rst

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ Breaking Changes:
1515
New Features:
1616
^^^^^^^^^^^^^
1717
- Parallelized updating and sampling from the replay buffer in DQN. (@flodorner)
18-
1918
- Docker build script, `scripts/build_docker.sh`, can push images automatically.
2019

2120
Bug Fixes:
@@ -30,9 +29,10 @@ Bug Fixes:
3029
- Fixed a bug in PPO2, ACER, A2C, and ACKTR where repeated calls to `learn(total_timesteps)` reset
3130
the environment on every call, potentially biasing samples toward early episode timesteps.
3231
(@shwang)
33-
34-
- Fixed by adding lazy property `ActorCriticRLModel.runner`. Subclasses now use lazily-generated
32+
- Fixed by adding lazy property `ActorCriticRLModel.runner`. Subclasses now use lazily-generated
3533
`self.runner` instead of reinitializing a new Runner every time `learn()` is called.
34+
- Fixed a bug in `check_env` where it would fail on high dimensional action spaces
35+
- Fixed `Monitor.close()` that was not calling the parent method
3636

3737
Deprecations:
3838
^^^^^^^^^^^^^
@@ -41,6 +41,7 @@ Others:
4141
^^^^^^^
4242
- Removed redundant return value from `a2c.utils::total_episode_reward_logger`. (@shwang)
4343
- Cleanup and refactoring in `common/identity_env.py` (@shwang)
44+
- Added a Makefile to simplify common development tasks (build the doc, type check, run the tests)
4445

4546
Documentation:
4647
^^^^^^^^^^^^^^

‎stable_baselines/bench/monitor.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,33 @@
55
import os
66
import time
77
from glob import glob
8+
from typing import Tuple, Dict, Any, List, Optional
89

10+
import gym
911
import pandas
10-
from gym.core import Wrapper
12+
import numpy as np
1113

1214

13-
class Monitor(Wrapper):
15+
class Monitor(gym.Wrapper):
1416
EXT = "monitor.csv"
1517
file_handler = None
1618

17-
def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), info_keywords=()):
19+
def __init__(self,
20+
env: gym.Env,
21+
filename: Optional[str],
22+
allow_early_resets: bool = True,
23+
reset_keywords=(),
24+
info_keywords=()):
1825
"""
1926
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
2027
21-
:param env: (Gym environment) The environment
22-
:param filename: (str) the location to save a log file, can be None for no log
28+
:param env: (gym.Env) The environment
29+
:param filename: (Optional[str]) the location to save a log file, can be None for no log
2330
:param allow_early_resets: (bool) allows the reset of the environment before it is done
2431
:param reset_keywords: (tuple) extra keywords for the reset call, if extra parameters are needed at reset
2532
:param info_keywords: (tuple) extra information to log, from the information return of environment.step
2633
"""
27-
Wrapper.__init__(self, env=env)
34+
super(Monitor, self).__init__(env=env)
2835
self.t_start = time.time()
2936
if filename is None:
3037
self.file_handler = None
@@ -53,12 +60,12 @@ def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), in
5360
self.total_steps = 0
5461
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
5562

56-
def reset(self, **kwargs):
63+
def reset(self, **kwargs) -> np.ndarray:
5764
"""
5865
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
5966
6067
:param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
61-
:return: ([int] or [float]) the first observation of the environment
68+
:return: (np.ndarray) the first observation of the environment
6269
"""
6370
if not self.allow_early_resets and not self.needs_reset:
6471
raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, "
@@ -68,16 +75,16 @@ def reset(self, **kwargs):
6875
for key in self.reset_keywords:
6976
value = kwargs.get(key)
7077
if value is None:
71-
raise ValueError('Expected you to pass kwarg %s into reset' % key)
78+
raise ValueError('Expected you to pass kwarg {} into reset'.format(key))
7279
self.current_reset_info[key] = value
7380
return self.env.reset(**kwargs)
7481

75-
def step(self, action):
82+
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]:
7683
"""
7784
Step the environment with the given action
7885
79-
:param action: ([int] or [float]) the action
80-
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
86+
:param action: (np.ndarray) the action
87+
:return: (Tuple[np.ndarray, float, bool, Dict[Any, Any]]) observation, reward, done, information
8188
"""
8289
if self.needs_reset:
8390
raise RuntimeError("Tried to step environment that needs reset")
@@ -105,34 +112,35 @@ def close(self):
105112
"""
106113
Closes the environment
107114
"""
115+
super(Monitor, self).close()
108116
if self.file_handler is not None:
109117
self.file_handler.close()
110118

111-
def get_total_steps(self):
119+
def get_total_steps(self) -> int:
112120
"""
113121
Returns the total number of timesteps
114122
115123
:return: (int)
116124
"""
117125
return self.total_steps
118126

119-
def get_episode_rewards(self):
127+
def get_episode_rewards(self) -> List[float]:
120128
"""
121129
Returns the rewards of all the episodes
122130
123131
:return: ([float])
124132
"""
125133
return self.episode_rewards
126134

127-
def get_episode_lengths(self):
135+
def get_episode_lengths(self) -> List[int]:
128136
"""
129137
Returns the number of timesteps of all the episodes
130138
131139
:return: ([int])
132140
"""
133141
return self.episode_lengths
134142

135-
def get_episode_times(self):
143+
def get_episode_times(self) -> List[float]:
136144
"""
137145
Returns the runtime in seconds of all the episodes
138146
@@ -148,7 +156,7 @@ class LoadMonitorResultsError(Exception):
148156
pass
149157

150158

151-
def get_monitor_files(path):
159+
def get_monitor_files(path: str) -> List[str]:
152160
"""
153161
get all the monitor files in the given path
154162
@@ -158,12 +166,12 @@ def get_monitor_files(path):
158166
return glob(os.path.join(path, "*" + Monitor.EXT))
159167

160168

161-
def load_results(path):
169+
def load_results(path: str) -> pandas.DataFrame:
162170
"""
163171
Load all Monitor logs from a given directory path matching ``*monitor.csv`` and ``*monitor.json``
164172
165173
:param path: (str) the directory path containing the log file(s)
166-
:return: (Pandas DataFrame) the logged data
174+
:return: (pandas.DataFrame) the logged data
167175
"""
168176
# get both csv and (old) json files
169177
monitor_files = (glob(os.path.join(path, "*monitor.json")) + get_monitor_files(path))

‎stable_baselines/common/env_checker.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _check_spaces(env: gym.Env) -> None:
135135
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces
136136

137137

138-
def _check_render(env: gym.Env, warn=True, headless=False) -> None:
138+
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None:
139139
"""
140140
Check the declared render modes and the `render()`/`close()`
141141
method of the environment.
@@ -163,7 +163,7 @@ def _check_render(env: gym.Env, warn=True, headless=False) -> None:
163163
env.close()
164164

165165

166-
def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None:
166+
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
167167
"""
168168
Check that an environment follows Gym API.
169169
This is particularly useful when using a custom environment.
@@ -205,8 +205,8 @@ def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None:
205205

206206
# Check for the action space, it may lead to hard-to-debug issues
207207
if (isinstance(action_space, spaces.Box) and
208-
(np.abs(action_space.low) != np.abs(action_space.high)
209-
or np.abs(action_space.low) > 1 or np.abs(action_space.high) > 1)):
208+
(np.any(np.abs(action_space.low) != np.abs(action_space.high))
209+
or np.any(np.abs(action_space.low) > 1) or np.any(np.abs(action_space.high) > 1))):
210210
warnings.warn("We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
211211
"cf https://stable-baselines.readthedocs.io/en/master/guide/rl_tips.html")
212212

‎tests/test_envs.py

+15
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ def test_custom_envs(env_class):
3838
check_env(env)
3939

4040

41+
def test_high_dimension_action_space():
42+
"""
43+
Test for continuous action space
44+
with more than one action.
45+
"""
46+
env = gym.make('Pendulum-v0')
47+
# Patch the action space
48+
env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32)
49+
# Patch to avoid error
50+
def patched_step(_action):
51+
return env.observation_space.sample(), 0.0, False, {}
52+
env.step = patched_step
53+
check_env(env)
54+
55+
4156
@pytest.mark.parametrize("new_obs_space", [
4257
# Small image
4358
spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),

0 commit comments

Comments
 (0)
Please sign in to comment.