Skip to content

Benchmarks

rllte-hub provides a large number of reusable datasets and models of representative RL benchmarks. All the files are deposited on the Hugging Face platform, view them by

Module Remark
rllte.hub.datasets Provide test scores and learning cures of various RL algorithms on different benchmarks.
rllte.hub.models Provide trained models of various RL algorithms on different benchmarks.
rllte.hub.applications Provide fast-APIs for training RL agents on recognized benchmarks.

Support list

Benchmark Algorithm Remark Reference
Atari Games PPO 50M, πŸ’―πŸ“ŠπŸ€– Paper
SAC 1M, πŸ’―πŸ“ŠπŸ€– Paper
DeepMind Control (Pixel) DrQ-v2 1M, πŸ’―πŸ“ŠπŸ€– Paper
DeepMind Control (State) SAC 10M, πŸ’―πŸ“ŠπŸ€–
DDPG 10M, πŸ’―πŸ“ŠπŸ€–
Procgen Games PPO 25M, πŸ’―πŸ“ŠπŸ€– Paper
DAAC 25M, πŸ’―πŸ“ŠπŸ€– Paper
MiniGrid Games

Tip

  • 🐌: Incoming.
  • (25M): 25 million training steps.
  • πŸ’―Scores: Available final scores.
  • πŸ“ŠCurves: Available training curves.
  • πŸ€–Models: Available trained models.

Datasets

.load_scores

Suppose we want to evaluate algorithm performance on the Procgen benchmark. Here is an example:

example.py
from rllte.hub.datasets import Procgen

procgen = Procgen()
procgen_scores = procgen.load_scores()
print(procgen_scores['ppo'].shape)

# Output:
# (10, 16)
For each algorithm, this will return a NdArray of size (10 x 16) where scores[n][m] represent the score on run n of task m.

.load_curves

Meanwhile, .load_curves will return the learning curves by a Python Dict like:

curves = {
    "ppo": {
        "train": {"bigfish": np.ndarray(shape=(Number of seeds, Number of points)), ...}, 
        "eval": {"bigfish": np.ndarray(shape=(Number of seeds, Number of points)), ...}, 
    },
    "daac": {
        "train": {"bigfish": np.ndarray(shape=(Number of seeds, Number of points)), ...}, 
        "eval": {"bigfish": np.ndarray(shape=(Number of seeds, Number of points)), ...}, 
    },
    ...
}
A code example for loading curves of the Procgen benchmark:
example.py
from rllte.hub.datasets import Procgen

if __name__ == "__main__":
    # load data
    procgen = Procgen()
    curves = procgen.load_curves()

    print(curves['ppo']['train']['bigfish'].shape)
    print(curves['ppo']['eval']['bigfish'].shape)

# Output:
# (10, 1525)
# (10, 153)

Models

Suppose we want to load an PPO agent trained on Procgen benchmark, here is an example:

example.py
from rllte.hub.models import Procgen
from rllte.env import make_procgen_env
import torch as th
import numpy as np

if __name__ == "__main__":
    # env setup
    device = "cuda:0"
    env_id = "starpilot"
    seed = 1
    # download the model
    procgen = Procgen()
    agent = procgen.load_models(agent="ppo",
                                env_id=env_id,
                                seed=seed,
                                device=device)
    # create env
    env = make_procgen_env(env_id=env_id, device=device, num_envs=1, seed=seed)
    # evaluate the model
    obs, infos = env.reset(seed=seed)
    # run the model
    episode_rewards, episode_steps = list(), list()
    while len(episode_rewards) < 10:
        # the exported model outputs logits of the action distribution
        action = th.softmax(agent(obs), dim=1).argmax(dim=1)
        obs, rewards, terminateds, truncateds, infos = env.step(action)

        if "episode" in infos:
            indices = np.nonzero(infos["episode"]["l"])
            episode_rewards.extend(infos["episode"]["r"][indices].tolist())
            episode_steps.extend(infos["episode"]["l"][indices].tolist())

    print(f"mean episode reward: {np.mean(episode_rewards)}")
    print(f"mean episode length: {np.mean(episode_steps)}")

# output:
mean episode reward: 30.0
mean episode length: 296.1

Applications

Suppose we want to train an PPO agent on Procgen benchmark, it suffices to write a train.py like:

from rllte.hub.applications import Procgen

app = Procgen(agent="PPO", env_id="coinrun", seed=1, device="cuda")
app.train(num_train_steps=2.5e+7)
All the results of rllte.hub.datasets and rllte.hub.models were trained via rllte.hub.applications, and all the hyper-parameters can be found in the reference.