Skip to content

Commit e315ebe

Browse files
authoredNov 24, 2019
Refactor Tests + Add Helpers (hill-a#508)
* Add helpers * Refactor some tests * Continue refactoring * Fix for codacy * Fixes for travis * Clean up imports * Fix syntax error * Fix VecEnv constructor * Fix perf check in tests * Seed identity env + minor updates * Allow more diff after training again * Try to fix travis non-determinism * Add tests for the new helpers * Codacy fixes * Fix callback logic * Address comments * Address review comments * Make codacy happy * Fix docstring indentation * Update README example * Remove use_subprocess and update doc
1 parent 42c2290 commit e315ebe

31 files changed

+324
-224
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
__pycache__/
1313
_build/
1414
*.npz
15+
*.zip
1516

1617
# Setuptools distribution and build folders.
1718
/dist/

‎README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ from stable_baselines.common.vec_env import DummyVecEnv
113113
from stable_baselines import PPO2
114114

115115
env = gym.make('CartPole-v1')
116-
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run
116+
# Optional: PPO2 requires a vectorized environment to run
117+
# the env is now wrapped automatically when passing it to the constructor
118+
# env = DummyVecEnv([lambda: env])
117119

118120
model = PPO2(MlpPolicy, env, verbose=1)
119121
model.learn(total_timesteps=10000)

‎docs/common/evaluation.rst

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
.. _eval:
2+
3+
Evaluation Helper
4+
=================
5+
6+
.. automodule:: stable_baselines.common.evaluation
7+
:members:

‎docs/guide/examples.rst

+17-7
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ In the following example, we will train, save and load a DQN model on the Lunar
5656
import gym
5757
5858
from stable_baselines import DQN
59+
from stable_baselines.common.evaluation import evaluate_policy
60+
5961
6062
# Create environment
6163
env = gym.make('LunarLander-v2')
@@ -71,6 +73,9 @@ In the following example, we will train, save and load a DQN model on the Lunar
7173
# Load the trained agent
7274
model = DQN.load("dqn_lunar")
7375
76+
# Evaluate the agent
77+
mean_reward, n_steps = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
78+
7479
# Enjoy trained agent
7580
obs = env.reset()
7681
for i in range(1000):
@@ -98,7 +103,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
98103
99104
from stable_baselines.common.policies import MlpPolicy
100105
from stable_baselines.common.vec_env import SubprocVecEnv
101-
from stable_baselines.common import set_global_seeds
106+
from stable_baselines.common import set_global_seeds, make_vec_env
102107
from stable_baselines import ACKTR
103108
104109
def make_env(env_id, rank, seed=0):
@@ -123,6 +128,10 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
123128
# Create the vectorized environment
124129
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
125130
131+
# Stable Baselines provides you with make_vec_env() helper
132+
# which does exactly the previous steps for you:
133+
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0)
134+
126135
model = ACKTR(MlpPolicy, env, verbose=1)
127136
model.learn(total_timesteps=25000)
128137
@@ -340,8 +349,6 @@ A2C policy gradient updates on the model.
340349
import gym
341350
import numpy as np
342351
343-
from stable_baselines.common.policies import MlpPolicy
344-
from stable_baselines.common.vec_env import DummyVecEnv
345352
from stable_baselines import A2C
346353
347354
def mutate(params):
@@ -365,9 +372,8 @@ A2C policy gradient updates on the model.
365372
366373
# Create env
367374
env = gym.make('CartPole-v1')
368-
env = DummyVecEnv([lambda: env])
369375
# Create policy with a small network
370-
model = A2C(MlpPolicy, env, ent_coef=0.0, learning_rate=0.1,
376+
model = A2C('MlpPolicy', env, ent_coef=0.0, learning_rate=0.1,
371377
policy_kwargs={'net_arch': [8, ]})
372378
373379
# Use traditional actor-critic policy gradient updates to
@@ -546,6 +552,9 @@ You can also move from learning on one environment to another for `continual lea
546552
obs, rewards, dones, info = env.step(action)
547553
env.render()
548554
555+
# Close the processes
556+
env.close()
557+
549558
# The number of environments must be identical when changing environments
550559
env = make_atari_env('SpaceInvadersNoFrameskip-v4', num_env=8, seed=0)
551560
@@ -558,6 +567,7 @@ You can also move from learning on one environment to another for `continual lea
558567
action, _states = model.predict(obs)
559568
obs, rewards, dones, info = env.step(action)
560569
env.render()
570+
env.close()
561571
562572
563573
Record a Video
@@ -591,6 +601,7 @@ Record a mp4 video (here using a random agent).
591601
for _ in range(video_length + 1):
592602
action = [env.action_space.sample()]
593603
obs, _, _, _ = env.step(action)
604+
# Save the video
594605
env.close()
595606
596607
@@ -606,10 +617,9 @@ Bonus: Make a GIF of a Trained Agent
606617
import imageio
607618
import numpy as np
608619
609-
from stable_baselines.common.policies import MlpPolicy
610620
from stable_baselines import A2C
611621
612-
model = A2C(MlpPolicy, "LunarLander-v2").learn(100000)
622+
model = A2C("MlpPolicy", "LunarLander-v2").learn(100000)
613623
614624
images = []
615625
obs = model.env.reset()

‎docs/guide/quickstart.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ Here is a quick example of how to train and run PPO2 on a cartpole environment:
1717
from stable_baselines import PPO2
1818
1919
env = gym.make('CartPole-v1')
20-
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run
20+
# Optional: PPO2 requires a vectorized environment to run
21+
# the env is now wrapped automatically when passing it to the constructor
22+
# env = DummyVecEnv([lambda: env])
2123
2224
model = PPO2(MlpPolicy, env, verbose=1)
2325
model.learn(total_timesteps=10000)

‎docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
8080
common/tf_utils
8181
common/cmd_utils
8282
common/schedules
83+
common/evaluation
8384

8485
.. toctree::
8586
:maxdepth: 1

‎docs/misc/changelog.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@ Breaking Changes:
1313
^^^^^^^^^^^^^^^^^
1414
- The `seed` argument has been moved from `learn()` method to model constructor
1515
in order to have reproducible results
16+
- `allow_early_resets` of the `Monitor` wrapper now default to `True`
17+
- `make_atari_env` now returns a `DummyVecEnv` by default (instead of a `SubprocVecEnv`)
18+
this usually improves performance.
1619

1720
New Features:
1821
^^^^^^^^^^^^^
1922
- Add `n_cpu_tf_sess` to model constructor to choose the number of threads used by Tensorflow
23+
- Environments are automatically wrapped in a `DummyVecEnv` if needed when passing them to the model constructor
24+
- Added `stable_baselines.common.make_vec_env` helper to simplify VecEnv creation
25+
- Added `stable_baselines.common.evaluation.evaluate_policy` helper to simplify model evaluation
2026
- `VecNormalize` now supports being pickled and unpickled.
2127
- Add parameter `exploration_initial_eps` to DQN. (@jdossgollin)
2228
- Add type checking and PEP 561 compliance.
@@ -38,6 +44,7 @@ Deprecations:
3844
Others:
3945
^^^^^^^
4046
- Add upper bound for Tensorflow version (<2.0.0).
47+
- Refactored test to remove duplicated code
4148
- Add pull request template
4249

4350
Documentation:
@@ -46,8 +53,11 @@ Documentation:
4653
- Add Snake Game AI project (@pedrohbtp)
4754
- Add note on the support Tensorflow versions.
4855
- Remove unnecessary steps required for Windows installation.
56+
- Remove `DummyVecEnv` creation when not needed
57+
- Added `make_vec_env` to the examples to simplify VecEnv creation
4958
- Add QuaRL project (@srivatsankrishnan)
5059
- Add Pwnagotchi project (@evilsocket)
60+
- Fix multiprocessing example (@rusu24edward)
5161
- Fix `result_plotter` example
5262
- Fix typo in algos.rst, "containes" to "contains" (@SyllogismRXS)
5363

@@ -530,4 +540,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
530540
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
531541
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
532542
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
533-
@MarvineGothic @jdossgollin @SyllogismRXS
543+
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward

‎docs/modules/a2c.rst

+3-4
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@ Train a A2C agent on `CartPole-v1` using 4 processes.
4949
import gym
5050
5151
from stable_baselines.common.policies import MlpPolicy
52-
from stable_baselines.common.vec_env import SubprocVecEnv
52+
from stable_baselines.common import make_vec_env
5353
from stable_baselines import A2C
5454
55-
# multiprocess environment
56-
n_cpu = 4
57-
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
55+
# Parallel environments
56+
env = make_vec_env('CartPole-v1', n_envs=4)
5857
5958
model = A2C(MlpPolicy, env, verbose=1)
6059
model.learn(total_timesteps=25000)

‎docs/modules/acer.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ Example
4343
import gym
4444
4545
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
46-
from stable_baselines.common.vec_env import SubprocVecEnv
46+
from stable_baselines.common import make_vec_env
4747
from stable_baselines import ACER
4848
4949
# multiprocess environment
50-
n_cpu = 4
51-
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
50+
env = make_vec_env('CartPole-v1', n_envs=4)
5251
5352
model = ACER(MlpPolicy, env, verbose=1)
5453
model.learn(total_timesteps=25000)

‎docs/modules/acktr.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,11 @@ Example
4444
import gym
4545
4646
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
47-
from stable_baselines.common.vec_env import SubprocVecEnv
47+
from stable_baselines.common import make_vec_env
4848
from stable_baselines import ACKTR
4949
5050
# multiprocess environment
51-
n_cpu = 4
52-
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
51+
env = make_vec_env('CartPole-v1', n_envs=4)
5352
5453
model = ACKTR(MlpPolicy, env, verbose=1)
5554
model.learn(total_timesteps=25000)

‎docs/modules/ddpg.rst

+2-8
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,10 @@ Example
6363
import numpy as np
6464
6565
from stable_baselines.ddpg.policies import MlpPolicy
66-
from stable_baselines.common.vec_env import DummyVecEnv
67-
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise, AdaptiveParamNoiseSpec
66+
from stable_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise, AdaptiveParamNoiseSpec
6867
from stable_baselines import DDPG
6968
7069
env = gym.make('MountainCarContinuous-v0')
71-
env = DummyVecEnv([lambda: env])
7270
7371
# the noise objects for DDPG
7472
n_actions = env.action_space.shape[-1]
@@ -148,7 +146,6 @@ You can easily define a custom architecture for the policy network:
148146
import gym
149147
150148
from stable_baselines.ddpg.policies import FeedForwardPolicy
151-
from stable_baselines.common.vec_env import DummyVecEnv
152149
from stable_baselines import DDPG
153150
154151
# Custom MLP policy of two layers of size 16 each
@@ -159,10 +156,7 @@ You can easily define a custom architecture for the policy network:
159156
layer_norm=False,
160157
feature_extraction="mlp")
161158
162-
# Create and wrap the environment
163-
env = gym.make('Pendulum-v0')
164-
env = DummyVecEnv([lambda: env])
165159
166-
model = DDPG(CustomDDPGPolicy, env, verbose=1)
160+
model = DDPG(CustomDDPGPolicy, 'Pendulum-v0', verbose=1)
167161
# Train the agent
168162
model.learn(total_timesteps=100000)

