Tutorials: Magical NumPy with JAX

Presented by:


Description

The greatest contribution of the age the decade in which deep learning exploded was not these big models, but a generalized toolkit to train any model by gradient descent. We're now in an era where differential computing can give you the toolkit to train models of any kind. Does a Pythonista well-versed in the PyData stack have to learn an entirely new toolkit, a new array library to have access to this power?

This tutorial's answer is as follows: If you can write NumPy code, then with JAX, differential computing is at your fingertips with no need to learn a new array library! In this tutorial, you will learn how to use the NumPy-compatible JAX API to write performant numerical models of the world and train them using gradient-based optimization. Along the way, you will write loopy numerical code without loops, think in data cubes, get your functional programming muscles trained up, generate random numbers completely deterministically (no, this is not an oxymoron!), and preview how to mix neural networks and probabilistic models together... leveraging everything you know about NumPy plus some nearly-learned JAX magic sprinkled in!