The Memoize Function

(An earlier version of this post had embedded Scastie windows. While really cool, this was greatly affecting page performance and I had to replace them with static images. You can still reach the Scastie windows by clicking on each image.)

(The code samples in this post can also be found in this github repo.)

My name is Brad and I’m a software engineer at Axon. Something that is interesting about engineering at Axon is that we write our backend services in Scala. I did not know Scala before starting here last year and I’ve had to learn a lot along the way.

Today, I want to introduce all of you to what has become one of my favorite tricks in Scala (and functional programming more generally): the memoize function.

In this post, we’re going to look at the type of optimization problems where memoization is useful, review how we traditionally memoize functions, and then take a look at how the memoize function can make our code cleaner and our lives easier.

I’ve Got 99 Problems and All of Them are Overlapping Subproblems

Sometimes when we’re working on programming problems that involve recursive structures, we end up repeating a lot of work and solving the same sub-problems over and over again. This class of problems has a name: Problems with Overlapping Subproblems. More specifically, these happen whenever:

  • A problem is composed of at least two other subproblems

  • Some or all of the subproblems show up more than once during the calculation

For example, in the naive recursive implementation of the Fibonacci sequence, each term (aside from the base cases) is calculated as the sum of two other smaller terms. Additionally, the smaller a term gets, the more frequently it shows up in the calculation:

As you can see, the sixth and fifth terms of the Fibonacci sequence are each only calculated once, but the fourth is calculated twice, the third three times, the second five times, and the first is calculated a whopping 8 times. (Also, have we seen that sequence of numbers anywhere before? 😉)

In seriousness though, this is a big problem because it means this solution is going to run in exponential time, which is wholly unnecessary when linear time solutions to this problem exist. That’s a shame though, because once you get rid of the println statements, this is an extremely clean implementation:

There are two general strategies that can be used to improve the performance of solutions to problems like these: memoization and dynamic programming. To keep things tidy, I’m going to shelve the discussion of dynamic programming for another day and instead take a deep dive into memoization.

The Top Down Approach

The basic idea of memoization is to only do a calculation once. When we solve a subproblem, we note the solution in a cache (a.k.a. take a memo of it), and if we need to solve the problem again later, we just retrieve the solution from the cache.

If we were to memoize the naive recursive fibonacci function, it would look something like this:

With the verbose variant looking like this:

On the whole, I consider this to be a pretty decent solution. It’s fairly readable and it works in linear time which is a massive improvement over the exponential run time of the naive solution. My only concern is that the code which is responsible for memoization is all tangled up with the core code that computes the Fibonacci sequence. This isn’t ideal for readability or for reuse if there happens to be another function that we would like to memoize.

Fortunately for use though, we have a solution to that problem.

Behold the Memoize Function!

Now we have come to the main course:

The memoize function allows us to cleanly separate our memoization logic from the core logic of our function. The result is as readable as the naive recursive solution while still being as performant as the memoized one.

There’s a lot going on in lines 1 - 3, so let’s break that down:

  1. memoize is a higher order function. Higher order functions are ones which take other functions as input or return other functions as output. In this case, memoize does both, taking in a function as input and returning a modified version of that function as its output.

  2. memoize is a generic function. This function takes in two type parameters, I and O, which allows us to memoize any one-parameter function. We can do Int => Int functions like fibonacci, but we can also do String => String, Double => Boolean, Foo => Bar, or anything else we can imagine.

  3. The implementation of memoize doesn’t just use a HashMap like the previous memoization example did. The implementation is a HashMap. How does that work? Well in Scala, HashMap[Key, Value] is a subtype of of the type Key => Value.

    As we said in the first bullet point, functions are first class citizens in Scala. We can assign them to values and give them to and return them from other functions. They even have types. The type of fibonacci is Int => Int, but what does Int => Int mean? It’s actually a short hand for the type Function1[T1, R], which is the trait of all single-parameter functions.

    The type HashMap[Key, Value] extends the MapOps trait, which extends the PartialFunction[Key, Value] trait, which extends Function1[T1, R]. In short, HashMap[Int, Int] is a subtype of Int => Int and can therefore be used anywhere we need an Int => Int, like in our memoize function.

  4. This isn’t just a vanilla HashMap either. It’s apply (a.k.a. getter) method has been overridden in a very special way. Specifically, it has been replaced with getorElseUpdate(key, f(key)). This method will either return the value associated with key if it’s in the HashMap, or it will calculate it using the supplied expression, add the key-value pair to the HashMap, and then return the value. What is this f function though? That’s the function we’re trying to memoize! In our example, that is the fibonacci function.

