Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3
Offered By: Aleksa Gordić - The AI Epiphany via YouTube
Course Description
Overview
Learn to code a Neural Network from scratch using pure JAX in this comprehensive tutorial video. Dive into creating a Multi-Layer Perceptron (MLP) and training it as a classifier on the MNIST dataset. Follow along as the instructor guides you through the process, from initializing the MLP and implementing prediction functions to setting up PyTorch data loaders and constructing the training loop. Enhance your understanding with visualizations of learned weights, embeddings using t-SNE, and analysis of dead neurons. Gain practical insights into advanced JAX techniques and neural network implementation over the course of this 86-minute learning experience.
Syllabus
Intro, structuring the code
MLP initialization function
Prediction function
PyTorch MNIST dataset
PyTorch data loaders
Training loop
Adding the accuracy metric
Visualize the image and prediction
Small code refactoring
Visualizing MLP weights
Visualizing embeddings using t-SNE
Analyzing dead neurons
Outro
Taught by
Aleksa Gordić - The AI Epiphany
Related Courses
Introduction to Artificial IntelligenceStanford University via Udacity Natural Language Processing
Columbia University via Coursera Probabilistic Graphical Models 1: Representation
Stanford University via Coursera Computer Vision: The Fundamentals
University of California, Berkeley via Coursera Learning from Data (Introductory Machine Learning course)
California Institute of Technology via Independent