GIRM
GIRM(
observation_space: gym.Space, action_space: gym.Space, device: str = 'cpu',
beta: float = 0.05, kappa: float = 2.5e-05, latent_dim: int = 128, lr: float = 0.001,
batch_size: int = 64, lambd: float = 0.5, lambd_recon: float = 1.0,
lambd_action: float = 1.0, kld_loss_beta: float = 1.0
)
Intrinsic Reward Driven Imitation Learning via Generative Model (GIRM). See paper: http://proceedings.mlr.press/v119/yu20d/yu20d.pdf
Args
- observation_space (Space) : The observation space of environment.
- action_space (Space) : The action space of environment.
- device (str) : Device (cpu, cuda, ...) on which the code should be run.
- beta (float) : The initial weighting coefficient of the intrinsic rewards.
- kappa (float) : The decay rate.
- latent_dim (int) : The dimension of encoding vectors.
- lr (float) : The learning rate.
- batch_size (int) : The batch size for update.
- lambd (float) : The weighting coefficient for combining actions.
- lambd_recon (float) : Weighting coefficient of the reconstruction loss.
- lambd_action (float) : Weighting coefficient of the action loss.
- kld_loss_beta (float) : Weighting coefficient of the divergence loss.
Returns
Instance of GIRM.
Methods:
.get_vae_loss
Compute the vae loss.
Args
- recon_x (th.Tensor) : Reconstructed x.
- x (th.Tensor) : Input x.
- mean (th.Tensor) : Sample mean.
- logvar (th.Tensor) : Log of the sample variance.
Returns
Loss values.
.compute_irs
Compute the intrinsic rewards for current samples.
Args
- samples (Dict) : The collected samples. A python dict like
{obs (n_steps, n_envs, obs_shape)
, actions (n_steps, n_envs, action_shape), rewards (n_steps, n_envs) , next_obs (n_steps, n_envs, *obs_shape) }. - step (int) : The global training step.
Returns
The intrinsic rewards.
.add
Add new samples to the intrinsic reward module.
.update
Update the intrinsic reward module if necessary.
Args
- samples : The collected samples. A python dict like
{obs (n_steps, n_envs, obs_shape)
, actions (n_steps, n_envs, action_shape), rewards (n_steps, n_envs) , next_obs (n_steps, n_envs, *obs_shape) }.
Returns
None