[HN Gopher] 4000x Speedup in Reinforcement Learning with Jax
___________________________________________________________________
 
4000x Speedup in Reinforcement Learning with Jax
 
Author : _hark
Score  : 18 points
Date   : 2023-04-06 21:46 UTC (1 hours ago)
 
web link (chrislu.page)
w3m dump (chrislu.page)
 
| _hark wrote:
| jax.vmap() is all you need?
 
  | schizo89 wrote:
  | Not only vectorization, but the plethora of environments
  | written in jax. Hopefully someone will port MuJoCo to jax soon
 
| sillysaurusx wrote:
| It's a little disingenuous to say that the 4000x speedup is due
| to Jax. I'm a huge Jax fanboy (one of the biggest) but the
| speedup here is thanks to running the simulation environment on a
| GPU. But as much as I love Jax, it's still extraordinarily
| difficult to implement even simple environments purely on a GPU.
| 
| My long-term ambition is to replicate OpenAI's Dota 2
| reinforcement learning work, since it's one of the most impactful
| (or at least most entertaining) use of RL. It would be more or
| less impossible to translate the game logic into Jax, short of
| transpiling C++ to Jax somehow. Which isn't a bad idea - someone
| should make that.
| 
| It should also be noted that there's a long history of RL being
| done on accelerators. AlphaZero's chess evaluations ran entirely
| on TPUs. Pytorch CUDA graphs also make it easier to implement
| this kind of thing nowadays, since (again, as much as I love Jax)
| some Pytorch constructs are simply easier to use than turning
| everything into a functional programming paradigm.
| 
| All that said, you should really try out Jax. The fact that you
| can calculate gradients w.r.t. any arbitrary function is just
| amazing, and you have complete control over what's JIT'ed into a
| GPU graph and what's not. It's a wonderful feeling compared to
| using Pytorch's accursed .backwards() accumulation scheme.
| 
| Can't wait for a framework that feels closer to pure arbitrary
| Python. Maybe AI can figure out how to do it.
 
| schizo89 wrote:
| Neural differential equations are also easier with jax. sim2real
| may be easier with simulator where some of hard computations are
| replaced with neural approximations
 
___________________________________________________________________
(page generated 2023-04-06 23:00 UTC)