Tuesday, September 23, 2014

Tail Call Recursion in Clojure

In this article, I explain tail call recursion, show an example in both Clojure and common lisp, opine that Clojure's approach to tail call recursion optimization is better than lisp's, and finally look at Clojure's "lazy-seq" which seems to mysteriously provide the benefits of tail call recursion optimization for algorithms which are *not* tail call recursive.

Note: in this article, I refer to "tail call recursion", but I am actually dealing with a subset of that domain: tail *self* call (a.k.a. "direct") recursion. The theories presented also apply to the general case, including mutual (a.k.a. "indirect") recursion. However, optimization of tail *mutual* call recursion is much more rare in compilers than optimization of tail *self* call recursion. So I'm only considering self call.

What Is Tail Call Recursion?

Tail call recursion is a special case of recursion. With recursion, a function calls itself. *Tail* call recursion is where the function calls itself, and then immediately returns, either returning nothing, or returning the value of the recursive call. I.e. it doesn't call itself, and then do something with the results recursive call's value before returning. The reason tail call recursion is an interesting special case is that it can be optimized in a way that doesn't require a stack frame for each recursive invocation (i.e. it converts it into a simple loop). This makes it faster and conserves stack space.

Take the most-obvious way to program factorial recursively in clojure:

(defn fact1 [x]
  (if (<= x 1)
    (* x (fact1 (dec x)))))

(fact1 5) ;= 120N
(fact1 9000) ;= StackOverflowError

This is not tail call recursive. If you call (fact 5), that initial invocation needs to wait for the recursive invocation to return the value of 4! so that it can multiply it by 5 and return the result. It needs to maintain the state of x across the recursive invocation, and it uses the stack for that. Clojure's stack seems awfully small to me; less than 9000 levels deep.

It's easy to modify fact1 to be tail recursive. It requires a modified version in which you pass a "result so far" as an additional input parameter:

(defn fact2-tail [x result-so-far]
  (if (<= x 1)
    (fact2-tail (dec x) (* x result-so-far))))

(fact2-tail 5 1N) ;= 120N
(fact2-tail 9000 1N) ;= StackOverflowError

(If you don't like having to pass the "1N", you can simply define a wrapper function.) This is tail call recursive because all the work, including the multiplication, is done before the recursive call.  So theoretically no state needs to be stacked up as the recursion goes deeper.  But as you can see by the stack overflow, Clojure doesn't realize this.

Be aware that some lisps will do automatic tail call optimization (or at least tail self-call optimization). I tried the equiv thing with clisp and it worked fine with big values:

 (defun fact2-tail-lisp (x result-so-far)
   (if (<= x 1)
     (fact2-tail-lisp (1- x) (* x result-so-far)))))

(fact2-tail-lisp 59999 1) ;= 260,630-digit number starting with 2606

So what's a Clojure programmer to do? Use the loop/recur construct. Let's convert fact1, which is NOT tail call recursive, to loop/recur:

(defn fact1-loop [x]
  (loop [i x]
    (if (<= i 1)
      (* i (recur (dec i))))))
CompilerException java.lang.UnsupportedOperationException: Can only recur from tail position

Interesting. Clojure is NOT smart enough to recognize tail call recursion in fact2-tail, but it is smart enough to recognize NON-tail call recursion in fact1_loop. So the Clojure programmer must both recode the algorithm to be tail call recursive and code it to use loop/recur:

(defn fact2-loop [x]
  (loop [i x, result-so-far 1N]
    (if (<= i 1)
      (recur (dec i) (* i result-so-far)))))

(fact2-loop 5) ;= 120N
(fact2-loop 59999) ;= 260,630-digit number starting with 2606

There! One advantage to using loop/recur is you don't need a separate function which takes the extra result-so-far parameter. It's kind of built in to the loop construct.

Which Is Better?

So, which is better: Clisp, which automatically did the optimization, or Clojure, where the programmer basically has to *tell* the compiler to optimize it with loop/recur?

I have to side with Clojure, and I'll tell you why. I started my experimentations knowing that lisp did the optimization automatically. I coded it up, and it overflowed the stack. Huh? I had to do a bunch of googling to discover this page which mentions in footnote 1 that you have to compile the code to get it to do the optimization. Had I not tried it with a large number, or if I had mis-coded the algorithm to not really be tail call recursive, then I would have put in some effort to achieve an intended benefit, would not have received that benefit, and I might not have known it.

Most of the time, you need to carefully code your algorithm to be tail call recursive. I prefer Clojure's approach where you tell the compiler that you intend it to perform the optimization, and it gives you an error if it can't, as opposed to lisp which silently leaves it un-optimized if you make a mistake.  It also gives a hint to a future maintainer that your potentially awkward algorithm is coded that way for a reason (admit it: fact2-tail and fact2-loop are more awkward algorithms than fact1).

What About lazy-seq?

Now I'm experimenting with lazy-seq, which can also be used to avoid stack consumption. It's a different animal -- the factorial function only returns a single value, whereas lazy-seq is intended to assemble a list (or more-accurately, a lazy sequence).

Let's start with re-writing fact1 to return a list of all factorials, from 1 to the requested value:

(defn fact3 [x i result-so-far]
  (if (> i x)
     (* i result-so-far)
     (fact3 x (inc i) (* i result-so-far)))))

(fact3 5 1N 1N) ;= (1N 2N 6N 120N)
(last (fact3 5 1N 1N)) ;= 120N
(last (fact3 9000 1N 1N)) ;= StackOverflowError

This function is recursive but doesn't use loop/recur. No wonder it blew up the stack. To convert this to loop/recur requires a change in the output.  The final return value has to be the entire list, so the "cons"ing needs to be done on-the-fly each time through the loop. Thus, larger factorial values are "cons"ed onto a list-so-far of smaller factorial values, resulting in a reversed list:

(defn fact3-loop [x]
  (loop [i 2N, result-so-far '(1N)]
    (if (> i x)
      (recur (inc i) (cons (* i (first result-so-far)) result-so-far)))))

(fact3-loop 5) ;= (120N 24N 6N 2N 1N)
(first (fact3-loop 5) ;= 120N
(first (fact3-loop 9000)) ;= 31,682-digit number starting with 8099
(first (fact3-loop 59999)) ;= OutOfMemoryError  Java heap space

We solved the stack problem, but ran out of heap assembling the massive list (note that most of the elements of the list hold VERY large numbers, which consumes significant space for each list element).

lazy-seq to the rescue! Here's it is in action:

(defn fact3-lazy [x i result-so-far]
  (if (> i x)
     (* i result-so-far)
     (lazy-seq (fact3-lazy x (inc i) (* i result-so-far))))))

(fact3-lazy 5 1N 1N) ;= (1N 2N 6N 24N 120N)
(last (fact3-lazy 5 1N 1N) ;= 120N
(last (fact3-lazy 9000 1N 1N) ;= 31,682-digit number starting with 8099
(last (fact3-lazy 59999 1N 1N) ;= 260,630-digit number starting with 2606

Look at the algorithm; it's almost identical to fact3. I.e. it is *not* tail call recursive! Each recursive call must return its list, which is then "cons"ed onto (* i result-so-far).

The lazy-seq function allows Clojure to conserve stack, even though the algorithm is not tail call recursive. It is also able to reclaim unneeded early list items (I'm only interested in the last one), so it conserves heap.

Did Clojure figure out a way to not have to maintain state across all 59999 invocations? Or is it just doing it with fact3-lazy in a more-space-efficient way than fact3? I'm guessing that it is maintaining state on the heap instead of the stack, but I'll be honest, I'm not exactly sure what's going on.

No comments: