Jump to content

Scala/Recursion

From Wikibooks, open books for an open world

Recursion refers to a general method that involves defining a solution or object in terms of itself. Recursive functions refer to a kind of function where the definition of a function includes calling the function itself. Often, recursive functions takes some input, divides it into smaller parts, solves the smaller (and potentially easier) parts, and combines them to produce a solution.

Recursive functions can sometimes be difficult to reason about, but have considerable expressive power. They are often used indirectly through the use of higher-order functions, or for recursively defined structures such as abstract syntax trees and parse trees.

Recursive functions

[edit | edit source]

A simple example of a recursive function:

def recursiveFunc(n:Int):Unit = {
  if (n > 0) {
    print(n + ", ")
    recursiveFunc(n-1)
  }
  else {
    println(n + ".")
  }
}
recursiveFunc(4) //Prints "4, 3, 2, 1, 0.".

The code defines a function named "recursiveFunc". It is recursive because it calls itself on line four. The function works by taking its argument and testing whether the condition "n > 0" is true. If it is larger, it prints the number, calls itself again with 1 subtracted from the number. If it is not larger than 0, it prints the number, and stops. Note that the return type is explicitly annotated to be "Unit", which is always necessary for recursive functions.

At the last line, the function is called with argument 4. The function tests the condition "n > 0", and since "n" is 4, the condition is true. Therefore, it prints the number followed by a comma and space, and calls itself again with the value "n-1". Since "n" was 4, "n-1" gives 3, and "recursiveFunc" is therefore called with argument 3. Again the condition is tested, it is still true, and 3 is printed, after which the function is called again, this time with argument 2. This is repeated again for 2 and 1, until finally the argument becomes 0. This time, the condition does not hold, and the only thing that happens is that 0 is printed followed by a period. Thus, the function prints "4, 3, 2, 1, 0.".

Another example of a recursive function is an implementation of the factorial function:

def fact(n:Int):Int = {
  if (n == 0) 1
  else n*fact(n-1)
}
println(fact(5)) //Prints "120".

The "fact" function takes an integer "n" as argument and returns the factorial of that integer, assuming that "n" is not negative. If n is negative, the function will either loop forever or cause a stack overflow.

The factorial of a number "n" is defined as the product of the integers from 1 to "n" inclusive. The "fact" function correctly computes the factorial of a non-negative number. To see why this is true, consider the following informal argument.

If "n" is 0, the factorial function is defined to be 1, and since "n == 0" is true, the first branch of the if-then-else expression is taken, giving the correct solution 1.

If "n" is greater than 0, we first compute the factorial of the number "n-1" by calling "fact(n-1)". Assuming that the computation of "fact(n-1)" is correct, the result is equal to the factorial of "n-1", namely the product of the numbers from 1 to "n-1" inclusive. What we want, however, is the factorial of "n", namely the product of numbers from 1 to "n" inclusive. However, the product of the numbers from 1 to "n-1" inclusive, multiplied with the number "n", is exactly the same as the product of the numbers from 1 to "n" inclusive. So, given that we have the product of the numbers from 1 to "n-1" inclusive, as well as the number "n", we can thus get the factorial for "n" by multiplying "n" with "fact(n-1)". Thus, if "fact(n-1)" is computed correctly, so is "fact(n)". However, since we can choose "n" arbitrarily, we only need to establish that "fact" is computed correctly for one value, or base case, in order to show that "fact" is computed correctly for all values greater than that base case. Since we showed that the factorial is computed correctly by "fact" for 0 earlier, we can pick 0 as our base case. Thus, the function "fact" correctly computes the factorial of any number greater or equal to 0.

Examples of alternative and easier ways to calculate the factorial that uses recursion indirectly:

def fact2(n:Int) = (1 to n).foldLeft(1)(_*_)
def fact3(n:Int) = (1 to n).product
println(fact2(5)) //Prints "120".
println(fact3(5)) //Prints "120".

The above examples uses concepts such as ranges, collections, higher-order functions and function literals. Using functions and methods that indirectly use recursion is often an alternative to using recursion directly.

Mutual recursion

[edit | edit source]

Mutual recursion refers to functions that are defined in terms of each other:

//WARNING: Calling these functions will result in infinite execution or stack overflow.
def f1(a:Int) = f2(a)
def f2(a:Int) = f3(a)
def f3(a:Int):Int = f1(a)

In the above example, three functions are defined that recurse mutually: "f1" calls "f2", "f2" calls "f3", and "f3" calls "f1". In a mutual recursion such as this, at least one of the functions must have an annotated return type, which in this case has been selected to be "f3".

Tail call optimization

[edit | edit source]

Tail recursive call optimization is supported in Scala.

def fact(n: Int, acc: Int): Int = n match {
  case 0 => acc
  case _ => fact(n - 1, n * acc)
}
fact(10, 1)

Using @tailrec annotation in scala.annotation.tailrec to emit an error when tail recursive optimization is not available.