How does automatic differentiation really work?


I have been using PyTorch for almost two years. Never have I felt the need to use its autograd module explicitly. However, the heedless use of loss.backward() without having any clue of what’s going on inside made me curious. Which is why I decided to do some digging.

This article will help you understand what automatic differentiation is. Although you may see a few equations, let me assure you that nothing more than basic differential calculus will be required.

Let’s explore how automatic differentiation (AD) works.

Why not use the definition of differentiation?

Consider a differentiable function f: R -> R. The derivative of f, by definition, is -

Derivative defnition. Original Image

Derivative defnition. Original Image

If we want to write a program to compute the derivative of f at some x, can’t we approximate it by choosing a small enough h and using the above expression?

Let’s try an exercise. Consider a simple function f(x) = e^x. The derivative of f is f itself. To compute the derivative using the above definition, we consider h as small as 1e-9 (1/10^9). Smaller the h, the more the precision. To measure how accurate the computed derivative is, we visualize the difference between the true derivative (derived analytically) and the computed derivative.

Original Image

Original Image

In our example, this translates to the following.

Example equation. Original Image

Example equation. Original Image

A simple Python program and a little plotting using Plotly results in the following plot

Plot of the difference between the derivative computed using finite differences and the true derivative. Original Image

Plot of the difference between the derivative computed using finite differences and the true derivative. Original Image

The value of y, which represents the error, is zero until x=14, after which things start to get a bit shaky. This is because division by a small number like h makes the result numerically unstable.

It is also a costly approach, especially for training neural networks. Suppose you want to compute

Screenshot from 2021-04-30 20-53-22.png

Where L is loss, and wi is some parameter. You have to make two forward passes (wi and wi+h) through the entire network. Imagine doing this for every parameter in the network!

Symbolic Differentiation

Although not very reliable, the previous approach required us to input a particular value of x and x+h and compute the derivative using just the function output. It did not care about the structure or the variables involved in the function.

In symbolic differentiation, the aim is to take a mathematical expression and return a mathematical expression for the derivative.

Automatic Differentiation, Roger Grosse

To understand symbolic differentiation, consider the following example -

credits - [1]

credits - [1]

The gradient of the above function can be represented by -

Expansion of the equation as used by symbolic differentiation. Credits - [1]

Expansion of the equation as used by symbolic differentiation. Credits - [1]

Assume that some software generates the above expression. All that is left to do is plug in the values of xi s and get the derivative. However, notice that this representation takes an enormous amount of space, and most of it are repeated sub-expressions. Not very optimal, is it?

Although this approach has its problems, performing it consistently and efficiently brings us close to automatic differentiation.

Automatic Differentiation (AD)

After looking at two naive methods to differentiate a function, we now step into what works in the real world. To begin with, we need to understand computation graphs.

Consider the following normal distribution equation.

Our objective is to differentiate y with respect to x. AD splits this equation into a series of basic low-level operations (operations like +,-, exp() which cannot be simplified further).

Example equation

Example equation

If you use PyTorch, you can compute the derivative within 2-3 lines as follows

We took the example input values x1=2 and x2=3. y.backward() computes the gradients. Original image.

We took the example input values x1=2 and x2=3. y.backward() computes the gradients. Original image.

x1.grad is the derivative of y with respect to x1.

What’s happening under the hood?

PyTorch first figures that x1 and x2 are the input variables. It then carefully tracks everything that happens to these two variables by scanning the code. The result of this scan is an optimised graph of computations that looks like this -

The computational graph. Original image.

The computational graph. Original image.

Although you might gather some idea about what the figure depicts, let’s explore. At the bottom, we have the inputs x1, x2 also called the leaf variables. As we move up, each node defines a basic operation performed on inputs to that node. Basic operations cannot be reduced to even more fundamental operations. Finally, after covering all the operations, we arrive at the final output.

V-1 (=x1) and V0(=x2) are independent variables. V1, V2, …, V5 are intermediate variables. We don’t see any Vi in the above code snippet. However, they are internally maintained and evaluated by AD. Finally, the target variable V6 (=y) is at the top of the graph. Each node is associated with a variable Vi, which can be evaluated when inputs to that node are known.

