Skip to content

HerReplayStorage

source

HerReplayStorage(
   observation_space: gym.Space, action_space: gym.Space, device: str = 'cpu',
   storage_size: int = 1000000, num_envs: int = 1, batch_size: int = 1024,
   goal_selection_strategy: str = 'future', num_goals: int = 4,
   reward_fn: Callable = lambdax: x, copy_info_dict: bool = False
)


Hindsight experience replay (HER) storage for off-policy algorithms. Based on: https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/her/her_replay_buffer.py

Args

  • observation_space (gym.Space) : Observation space.
  • action_space (gym.Space) : Action space.
  • device (str) : Device to convert the data.
  • storage_size (int) : The capacity of the storage.
  • num_envs (int) : The number of parallel environments.
  • batch_size (int) : Batch size of samples.
  • goal_selection_strategy (str) : A goal selection strategy of ["future", "final", "episode"].
  • num_goals (int) : The number of goals to sample.
  • reward_fn (Callable) : Function to compute new rewards based on state and goal, whose definition is same as https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/envs/bit_flipping_env.py#L190 copy_info_dict (bool) whether to copy the info dictionary and pass it to compute_reward() method.

Returns

Dict replay storage.

Methods:

.add

source

.add(
   observations: Dict[str, th.Tensor], actions: th.Tensor, rewards: th.Tensor,
   terminateds: th.Tensor, truncateds: th.Tensor, infos: Dict[str, Any],
   next_observations: Dict[str, th.Tensor]
)


Add sampled transitions into storage.

Args

  • observations (Dict[str, th.Tensor]) : Observations.
  • actions (th.Tensor) : Actions.
  • rewards (th.Tensor) : Rewards.
  • terminateds (th.Tensor) : Termination flag.
  • truncateds (th.Tensor) : Truncation flag.
  • infos (Dict[str, Any]) : Additional information.
  • next_observations (Dict[str, th.Tensor]) : Next observations.

Returns

None.

.sample

source

.sample()


Sample from the storage.

.update

source

.update(
   *args, **kwargs
)


Update the storage if necessary.