Tutorials

Speed Up Your Code by 50x: A Guide to Moving from NumPy to JAX

Wednesday, May 14th, 2025 9 a.m.–12:30 p.m. in Room 321

Presented by

Ian Quah, Bryan Quah

Experience Level:

Some experience

Description

JAX is more than just "NumPy for the GPU"—it offers advanced features but also presents unique challenges. This hands-on tutorial provides a practical introduction to JAX through interactive exercises covering key concepts such as:

  • jit compilation for performance optimization
  • Native control flow using loop primitives
  • Efficient function mapping with vmap
  • Performance profiling techniques
  • jax random number generation design and usage

Participants will then deepen their understanding by iteratively migrating a Gaussian Mixture Model from a pure numpy implementation to an optimized jax version, highlighting a real-world use-case.

This tutorial distills lessons the authors found invaluable during their own migration from numpy to jax, achieving over an order-of-magnitude speedup in real-world applications. Designed to provide attendees with a jumpstart on adopting jax, this session—along with its comprehensive set of notebooks—aims to be a one-stop resource for anyone looking to leverage jax for numerical computing and machine learning.

Search