‎docs/modules/gail.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Example
111111
# Load the expert dataset
112112
dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1)
113113
114-
model = GAIL("MlpPolicy", 'Pendulum-v0', dataset, verbose=1)
114+
model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
115115
# Note: in practice, you need to train for 1M steps to have a working policy
116116
model.learn(total_timesteps=1000)
117117
model.save("gail_pendulum")

‎docs/modules/ppo1.rst

-2
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@ Example
5959
import gym
6060
6161
from stable_baselines.common.policies import MlpPolicy
62-
from stable_baselines.common.vec_env import DummyVecEnv
6362
from stable_baselines import PPO1
6463
6564
env = gym.make('CartPole-v1')
66-
env = DummyVecEnv([lambda: env])
6765
6866
model = PPO1(MlpPolicy, env, verbose=1)
6967
model.learn(total_timesteps=25000)

‎docs/modules/ppo2.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ Train a PPO agent on `CartPole-v1` using 4 processes.
6161
import gym
6262
6363
from stable_baselines.common.policies import MlpPolicy
64-
from stable_baselines.common.vec_env import SubprocVecEnv
64+
from stable_baselines.common import make_vec_env
6565
from stable_baselines import PPO2
6666
6767
# multiprocess environment
68-
n_cpu = 4
69-
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
68+
env = make_vec_env('CartPole-v1', n_envs=4)
7069
7170
model = PPO2(MlpPolicy, env, verbose=1)
7271
model.learn(total_timesteps=25000)

