|
| _hark wrote:
| jax.vmap() is all you need?
| schizo89 wrote:
| Not only vectorization, but the plethora of environments
| written in jax. Hopefully someone will port MuJoCo to jax soon
| sillysaurusx wrote:
| It's a little disingenuous to say that the 4000x speedup is due
| to Jax. I'm a huge Jax fanboy (one of the biggest) but the
| speedup here is thanks to running the simulation environment on a
| GPU. But as much as I love Jax, it's still extraordinarily
| difficult to implement even simple environments purely on a GPU.
|
| My long-term ambition is to replicate OpenAI's Dota 2
| reinforcement learning work, since it's one of the most impactful
| (or at least most entertaining) use of RL. It would be more or
| less impossible to translate the game logic into Jax, short of
| transpiling C++ to Jax somehow. Which isn't a bad idea - someone
| should make that.
|
| It should also be noted that there's a long history of RL being
| done on accelerators. AlphaZero's chess evaluations ran entirely
| on TPUs. Pytorch CUDA graphs also make it easier to implement
| this kind of thing nowadays, since (again, as much as I love Jax)
| some Pytorch constructs are simply easier to use than turning
| everything into a functional programming paradigm.
|
| All that said, you should really try out Jax. The fact that you
| can calculate gradients w.r.t. any arbitrary function is just
| amazing, and you have complete control over what's JIT'ed into a
| GPU graph and what's not. It's a wonderful feeling compared to
| using Pytorch's accursed .backwards() accumulation scheme.
|
| Can't wait for a framework that feels closer to pure arbitrary
| Python. Maybe AI can figure out how to do it.
| schizo89 wrote:
| Neural differential equations are also easier with jax. sim2real
| may be easier with simulator where some of hard computations are
| replaced with neural approximations
___________________________________________________________________
(page generated 2023-04-06 23:00 UTC) |