Recursion occurs when a function call causes that same function to be called again before the original function call terminates. For example, consider the well-known mathematical expression x!
(i.e. the factorial operation). The factorial operation is defined for all nonnegative integers as follows:
In Python, a naïve implementation of the factorial operation can be defined as a function as follows:
def factorial(n):
if n == 0:
return 1
else:
return n * factorial(n - 1)
Recursion functions can be difficult to grasp sometimes, so let's walk through this step-by-step. Consider the expression factorial(3)
. This and all function calls create a new environment. An environment is basically just a table that maps identifiers (e.g. n
, factorial
, print
, etc.) to their corresponding values. At any point in time, you can access the current environment using locals()
. In the first function call, the only local variable that gets defined is n = 3
. Therefore, printing locals()
would show {'n': 3}
. Since n == 3
, the return value becomes n * factorial(n - 1)
.
At this next step is where things might get a little confusing. Looking at our new expression, we already know what n
is. However, we don't yet know what factorial(n - 1)
is. First, n - 1
evaluates to 2
. Then, 2
is passed to factorial
as the value for n
. Since this is a new function call, a second environment is created to store this new n
. Let A be the first environment and B be the second environment. A still exists and equals {'n': 3}
, however, B (which equals {'n': 2}
) is the current environment. Looking at the function body, the return value is, again, n * factorial(n - 1)
. Without evaluating this expression, let's substitute it into the original return expression. By doing this, we're mentally discarding B, so remember to substitute n
accordingly (i.e. references to B's n
are replaced with n - 1
which uses A's n
). Now, the original return expression becomes n * ((n - 1) * factorial((n - 1) - 1))
. Take a second to ensure that you understand why this is so.
Now, let's evaluate the factorial((n - 1) - 1))
portion of that. Since A's n == 3
, we're passing 1
into factorial
. Therefore, we are creating a new environment C which equals {'n': 1}
. Again, the return value is n * factorial(n - 1)
. So let's replace factorial((n - 1) - 1))
of the “original” return expression similarly to how we adjusted the original return expression earlier. The “original” expression is now n * ((n - 1) * ((n - 2) * factorial((n - 2) - 1)))
.
Almost done. Now, we need to evaluate factorial((n - 2) - 1)
. This time, we're passing in 0
. Therefore, this evaluates to 1
. Now, let's perform our last substitution. The “original” return expression is now n * ((n - 1) * ((n - 2) * 1))
. Recalling that the original return expression is evaluated under A, the expression becomes 3 * ((3 - 1) * ((3 - 2) * 1))
. This, of course, evaluates to 6. To confirm that this is the correct answer, recall that 3! == 3 * 2 * 1 == 6
. Before reading any further, be sure that you fully understand the concept of environments and how they apply to recursion.
The statement if n == 0: return 1
is called a base case. This is because, it exhibits no recursion. A base case is absolutely required. Without one, you'll run into infinite recursion. With that said, as long as you have at least one base case, you can have as many cases as you want. For example, we could have equivalently written factorial
as follows:
def factorial(n):
if n == 0:
return 1
elif n == 1:
return 1
else:
return n * factorial(n - 1)
You may also have multiple recursion cases, but we won't get into that since it's relatively uncommon and is often difficult to mentally process.
You can also have “parallel” recursive function calls. For example, consider the Fibonacci sequence which is defined as follows:
We can define this is as follows:
def fib(n):
if n == 0 or n == 1:
return n
else:
return fib(n - 2) + fib(n - 1)
I won't walk through this function as thoroughly as I did with factorial(3)
, but the final return value of fib(5)
is equivalent to the following (syntactically invalid) expression:
(
fib((n - 2) - 2)
+
(
fib(((n - 2) - 1) - 2)
+
fib(((n - 2) - 1) - 1)
)
)
+
(
(
fib(((n - 1) - 2) - 2)
+
fib(((n - 1) - 2) - 1)
)
+
(
fib(((n - 1) - 1) - 2)
+
(
fib((((n - 1) - 1) - 1) - 2)
+
fib((((n - 1) - 1) - 1) - 1)
)
)
)
This becomes (1 + (0 + 1)) + ((0 + 1) + (1 + (0 + 1)))
which of course evaluates to 5
.
Now, let's cover a few more vocabulary terms:
return foo(n - 1)
is a tail call, but return foo(n - 1) + 1
is not (since the addition is the last operation).Tail call optimization is helpful for a number of reasons:
Python has no form of TCO implemented for a number of a reasons. Therefore, other techniques are required to skirt this limitation. The method of choice depends on the use case. With some intuition, the definitions of factorial
and fib
can relatively easily be converted to iterative code as follows:
def factorial(n):
product = 1
while n > 1:
product *= n
n -= 1
return product
def fib(n):
a, b = 0, 1
while n > 0:
a, b = b, a + b
n -= 1
return a
This is usually the most efficient way to manually eliminate recursion, but it can become rather difficult for more complex functions.
Another useful tool is Python's lru_cache decorator which can be used to reduce the number of redundant calculations.
You now have an idea as to how to avoid recursion in Python, but when should you use recursion? The answer is “not often”. All recursive functions can be implemented iteratively. It's simply a matter of figuring out how to do so. However, there are rare cases in which recursion is okay. Recursion is common in Python when the expected inputs wouldn't cause a significant number of a recursive function calls.
If recursion is a topic that interests you, I implore you to study functional languages such as Scheme or Haskell. In such languages, recursion is much more useful.
Please note that the above example for the Fibonacci sequence, although good at showing how to apply the definition in python and later use of the lru cache, has an inefficient running time since it makes 2 recursive calls for each non base case. The number of calls to the function grows exponentially to n
.
Rather non-intuitively a more efficient implementation would use linear recursion:
def fib(n):
if n <= 1:
return (n,0)
else:
(a, b) = fib(n - 1)
return (a + b, a)
But that one has the issue of returning a pair of numbers. This emphasizes that some functions really do not gain much from recursion.