‎docs/modules/sac.rst

-2
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,9 @@ Example
7575
import numpy as np
7676
7777
from stable_baselines.sac.policies import MlpPolicy
78-
from stable_baselines.common.vec_env import DummyVecEnv
7978
from stable_baselines import SAC
8079
8180
env = gym.make('Pendulum-v0')
82-
env = DummyVecEnv([lambda: env])
8381
8482
model = SAC(MlpPolicy, env, verbose=1)
8583
model.learn(total_timesteps=50000, log_interval=10)

‎docs/modules/td3.rst

-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ Example
7373
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
7474
7575
env = gym.make('Pendulum-v0')
76-
env = DummyVecEnv([lambda: env])
7776
7877
# The noise objects for TD3
7978
n_actions = env.action_space.shape[-1]

‎docs/modules/trpo.rst

-2
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@ Example
4949
import gym
5050
5151
from stable_baselines.common.policies import MlpPolicy
52-
from stable_baselines.common.vec_env import DummyVecEnv
5352
from stable_baselines import TRPO
5453
5554
env = gym.make('CartPole-v1')
56-
env = DummyVecEnv([lambda: env])
5755
5856
model = TRPO(MlpPolicy, env, verbose=1)
5957
model.learn(total_timesteps=25000)

‎setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@
8080
from stable_baselines import PPO2
8181
8282
env = gym.make('CartPole-v1')
83-
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run
83+
# Optional: PPO2 requires a vectorized environment to run
84+
# the env is now wrapped automatically when passing it to the constructor
85+
# env = DummyVecEnv([lambda: env])
8486
8587
model = PPO2(MlpPolicy, env, verbose=1)
8688
model.learn(total_timesteps=10000)

‎stable_baselines/bench/monitor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Monitor(Wrapper):
1414
EXT = "monitor.csv"
1515
file_handler = None
1616

