YoVDO

Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3

Offered By: Aleksa Gordić - The AI Epiphany via YouTube

Tags

Neural Networks Courses Programming Courses Machine Learning Courses t-SNE Courses MNIST Dataset Courses JAX Courses

Course Description

Overview

Learn to code a Neural Network from scratch using pure JAX in this comprehensive tutorial video. Dive into creating a Multi-Layer Perceptron (MLP) and training it as a classifier on the MNIST dataset. Follow along as the instructor guides you through the process, from initializing the MLP and implementing prediction functions to setting up PyTorch data loaders and constructing the training loop. Enhance your understanding with visualizations of learned weights, embeddings using t-SNE, and analysis of dead neurons. Gain practical insights into advanced JAX techniques and neural network implementation over the course of this 86-minute learning experience.

Syllabus

Intro, structuring the code
MLP initialization function
Prediction function
PyTorch MNIST dataset
PyTorch data loaders
Training loop
Adding the accuracy metric
Visualize the image and prediction
Small code refactoring
Visualizing MLP weights
Visualizing embeddings using t-SNE
Analyzing dead neurons
Outro


Taught by

Aleksa Gordić - The AI Epiphany

Related Courses

JAX Crash Course - Accelerating Machine Learning Code
AssemblyAI via YouTube
NFNets - High-Performance Large-Scale Image Recognition Without Normalization
Yannic Kilcher via YouTube
Diffrax - Numerical Differential Equation Solvers in JAX
Fields Institute via YouTube
JAX- Accelerated Machine Learning Research via Composable Function Transformations in Python
Fields Institute via YouTube
Getting Started with Automatic Differentiation
PyCon US via YouTube