Skip to content

Commit

Permalink
Merge pull request #1 from PavelCz/daniel_autoreset_tweaks
Browse files Browse the repository at this point in the history
Discard terminal obs by default, set reset reward
  • Loading branch information
PavelCz authored Jan 18, 2023
2 parents 3cad441 + 474ff9b commit f6ffcf5
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/seals/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,30 @@ class AutoResetWrapper(gym.Wrapper):
"""Hides done=True and auto-resets at the end of each episode.
Depending on the flag 'discard_terminal_observation', either discards the terminal
observation or pads with an additional 'reset transition'. The latter is the default
observation or pads with an additional 'reset transition'. The former is the default
behavior.
In the latter case, the action taken during the 'reset transition' will not have an
effect, the reward will always be 0.0, and info an empty dictionary.
effect, the reward will be constant (set by the wrapper argument `reset_reward`,
which has default value 0.0), and info an empty dictionary.
"""

def __init__(self, env, discard_terminal_observation=False):
def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0):
"""Builds the wrapper.
Args:
env: The environment to wrap.
discard_terminal_observation: Defaults to False. If True, the terminal
discard_terminal_observation: Defaults to True. If True, the terminal
observation is discarded and the environment is reset immediately. The
returned observation will then be the start of the next episode. The
overridden observation is stored in `info["terminal_observation"]`.
If False, the terminal observation is returned and the environment is
reset in the next step.
reset_reward: The reward to return for the reset transition. Defaults to
0.0.
"""
super().__init__(env)
self.discard_terminal_observation = discard_terminal_observation
self.reset_reward = reset_reward
self.previous_done = False # Whether the previous step returned done=True.

def step(self, action):
Expand Down Expand Up @@ -64,7 +68,7 @@ def _step_pad(self, action):
if self.previous_done:
self.previous_done = False
# This transition will only reset the environment, the action is ignored.
return self.env.reset(), 0.0, False, {}
return self.env.reset(), self.reset_reward, False, {}

obs, rew, done, info = self.env.step(action)
if done:
Expand Down

0 comments on commit f6ffcf5

Please sign in to comment.