Adding Tail Call Optimization to A Lisp Written in Go

code lisp l1 golang .....

Later: Practices for Software Projects
Earlier: Tests by Example in Clojure and Common Lisp

The last few days have been devoted to improving l1, the homegrown lisp I wrote about earlier this year. A number of changes have landed in the last week:

I also implemented the bulk of the automated tests in the language itself. This was a decisive step forward in both ease of creating new tests and confidence that the language was approaching something usable.

The work I'm happiest with, though, because it taught me the most, was implementing tail call optimization (TCO) in the language, which the rest of this post will be about.

Motivation

The need for some form of TCO became clear as I started to write more small programs in l1. Perhaps the simplest example is one that sums all the natural numbers up to $n$:

(defn sum-to-acc (n acc)
  (cond ((zero? n) acc)
        (t (sum-to-acc (- n 1) (+ n acc)))))

(defn sum-to (n)
  (sum-to-acc n 0))

Calling sum-to for small $n$ worked fine:

(sum-to 100)
;;=>
5050

However, larger $n$ blew up spectacularly:

(sum-to (* 1000 1000))
;;=>
runtime: goroutine stack exceeds 1000000000-byte limit
runtime: sp=0x14020500360 stack=[0x14020500000, 0x14040500000]
fatal error: stack overflow

runtime stack:
runtime.throw({0x10289aa2b?, 0x10294ddc0?})
	/opt/homebrew/Cellar/go/1.18.3/libexec/src/runtime/panic.go:992 +0x50
runtime.newstack()
	/opt/homebrew/Cellar/go/1.18.3/libexec/src/runtime/stack.go:1101 +0x46c
runtime.morestack()
	/opt/homebrew/Cellar/go/1.18.3/libexec/src/runtime/asm_arm64.s:314 +0x70

goroutine 1 [running]:
strings.(*Reader).ReadByte(0x1401c2820e0?)
	/opt/homebrew/Cellar/go/1.18.3/libexec/src/strings/reader.go:66 +0x98 fp=0x14020500360 sp=0x14020500360 pc=0x102883348
math/big.nat.scan({0x0, 0x1401c2820e0?, 0x0}, {0x1028dc7c8, 0x1401c2820e0}, 0xa, 0x0)
	/opt/homebrew/Cellar/go/1.18.3/libexec/src/math/big/natconv.go:126 +0x80 fp=0x14020500430 sp=0x14020500360 pc=0x10288b1e0
;; ...

This happens, of course, because sum-to-acc calls itself a million times, each time storing a copy of its local bindings on the stack, which eventually consumes all the space on the stack.

Getting simple recursive functions like this to work for large $n$ is especially important because l1 doesn't have loops (yet)!

The Optimization

The solution is hinted at already in my test case. Note that I did not write sum-to as a single recursive function, as follows:

(defn sum-to-notail (n)
  (cond ((zero? n) 0)
        (t (+ n (sum-to-notail (- n 1))))))

While this function looks slightly simpler, it is harder for a compiler or interpreter to optimize. The difference is that, whereas sum-to-notail does some work after calling itself (by adding n to the result), sum-to-acc calls itself from the tail position; that is, the function returns immediately after calling itself.

People have long realized that function calls from the tail position can be replaced by updating the return address and then jumping directly to the the new function without adding new information to the stack. This is something I had heard about for years and used in various "functional" languages, without ever really implementing myself (and therefore fully understanding). It's an easy thing to take for granted without knowing anything about how it's actually implemented under the hood. The failure of sum-to-acc and similar recursive functions, described above, meant I would have to learn.

Two very different blog posts were helpful to me in pointing the way forward: Adding tail call optimization to a Lisp interpreter in Ruby, and How Tail Call Optimization Works. The posts focus on very different languages (Ruby vs. C / assembler), but they each revolve around what are effectively GOTO statements. I'm old enough to remember BASIC and the pernicious GOTO statement leading to "spaghetti code." I doubt I've ever used a GOTO statement in production code, whose use in modern programming languages fell out of favor in the aftermath of Dijkstra's famous Go To Statement Considered Harmful paper. But the ability to transfer control to another part of your program without invoking a function call is key to the optimization.

The Approach

Since the strategy is general, let's lose all the parentheses for a moment and rewrite sum-to-acc in language-agnostic pseudo-code:

function sum-to-acc(n sum)
   if n == 0, then return sum
   return sum-to-acc(n - 1, n + sum)

In most languages (without TCO), when this function is called, the values of n and sum, as well as the return address, will be put on the stack, whose evolution looks something like the following.1:

 first invocation: [n=5, sum=0,  ret=sum-to:...]

