[HN Gopher] Useful algorithms that are not optimized by Jax, PyT...
___________________________________________________________________
 
Useful algorithms that are not optimized by Jax, PyTorch, or
TensorFlow
 
Author : ChrisRackauckas
Score  : 152 points
Date   : 2021-07-21 11:10 UTC (1 days ago)
 
web link (www.stochasticlifestyle.com)
w3m dump (www.stochasticlifestyle.com)
 
| cjv wrote:
| ...doesn't the JAX example just need the argument set to
| static_argnums and then it will work?
 
  | ChrisRackauckas wrote:
  | static_argnums is really just a way to give a bit more
  | assumptions to attempt to build a quasi-static code even if
  | it's using dynamic constructs. In this example that will force
  | it to trace one only one of the two branches (depending on
  | whichever static_argnums sends it down). That is going to
  | generate incorrect code for input values which should've traced
  | the other branch (so the real solution of `lax.cond` is to
  | always trace and always compute both branches, as mentioned in
  | the post). If the computation is actually not quasi-static,
  | there's no good choice for a static argnum. See the factorial
  | example.
 
    | cjv wrote:
    | Ah, thanks for the explanation.
 
| ssivark wrote:
| There are many interesting threads in this post, one of which is
| using "non standard interpretations" of programs, and enabling
| the compiler to augment the human-written code with the extra
| pieces necessary to get gradients, propagate uncertainties, etc.
| I wonder whether there's a more unified discussion of the
| potential of these methods. I suspect that a lot of "solvers"
| (each typically with their own DSL for specifying the problem)
| might be nicely formulated in such a framework. (Particularly in
| the case of auto diff, I found recent work/talks by Conal Elliot
| and Tom Minka quite enlightening.)
| 
| Tangentially, thinking about Julia, while one initially gets awed
| by the speed, and then the multiple dispatch, I wonder whether
| it's deepest superpower (that we're still discovering) might be
| the expressiveness to augment the compiler to do interesting
| things with a piece of code. Generic programming then acts as a
| lever to use these improvements for a variety of use cases, and
| the speed is merely the icing on the cake!
 
  | shakow wrote:
  | > the expressiveness to augment the compiler to do interesting
  | things with a piece of code.
  | 
  | Julia has very interesting propositions on the subject, from
  | language-level autodiff (https://fluxml.ai/Zygote.jl/latest/)
  | to automated probabilistic programming (https://turing.ml/dev/)
  | through DEs (https://diffeq.sciml.ai/stable/) and optimization
  | (https://jump.dev/).
  | 
  | The whole ecosystem is in ebullition, and I'm very eager to see
  | if it will be able to transform in the comping years into a
  | solid foundation able to rival the layers of warts stacked on
  | top of Python.
 
  | mccoyb wrote:
  | Just a comment: you're right on the money here. This is the
  | dream that a few people in the Julia community are working
  | towards.
  | 
  | The framework of abstract interpretation, when combined with
  | multiple dispatch as a language design feature, is absolutely
  | insane.
  | 
  | I think programming language enthusiasts might meditate on
  | these points --- and get quite excited with the direction that
  | the Julia compiler implementation is heading.
 
| marcle wrote:
| There is no free lunch:).
| 
| I remember spending a summer using Template Model Builder (TMB),
| which is a useful R/C++ automatic differentiation (AD) framework,
| for working with accelerated failure time models. For these
| models, the survival to time T given covariates X is defined by
| S(t|X) = P(T>t|X) = S_0(t exp(-beta^T X)) for baseline survival
| S_0(t). I wanted to use splines for the baseline survival and
| then use AD for gradients and random effects. Unfortunately,
| after implementing the splines in template C++, I found a web
| page entitled "Things you should NOT do in TMB"
| (https://github.com/kaskr/adcomp/wiki/Things-you-should-NOT-d...)
| - which included using if statements that are based on
| coefficients. In this case, the splines for S_0 depend on beta,
| which is this specific excluded case:(. An older framework (ADMB)
| did not have this constraint, but dissemination of code was more
| difficult. Finally, PyTorch did not have an implementation of
| B-splines or an implementation for Laplace's approximation.
| Returning to my opening comment, there is no free lunch.
 
  | hyperbovine wrote:
  | Were you optimizing over the knots as well? Otherwise I can't
  | see why this would be disallowed using either forward or
  | reverse-mode AD. An infinitesimal perturbation of beta will not
  | cause t * exp(-beta^T x) to cross a knot, so the whole thing is
  | smooth. (And, with B-splines the derivatives are continuous
  | from piece to piece anyways.) But in general I agree--a good
  | spline implementation I something I miss the most when moving
  | from scipy.interpolate to jax.scipy. Given that the SciPy
  | implementation is mostly F77 code written before I was born, I
  | do not see this situation resolving itself anytime soon.
 
    | svantana wrote:
    | It's not about smoothness, it's about how to JIT the gradient
    | function. ML libs don't generally do interpolation, partly
    | because it's tricky to vectorize (you have to search for
    | which segment to use for each element) and partly because
    | most ML practioners don't need it. What I've done in my code
    | is use all the vertices for all the elements, but with
    | weights that are mostly zero. It's pretty fast on GPU because
    | I don't use that many vertices.
 
  | ChrisRackauckas wrote:
  | There is definitely no free lunch, it's good to really
  | delineate the engineering trade-offs you're making! A lot of
  | this work actually comes from the fact that some people I work
  | with were building tools that could efficiently handle dynamic
  | control flow without requiring tracing (see the description of
  | Zygote.jl https://arxiv.org/abs/1810.07951). I had to bring up
  | the question: why? It's much harder to build, needs more
  | machinery, and in some cases can make less assumptions/less
  | fusions (a general form of vmap is much harder for example if
  | you cannot trace, see KernelAbstractions.jl for details). This
  | line of inquiry led an example of why you might want to support
  | such dynamic behaviors, so I'll leave it up to someone else to
  | declare whether the maintenance or complexity cost is worth it
  | to them. I wouldn't say that this means Jax or Tensorflow are
  | doomed (far from it: simple ML architectures are quasi-static,
  | so it's building for the correct audience), but it's good to
  | know what exactly you're leaving out when you make a
  | simplifying assumption.
 
| _hl_ wrote:
| Tangentialy related: Faster training of Neural ODEs is super
| exciting! There are a lot of promising applications (although
| personally I believe that the intuition of "magically choosing
| the number of layers" is misguided, but I'm not am expert and
| might be wrong) but right now it takes forever to train even on
| toy problems, but I'm sure that enough work in this direction
| will eventually lead to more practical methods.
 
| 6gvONxR4sf7o wrote:
| This is a really cool post.
| 
| It seems like you can't solve this kind of thing with a new jax
| primitive for the algorithm, but what prevents new function
| transformations from doing what the mentioned julia libraries do?
| It seems like between new function transformations and new
| primitives, you out to be able to do just about anything. Is XLA
| the issue, and you could run but not jit the result?
 
  | ChrisRackauckas wrote:
  | XLA is the limiting factor in a lot of these cases, though
  | maybe saying limiting factor is wrong because it's more of a
  | "trade-off factor". XLA wants to know the static size of a lot
  | of arguments so it can build a mathematical description of the
  | compute graph and fuse linear algebra commands freely. What the
  | Julia libraries like Zygote do is say "there is no good
  | mathematical description of this, so I will generate source
  | code instead" (and some programs like Tapenade are similar).
  | For example, while loops are translated into for loops where a
  | stack of the Boolean choices are stored so they can be ran in
  | reverse during the backpass. The Julia libraries can sometimes
  | have more trouble automatically fusing linear algebra commands
  | though, since then they need to say "my IR lets non-static
  | things occur, so therefore I need to prove it's static before
  | doing transformation X". It's much easier to know you can do
  | such transformations if anything written in the IR obeys the
  | rules required for the transform! So it's a trade-off. In the
  | search for allowing differentiation of any program in the
  | language, the Julia AD tools have gone for extreme flexibility
  | (and can rely on the fact that Julia has a compiler that can
  | JIT compile any generated valid Julia code) and I find it
  | really interesting to try and elucidate what you actually gain
  | from that.
 
    | awaythrowact wrote:
    | If the next machine learning killer-app model requires
    | autodiff'ed dynamic control flow, do you think
    | Google/Facebook will build that capability into
    | XLA/TorchScript? Seems like if NLP SOTA requires dynamic
    | control flow, Google will build it? Maybe they let you
    | declare some subgraph as "dynamic" to avoid static
    | optimizations? But maybe the static graph assumption is so
    | deeply embedded into the XLA architecture, they'd be better
    | off just adopting Julia? (I honestly don't know the answer,
    | asking your opinion!)
 
      | ChrisRackauckas wrote:
      | "Maybe they let you declare some subgraph as 'dynamic' to
      | avoid static optimizations?" What you just described is
      | Tensorflow Eager and why it has some performance issues
      | (but more flexibility!). XLA makes some pretty strong
      | assumptions and I don't think that should change.
      | Tensorflow's ability to automatically generate good
      | automatically parallelized production code stems from the
      | restrictions it has imposed. So I wouldn't even try for a
      | "one true AD to rule them all" since making things more
      | flexible will reduce the amount of compiler optimizations
      | that can be automatically performed.
      | 
      | To get the more flexible form, you really would want to do
      | it in a way that uses a full programming language's IR as
      | its target. I think trying to use a fully dynamic
      | programming language IR directly (Python, R, etc.) directly
      | would be pretty insane because it would be hard to enforce
      | rules and get performance. So some language that has a
      | front end over an optimizing compiler (LLVM) would probably
      | make the most sense. Zygote and Diffractor uses Julia's IR,
      | but there are other ways to do this as well. Enzyme
      | (https://github.com/wsmoses/Enzyme.jl) uses the LLVM IR
      | directly for doing source-to-source translations. Using
      | some dialect of LLVM (provided by MLIR) might be an
      | interesting place to write a more ML-focused flexible AD
      | system. Swift for Tensorflow used the Swift IR. This
      | mindset starts to show why those tools were chosen.
 
        | awaythrowact wrote:
        | Makes sense. I don't use TF Eager, but I do use Jax, and
        | Jax lets you arbitrarily compose JITed and non-JITed
        | code, which made me think that might be a viable pattern.
        | I guess I wondered if there might be something like
        | "nonstatic_jit(foo)" that would do "julia style"
        | compiling on function foo, in addition to "jit(foo)" that
        | compiles foo to optimized XLA ops. Probably impractical.
        | Thanks.
 
    | 6gvONxR4sf7o wrote:
    | > and can rely on the fact that Julia has a compiler that can
    | JIT compile any generated valid Julia code
    | 
    | This seems to be the key bit. It's a great data point around
    | the meme of "with a sufficiently advanced compiler..." In
    | this case we have sufficiently advanced compilers to make
    | very different JIT trade offs. XLA is differently powerful
    | compared to Julia. Very cool, thanks for the insight.
 
| ipsum2 wrote:
| The example that fails in Jax would work fine in PyTorch. If
| you're working on purely training the model, TorchScript doesn't
| give many benefits, if any.
 
___________________________________________________________________
(page generated 2021-07-22 23:00 UTC)