/dev/posts/

Sibling tail call optimization in Python

Published:

Updated:

In Tail Recursion In Python, Chris Penner implements (self) tail-call optimization (TCO) in Python using a function decorator. Here I am extending the approach for sibling calls.

Problem

The example function is a functional-style factorial function defined with tail recursion as:

def factorial(n, accumulator=1):
    if n == 0:
      return accumulator
    else:
      return factorial(n-1, accumulator * n)

The python interpreter does not implement taill call optimization so calling factorial(1000) overflows the stack:

Traceback (most recent call last):
  File "plain.py", line 10, in 
    print(factorial(2000))
  File "plain.py", line 8, in factorial
    return factorial(n-1, accumulator * n)
  File "plain.py", line 8, in factorial
    return factorial(n-1, accumulator * n)
  File "plain.py", line 8, in factorial
    return factorial(n-1, accumulator * n)
  File "plain.py", line 8, in factorial
    return factorial(n-1, accumulator * n)
  File "plain.py", line 8, in factorial
    return factorial(n-1, accumulator * n)
  [..]
  File "plain.py", line 8, in factorial
    return factorial(n-1, accumulator * n)
RuntimeError: maximum recursion depth exceeded

Original Solution

Chris Penner implements taill call optimization using a function decorator (tail_recursive):

# Normal recursion depth maxes out at 980, this one works indefinitely
@tail_recursive
def factorial(n, accumulator=1):
    if n == 0:
        return accumulator
    recurse(n-1, accumulator=accumulator*n)

With the recurse function triggering the self taill-call (factorial calls itself).

The implementation is:

class Recurse(Exception):
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

def recurse(*args, **kwargs):
    raise Recurse(*args, **kwargs)

def tail_recursive(f):
    def decorated(*args, **kwargs):
        while True:
            try:
                return f(*args, **kwargs)
            except Recurse as r:
                args = r.args
                kwargs = r.kwargs
                continue
    return decorated

This is implemented by wrapping the original factorial functions. The recurse functions throws an exception which is caught be the wrapper function: the wrapper calls the factorial function again with the new arguments.

Limitations

One limitation of this approach is that it only allows self tail-calls (the function calls itself) but not sibling tail-calls (eg. functions a calls function b and function b calls function a).

Sibling-call friendly version (without exceptions)

Instead we could use someting like this:

from tco import tail_recursive, tail_call


# Normal recursion depth maxes out at 980, this one works indefinitely
@tail_recursive
def factorial(n, accumulator=1):
    if n == 0:
        return accumulator
    return tail_call(factorial)(n-1, accumulator=accumulator*n)


print(factorial(2000))

With the implementation:

class TailCall:
    def __init__(self, f, args, kwargs):
        self.f = f
        self.args = args
        self.kwargs = kwargs


def tail_call(f):
    def wrapper(*args, **kwargs):
        return TailCall(f, args, kwargs)
    return wrapper


def tail_recursive(f):
    def wrapper(*args, **kwargs):
        func = f
        while True:
            res = func(*args, **kwargs)
            if not isinstance(res, TailCall):
                return res
            args = res.args
            kwargs = res.kwargs
            if hasattr(res.f, "__tc_original__"):
                func = getattr(res.f, "__tc_original__")
            else:
                func = res.f
    wrapper.__tc_original__ = f
    return wrapper

This implementation does not use exceptions so we need to return the TailCall value (otherwise nothing happens).

With this approach, we can have sibling TCO:

from tco import tail_recursive, tail_call


@tail_recursive
def factorial(n, accumulator=1):
    if n == 0:
        return accumulator
    return tail_call(factorial2)(n-1, accumulator=accumulator*n)


@tail_recursive
def factorial2(n, accumulator=1):
    if n == 0:
        return accumulator
    return tail_call(factorial)(n-1, accumulator=accumulator*n)


print(factorial(2000))

I tend to like the exception free-approach better. It might make the typing system unhappy however.

Sibling-call friendly version (with exceptions)

Here is the same thing with exceptions:

class TailCall(BaseException):
    def __init__(self, f, args, kwargs):
        self.f = f
        self.args = args
        self.kwargs = kwargs


def tail_call(f):
    def wrapper(*args, **kwargs):
        raise TailCall(f, args, kwargs)
    return wrapper


def tail_recursive(f):
    def wrapper(*args, **kwargs):
        func = f
        while True:
            try:
                return func(*args, **kwargs)
            except TailCall as e:
                args = e.args
                kwargs = e.kwargs
                if hasattr(e.f, "__tc_original__"):
                    func = getattr(e.f, "__tc_original__")
                else:
                    func = e.f
    wrapper.__tc_original__ = f
    return wrapper

I am deriving TailCall from BaseException instead of Exception because the tail-recursive functions might catch Exception which would break the TCO mechanism.