second invocation: [n=4, sum=5,  ret=sum-to-acc:...]
                   [n=5, sum=0,  ret=sum-to:...]

 third invocation: [n=3, sum=9,  ret=sum-to-acc:...]
                   [n=4, sum=5,  ret=sum-to-acc:...]
                   [n=5, sum=0,  ret=sum-to:...]

fourth invocation: [n=2, sum=12, ret=sum-to-acc:...]
                   [n=3, sum=9,  ret=sum-to-acc:...]
                   [n=4, sum=5,  ret=sum-to-acc:...]
                   [n=5, sum=0,  ret=sum-to:...]

 fifth invocation: [n=1, sum=14, ret=sum-to-acc:...]
                   [n=2, sum=12, ret=sum-to-acc:...]
                   [n=3, sum=9,  ret=sum-to-acc:...]
                   [n=4, sum=5,  ret=sum-to-acc:...]
                   [n=5, sum=0,  ret=sum-to:...]

 sixth invocation: [n=0, sum=15, ret=sum-to-acc:...]
                   [n=1, sum=14, ret=sum-to-acc:...]
                   [n=2, sum=12, ret=sum-to-acc:...]
                   [n=3, sum=9,  ret=sum-to-acc:...]
                   [n=4, sum=5,  ret=sum-to-acc:...]
                   [n=5, sum=0,  ret=sum-to:...]

At the sixth invocation, our terminating condition is reached, and 15 is returned, with all the pending stack frames popped off the stack.

With TCO, the implementation looks more like the following:

function sum-to-acc(n sum)
TOP:
   if n == 0, then return sum
   n = n - 1
   sum = sum + n
   GOTO TOP

as a result, the evolution of the stack looks as follows:

 first invocation: [n=5, sum=0, ret=sum-to:...]

second invocation: [n=4, sum=5, ret=sum-to:...]

 third invocation: [n=3, sum=9, ret=sum-to:...]

fourth invocation: [n=2, sum=12, ret=sum-to:...]

 fifth invocation: [n=1, sum=14, ret=sum-to:...]

 sixth invocation: [n=0, sum=15, ret=sum-to:...]

All those extra stack frames are gone: recursion has turned into a form of iteration.

Implementing TCO, then, has two ingredients:

  1. Replace the values of the current arguments with their new values directly.
  2. Jump straight to the next call of the function without adding to the stack;

This low-level, imperative optimization makes high-level, functional, recursive implementations efficient.

Implementation

In thinking about the implementation for l1, I was pleased to learn that Go actually has a goto statement. However, my implementation was poorly set up to use it.

Early in the implementation of l1, I noticed that each data type (numbers, atoms, and lists) had its own evaluation rules, so it made sense to make use of Go's features supporting polymorphism, namely interfaces and receivers. I had a Sexpr interface which looked like the following:

type Sexpr interface {
	String() string
	Eval(*env) (Sexpr, error)  // <--------
	Equal(Sexpr) bool
}

Numbers and atoms, for example, had fairly simple Eval implementations. For example,

func (a Atom) Eval(e *env) (Sexpr, error) {
	if a.s == "t" {
		return a, nil
	}
	ret, ok := e.Lookup(a.s)
	if ok {
		return ret, nil
	}
	ret, ok = builtins[a.s]
	if ok {
		return ret, nil
	}
	return nil, fmt.Errorf("unknown symbol: %s", a.s)
}

And, of course, numbers eval to themselves:

func (n Number) Eval(e *env) (Sexpr, error) {
	return n, nil
}

Lists, as you would expect, were more complicated – evaluating a list expression needs to handle special forms2, user-defined functions, and built-in functions. Following the classic Structure and Interpretation of Computer Programs, I separated the core logic for function application into separate Eval and Apply phases. And to prevent the Eval for lists from getting too large, I broke out the evaluation rules for different cases (e.g. for let and cond special forms and for function application) into their own functions.

In other words, I had evaluation logic spread over ten functions in five files. Sadly, the need to jump back to the beginning of an evaluation rather than recursively calling Eval again meant that several of those nicely broken out functions had to be brought together into a single function, because goto does not support jumping from one function to another. (C has setjmp and longjmp, which effectively do this, but I would want to upgrade my IQ by a few points before applying them in this situation.)

There were actually three cases where I was performing an evaluation step right before returning, and the goto pattern could be used:

  1. When evaluating code in the tail position of a user-defined function;
  2. When evaluating code in the last expression in a let block;
  3. When evaluating code in the chosen branch of a cond clause.

