Creating a Neural Network from Scratch in Rust — Part 2
In the previous article of this series, Creating a Neural Network from Scratch in Rust — Part 1, we learned more about the neurons and the perceptron model. Now we're ready to dig deeper into how neural networks learn, process, and "store" information. My goal in this post is to build up the foundation and understanding needed so that in the next articles, we can continue implementing our neural network in Rust. My goal here is not to explain all the math and details, but to provide a nice mental model and visualization of the processes, so that when you look at the formulas, they don't look too scary ;)
Teaching neural networks In the previous article, we implemented some simple logic gates with our perceptron model. They work well, but even for a simple model, tweaking the parameters to match the desired output is kind of difficult to do by hand. Now, think of a model with billions or even trillions of parameters! It's clearly impossible to adjust all of them by hand. We need a way to do that automatically! But here's the catch. How can we adjust those parameters? We need to model the problem in a way that allows us to at least attempt to express it through code. But…How so? Describing the output We can start by modeling the desired output of our neural network, since we already know what it is intended to be. Let's define some sort of a function that tells us how good or bad the neural network predicts the desired output for a given training sample. A common function used in this case, especially in classification problems, is the MSE(Mean Squared Error) function. Let us take a look at it for a second:
Don't be scared! This function is quite simple, actually! We're basically computing the difference between the expected result Yi and the result we got from the Neural Network Ŷi. We then square the result, so it's always positive, sum them up, and take the mean. That effectively gives us a way to measure the "performance" of our neural network! Notice that, if our neural network starts to predict values that are far from the expected output, the squared sum will grow, meaning our neural network is performing poorly. On the other hand, if our neural network predicts the values correctly, the difference will be zero, so the average at the end will approach zero, meaning our neural network is well-tuned. Optimize a function and get a NNW for free! Now we have a way to measure how good or bad our neural network performs! But have you noticed something interesting about this formula? Well, since it measures how good or bad the neural network performs at predicting the correct output, if we optimize this function, this means we're getting a better neural network! That's the key idea behind neural network training! Optimizing functions Alright, now we're making progress! We know that we need to optimize(minimize) a certain function now in order to train our neural network. The question is: How? There are many ways to optimize a function; this kind of problem is known as an optimization problem, and there are many ways to approach it. One way that's particularly good for us, given we're dealing with a differentiable function and with a training dataset, is to use a technique called gradient descent! Gradient Descent Algorithm Gradient descent is at the heart of the optimization of every neural network out there! This section will be a bit more intense on the math, so hang tight. Don't be scared, though. You don't need to fully grasp all the concepts here, but make sure you get the intuition in your head! What's a Gradient even? Looking at the definition from Wikipedia, a Gradient is: In vector calculus, the gradient of a scalar-valued differentiable function f of several variables is the vector field (or vector-valued function) ∇f. whose value at a point p gives the direction and the rate of fastest increase. That sounds very complicated, I know. But let's together try unpacking what this definition is trying to say: We see that the gradient of a function that operates on vectors is also a vector. Perhaps the most important piece of information from this definition is the ending. We see that the gradient vector always points in the direction of the fastest increase of the function at that specific point. If we have a function that operates on three dimensions, for example, we may see the gradient as something like this:
Did you see that? The red arrows represent the gradient of the function calculated at specific points in the 3-dimensional space. As we can see, this vector always points to the location where de function grows faster. Keep that intuition in mind, we're gonna need it soon! Optimizing using gradient As we just learned, the gradient of a function always points to the direction of fastest increase, right? But we're trying to minimize our function, meaning we wanna get to a minimum of such a function. So we may take the opposite direction of the gradient vector! If we follow the direction of the negative gradient, we're in fact "walking down" the function, meaning we're gonna reach a minimum at some point! Again, in a 3-dimensional space, we can imagine we're going down the function until we reach a minimum of the function. It will not necessarily be the global minimum, meaning the lowest we can get, but it could also be a local minimum.
Following the gradient descent That's in short what the gradient descent algorithm does! Pretty simple, huh? Now, if we want to train our neural network, we can just minimize this function and use this information to update our neural network. We will see how this is done in the next article! Conclusions In this article, we've explored a way to represent the error of a neural network using the MSE function. We've also seen how it could be used as feedback to our neural network during training. In the next articles, we're gonna explore in a bit more detail how this training works. Also, we will build an automatic differentiation tool from scratch in Rust for computing this gradient! All this knowledge will allow us in no time to have a fully working neural network, so stay tuned!