YoVDO

Machine Learning with JAX - From Hero to HeroPro+

Offered By: Aleksa Gordić - The AI Epiphany via YouTube

Tags

Machine Learning Courses Neural Networks Courses Gradient Descent Courses TPUs Courses

Course Description

Overview

Dive into advanced machine learning concepts with JAX in this comprehensive tutorial video. Learn to convert stateful models to stateless, master PyTrees, and train a multilayer perceptron using pure JAX. Explore custom PyTrees, parallelism with TPUs, and inter-device communication. Discover techniques for training models across multiple machines, implementing per-example gradients, and even tackle meta-learning with a 3-line MAML implementation. Gain practical insights into JAX's powerful features for building and optimizing complex machine learning models.

Syllabus

My get started with JAX repo
Stateful to stateless conversion
PyTrees in depth
Training an MLP in pure JAX
Custom PyTrees
Parallelism in JAX TPUs example
Communication between devices
value_and_grad and has_aux
Training an ML model on multiple machines
stop grad, per example grads
Implementing MAML in 3 lines
Outro


Taught by

Aleksa Gordić - The AI Epiphany

Related Courses

Production Machine Learning Systems
Google Cloud via Coursera
Deep Learning
Kaggle via YouTube
All About AI Accelerators - GPU, TPU, Dataflow, Near-Memory, Optical, Neuromorphic & More
Yannic Kilcher via YouTube
PyTorch NLP Model Training and Fine-Tuning on Colab TPU Multi-GPU with Accelerate
1littlecoder via YouTube
Solving a Complex Game with AI and All the Google Cloud Power
Devoxx via YouTube