Pre-training with Intrinsic Rewards
Pre-training
Currently, RLLTE only supports online pre-training via intrinsic reward. To turn on the pre-training mode,
it suffices to write a train.py
like:
train.py
Run from rllte.agent import PPO
from rllte.env import make_atari_env
from rllte.xplore.reward import RE3
if __name__ == "__main__":
# env setup
device = "cuda:0"
env = make_atari_env(device=device)
# create agent and turn on pre-training mode
agent = PPO(env=env,
device=device,
tag="ppo_atari",
pretraining=True)
# create intrinsic reward
re3 = RE3(envs=env, device=device)
# set the reward module
agent.set(reward=re3)
# start training
agent.train(num_train_steps=5000)
train.py
and you'll see the pre-training mode is on:
[08/04/2023 05:05:54 PM] - [INFO.] - Invoking RLLTE Engine...
[08/04/2023 05:05:54 PM] - [INFO.] - ================================================================================
[08/04/2023 05:05:54 PM] - [INFO.] - Tag : ppo_atari
[08/04/2023 05:05:54 PM] - [INFO.] - Device : NVIDIA GeForce RTX 3090
[08/04/2023 05:05:54 PM] - [DEBUG] - Agent : PPO
[08/04/2023 05:05:54 PM] - [DEBUG] - Encoder : MnihCnnEncoder
[08/04/2023 05:05:54 PM] - [DEBUG] - Policy : OnPolicySharedActorCritic
[08/04/2023 05:05:54 PM] - [DEBUG] - Storage : VanillaRolloutStorage
[08/04/2023 05:05:54 PM] - [DEBUG] - Distribution : Categorical
[08/04/2023 05:05:54 PM] - [DEBUG] - Augmentation : False
[08/04/2023 05:05:54 PM] - [DEBUG] - Intrinsic Reward : True, RE3
[08/04/2023 05:05:54 PM] - [INFO.] - Pre-training Mode : On
[08/04/2023 05:05:54 PM] - [DEBUG] - ================================================================================
...
Tip
When the pre-training mode is on, a reward
module must be specified!
For all supported reward modules, see API Documentation.
Fine-tuning
Once the pre-training is finished, you can find the model parameters in the pretrained
subfolder of the working directory. To
load the parameters, just turn off the pre-training mode and load the parameters with .load()
function:
train.py
Run from rllte.agent import PPO
from rllte.env import make_atari_env
if __name__ == "__main__":
# env setup
device = "cuda:0"
env = make_atari_env(device=device)
eval_env = make_atari_env(device=device)
# create agent and turn off pre-training mode
agent = PPO(env=env,
eval_env=eval_env,
device=device,
tag="ppo_atari",
pretraining=False)
# start training
agent.train(num_train_steps=5000,
init_model_path="/export/yuanmingqi/code/rllte/logs/ppo_atari/2023-06-05-02-42-12/pretrained/pretrained.pth")
train.py
and you'll see the pre-trained model parameters are loaded:
[08/04/2023 05:07:52 PM] - [INFO.] - Invoking RLLTE Engine...
[08/04/2023 05:07:52 PM] - [INFO.] - ================================================================================
[08/04/2023 05:07:52 PM] - [INFO.] - Tag : ppo_atari
[08/04/2023 05:07:52 PM] - [INFO.] - Device : NVIDIA GeForce RTX 3090
[08/04/2023 05:07:53 PM] - [DEBUG] - Agent : PPO
[08/04/2023 05:07:53 PM] - [DEBUG] - Encoder : MnihCnnEncoder
[08/04/2023 05:07:53 PM] - [DEBUG] - Policy : OnPolicySharedActorCritic
[08/04/2023 05:07:53 PM] - [DEBUG] - Storage : VanillaRolloutStorage
[08/04/2023 05:07:53 PM] - [DEBUG] - Distribution : Categorical
[08/04/2023 05:07:53 PM] - [DEBUG] - Augmentation : False
[08/04/2023 05:07:53 PM] - [DEBUG] - Intrinsic Reward : False
[08/04/2023 05:07:53 PM] - [DEBUG] - ================================================================================
[08/04/2023 05:07:53 PM] - [INFO.] - Loading Initial Parameters from ./logs/ppo_atari/...
...