JAX is more than just "NumPy for the GPU"—it offers advanced features but also presents unique challenges. This hands-on tutorial provides a practical introduction to JAX through interactive exercises covering key concepts such as:
jit
compilation for performance optimization- Native control flow using loop primitives
- Efficient function mapping with vmap
- Performance profiling techniques
jax
random number generation design and usage
Participants will then deepen their understanding by iteratively migrating a Gaussian Mixture Model from a pure numpy
implementation to an optimized jax
version, highlighting a real-world use-case.
This tutorial distills lessons the authors found invaluable during their own migration from numpy
to jax
, achieving over an order-of-magnitude speedup in real-world applications. Designed to provide attendees with a jumpstart on adopting jax
, this session—along with its comprehensive set of notebooks—aims to be a one-stop resource for anyone looking to leverage jax
for numerical computing and machine learning.