I wound up with code which with looks like the following. Several steps are indicated only with comments. Note the tiny, easy-to-miss top: label at the very beginning:

// lisp.go
//
func eval(expr Sexpr, e *env) (Sexpr, error) {
top:
	switch t := expr.(type) {
	case Atom:
		return evAtom(t, e)
	case Number:
		return expr, nil
	// ...
	case *ConsCell:
		if t == Nil {
			return Nil, nil
		}
		// special forms:
		if carAtom, ok := t.car.(Atom); ok {
			switch {
			case carAtom.s == "quote":
				return t.cdr.(*ConsCell).car, nil
			case carAtom.s == "cond":
				pairList := t.cdr.(*ConsCell)
				if pairList == Nil {
					return Nil, nil
				}
				for {
					if pairList == Nil {
						return Nil, nil
					}
					pair := pairList.car.(*ConsCell)
					ev, err := eval(pair.car, e)
					if err != nil {
						return nil, err
					}
					if ev == Nil {
						pairList = pairList.cdr.(*ConsCell)
						continue
					}
					expr = pair.cdr.(*ConsCell).car
					goto top
				}
			// ...

The code so far shows the evaluation for atoms, numbers, and cond statements. cond does not introduce any new bindings, but when the first truthy condition is encountered, it evaluates the next argument as its final act. So the code above simply replaces the expression to be evaluated, expr, with the expression from the matching clause, and then restarts the evaluation via goto, without the overhead of a separate function call.

The code for let is somewhat similar:

			case carAtom.s == "let":
				args := t.cdr.(*ConsCell)
				if args == Nil {
					return nil, fmt.Errorf("let requires a binding list")
				}
				// ... code to set up let bindings ...
				body := args.cdr.(*ConsCell)
				var ret Sexpr = Nil
				for {
					var err error
					if body == Nil {
						return ret, nil
					}
					// Implement TCO for `let`:
					if body.cdr == Nil {
						expr = body.car
						e = &newEnv
						goto top
					}
					ret, err = eval(body.car, &newEnv)
					if err != nil {
						return nil, err
					}
					body = body.cdr.(*ConsCell)
				}

The for loop invokes a new eval for each expression in the body of the let, except for the last one: when the last expression is reached, (the cdr is Nil), the last eval is done by jumping to the beginning of the function, once it has updated its environment to point to include the new bindings.

The last use of this pattern is in function invocation proper, which looks similar:

			// (... code to set up new environment based on passed arguments ...)
			var ret Sexpr = Nil
			for {
				if lambda.body == Nil {
					return ret, nil
				}
				// TCO:
				if lambda.body.cdr == Nil {
					expr = lambda.body.car
					e = &newEnv
					goto top
				}
				ret, err = eval(lambda.body.car, &newEnv)
				if err != nil {
					return nil, err
				}
				lambda.body = lambda.body.cdr.(*ConsCell)
			}

I've skipped various parts of eval that aren't relevant for TCO optimization – if you're interested, you can check out the code yourself.

To be clear, what we are optimizing is all tail calls, not just recursive ones – though the recursive ones were the primary objective due to the stack overflows reported above.

The end result is that sum-to now can complete for large values of $n$:

(sum-to (* 1000 1000))
;;=>
500000500000

Incidentally, a variant of our test case failed before I added the TCO optimization to let shown above; this now works, as well:

(defn sum-to-acc-with-let (n acc)
  (let ((_ 1))
    (cond ((zero? n) acc)
          (t (sum-to-acc-with-let (- n 1) (+ n acc))))))

(defn sum-to-with-let (n) (sum-to-acc-with-let n 0))

(sum-to-with-let (* 1000 1000))
;;=>
500000500000

Conclusion

Getting tail-call optimization to work was very satisfying… though the eval implementation is certainly more complex than before. (Ah, optimization!)

To ensure TCO continues to work, variants of sum-to with and without let are run on every build, along with a few other short example programs.

After implementing TCO in my own code, I can appreciate and understand the optimization better when I see it in the wild. I fully expect to use the pattern again when implementing future lisps (yes, I hope there will be more).


1

Note that this is a somewhat abstract representation: the details are language-specific. The ret=sum-to:... notation means that when the function returns, control will pass back to where it left off inside the sum-to function.

2

A special form is one that does not follow the normal evaluation rule for functions – it may evaluate its arguments once, many times, or not at all. (I am glossing over macros for the time being; l1 does not have them yet.)

Later: Practices for Software Projects
Earlier: Tests by Example in Clojure and Common Lisp