Machine Learning with JAX - From Hero to HeroPro+
Offered By: Aleksa Gordić - The AI Epiphany via YouTube
Course Description
Overview
          Dive into advanced machine learning concepts with JAX in this comprehensive tutorial video. Learn to convert stateful models to stateless, master PyTrees, and train a multilayer perceptron using pure JAX. Explore custom PyTrees, parallelism with TPUs, and inter-device communication. Discover techniques for training models across multiple machines, implementing per-example gradients, and even tackle meta-learning with a 3-line MAML implementation. Gain practical insights into JAX's powerful features for building and optimizing complex machine learning models.
        
Syllabus
 My get started with JAX repo
 Stateful to stateless conversion
 PyTrees in depth
 Training an MLP in pure JAX
 Custom PyTrees
 Parallelism in JAX TPUs example
 Communication between devices
 value_and_grad and has_aux
 Training an ML model on multiple machines
 stop grad, per example grads
 Implementing MAML in 3 lines
 Outro
Taught by
Aleksa Gordić - The AI Epiphany
Related Courses
Neural Networks for Machine LearningUniversity of Toronto via Coursera Good Brain, Bad Brain: Basics
University of Birmingham via FutureLearn Statistical Learning with R
Stanford University via edX Machine Learning 1—Supervised Learning
Brown University via Udacity Fundamentals of Neuroscience, Part 2: Neurons and Networks
Harvard University via edX