We know how to compute derivatives for a basic operation. Derivative of sin(x) at any point x is cos(x). Derivative of exp(x) is exp(x). Hence, once we have our graph, computing derivatives becomes simpler.


You have the graph, now what? Instead of explaining the theory, I’m going to continue with the same example. For now, know that there are two ways to perform AD, forward mode and reverse mode. We’re going to look at the forward mode first.


Forward Mode

The forward mode uses chain rule in serial order to differentiate each intermediate step till the final derivative of interest is obtained. In the above example, our goal is to compute

Screenshot from 2021-05-04 08-06-30.png

Assume that we are evaluating the above derivative for x1 = 1.5 and x2 = 0.5. We won’t deal with the differentiation with respect to x2 currently for simplicity. First, let us evaluate all the Vi s.

Evaluation trace. Credits - [1]

Evaluation trace. Credits - [1]

Refer to the graph above if you don’t understand anything.

Let us define the following notation

Credits - [1]

Credits - [1]

To evaluate a node in the graph, we needed values of the inputs to the node. Similarly, to evaluate the partial derivative (PD) of the node in the graph, we need the PDs (and the values) of the inputs to the node. Once we have that, our job is done.

Example snippet of how derivatives would be computed in forward mode.. Original image.

Example snippet of how derivatives would be computed in forward mode.. Original image.


Using this logic, we evaluate the PDs for each Vi

Credits - [1]

Credits - [1]

Evaluating the PDs of the independent variables x1 and x2 aren’t hard. For the intermediate variables V1, …, V5, we follow basic differentiation rules. Let’s evaluate V1 as an example

Evaluating V1. Original Image

Evaluating V1. Original Image

As we evaluate a Vi, we know that the derivative of Vi depends on its predecessors, which we have already calculated. This evaluation guarantees that we have everything we need to calculate the PD of Vi. And that’s all you need to know about forward mode!

It is called ‘forward’ mode because the derivatives of Vi are carried along simultaneously along with Vi itself

Reverse Mode

In forward mode, we computed the PD of every Vi with respect to the input xi. In reverse mode, we compute the derivative of the output variable y with respect to each Vi, the intermediate, and the input ones.

Difference between forward and reverse mode. Original image.

Difference between forward and reverse mode. Original image.

We start with computing the derivative of y with respect to V6, equal to 1 as V6=y. For any Vi, the derivative calculation follows.

Computing the PD in reverse mode. Original image.

Computing the PD in reverse mode. Original image.

For a variable Vi, children of Vi are the nodes Vj where a directed edge from Vi to Vj exists. In our example, V6 is a child of V5, V2 and V4 are children of V1 and so on. Using the above definition, the PD calculation becomes desirably mechanical. For example, let’s compute the derivative of y with respect to V5.

Original image

Original image

V5 has only one child. V5 bar is derived from the equation V6=V5*V4. As an alternative example of a node with multiple children, consider the evaluation of the V4 bar

Evaluation of V4 bar. Original image

Evaluation of V4 bar. Original image

Pretty straightforward, right?

Which one is better?

If both methods achieve the same result, which one to choose?

Remember that we computed derivatives of Vi with respect to each input xi in forward mode. That constituted an entire sweep through the graph. Similarly, in reverse mode, we computed the derivative with respect to the dependent variable y, which included a single sweep.

If we have more independent variables than dependent variables, reverse mode is more efficient. On the contrary, forward mode is more efficient if we have more dependent variables than independent variables.

In deep learning, the number of independent variables is usually much larger than the number of dependent variables. Hence, using reverse mode makes more sense. PyTorch uses a reverse model AD.

Conclusion

If you stuck with me so far, I hope it was worth your time. There are more things to know about AD, like how does it work for matrices? How is memory managed in AD? If you are still hungry to know more, check out the references below.

Lastly, if you appreciate my efforts, make sure you subscribe to my blog.

Next
Next

Customize transformer models to your domain