OnPolicySharedActorCritic
OnPolicySharedActorCritic(
observation_space: gym.Space, action_space: gym.Space, feature_dim: int,
hidden_dim: int = 512, opt_class: Type[th.optim.Optimizer] = th.optim.Adam,
opt_kwargs: Optional[Dict[str, Any]] = None, aux_critic: bool = False,
init_fn: str = 'orthogonal'
)
Actor-Critic network for on-policy algorithms like PPO
and A2C
.
Args
- observation_space (gym.Space) : Observation space.
- action_space (gym.Space) : Action space.
- feature_dim (int) : Number of features accepted.
- hidden_dim (int) : Number of units per hidden layer.
- opt_class (Type[th.optim.Optimizer]) : Optimizer class.
- opt_kwargs (Dict[str, Any]) : Optimizer keyword arguments.
- aux_critic (bool) : Use auxiliary critic or not, for
PPG
agent. - init_fn (str) : Parameters initialization method.
Returns
Actor-Critic network instance.
Methods:
.describe
Describe the policy.
.freeze
Freeze all the elements like encoder
and dist
.
Args
- encoder (nn.Module) : Encoder network.
- dist (Distribution) : Distribution class.
Returns
None.
.forward
Get actions and estimated values for observations.
Args
- obs (th.Tensor) : Observations.
- training (bool) : training mode,
True
orFalse
.
Returns
Sampled actions, estimated values, and log of probabilities for observations when training
is True
,
else only deterministic actions.
.get_value
Get estimated values for observations.
Args
- obs (th.Tensor) : Observations.
Returns
Estimated values.
.evaluate_actions
Evaluate actions according to the current policy given the observations.
Args
- obs (th.Tensor) : Sampled observations.
- actions (th.Tensor) : Sampled actions.
Returns
Estimated values, log of the probability evaluated at actions
, entropy of distribution.
.get_policy_outputs
Get policy outputs for training.
Args
- obs (Tensor) : Observations.
Returns
Policy outputs like unnormalized probabilities for Discrete
tasks.
.get_dist_and_aux_value
Get probs and auxiliary estimated values for auxiliary phase update.
Args
- obs : Sampled observations.
Returns
Sample distribution, estimated values, auxiliary estimated values.
.save
Save models.
Args
- path (Path) : Save path.
- pretraining (bool) : Pre-training mode.
- global_step (int) : Global training step.
Returns
None.