Jump to content

Scala/Higher-order functions 1

From Wikibooks, open books for an open world

Higher-order functions are functions that either takes as argument or returns a function. By using functions instead of simpler values, higher-order functions become very flexible. A simple example is testing whether at least one element in a list passes some test. By using an existing higher-order function defined for List, we don't have to write the code that tests each element, but only have to write a function that contains the test itself. For instance, assume that we want to test whether some list contains the number 4:

def isEqualToFour(a:Int) = a == 4
val list = List(1, 2, 3, 4)
val resultExists4 = list.exists(isEqualToFour)
println(resultExists4) //Prints "true".

In the above example, we first define our function that contains the test (equality to 4). We then define our list that happens to contain the number 4 (meaning that the final result should be true). In the third line, we call the method "exists", which takes our function containing the test, apply it to the elements of the list, and returns whether the function was true for at least one of the elements. Since the list indeed contains 4, the final result is true, which is also what is printed.

If we instead wanted to test whether all the numbers in the list is equal to 4 (which is clearly false), we would use the "forall" method instead. "forall" tests whether the given function is true for every single element in the list.

val resultForall4 = list.forall(isEqualToFour)
println(resultForall4) //Prints "false".

As expected, the result is false. Note that we didn't have to redefine the function containing the test. By separating the testing into a test function ("isEqualToFour"), and the logic that applies the test function into higher-order functions ("exists" and "forall"), we avoid a considerable amount of duplication.

Another common higher-order function is that of "map". Let's say that you have a list of numbers and want to change each number in the list independently, for instance multiplying each number by some constant. That is exactly what "map" does: it takes a transformation function and applies it to each element independently to create a new list. Let's see it in action:

def multiplyBy42(a:Int) = 42*a
val resultMultiplyBy42 = list.map(multiplyBy42)
println(resultMultiplyBy42) //Prints "List(42, 84, 126, 168)".

By using "map" we avoid having to apply the function to each element and to construct the new list ourselves.

There are plenty of other higher-order functions defined for not just List, but most of the other collections, as well as for other classes in the Scala library. Some of the note-worthy functions include "reduce" and "foldLeft"/"foldRight".

"reduce" takes a function that takes two elements and combine them somehow into a new element of the same type, and keeps doing that until there is only one, resulting element. Examples of uses of "reduce" includes cases such as when you want to find the sum or the total product of some numbers, or want to combine a lot of strings into one string, maybe by putting something like "\n", "," or ";" between subsequent strings.

"foldLeft"/"foldRight" are basically sequential transformations. While "map" takes each element and transforms it independently, the folds goes through the collection sequentially, taking each element and the previous result, and transforming it into a new result (such as a new list or a sum). The folds are more difficult to use than "map" and "reduce", but are more flexible, and can in fact be used to define both "map" and "reduce" themselves. The "Left" and "Right" refers to which direction the fold goes through the elements.