17-
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
17+
def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), info_keywords=()):
1818
"""
1919
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
2020

‎stable_baselines/common/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from stable_baselines.common.misc_util import zipsame, set_global_seeds, boolean_flag
77
from stable_baselines.common.base_class import BaseRLModel, ActorCriticRLModel, OffPolicyRLModel, SetVerbosity, \
88
TensorboardWriter
9+
from stable_baselines.common.cmd_util import make_vec_env

‎stable_baselines/common/base_class.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base,
7070
if isinstance(env, VecEnv):
7171
self.n_envs = env.num_envs
7272
else:
73-
raise ValueError("Error: the model requires a vectorized environment, please use a VecEnv wrapper.")
73+
# The model requires a VecEnv
74+
# wrap it in a DummyVecEnv to avoid error
75+
self.env = DummyVecEnv([lambda: env])
76+
if self.verbose >= 1:
77+
print("Wrapping the env in a DummyVecEnv.")
78+
self.n_envs = 1
7479
else:
7580
if isinstance(env, VecEnv):
7681
if env.num_envs == 1:

‎stable_baselines/common/cmd_util.py

+66-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import os
6+
import warnings
67

78
import gym
89

@@ -14,20 +15,80 @@
1415
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
1516

1617

18+
def make_vec_env(env_id, n_envs=1, seed=None, start_index=0,
19+
monitor_dir=None, wrapper_class=None,
20+
env_kwargs=None, vec_env_cls=None, vec_env_kwargs=None):
21+
"""
22+
Create a wrapped, monitored `VecEnv`.
23+
By default it uses a `DummyVecEnv` which is usually faster
24+
than a `SubprocVecEnv`.
25+
26+
:param env_id: (str or Type[gym.Env]) the environment ID or the environment class
27+
:param n_envs: (int) the number of environments you wish to have in parallel
28+
:param seed: (int) the inital seed for the random number generator
29+
:param start_index: (int) start rank index
30+
:param monitor_dir: (str) Path to a folder where the monitor files will be saved.
31+
If None, no file will be written, however, the env will still be wrapped
32+
in a Monitor wrapper to provide additional information about training.
33+
:param wrapper_class: (gym.Wrapper or callable) Additional wrapper to use on the environment.
34+
This can also be a function with single argument that wraps the environment in many things.
35+
:param env_kwargs: (dict) Optional keyword argument to pass to the env constructor
36+
:param vec_env_cls: (Type[VecEnv]) A custom `VecEnv` class constructor. Default: None.
37+
:param vec_env_kwargs: (dict) Keyword arguments to pass to the `VecEnv` class constructor.
38+
:return: (VecEnv) The wrapped environment
39+
"""
40+
env_kwargs = {} if env_kwargs is None else env_kwargs
41+
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
42+
43+
def make_env(rank):
44+
def _init():
45+
if isinstance(env_id, str):
46+
env = gym.make(env_id)
47+
if len(env_kwargs) > 0:
48+
warnings.warn("No environment class was passed (only an env ID) so `env_kwargs` will be ignored")
49+
else:
50+
env = env_id(**env_kwargs)
51+
if seed is not None:
52+
env.seed(seed + rank)
53+
env.action_space.seed(seed + rank)
54+
# Wrap the env in a Monitor wrapper
55+
# to have additional training information
56+
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
57+
# Create the monitor folder if needed
58+
if monitor_path is not None:
59+
os.makedirs(monitor_dir, exist_ok=True)
60+
env = Monitor(env, filename=monitor_path)
61+
# Optionally, wrap the environment with the provided wrapper
62+
if wrapper_class is not None:
63+
env = wrapper_class(env)
64+
return env
65+
return _init
66+
67+
# No custom VecEnv is passed
68+
if vec_env_cls is None:
69+
# Default: use a DummyVecEnv
70+
vec_env_cls = DummyVecEnv
71+
72+
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
73+
74+
1775
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None,
18-
start_index=0, allow_early_resets=True, start_method=None):
76+
start_index=0, allow_early_resets=True,
77+
start_method=None, use_subprocess=False):
1978
"""
20-
Create a wrapped, monitored SubprocVecEnv for Atari.
79+
Create a wrapped, monitored VecEnv for Atari.
2180
2281
:param env_id: (str) the environment ID
2382
:param num_env: (int) the number of environment you wish to have in subprocesses
2483
:param seed: (int) the inital seed for RNG
2584
:param wrapper_kwargs: (dict) the parameters for wrap_deepmind function
2685
:param start_index: (int) start rank index
2786
:param allow_early_resets: (bool) allows early reset of the environment
28-
:return: (Gym Environment) The atari environment
2987
:param start_method: (str) method used to start the subprocesses.
3088
See SubprocVecEnv doc for more information
89+
:param use_subprocess: (bool) Whether to use `SubprocVecEnv` or `DummyVecEnv` when
90+
`num_env` > 1, `DummyVecEnv` is usually faster. Default: False
91+
:return: (VecEnv) The atari environment
3192
"""
3293
if wrapper_kwargs is None:
3394
wrapper_kwargs = {}
@@ -43,8 +104,8 @@ def _thunk():
43104
set_global_seeds(seed)
44105

45106
# When using one environment, no need to start subprocesses
46-
if num_env == 1:
47-
return DummyVecEnv([make_env(0)])
107+
if num_env == 1 or not use_subprocess:
108+
return DummyVecEnv([make_env(i + start_index) for i in range(num_env)])
48109

49110
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)],
50111
start_method=start_method)

‎stable_baselines/common/evaluation.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
3+
from stable_baselines.common.vec_env import VecEnv
4+
5+
6+
def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
7+
render=False, callback=None, reward_threshold=None,
8+
return_episode_rewards=False):
9+
"""
10+
Runs policy for `n_eval_episodes` episodes and returns average reward.
11+
This is made to work only with one env.
12+
13+
:param model: (BaseRLModel) The RL agent you want to evaluate.
14+
:param env: (gym.Env or VecEnv) The gym environment. In the case of a `VecEnv`
15+
this must contain only one environment.
16+
:param n_eval_episodes: (int) Number of episode to evaluate the agent
17+
:param deterministic: (bool) Whether to use deterministic or stochastic actions
18+
:param render: (bool) Whether to render the environement or not
19+
:param callback: (callable) callback function to do additional checks,
20+
called after each step.
21+
:param reward_threshold: (float) Minimum expected reward per episode,
22+
this will raise an error if the performance is not met
23+
:param return_episode_rewards: (bool) If True, a list of reward per episode
24+
will be returned instead of the mean.
25+
:return: (float, int) Mean reward per episode, total number of steps
26+
returns ([float], int) when `return_episode_rewards` is True
27+
"""
28+
if isinstance(env, VecEnv):
29+
assert env.num_envs == 1, "You must pass only one environment when using this function"
30+
31+
episode_rewards, n_steps = [], 0
32+
for _ in range(n_eval_episodes):
33+
obs = env.reset()
34+
done, state = False, None
35+
episode_reward = 0.0
36+
while not done:
37+
action, state = model.predict(obs, state=state, deterministic=deterministic)
38+
obs, reward, done, _info = env.step(action)
39+
episode_reward += reward
40+
if callback is not None:
41+
callback(locals(), globals())
42+
n_steps += 1
43+
if render:
44+
env.render()
45+
episode_rewards.append(episode_reward)
46+
mean_reward = np.mean(episode_rewards)
47+
if reward_threshold is not None:
48+
assert mean_reward > reward_threshold, 'Mean reward below threshold: '\
49+
'{:.2f} < {:.2f}'.format(mean_reward, reward_threshold)
50+
if return_episode_rewards:
51+
return episode_rewards, n_steps
52+
return mean_reward, n_steps

‎tests/test_action_space.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from stable_baselines import A2C, PPO1, PPO2, TRPO
55
from stable_baselines.common.identity_env import IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
66
from stable_baselines.common.vec_env import DummyVecEnv
7+
from stable_baselines.common.evaluation import evaluate_policy
78

89
MODEL_LIST = [
910
A2C,
@@ -27,14 +28,8 @@ def test_identity_multidiscrete(model_class):
2728

2829
model = model_class("MlpPolicy", env)
2930
model.learn(total_timesteps=1000)
30-
31-
n_trials = 1000
32-
reward_sum = 0
31+
evaluate_policy(model, env, n_eval_episodes=5)
3332
obs = env.reset()
34-
for _ in range(n_trials):
35-
action, _ = model.predict(obs)
36-
obs, reward, _, _ = env.step(action)
37-
reward_sum += reward
3833

3934
assert np.array(model.action_probability(obs)).shape == (2, 1, 10), \
4035
"Error: action_probability not returning correct shape"
@@ -56,14 +51,8 @@ def test_identity_multibinary(model_class):
5651

5752
model = model_class("MlpPolicy", env)
5853
model.learn(total_timesteps=1000)
59-
60-
n_trials = 1000
61-
reward_sum = 0
54+
evaluate_policy(model, env, n_eval_episodes=5)
6255
obs = env.reset()
63-
for _ in range(n_trials):
64-
action, _ = model.predict(obs)
65-
obs, reward, _, _ = env.step(action)
66-
reward_sum += reward
6756

6857
assert model.action_probability(obs).shape == (1, 10), \
6958
"Error: action_probability not returning correct shape"

‎tests/test_auto_vec_detection.py

+19-62
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
import pytest
22
import numpy as np
33

4-
from stable_baselines import A2C, ACER, ACKTR, DDPG, DQN, PPO1, PPO2, SAC, TRPO
4+
from stable_baselines import A2C, ACER, ACKTR, DDPG, DQN, PPO1, PPO2, SAC, TRPO, TD3
55
from stable_baselines.common.vec_env import DummyVecEnv
66
from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, \
77
IdentityEnvMultiDiscrete
8+
from stable_baselines.common.evaluation import evaluate_policy
9+
10+
11+
def check_shape(make_env, model_class, shape_1, shape_2):
12+
model = model_class(policy="MlpPolicy", env=DummyVecEnv([make_env]))
13+
14+
env0 = make_env()
15+
env1 = DummyVecEnv([make_env])
16+
17+
for env, expected_shape in [(env0, shape_1), (env1, shape_2)]:
18+
def callback(locals_, _globals):
19+
assert np.array(locals_['action']).shape == expected_shape
20+
evaluate_policy(model, env, n_eval_episodes=5, callback=callback)
821

922

1023
@pytest.mark.slow
@@ -15,46 +28,18 @@ def test_identity(model_class):
1528
1629
:param model_class: (BaseRLModel) the RL model
1730
"""
18-
model = model_class(policy="MlpPolicy", env=DummyVecEnv([lambda: IdentityEnv(dim=10)]))
19-
20-
env0 = IdentityEnv(dim=10)
21-
env1 = DummyVecEnv([lambda: IdentityEnv(dim=10)])
22-
23-
n_trials = 100
24-
for env, expected_shape in [(env0, ()), (env1, (1,))]:
25-
obs = env.reset()
26-
for _ in range(n_trials):
27-
action, _ = model.predict(obs)
28-
assert np.array(action).shape == expected_shape
29-
obs, _, _, _ = env.step(action)
30-
31-
# Free memory
32-
del model, env0, env1
31+
check_shape(lambda: IdentityEnv(dim=10), model_class, (), (1,))
3332

