Sibling tail call optimization in Python
Published:
Updated:
In Tail Recursion In Python, Chris Penner implements (self) tail-call optimization (TCO) in Python using a function decorator. Here I am extending the approach for sibling calls.
Problem
The example function is a functional-style factorial function defined with tail recursion as:
def factorial(n, accumulator=1):
if n == 0:
return accumulator
else:
return factorial(n-1, accumulator * n)
The python interpreter does not implement taill call optimization so calling factorial(1000)
overflows the stack:
Traceback (most recent call last): File "plain.py", line 10, inprint(factorial(2000)) File "plain.py", line 8, in factorial return factorial(n-1, accumulator * n) File "plain.py", line 8, in factorial return factorial(n-1, accumulator * n) File "plain.py", line 8, in factorial return factorial(n-1, accumulator * n) File "plain.py", line 8, in factorial return factorial(n-1, accumulator * n) File "plain.py", line 8, in factorial return factorial(n-1, accumulator * n) [..] File "plain.py", line 8, in factorial return factorial(n-1, accumulator * n) RuntimeError: maximum recursion depth exceeded
Original Solution
Chris Penner implements taill call optimization using a function decorator (tail_recursive
):
# Normal recursion depth maxes out at 980, this one works indefinitely
@tail_recursive
def factorial(n, accumulator=1):
if n == 0:
return accumulator
recurse(n-1, accumulator=accumulator*n)
With the recurse
function triggering the self taill-call (factorial
calls itself).
The implementation is:
class Recurse(Exception):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
def recurse(*args, **kwargs):
raise Recurse(*args, **kwargs)
def tail_recursive(f):
def decorated(*args, **kwargs):
while True:
try:
return f(*args, **kwargs)
except Recurse as r:
args = r.args
kwargs = r.kwargs
continue
return decorated
This is implemented by wrapping the original factorial
functions. The recurse
functions throws an exception which is caught be the wrapper function: the wrapper calls the factorial
function again with the new arguments.
Limitations
One limitation of this approach is that it only allows self tail-calls (the function calls itself) but not sibling tail-calls (eg. functions a
calls function b
and function b
calls function a
).
Sibling-call friendly version (without exceptions)
Instead we could use someting like this:
from tco import tail_recursive, tail_call
# Normal recursion depth maxes out at 980, this one works indefinitely
@tail_recursive
def factorial(n, accumulator=1):
if n == 0:
return accumulator
return tail_call(factorial)(n-1, accumulator=accumulator*n)
print(factorial(2000))
With the implementation:
class TailCall:
def __init__(self, f, args, kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
def tail_call(f):
def wrapper(*args, **kwargs):
return TailCall(f, args, kwargs)
return wrapper
def tail_recursive(f):
def wrapper(*args, **kwargs):
func = f
while True:
res = func(*args, **kwargs)
if not isinstance(res, TailCall):
return res
args = res.args
kwargs = res.kwargs
if hasattr(res.f, "__tc_original__"):
func = getattr(res.f, "__tc_original__")
else:
func = res.f
wrapper.__tc_original__ = f
return wrapper
This implementation does not use exceptions so we need to return
the TailCall
value (otherwise nothing happens).
With this approach, we can have sibling TCO:
from tco import tail_recursive, tail_call
@tail_recursive
def factorial(n, accumulator=1):
if n == 0:
return accumulator
return tail_call(factorial2)(n-1, accumulator=accumulator*n)
@tail_recursive
def factorial2(n, accumulator=1):
if n == 0:
return accumulator
return tail_call(factorial)(n-1, accumulator=accumulator*n)
print(factorial(2000))
I tend to like the exception free-approach better. It might make the typing system unhappy however.
Sibling-call friendly version (with exceptions)
Here is the same thing with exceptions:
class TailCall(BaseException):
def __init__(self, f, args, kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
def tail_call(f):
def wrapper(*args, **kwargs):
raise TailCall(f, args, kwargs)
return wrapper
def tail_recursive(f):
def wrapper(*args, **kwargs):
func = f
while True:
try:
return func(*args, **kwargs)
except TailCall as e:
args = e.args
kwargs = e.kwargs
if hasattr(e.f, "__tc_original__"):
func = getattr(e.f, "__tc_original__")
else:
func = e.f
wrapper.__tc_original__ = f
return wrapper
I am deriving TailCall
from BaseException
instead of Exception
because the tail-recursive functions might catch Exception
which would break the TCO mechanism.