Saturday, June 4, 2011

What's the best way to transform a list with a cumulative function?

I've been doing this kind of thing a lot lately: I have a list of somethings. I want a new list based on that list, but it's not a straight one-to-one map() operation - each value in the resulting list is a function of the corresponding value and all values before it in the input list. (I'm sure there's a one-word name for this type of function, but I don't specialise in maths vocab.)

Example
As an example to discuss, imagine I have a list of numbers, and I want a list of the cumulative totals after each number:
 def main(args: Array[String]) {
val numbers = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
val cumulativeTotals = ... ?
println(cumulativeTotals)
}
should output:
List(1, 3, 6, 10, 15, 21, 28, 36, 45, 55)

Funky Folds


So, until the other day, I've usualy been doing this with a foldLeft():
 private def cumulativeTotalFolded(numbers: List[Int]): Seq[Int] = {
numbers.foldLeft((0, List[Int]()))((currentTotalAndCumulativeTotals, n) => {
val (currentTotal, cumulativeTotals) = currentTotalAndCumulativeTotals
(currentTotal + n, (currentTotal + n) :: cumulativeTotals)
})._2.reverse
}
Now, I'm not holding that up as a good example of anything. Folds can be hard to understand at the best of times. A fold that passes around a Tuple2 as the accumulator value is not code that simply communicates to the next guy what's going on.

Stream Simplicity
So, after a particularly hairy instance of one of these the other night, I lay in bed trying to think of a better way. It struck me ('round about 1am) that Streams are a much more natural construct for calculating cumulative functions.

If you haven't come across Streams before, they're basically a way to define a collection recursively be providing the first value in the collection and a function that will calculate the next element in the collection (and, recursively, all the elements after it) as each subsequent element is requested by the client.

Streams are good for this problem because the iteration through the list is simple, while the "next" value usually has easy access to the "previous" value, so to speak. I've used this once or twice now and I like it a lot better.

For the problem above, my solution using a Steram looks like this:
 private def cumulativeTotalStream(numbers: List[Int], total: Int = 0): Stream[Int] = {
numbers match {
case head :: tail =>
Stream.cons(total + head, cumulativeTotalStream(tail, total + head))
case Nil => Stream.Empty
}
}
(Note: to make this work with the println() above, you'll need to toList() the Stream.

Recursion Wrangling
There is, of course, another obvious way, which is to construct the list using a recursive function that passes two accumulators: one for the current total and one for the resulting List of totals that is accumulating:
@tailrec
private def cumulativeTotalRecursive(
numbers: List[Int], currentTotal: Int = 0, cumulativeTotals: List[Int] = Nil): Seq[Int] = {

numbers match {
case head :: tail =>
cumulativeTotalRecursive(tail, currentTotal + head, (currentTotal + head) :: cumulativeTotals)
case Nil => cumulativeTotals.reverse
}
}
There's nothing wrong with this solution, and it probably performs much better than the Stream version, but I feel a bit weird about passing so many parameters around for such a simple operation.

I could reduce the parameter count by getting the currentTotal from the accumulator list instead of passing it around:
 @tailrec
private def cumulativeTotalRecursive2(
numbers: List[Int], cumulativeTotals: List[Int] = Nil): Seq[Int] = {

(numbers, cumulativeTotals) match {
case (number :: tail, Nil) => cumulativeTotalRecursive2(tail, List(number))
case (number :: tail, lastTotal :: otherTotals) =>
cumulativeTotalRecursive2(tail, (number + lastTotal) :: cumulativeTotals)
case (Nil, _) => cumulativeTotals.reverse
}
}
but then the function body ends up more complex than the one with an extra parameter, which isn't a good trade-off.

Mutability the key to Simplicity?
Finally, I thought about the possibility of solving it with a for comprehension. I realised quickly that I'd need a mutable variable, but the result is a very, very simple piece of code:

private def cumulativeTotalComprehension(numbers: List[Int]): Seq[Int] = {
var currentTotal = 0
for (n <- numbers) yield {
currentTotal += n
currentTotal
}
}
I'm pretty sure this wouldn't look as pretty for a lot of the problems I've been tackling with the foldLeft() but, mind you, none of them looked very pretty either. Is this an okay solution? Do Scala afficionados going to vomit in disgust when they see the var keyword?

Your Turn
What I'd really like to know is whether there's an idiomatic way of doing this that I've just never come across.

That's all I can come up with at the moment. I'm sure there's other ways to do it. One that looks simple, takes a sinlge parameter and runs fast would be ideal. If you've got some good ideas of other ways to do this, please leave something in the comments!

Want to learn more?
If all of this has just made you think you might need to do a bit more study on what recursion, Streams or folding are, try one of these great books:

From Amazon...


From Book Depository...


6 comments:

  1. scala> val numbers = List.range(1, 11)
    numbers: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

    scala> numbers.scanLeft(0)(_ + _)
    res0: List[Int] = List(0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55)

    scala> res0.tail
    res1: List[Int] = List(1, 3, 6, 10, 15, 21, 28, 36, 45, 55)

    scala>

    ReplyDelete
  2. You can simplify the 2nd recursive solution a bit by using getOrElse:

    def cumulativeTotalRec(nums: List[Int], totals: List[Int] = Nil): Seq[Int] = {
    nums match {
    case Nil => totals.reverse
    case n :: ns =>
    cumulativeTotalRec(ns, (n + totals.headOption.getOrElse(0)) :: totals)
    }

    Which makes it nicer than the first one, with no extra complication.
    }

    ReplyDelete
  3. How about this:

    scala> List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).scanLeft(0)(_ + _)
    res0: List[Int] = List(0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55)

    ReplyDelete
  4. @missingfaktor and @Derek: Thanks! I had a feeling there might be a built-in function for doing this, but couldn't find it. (I mean to ask the question, but forgot.)

    @Andrew: Good addition as well. I usually forget I can get the head like that rather than matching it with ::.

    Cheers, guys.

    ReplyDelete
  5. Beside the previous answers for the specific problem, from a general point of view there is nothing bad about using mutable variables as long as they don't leak outside the definition, as is in your case.
    An external user of cumulativeTotalComprehension can't access the mutable state, so the referential integrity of the function is preserved.

    ReplyDelete
  6. scala> numbers.foldLeft(List[Int](0))((acc,e) => acc.head + e :: acc).reverse.tail
    res10: List[Int] = List(1, 3, 6, 10, 15, 21, 28, 36, 45, 55)

    scanLeft is definitely nicer :)

    ReplyDelete