3433

3534
@pytest.mark.slow
36-
@pytest.mark.parametrize("model_class", [A2C, DDPG, PPO1, PPO2, SAC, TRPO])
35+
@pytest.mark.parametrize("model_class", [A2C, DDPG, PPO1, PPO2, SAC, TRPO, TD3])
3736
def test_identity_box(model_class):
3837
"""
3938
test the Box environment vectorisation detection
4039
4140
:param model_class: (BaseRLModel) the RL model
4241
"""
43-
model = model_class(policy="MlpPolicy", env=DummyVecEnv([lambda: IdentityEnvBox(eps=0.5)]))
44-
45-
env0 = IdentityEnvBox()
46-
env1 = DummyVecEnv([lambda: IdentityEnvBox(eps=0.5)])
47-
48-
n_trials = 100
49-
for env, expected_shape in [(env0, (1,)), (env1, (1, 1))]:
50-
obs = env.reset()
51-
for _ in range(n_trials):
52-
action, _ = model.predict(obs)
53-
assert np.array(action).shape == expected_shape
54-
obs, _, _, _ = env.step(action)
55-
56-
# Free memory
57-
del model, env0, env1
42+
check_shape(lambda: IdentityEnvBox(eps=0.5), model_class, (1,), (1, 1))
5843

5944

6045
@pytest.mark.slow
@@ -65,21 +50,7 @@ def test_identity_multi_binary(model_class):
6550
6651
:param model_class: (BaseRLModel) the RL model
6752
"""
68-
model = model_class(policy="MlpPolicy", env=DummyVecEnv([lambda: IdentityEnvMultiBinary(dim=10)]))
69-
70-
env0 = IdentityEnvMultiBinary(dim=10)
71-
env1 = DummyVecEnv([lambda: IdentityEnvMultiBinary(dim=10)])
72-
73-
n_trials = 100
74-
for env, expected_shape in [(env0, (10,)), (env1, (1, 10))]:
75-
obs = env.reset()
76-
for _ in range(n_trials):
77-
action, _ = model.predict(obs)
78-
assert np.array(action).shape == expected_shape
79-
obs, _, _, _ = env.step(action)
80-
81-
# Free memory
82-
del model, env0, env1
53+
check_shape(lambda: IdentityEnvMultiBinary(dim=10), model_class, (10,), (1, 10))
8354

8455

8556
@pytest.mark.slow
@@ -90,18 +61,4 @@ def test_identity_multi_discrete(model_class):
9061
9162
:param model_class: (BaseRLModel) the RL model
9263
"""
93-
model = model_class(policy="MlpPolicy", env=DummyVecEnv([lambda: IdentityEnvMultiDiscrete(dim=10)]))
94-
95-
env0 = IdentityEnvMultiDiscrete(dim=10)
96-
env1 = DummyVecEnv([lambda: IdentityEnvMultiDiscrete(dim=10)])
97-
98-
n_trials = 100
99-
for env, expected_shape in [(env0, (2,)), (env1, (1, 2))]:
100-
obs = env.reset()
101-
for _ in range(n_trials):
102-
action, _ = model.predict(obs)
103-
assert np.array(action).shape == expected_shape
104-
obs, _, _, _ = env.step(action)
105-
106-
# Free memory
107-
del model, env0, env1
64+
check_shape(lambda: IdentityEnvMultiDiscrete(dim=10), model_class, (2,), (1, 2))

