1
1
import pytest
2
2
import numpy as np
3
3
4
- from stable_baselines import A2C , ACER , ACKTR , DDPG , DQN , PPO1 , PPO2 , SAC , TRPO
4
+ from stable_baselines import A2C , ACER , ACKTR , DDPG , DQN , PPO1 , PPO2 , SAC , TRPO , TD3
5
5
from stable_baselines .common .vec_env import DummyVecEnv
6
6
from stable_baselines .common .identity_env import IdentityEnv , IdentityEnvBox , IdentityEnvMultiBinary , \
7
7
IdentityEnvMultiDiscrete
8
+ from stable_baselines .common .evaluation import evaluate_policy
9
+
10
+
11
+ def check_shape (make_env , model_class , shape_1 , shape_2 ):
12
+ model = model_class (policy = "MlpPolicy" , env = DummyVecEnv ([make_env ]))
13
+
14
+ env0 = make_env ()
15
+ env1 = DummyVecEnv ([make_env ])
16
+
17
+ for env , expected_shape in [(env0 , shape_1 ), (env1 , shape_2 )]:
18
+ def callback (locals_ , _globals ):
19
+ assert np .array (locals_ ['action' ]).shape == expected_shape
20
+ evaluate_policy (model , env , n_eval_episodes = 5 , callback = callback )
8
21
9
22
10
23
@pytest .mark .slow
@@ -15,46 +28,18 @@ def test_identity(model_class):
15
28
16
29
:param model_class: (BaseRLModel) the RL model
17
30
"""
18
- model = model_class (policy = "MlpPolicy" , env = DummyVecEnv ([lambda : IdentityEnv (dim = 10 )]))
19
-
20
- env0 = IdentityEnv (dim = 10 )
21
- env1 = DummyVecEnv ([lambda : IdentityEnv (dim = 10 )])
22
-
23
- n_trials = 100
24
- for env , expected_shape in [(env0 , ()), (env1 , (1 ,))]:
25
- obs = env .reset ()
26
- for _ in range (n_trials ):
27
- action , _ = model .predict (obs )
28
- assert np .array (action ).shape == expected_shape
29
- obs , _ , _ , _ = env .step (action )
30
-
31
- # Free memory
32
- del model , env0 , env1
31
+ check_shape (lambda : IdentityEnv (dim = 10 ), model_class , (), (1 ,))
33
32
34
33
35
34
@pytest .mark .slow
36
- @pytest .mark .parametrize ("model_class" , [A2C , DDPG , PPO1 , PPO2 , SAC , TRPO ])
35
+ @pytest .mark .parametrize ("model_class" , [A2C , DDPG , PPO1 , PPO2 , SAC , TRPO , TD3 ])
37
36
def test_identity_box (model_class ):
38
37
"""
39
38
test the Box environment vectorisation detection
40
39
41
40
:param model_class: (BaseRLModel) the RL model
42
41
"""
43
- model = model_class (policy = "MlpPolicy" , env = DummyVecEnv ([lambda : IdentityEnvBox (eps = 0.5 )]))
44
-
45
- env0 = IdentityEnvBox ()
46
- env1 = DummyVecEnv ([lambda : IdentityEnvBox (eps = 0.5 )])
47
-
48
- n_trials = 100
49
- for env , expected_shape in [(env0 , (1 ,)), (env1 , (1 , 1 ))]:
50
- obs = env .reset ()
51
- for _ in range (n_trials ):
52
- action , _ = model .predict (obs )
53
- assert np .array (action ).shape == expected_shape
54
- obs , _ , _ , _ = env .step (action )
55
-
56
- # Free memory
57
- del model , env0 , env1
42
+ check_shape (lambda : IdentityEnvBox (eps = 0.5 ), model_class , (1 ,), (1 , 1 ))
58
43
59
44
60
45
@pytest .mark .slow
@@ -65,21 +50,7 @@ def test_identity_multi_binary(model_class):
65
50
66
51
:param model_class: (BaseRLModel) the RL model
67
52
"""
68
- model = model_class (policy = "MlpPolicy" , env = DummyVecEnv ([lambda : IdentityEnvMultiBinary (dim = 10 )]))
69
-
70
- env0 = IdentityEnvMultiBinary (dim = 10 )
71
- env1 = DummyVecEnv ([lambda : IdentityEnvMultiBinary (dim = 10 )])
72
-
73
- n_trials = 100
74
- for env , expected_shape in [(env0 , (10 ,)), (env1 , (1 , 10 ))]:
75
- obs = env .reset ()
76
- for _ in range (n_trials ):
77
- action , _ = model .predict (obs )
78
- assert np .array (action ).shape == expected_shape
79
- obs , _ , _ , _ = env .step (action )
80
-
81
- # Free memory
82
- del model , env0 , env1
53
+ check_shape (lambda : IdentityEnvMultiBinary (dim = 10 ), model_class , (10 ,), (1 , 10 ))
83
54
84
55
85
56
@pytest .mark .slow
@@ -90,18 +61,4 @@ def test_identity_multi_discrete(model_class):
90
61
91
62
:param model_class: (BaseRLModel) the RL model
92
63
"""
93
- model = model_class (policy = "MlpPolicy" , env = DummyVecEnv ([lambda : IdentityEnvMultiDiscrete (dim = 10 )]))
94
-
95
- env0 = IdentityEnvMultiDiscrete (dim = 10 )
96
- env1 = DummyVecEnv ([lambda : IdentityEnvMultiDiscrete (dim = 10 )])
97
-
98
- n_trials = 100
99
- for env , expected_shape in [(env0 , (2 ,)), (env1 , (1 , 2 ))]:
100
- obs = env .reset ()
101
- for _ in range (n_trials ):
102
- action , _ = model .predict (obs )
103
- assert np .array (action ).shape == expected_shape
104
- obs , _ , _ , _ = env .step (action )
105
-
106
- # Free memory
107
- del model , env0 , env1
64
+ check_shape (lambda : IdentityEnvMultiDiscrete (dim = 10 ), model_class , (2 ,), (1 , 2 ))
0 commit comments