Shallow vs Deep Memoization

Any particularly astute observers out there may have noticed that we partially re-wrote our fibonacci function after we introduced memoize, taking it from this:

def fibonacci(n: Int): Int = n match {
  case 0 | 1 => 1
  case n => fibonacci(n - 1) + fibonacci(n - 2)
}

to this:

val fibonacci: Int => Int = memoize {
  case 0 | 1 => 1
  case n => fibonacci(n - 1) + fibonacci(n - 2)
}

You may now be wondering, “If we already had a perfectly good fibonacci function, couldn’t we have just passed that into memoize like instead of re-writing the whole thing?” as in:

val memoizedFibonacci = memoize(fibonacci)

Well, sort of. We could have done that and it would have compiled and returned correct results, but it wouldn’t have been exactly what we were expecting. Let’s walk through it to see what I mean:

There are a couple of things going on here that are weird:

  • We’re not using memoization to calculate fibonacci(5) initially. We can see from the console output that the calculation is making repeated calls to fibonacci for the inputs 3, 2, 1 and 0. Fortunately though, when we call fibonacci(5) a second time, it is correctly pulling that out of the cache.

  • It didn’t have the value of fibonacci(4) cached. Even though this was the first time we called fibonacci(4) directly, fibonacci(4) was definitely calculated while calculating fibonacci(5), and so it’s strange that that result hadn’t already been memoized.

The cause of both of these problems becomes pretty clear if we step through the calculation ourselves:

  1. We first calculate memoizedFibonacci as the result of memoizing unmemoizedFibonacci. Even though the explicit type of memoizedFibonacci is Int => Int, the underlying implementation is a HashMap with an overridden apply method.

  2. When we call memoizedFibonacci(5) the first time, we’re really invoking the overridden apply method of the HashMap, which checks if the value 5 is contained in the HashMap. 5 is not in the HashMap though, so it goes on to calculate the value of 5 using the expression f(5). What is f though? It’s unmemoizedFibonacci, and this is the cause of our problems!

  3. unmemoizedFibonacci does not check the cache before calculating its result. (It doesn’t even know about the cache.) Rather, it just goes on to calculate unmemoizedFibonacci(5) as unmemoizedFibonacci(4) + unmemoizedFibonacci(3) which is why we’re seeing all the repeated calls when we call memoizedFibonacci(5) the first time.

  4. Additionally unmemoizedFibonacci doesn’t store any of it’s intermediate results in the cache either (because it still doesn’t know about the cache). It’s only after HashMap.apply has calculated f(5) does it store the result of f(5) (and only the result of f(5)) in the HashMap. This is why we have to calculate memoizedFibonacci(4) instead of just being able to look it up in the cache.

In short, the root of our problem is that our function is only shallowly memoized. It’s only checking for and storing the results of it’s calculations at the outermost level. Once we enter the function f, all efforts to memoize the calculation go out the window.

In order to realize all the benefits of the memoize function, we need deeply memoize our functions. Our default value function, f, needs to use the memoized function that we are creating. This is why we had to rewrite our fibonacci function.

Let’s Wrap Up

Before we get out of here, I need to offer a huge shoutout to Stack Overflow user pathikrit for introducing me to the memoize function.

To summarize what we’ve learned today:

  • There’s a class of problems called Problems with Overlapping Subproblems which can be solved recursively, but the recursive solutions usually exhibit exponential runtime and need some help to become performant enough for practical use.

  • One method for optimizing this type of problems is memoization. This technique involves recording the solutions to subproblems in a cache so they can be quickly retrieved when they’re needed again later.

  • The memoize function is a helpful method for cleanly separating memoization logic from the core logic of your solution.

  • You can use the memoize function to create shallow memoizations of any existing function. To realize all the benefits memoization though, you will need to write new, deeply-memoized implementations your functions.

I hope this post has been interesting or useful or both.

Happy Scala-ing!

Previous
Previous

Much Ado About Null

Next
Next

WTF (What the Function) is a For-Comprehension