0

Backgroud

sbx is a jax implementation of stable-baselines3. As claimed here, it can accelate RL training by jit compared to sb3+pytorch.

Question

I tested sbx ppo and sb3 ppo on gym env Hopper-v5. The results shows that neither sbx on cpu or sbx on cuda is significantly faster than sb3 on cpu. Only 1.06x speedup is achieved according to the fps(frame per second) graph.

Episodic return mean learning curve(green: sbx on cpu and cuda, red: sb3 on cpu):

eps_rew_mean

fps(frame per second) graph:

Image

My question are:

  1. Is the result expected? if it is what is benefit to use sbx ppo? if not where did I implement wrong?
  2. Is the performance improvements claimed here still true in 2025?

Additional context

Companion github issues.

The hyparameters of all runs are kept consistent as below:

batch_size:64
device:"cuda"
env_id:"hopper"
learning_rate:0.0001
n_envs:16
n_epochs:10
n_steps:128
seed:0
total_timesteps:1,000,000

System hardware:

item value
CPU count 8
Logical CPU count 16
GPU count 1
GPU type NVIDIA GeForce RTX 3080 Ti

Here is the wandb report: https://wandb.ai/zhixin/jaxrl/reports/performance-compare-sbx-vs-sb3-ppo--VmlldzoxMjU0ODY2NQ

Code can be found here

Run script:

# sbx gym cpu
python -O sb3_learn.py --sbx --seed 0 --n_envs 16 --n_steps 128 --learning_rate 0.0001 --batch-size 64 --total_timesteps 1000000

sbx gym cuda

python -O sb3_learn.py --sbx --device cuda --seed 0 --n_envs 16 --n_steps 128 --learning_rate 0.0001 --batch-size 64 --total_timesteps 1000000

sb3 gym cpu

python -O sb3_learn.py --seed 0 --n_envs 16 --n_steps 128 --learning_rate 0.0001 --batch-size 64 --total_timesteps 1000000

zhixin
  • 53
  • 6

1 Answers1

1

After research, I found the reason is I test sbx on wsl2 system rather than native Linux system. JAX official tutorial says JAX on wsl is experimental.

After switching to Linux (Ubuntu 18.04) on same hardware, sbx shows significant performance improvement than sb3. Below is the fps of PPO on env Hopper with sbx cuda, sbx cpu and sb3 cpu.

enter image description here

So the lesson is don't use jax things on wsl to expect performance gains.

zhixin
  • 53
  • 6