‎tests/test_continuous.py

+12-29
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from stable_baselines import A2C, ACKTR, SAC, DDPG, PPO1, PPO2, TRPO, TD3
99
# TODO: add support for continuous actions
1010
# from stable_baselines.acer import ACER
11-
from stable_baselines.common import set_global_seeds
1211
from stable_baselines.common.vec_env import DummyVecEnv
1312
from stable_baselines.common.identity_env import IdentityEnvBox
1413
from stable_baselines.ddpg import AdaptiveParamNoiseSpec, NormalActionNoise
14+
from stable_baselines.common.evaluation import evaluate_policy
1515
from tests.test_common import _assert_eq
1616

1717

18-
N_TRIALS = 1000
18+
N_EVAL_EPISODES = 20
1919
NUM_TIMESTEPS = 15000
2020

2121
MODEL_LIST = [
@@ -44,18 +44,10 @@ def test_model_manipulation(request, model_class):
4444
env = DummyVecEnv([lambda: IdentityEnvBox(eps=0.5)])
4545

4646
# create and train
47-
model = model_class(policy="MlpPolicy", env=env)
47+
model = model_class(policy="MlpPolicy", env=env, seed=0)
4848
model.learn(total_timesteps=NUM_TIMESTEPS)
4949

50-
# predict and measure the acc reward
51-
acc_reward = 0
52-
set_global_seeds(0)
53-
obs = env.reset()
54-
for _ in range(N_TRIALS):
55-
action, _ = model.predict(obs)
56-
obs, reward, _, _ = env.step(action)
57-
acc_reward += reward
58-
acc_reward = sum(acc_reward) / N_TRIALS
50+
acc_reward, _ = evaluate_policy(model, env, n_eval_episodes=N_EVAL_EPISODES)
5951

6052
# saving
6153
model_fname = './test_model_{}.zip'.format(request.node.name)
@@ -70,16 +62,9 @@ def test_model_manipulation(request, model_class):
7062
env = DummyVecEnv([lambda: IdentityEnvBox(eps=0.5)])
7163
model.set_env(env)
7264

73-
# predict the same output before saving
74-
loaded_acc_reward = 0
75-
set_global_seeds(0)
76-
obs = env.reset()
77-
for _ in range(N_TRIALS):
78-
action, _ = model.predict(obs)
79-
obs, reward, _, _ = env.step(action)
80-
loaded_acc_reward += reward
81-
loaded_acc_reward = sum(loaded_acc_reward) / N_TRIALS
65+
loaded_acc_reward, _ = evaluate_policy(model, env, n_eval_episodes=N_EVAL_EPISODES)
8266

67+
obs = env.reset()
8368
with pytest.warns(None) as record:
8469
act_prob = model.action_probability(obs)
8570

@@ -124,27 +109,25 @@ def test_model_manipulation(request, model_class):
124109
# loaded_acc_reward = 0
125110
# set_global_seeds(0)
126111
# obs = env.reset()
127-
# for _ in range(N_TRIALS):
112+
# for _ in range(N_EVAL_EPISODES):
128113
# action, _ = model.predict(obs)
129114
# obs, reward, _, _ = env.step(action)
130115
# loaded_acc_reward += reward
131-
# loaded_acc_reward = sum(loaded_acc_reward) / N_TRIALS
116+
# loaded_acc_reward = sum(loaded_acc_reward) / N_EVAL_EPISODES
132117
# # assert <10% diff
133118
# assert abs(acc_reward - loaded_acc_reward) / max(acc_reward, loaded_acc_reward) < 0.1, \
134119
# "Error: the prediction seems to have changed between pre learning and post learning"
135120

136121
# predict new values
137-
obs = env.reset()
138-
for _ in range(N_TRIALS):
139-
action, _ = model.predict(obs)
140-
obs, _, _, _ = env.step(action)
122+
123+
evaluate_policy(model, env, n_eval_episodes=N_EVAL_EPISODES)
141124

142125
# Free memory
143126
del model, env
144127

145128
finally:
146-
if os.path.exists("./test_model.zip"):
147-
os.remove("./test_model.zip")
129+
if os.path.exists(model_fname):
130+
os.remove(model_fname)
148131

149132

150133
def test_ddpg():

‎tests/test_gail.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
TD3, TRPO, SAC
99
from stable_baselines.common.cmd_util import make_atari_env
1010
from stable_baselines.common.vec_env import VecFrameStack
11+
from stable_baselines.common.evaluation import evaluate_policy
1112
from stable_baselines.gail import ExpertDataset, generate_expert_traj
1213

14+
1315
EXPERT_PATH_PENDULUM = "stable_baselines/gail/dataset/expert_pendulum.npz"
1416
EXPERT_PATH_DISCRETE = "stable_baselines/gail/dataset/expert_cartpole.npz"
1517

@@ -36,13 +38,7 @@ def test_gail(expert_env):
3638
model = model.load("GAIL-{}".format(env_id), env=env)
3739
model.learn(1000)
3840

39-
obs = env.reset()
40-
41-
for _ in range(1000):
42-
action, _ = model.predict(obs)
43-
obs, _, done, _ = env.step(action)
44-
if done:
45-
obs = env.reset()
41+
evaluate_policy(model, env, n_eval_episodes=5)
4642
del dataset, model
4743

4844
@pytest.mark.parametrize("generate_env", [

‎tests/test_identity.py

+3-19
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from stable_baselines.ddpg import NormalActionNoise
66
from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox
77
from stable_baselines.common.vec_env import DummyVecEnv
8-
from stable_baselines.common import set_global_seeds
8+
from stable_baselines.common.evaluation import evaluate_policy
99

1010

1111
# Hyperparameters for learning identity for each RL model
@@ -39,24 +39,16 @@ def test_identity(model_name):
3939
env = DummyVecEnv([lambda: IdentityEnv(10)])
4040

4141
model = LEARN_FUNC_DICT[model_name](env)
42+
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=0.9)
4243

43-
n_trials = 1000
44-
reward_sum = 0
45-
set_global_seeds(0)
4644
obs = env.reset()
47-
for _ in range(n_trials):
48-
action, _ = model.predict(obs)
49-
obs, reward, _, _ = env.step(action)
50-
reward_sum += reward
51-
5245
assert model.action_probability(obs).shape == (1, 10), "Error: action_probability not returning correct shape"
5346
action = env.action_space.sample()
5447
action_prob = model.action_probability(obs, actions=action)
5548
assert np.prod(action_prob.shape) == 1, "Error: not scalar probability"
5649
action_logprob = model.action_probability(obs, actions=action, logp=True)
5750
assert np.allclose(action_prob, np.exp(action_logprob)), (action_prob, action_logprob)
5851

59-
assert reward_sum > 0.9 * n_trials
6052
# Free memory
6153
del model, env
6254

@@ -80,14 +72,6 @@ def test_identity_continuous(model_class):
8072
action_noise=action_noise, buffer_size=int(1e6))
8173
model.learn(total_timesteps=20000)
8274

83-
n_trials = 1000
84-
reward_sum = 0
85-
set_global_seeds(0)
86-
obs = env.reset()
87-
for _ in range(n_trials):
88-
action, _ = model.predict(obs)
89-
obs, reward, _, _ = env.step(action)
90-
reward_sum += reward
91-
assert reward_sum > 0.9 * n_trials
75+
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=0.9)
9276
# Free memory
9377
del model, env

