Recap
This is part two of my series on how neural nets work. If you’re here, you should have a solid understanding of how forward propagation works. If not, check out my article on it here. Let’s quickly go over everything as a refresher. Our network is a digit classifier for greyscale images with dimensions of 28 by 28. It consists of 4 layers of nodes, which you can think of as compound functions. Connecting nodes to each other are numbers that we call weights. The first layer of our network does not consist of nodes but of our input data, which is a one-dimensional vector of length 784(28 x 28). For each node in the second layer, a dot product of the previous layer and the weights that connect them is computed, the output of which is fed into a ReLU activation function. Do the same for the third layer, except that the input is not our image, but the second layer’s output. For the fourth layer, swap out the ReLU for a Softmax function. The output of our network is a vector of probabilities where each value represents a class(digits 0-9). We are now ready to begin backpropagation but, before we can do that, there are two more topics we have to discuss: batches, and one-hot encoding.
Batches
Instead of forward propagating a single image at a time, we’ll pass in many images concurrently in what is called a batch. One we have gone through forward propagation for each piece of data, we can compute the gradients using the average of the errors. This will cause our network to converge faster and have a lower chance of overfitting.
One-hot encoding
This is a method of converting a scalar to a vector. We often use this encoding as a way to simplify our equations and eventual code. Here’s an example:
Error functions
Before we can begin to improve our network, let's first figure out how to quantify the word “better”. Your intuition may say being “better” means that the accuracy has increased, but as it turns out there is a more useful statistic. Remember, our Softmax output is not a single digit representing the network's guess, but is a vector of probabilities for each class. You can think of this vector as the confidence that a given class is the correct label. Now, imagine that the input for our network is an image of a 4. If our probability for a 4 is 11%, it could still be our network's guess, as every other probability could be 10% or 9%. This is because the “guess” is simply the plurality of our vector. While we may still get the correct output from our network, it is likely that a small change in our input data will cause this to not be the case. Thus, we want to define our error function in such a way that the error is low when our network has a high confidence for the correct class, and high when it has a low confidence. The first implementation you might think of is 1-X, where X is the correct class’ probability.
Better error functions
However, this simple error function has some drawbacks. We might want our function to punish lower confidences much more than higher confidences. An example of this would be (1-X)^2, called Mean Squared Error. This is very common in regression problems however, for classification problems, we prefer to use another function: Categorical Cross Entropy, or log-loss. I will go over what the function actually is when some of the variables have been explained.
One of the best parts of Categorical Cross Entropy is its relationship with the Softmax function. For reasons that will be explained later, we need to be able to find the derivatives of each function we use. Now, Categorical Cross Entropy’s derivative is absolutely hideous. You can read about deriving it here. However, when we take the derivative of it combined with the Softmax, we get a beautifully simple function, one which we will go over later.
Gradient descent
The “learning” in machine learning is the process of tweaking our weights such that our error function will decrease. Essentially, we want to find the minimum of our error function. Let's imagine for a minute that instead of a complex neural net, our function is a simple parabola. If you have taken calculus, you might have some ideas about how to find this minimum. Perhaps we would first find every zero of the function’s first derivative, and then use the second derivative to determine whether those points are minima. While this works very well for our simple parabola, finding every zero of the first derivative is not very practical for complex functions like our neural net. Instead, it’s common to use a strategy called Gradient Descent.
Let's continue with our parabola function. First, we begin at a random point along the curve. Then, we will take the derivative at this location. Now, we will take a small step in the negative direction of that derivative. If we repeat this process many times, we will eventually end up at a minimum. If we want to make this more efficient, we can make our step size proportional to our derivative. Even though our network has thousands of parameters(each weight can be thought of as a parameter, along with the input data) and the parabola only has one, gradient descent works just as well. So, our next goal is to figure out how to get the derivative of each weight with respect to our error function.
Chain Rule
Before we can begin our derivations, we need to figure out how to take derivatives of complex functions. If you have taken any level of calculus, you should hopefully be familiar with the chain rule. If not, the rule is simply a means to compute the derivative of a compound function. It looks like this:
In our context, we use the chain rule to get the derivative of a series of functions. The chain rule will look something like this:
Variables & notation
Let's go over the notation we will use to describe our equations before we go any further.
- X refers to the input data for our network, which is a matrix of pixels.
- Y refers to the one-hot encoding of the digit our image depicts.
- W refers to a two dimensional matrix of weights for a given layer.
When working with vectors and matrices, it’s very important to keep the dimensions(or shape) of them in mind, otherwise implementation can become a nightmare. Here are the dimensions of our variables so far:
We should also formally define the functions that we use during the derivations.
- Z refers to the linear combination(dot product) of our weights and inputs.
- A refers to the activation function’s output for a given layer.
- E is our error function: Categorical Cross-entropy.
First steps
Intuitively, it makes sense to begin our derivations with our first layer of weights. Unfortunately, as we propagate through our layers, this becomes incredibly difficult. For example, a single weight may only affect a single node in the subsequent layer, but that node affects every node in the layer after that. As you can see, our derivation will quickly spiral out of feasibility. Instead, we will begin from the end of our network and move towards the front. This is where the name backpropagation comes from, as we are propagating our gradients backwards.
Derivations
Let’s begin at the very end of our network with our error function. Remember, we are using Categorical Cross Entropy, which has the property that its derivative when combined with the Softmax becomes very simple. This essentially means that during backpropagation, we are effectively combining these two functions into one and computing only a single derivative, bypassing the chain rule altogether.
So, our derivation(miraculously) ends up as:
Now that we have the derivative of our dot product with respect to the error function, we should now find the derivative of the weights with respect to the dot product.
We now have everything we need to calculate the derivative of the weights with respect to the error, simply by multiplying each partial derivative together.
You might have realized that this derivative is not actually just the multiplication of our two derivatives. Because our input is not a single image, we have to divide the first derivative by the number of batches. We also have to transpose the second derivative in order for the matrix multiplication to work correctly.
Derivations(contd.)
OK, we can compute the derivative of the last layer of weights with respect to our error function. Let’s move on to the next layer of weights. Because we have the derivative of the linear combination of the last layer with respect to the error function, we need to find just three gradients:
The first is the second layer’s activation function with respect to the third layers linear combination:
The second gradient we must find is the second layer’s linear combination with respect to the second layer’s activation, which is simply the derivative of the activation:
The third gradient is the weights of the second layer with respect to the second layer’s linear combination:
Finally, we can again multiply these together to find the gradients of the second layer of weights:
We have to do some transposing and an element wise multiplication(⊙) to make the matrix multiplication work, but these steps are using the same concepts we used earlier to derive the last layer’s weight gradients.
The process for finding the third layer of gradients is just about exactly the same as the second layer, so I won’t include it in this post, however you are encouraged to derive it yourself.
Updating the network
Now that we have the derivatives for each of our weights with respect to the error function, we can enjoy the fruits of our labor and update them as we do in Gradient Descent. This is very simple, and will look like this:
The alpha in each of these equations is our learning rate, yet another hyperparameter. The learning rate determines the size of our gradient descent steps. A typical value for this would be 0.1.
Conclusion
That's it! You now (hopefully) understand both forward and backward propagation. All that's left to do is repeat this process using new images in our input batch each time. As you perform more iterations, your network will slowly become more and more accurate.
However, if you have played with my demo at all, you might have noticed the accuracy is still not great. The truth is that basic neural nets have some large issues, as their lack of spatial reasoning generally causes them to have a hard time classifying images. To fix these issues, Convolutional Neural Networks(CNNs) were created, which are a topic of a future post! Thanks for reading, and have a great day!