@@ -18,7 +18,7 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
18
18
- ob: (np.ndarray) observations
19
19
- rew: (numpy float) rewards (if gail is used it is the predicted reward)
20
20
- vpred: (numpy float) action logits
21
- - new : (numpy bool) dones (is end of episode)
21
+ - dones : (numpy bool) dones (is end of episode -> True if first timestep of an episode)
22
22
- ac: (np.ndarray) actions
23
23
- prevac: (np.ndarray) previous actions
24
24
- nextvpred: (numpy float) next action logits
@@ -32,7 +32,6 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
32
32
# Initialize state variables
33
33
step = 0
34
34
action = env .action_space .sample () # not used, just so we have the datatype
35
- new = True
36
35
observation = env .reset ()
37
36
38
37
cur_ep_ret = 0 # return in current episode
@@ -51,7 +50,7 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
51
50
actions = np .array ([action for _ in range (horizon )])
52
51
prev_actions = actions .copy ()
53
52
states = policy .initial_state
54
- done = None
53
+ done = True # marks if we're on first timestep of an episode
55
54
56
55
while True :
57
56
prevac = action
@@ -66,9 +65,20 @@ def traj_segment_generator(policy, env, horizon, reward_giver=None, gail=False):
66
65
else :
67
66
current_it_timesteps = sum (ep_lens ) + current_it_len
68
67
69
- yield {"ob" : observations , "rew" : rews , "dones" : dones , "true_rew" : true_rews , "vpred" : vpreds ,
70
- "ac" : actions , "prevac" : prev_actions , "nextvpred" : vpred * (1 - new ), "ep_rets" : ep_rets ,
71
- "ep_lens" : ep_lens , "ep_true_rets" : ep_true_rets , "total_timestep" : current_it_timesteps }
68
+ yield {
69
+ "ob" : observations ,
70
+ "rew" : rews ,
71
+ "dones" : dones ,
72
+ "true_rew" : true_rews ,
73
+ "vpred" : vpreds ,
74
+ "ac" : actions ,
75
+ "prevac" : prev_actions ,
76
+ "nextvpred" : vpred [0 ] * (1 - done ),
77
+ "ep_rets" : ep_rets ,
78
+ "ep_lens" : ep_lens ,
79
+ "ep_true_rets" : ep_true_rets ,
80
+ "total_timestep" : current_it_timesteps
81
+ }
72
82
_ , vpred , _ , _ = policy .step (observation .reshape (- 1 , * observation .shape ))
73
83
# Be careful!!! if you change the downstream algorithm to aggregate
74
84
# several of these batches, then be sure to do a deepcopy
0 commit comments