JAX is up and coming in the Machine Learning space with ambitions to make machine learning simple yet efficient. JAX is still a Google and Deepmind research project and not yet an official Google product but has been used extensively internally and adopted by external ML researchers. We wanted to offer an introduction to JAX, how to install JAX, and its advantages and capabilities.
What Is JAX for Machine Learning?
JAX is a Python library designed for high-performance numerical computing, especially machine learning research. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. JAX focuses on accelerating the machine learning process by using XLA to compile NumPy functions on GPUs and uses autograd to differentiate Python and NumPy functions as well as gradient-based optimization. JAX is able to differentiate through loops, branches, recursion, and closures, and take derivatives of derivatives of derivatives with ease using GPU acceleration. JAX also supports backpropagation and forward-mode differentiation.