Wednesday, April 7, 2010

Ensuring Tail Call Optimization in Scala

If you're writing functional code you're probably writing recursive code. And if you're writing recursive code in a language that is compiled into imperative instructions - as Scala is - then tail call optimisation becomes very important.

What is a Tail Call?
When you write a recursive function, there are two ways to do it. If your function calls itself recursively and then performs some other operation with the result, it is NOT tail call recursive. On the other hand, if the last thing your function does is call itself recursively, so that the result of the recursive call is the result of the function, that is a tail call recursion, so named because the recursive call is at the "tail" of the function.

Here is an example of both:

object FactorialTest {
def main(args: Array[String]) {

def factorial_NonTailCall(n: Int): Int = n match {
case 1 => 1
case _ => factorial_NonTailCall(n - 1) * n

def factorial_TailCall(n: Int): Int = {
def f(n: Int, r: Int): Int = n match {
case 1 => r
case _ => f(n - 1, r * n)
f(n, 1)

You can see in factorial_NonTailCall() that, after the function calls itself recursively, it multiplies the result by n and returns the result of that operation. However, in the f() function inside factorial_TailCall(), the recursive call (to f(n - 1, r * n)) is also the result of the function, making this a tail call.

What is Tail Call Optimisation?
Now, if you write your recursive functions using tail calls, that Scala compiler *may* be able to optimise your function. What needs to be optimised? Well, if you've ever written a recursive function in Java - either intentionally or by accident - you may have noticed that, if your function recurses too many times, you get a StackOverflowError. This is because, every time the method calls itself, another frame is added to the stack, and stack memory is relatively limited (compared to heap memory). Scala, running on the JVM, has the same problem - if a method calls itself too many times, it will run out of stack memory.

Scala's tail call optimisation gets around this problem by writing the byte code for tail-call recursive methods in such a way that they don't actually have to call themselves. It compiles the recursive function into something more like a for loop, using a goto operation to jump to the top of the function without actually invoking a method call. If you're interested in reading byte code, you can have a look at the end result at the end of this post.

Ensuring Tail Call Optimisation in Scala
So, tail call recursion is good, because the optimisation is important for your application to not blow up. The theory of tail calls is quite simple, but sometimes in practice it can be quite complex to figure out if you've got it right.

As of Scala 2.8, there is now an easy way to ensure that you function is optimised for tail recursion: the @tailrec annotation. If you add this annotation to any function, the compiler will fail with an error if it cannot perform tail call optimisation on the function.

Pre-Requisites for Tail Call Optimisation
Oddly enough, using tail calls isn't the only requirement for getting your function optimised. The other pre-requisite is that your function has to not be override-able. This is because any function that can be overridden has to be invoked polymorphically by the JVM, which means that any recursion within it has to use a dynamic method invocation and hence can't be optimised to a goto.

If that's all a bit complex, there moral of the story is quite simple: in order for your tail call recursive function to be optimised for tail call optimisation, it also hast to be either private or final (or in a final class or object). So if you've written a tail recursive function and annotated it with @tailrec but can't figure out why the compiler is telling you "error: could not optimize @tailrec annotated method", the first thing to check is that your function can't be overridden.

Remember: private or final.

The Guts
As promised, here's the byte code for the FactorialTest. You can see in the code for the f() method that it never invokes itself, but instead has a "goto 0" command at line 43.

[scala] graham$ javap -private -c -classpath . FactorialTest$
Compiled from "FactorialTest.scala"
public final class FactorialTest$ extends java.lang.Object implements scala.ScalaObject{
public static final FactorialTest$ MODULE$;

public static {};
0: new #10; //class FactorialTest$
3: invokespecial #13; //Method "":()V
6: return

private FactorialTest$();
0: aload_0
1: invokespecial #17; //Method java/lang/Object."":()V
4: aload_0
5: putstatic #19; //Field MODULE$:LFactorialTest$;
8: return

private final int f$1(int, int);
0: iload_1
1: istore 4
3: iload 4
5: iconst_1
6: if_icmpne 31
9: iconst_1
10: ifeq 15
13: iload_2
14: ireturn
15: new #23; //class scala/MatchError
18: dup
19: iload 4
21: invokestatic #29; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
24: invokevirtual #33; //Method java/lang/Object.toString:()Ljava/lang/String;
27: invokespecial #36; //Method scala/MatchError."":(Ljava/lang/String;)V
30: athrow
31: iconst_1
32: ifeq 46
35: iload_1
36: iconst_1
37: isub
38: iload_2
39: iload_1
40: imul
41: istore_2
42: istore_1
43: goto 0
46: new #23; //class scala/MatchError
49: dup
50: iload 4
52: invokestatic #29; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
55: invokevirtual #33; //Method java/lang/Object.toString:()Ljava/lang/String;
58: invokespecial #36; //Method scala/MatchError."":(Ljava/lang/String;)V
61: athrow

private int factorial_TailCall(int);
0: aload_0
1: iload_1
2: iconst_1
3: invokespecial #46; //Method f$1:(II)I
6: ireturn

private int factorial_NonTailCall(int);
0: iload_1
1: istore_2
2: iload_2
3: iconst_1
4: if_icmpne 30
7: iconst_1
8: ifeq 15
11: iconst_1
12: goto 43
15: new #23; //class scala/MatchError
18: dup
19: iload_2
20: invokestatic #29; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
23: invokevirtual #33; //Method java/lang/Object.toString:()Ljava/lang/String;
26: invokespecial #36; //Method scala/MatchError."":(Ljava/lang/String;)V
29: athrow
30: iconst_1
31: ifeq 44
34: aload_0
35: iload_1
36: iconst_1
37: isub
38: invokespecial #49; //Method factorial_NonTailCall:(I)I
41: iload_1
42: imul
43: ireturn
44: new #23; //class scala/MatchError
47: dup
48: iload_2
49: invokestatic #29; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
52: invokevirtual #33; //Method java/lang/Object.toString:()Ljava/lang/String;
55: invokespecial #36; //Method scala/MatchError."":(Ljava/lang/String;)V
58: athrow

public void main(java.lang.String[]);
0: getstatic #57; //Field scala/Predef$.MODULE$:Lscala/Predef$;
3: aload_0
4: iconst_5
5: invokespecial #59; //Method factorial_TailCall:(I)I
8: invokestatic #29; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
11: invokevirtual #63; //Method scala/Predef$.println:(Ljava/lang/Object;)V
14: getstatic #57; //Field scala/Predef$.MODULE$:Lscala/Predef$;
17: aload_0
18: iconst_5
19: invokespecial #49; //Method factorial_NonTailCall:(I)I
22: invokestatic #29; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
25: invokevirtual #63; //Method scala/Predef$.println:(Ljava/lang/Object;)V
28: return


1 comment:

  1. Graham,
    Thanks for the nice post. I learn quite a lot of stuff from here. I have also blog about it.