‎tests/test_lstm_policy.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from stable_baselines.common.vec_env import SubprocVecEnv
1212
from stable_baselines.common.vec_env.vec_normalize import VecNormalize
1313
from stable_baselines.ppo2.ppo2 import safe_mean
14+
from stable_baselines.common.evaluation import evaluate_policy
1415

1516

1617
class CustomLSTMPolicy1(LstmPolicy):
@@ -85,11 +86,7 @@ def test_lstm_policy(request, model_class, policy):
8586
model.learn(total_timesteps=100)
8687

8788
env = model.get_env()
88-
# predict and measure the acc reward
89-
obs = env.reset()
90-
for _ in range(N_TRIALS):
91-
action, _ = model.predict(obs)
92-
obs, _, _, _ = env.step(action)
89+
evaluate_policy(model, env, n_eval_episodes=10)
9390
# saving
9491
model.save(model_fname)
9592
del model, env

‎tests/test_save.py

+20-41
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import numpy as np
88

99
from stable_baselines import A2C, ACER, ACKTR, DQN, PPO1, PPO2, TRPO
10-
from stable_baselines.common import set_global_seeds
1110
from stable_baselines.common.identity_env import IdentityEnv
1211
from stable_baselines.common.vec_env import DummyVecEnv
12+
from stable_baselines.common.evaluation import evaluate_policy
1313
from stable_baselines.common.policies import MlpPolicy, FeedForwardPolicy
1414

15-
N_TRIALS = 2000
15+
N_EVAL_EPISODES = 100
1616

1717
MODEL_LIST = [
1818
A2C,
@@ -44,11 +44,11 @@ def test_model_manipulation(request, model_class, storage_method, store_format):
4444
works and that the action prediction works
4545
4646
:param model_class: (BaseRLModel) A RL model
47-
:param storage_method: (str) Should file be saved to a file ("path") or to a buffer
47+
:param storage_method: (str) Should file be saved to a file ("path") or to a buffer
4848
("file-like")
4949
:param store_format: (str) Save format, either "zip" or "cloudpickle".
5050
"""
51-
51+
5252
# Use postfix ".model" so we can remove the file later
5353
model_fname = './test_model_{}.model'.format(request.node.name)
5454
store_as_cloudpickle = store_format == "cloudpickle"
@@ -57,20 +57,12 @@ def test_model_manipulation(request, model_class, storage_method, store_format):
5757
env = DummyVecEnv([lambda: IdentityEnv(10)])
5858

5959
# create and train
60-
model = model_class(policy="MlpPolicy", env=env)
61-
model.learn(total_timesteps=50000)
60+
model = model_class(policy="MlpPolicy", env=env, seed=0)
61+
model.learn(total_timesteps=10000)
6262

63-
# predict and measure the acc reward
64-
acc_reward = 0
65-
set_global_seeds(0)
66-
obs = env.reset()
67-
for _ in range(N_TRIALS):
68-
action, _ = model.predict(obs)
69-
# Test action probability method
70-
model.action_probability(obs)
71-
obs, reward, _, _ = env.step(action)
72-
acc_reward += reward
73-
acc_reward = sum(acc_reward) / N_TRIALS
63+
env.envs[0].action_space.seed(0)
64+
mean_reward, _ = evaluate_policy(model, env, deterministic=True,
65+
n_eval_episodes=N_EVAL_EPISODES)
7466

7567
# test action probability for given (obs, action) pair
7668
env = model.get_env()
@@ -107,37 +99,24 @@ def test_model_manipulation(request, model_class, storage_method, store_format):
10799
model.set_env(env)
108100

109101
# predict the same output before saving
110-
loaded_acc_reward = 0
111-
set_global_seeds(0)
112-
obs = env.reset()
113-
for _ in range(N_TRIALS):
114-
action, _ = model.predict(obs)
115-
obs, reward, _, _ = env.step(action)
116-
loaded_acc_reward += reward
117-
loaded_acc_reward = sum(loaded_acc_reward) / N_TRIALS
118-
assert abs(acc_reward - loaded_acc_reward) < 0.1, "Error: the prediction seems to have changed between " \
119-
"loading and saving"
102+
env.envs[0].action_space.seed(0)
103+
loaded_mean_reward, _ = evaluate_policy(model, env, deterministic=True, n_eval_episodes=N_EVAL_EPISODES)
104+
# Allow 10% diff
105+
assert abs((mean_reward - loaded_mean_reward) / mean_reward) < 0.1, "Error: the prediction seems to have changed between " \
106+
"loading and saving"
120107

121108
# learn post loading
122109
model.learn(total_timesteps=100)
123110

124111
# validate no reset post learning
125-
loaded_acc_reward = 0
126-
set_global_seeds(0)
127-
obs = env.reset()
128-
for _ in range(N_TRIALS):
129-
action, _ = model.predict(obs)
130-
obs, reward, _, _ = env.step(action)
131-
loaded_acc_reward += reward
132-
loaded_acc_reward = sum(loaded_acc_reward) / N_TRIALS
133-
assert abs(acc_reward - loaded_acc_reward) < 0.1, "Error: the prediction seems to have changed between " \
134-
"pre learning and post learning"
112+
env.envs[0].action_space.seed(0)
113+
loaded_mean_reward, _ = evaluate_policy(model, env, deterministic=True, n_eval_episodes=N_EVAL_EPISODES)
114+
115+
assert abs((mean_reward - loaded_mean_reward) / mean_reward) < 0.15, "Error: the prediction seems to have changed between " \
116+
"pre learning and post learning"
135117

136118
# predict new values
137-
obs = env.reset()
138-
for _ in range(N_TRIALS):
139-
action, _ = model.predict(obs)
140-
obs, _, _, _ = env.step(action)
119+
evaluate_policy(model, env, n_eval_episodes=N_EVAL_EPISODES)
141120

142121
del model, env
143122

‎tests/test_utils.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
import shutil
3+
4+
import pytest
5+
import gym
6+
7+
from stable_baselines import A2C
8+
from stable_baselines.bench.monitor import Monitor
9+
from stable_baselines.common.evaluation import evaluate_policy
10+
from stable_baselines.common.cmd_util import make_vec_env
11+
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
12+
13+
14+
@pytest.mark.parametrize("env_id", ['CartPole-v1', lambda: gym.make('CartPole-v1')])
15+
@pytest.mark.parametrize("n_envs", [1, 2])
16+
@pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv])
17+
@pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.TimeLimit])
18+
def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class):
19+
env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls,
20+
wrapper_class=wrapper_class, monitor_dir=None, seed=0)
21+
22+
assert env.num_envs == n_envs
23+
24+
if vec_env_cls is None:
25+
assert isinstance(env, DummyVecEnv)
26+
if wrapper_class is not None:
27+
assert isinstance(env.envs[0], wrapper_class)
28+
else:
29+
assert isinstance(env.envs[0], Monitor)
30+
else:
31+
assert isinstance(env, SubprocVecEnv)
32+
# Kill subprocesses
33+
env.close()
34+
35+
36+
def test_custom_vec_env():
37+
"""
38+
Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.
39+
"""
40+
monitor_dir = 'logs/test_make_vec_env/'
41+
env = make_vec_env('CartPole-v1', n_envs=1,
42+
monitor_dir=monitor_dir, seed=0,
43+
vec_env_cls=SubprocVecEnv, vec_env_kwargs={'start_method': None})
44+
45+
46+
assert env.num_envs == 1
47+
assert isinstance(env, SubprocVecEnv)
48+
assert os.path.isdir('logs/test_make_vec_env/')
49+
# Kill subprocess
50+
env.close()
51+
# Cleanup folder
52+
shutil.rmtree(monitor_dir)
53+
54+
# This should fail because DummyVecEnv does not have any keyword argument
55+
with pytest.raises(TypeError):
56+
make_vec_env('CartPole-v1', n_envs=1, vec_env_kwargs={'dummy': False})
57+
58+
59+
def test_evaluate_policy():
60+
model = A2C('MlpPolicy', 'Pendulum-v0', seed=0)
61+
n_steps_per_episode, n_eval_episodes = 200, 2
62+
model.n_callback_calls = 0
63+
64+
def dummy_callback(locals_, _globals):
65+
locals_['model'].n_callback_calls += 1
66+
67+
_, n_steps = evaluate_policy(model, model.get_env(), n_eval_episodes, deterministic=True,
68+
render=False, callback=dummy_callback, reward_threshold=None,
69+
return_episode_rewards=False)
70+
assert n_steps == n_steps_per_episode * n_eval_episodes
71+
assert n_steps == model.n_callback_calls
72+
73+
# Reaching a mean reward of zero is impossible with the Pendulum env
74+
with pytest.raises(AssertionError):
75+
evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0)
76+
77+
episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
78+
assert len(episode_rewards) == n_eval_episodes

0 commit comments

Comments
 (0)
Please sign in to comment.