leo.blog();

JAX

JAX is a machine learning framework by Google.

The core of JAX handles things like calculating gradients or vectorizing operations. But there are libraries called Flax and Flax-Linen that implement neural networks on top of JAX.

There is an official tutorial that implements JAX from scratch, if you want to learn how it works. You can find it at https://docs.jax.dev/en/latest/autodidax.html.

Leave a Comment