TD3-JAX
A JAX Implementation of the Twin Delayed DDPG Algorithm
Requirements
Beside each requirement, I have stated the version installed on my system for reproducibility..
JAX - jax 0.1.59, jaxlib 0.1.39
Haiku - dm-haiku 0.0.1a0, dm-sonnet 2.0.0b0
RLax - rlax 0.0.0
Gym - gym 0.15.4
MuJoCo - mujoco-py 2.0.2.9
Command line arguments
In order to run each environment
$$
for seed in {0..9}; do python main.py --env Hopper-v2 --seed $seed; done
$$
The default hyper parameters aren't ideal for all domains. Based on some limited testing and intuition, the following values are better than the defaults.
EnvironmentCommand line additionSwimmer-v2--discount 0.995
Results
For each seed, we maintain the 'best policy seen' during evaluation, which we re-evaluate at the end of training. These results are the average +- one standard deviation for this metric. All reported results are based on 10 seeds (0 to 9).
EnvironmentBest policy per runHopper-v23691.5 ± 61.7Humanoid-v25194.0 ± 97.1Walker2d-v24328.8 ± 1059.0Ant-v23505.4 ± 411.7HalfCheetah-v210411.0 ± 1323.5Swimmer-v2314.1 ± 69.2InvertedPendulum-v21000.0 ± 0.0InvertedDoublePendulum-v29350.6 ± 26.8Reacher-v2-4.0 ± 0.3
The code for reproducing the figures, including per-seed representation for each environment is provided in plot_results.ipynb



Based on the per-seed analysis, it seems that with some hyperparameter tuning, the results of TD3 can improve dramatically. Mainly, it seems that in some domains, it takes a while for the algorithm to start learning -- either a result of low learning rates, large experience replay or un-optimized discount factor.