# Sibling Tail Call Optimization in Python

| 🤔 | 👍 | 👎 |

In Tail Recursion In Python, Chris Penner implements (self) tail-call optimization (TCO) in Python using a function decorator. Here I'm extending the approach for sibling calls.

## Problem

The example function is a functional-style factorial function defined with tail recursion as:

``````def factorial(n, accumulator=1):
if n == 0:
return accumulator
else:
return factorial(n-1, accumulator * n)
``````

The python interpreter does not implement taill call optimization so calling `factorial(1000)` overflows the stack:

```Traceback (most recent call last):
File "plain.py", line 10, in
print(factorial(2000))
File "plain.py", line 8, in factorial
return factorial(n-1, accumulator * n)
File "plain.py", line 8, in factorial
return factorial(n-1, accumulator * n)
File "plain.py", line 8, in factorial
return factorial(n-1, accumulator * n)
File "plain.py", line 8, in factorial
return factorial(n-1, accumulator * n)
File "plain.py", line 8, in factorial
return factorial(n-1, accumulator * n)
[..]
File "plain.py", line 8, in factorial
return factorial(n-1, accumulator * n)
RuntimeError: maximum recursion depth exceeded
```

## Original Solution

Chris Penner implements taill call optimization using a function decorator (`tail_recursive`):

``````# Normal recursion depth maxes out at 980, this one works indefinitely
@tail_recursive
def factorial(n, accumulator=1):
if n == 0:
return accumulator
recurse(n-1, accumulator=accumulator*n)
``````

With the `recurse` function triggering the self taill-call (`factorial` calls itself).

The implementation is:

``````class Recurse(Exception):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs

def recurse(*args, **kwargs):
raise Recurse(*args, **kwargs)

def tail_recursive(f):
def decorated(*args, **kwargs):
while True:
try:
return f(*args, **kwargs)
except Recurse as r:
args = r.args
kwargs = r.kwargs
continue
return decorated
``````

This is implemented by wrapping the original `factorial` functions. The `recurse` functions throws an exception which is caught be the wrapper function: the wrapper calls the `factorial` function again with the new arguments.

## Limitations

One limitation of this approach is that it only allows self tail-calls (the function calls itself) but not sibling tail-calls (eg. functions `a` calls function `b` and function `b` calls function `a`).

## Sibling-call friendly version (without exceptions)

Instead we could use someting like this:

``````from tco import tail_recursive, tail_call

# Normal recursion depth maxes out at 980, this one works indefinitely
@tail_recursive
def factorial(n, accumulator=1):
if n == 0:
return accumulator
return tail_call(factorial)(n-1, accumulator=accumulator*n)

print(factorial(2000))
``````

With the implementation:

``````class TailCall:
def __init__(self, f, args, kwargs):
self.f = f
self.args = args
self.kwargs = kwargs

def tail_call(f):
def wrapper(*args, **kwargs):
return TailCall(f, args, kwargs)
return wrapper

def tail_recursive(f):
def wrapper(*args, **kwargs):
func = f
while True:
res = func(*args, **kwargs)
if not isinstance(res, TailCall):
return res
args = res.args
kwargs = res.kwargs
if hasattr(res.f, "__tc_original__"):
func = getattr(res.f, "__tc_original__")
else:
func = res.f
wrapper.__tc_original__ = f
return wrapper
``````

This implementation does not use exceptions so we need to `return` the `TailCall` value (otherwise nothing happens).

With this approach, we can have sibling TCO:

``````from tco import tail_recursive, tail_call

@tail_recursive
def factorial(n, accumulator=1):
if n == 0:
return accumulator
return tail_call(factorial2)(n-1, accumulator=accumulator*n)

@tail_recursive
def factorial2(n, accumulator=1):
if n == 0:
return accumulator
return tail_call(factorial)(n-1, accumulator=accumulator*n)

print(factorial(2000))
``````

I tend to like the exception free-approach better. It might make the typing system unhappy however.

## Sibling-call friendly version (with exceptions)

Here's the same thing with exceptions:

``````class TailCall(BaseException):
def __init__(self, f, args, kwargs):
self.f = f
self.args = args
self.kwargs = kwargs

def tail_call(f):
def wrapper(*args, **kwargs):
raise TailCall(f, args, kwargs)
return wrapper

def tail_recursive(f):
def wrapper(*args, **kwargs):
func = f
while True:
try:
return func(*args, **kwargs)
except TailCall as e:
args = e.args
kwargs = e.kwargs
if hasattr(e.f, "__tc_original__"):
func = getattr(e.f, "__tc_original__")
else:
func = e.f
wrapper.__tc_original__ = f
return wrapper
``````

I'm deriving `TailCall` from `BaseException` instead of `Exception` because the tail-recursive functions might catch `Exception` which would break the TCO mechanism.