The Memoize Function
(An earlier version of this post had embedded Scastie windows. While really cool, this was greatly affecting page performance and I had to replace them with static images. You can still reach the Scastie windows by clicking on each image.)
(The code samples in this post can also be found in this github repo.)
My name is Brad and I’m a software engineer at Axon. Something that is interesting about engineering at Axon is that we write our backend services in Scala. I did not know Scala before starting here last year and I’ve had to learn a lot along the way.
Today, I want to introduce all of you to what has become one of my favorite tricks in Scala (and functional programming more generally): the memoize function.
In this post, we’re going to look at the type of optimization problems where memoization is useful, review how we traditionally memoize functions, and then take a look at how the memoize function can make our code cleaner and our lives easier.
I’ve Got 99 Problems and All of Them are Overlapping Subproblems
Sometimes when we’re working on programming problems that involve recursive structures, we end up repeating a lot of work and solving the same sub-problems over and over again. This class of problems has a name: Problems with Overlapping Subproblems. More specifically, these happen whenever:
A problem is composed of at least two other subproblems
Some or all of the subproblems show up more than once during the calculation
For example, in the naive recursive implementation of the Fibonacci sequence, each term (aside from the base cases) is calculated as the sum of two other smaller terms. Additionally, the smaller a term gets, the more frequently it shows up in the calculation:
As you can see, the sixth and fifth terms of the Fibonacci sequence are each only calculated once, but the fourth is calculated twice, the third three times, the second five times, and the first is calculated a whopping 8 times. (Also, have we seen that sequence of numbers anywhere before? 😉)
In seriousness though, this is a big problem because it means this solution is going to run in exponential time, which is wholly unnecessary when linear time solutions to this problem exist. That’s a shame though, because once you get rid of the println
statements, this is an extremely clean implementation:
There are two general strategies that can be used to improve the performance of solutions to problems like these: memoization and dynamic programming. To keep things tidy, I’m going to shelve the discussion of dynamic programming for another day and instead take a deep dive into memoization.
The Top Down Approach
The basic idea of memoization is to only do a calculation once. When we solve a subproblem, we note the solution in a cache (a.k.a. take a memo of it), and if we need to solve the problem again later, we just retrieve the solution from the cache.
If we were to memoize the naive recursive fibonacci function, it would look something like this:
With the verbose variant looking like this:
On the whole, I consider this to be a pretty decent solution. It’s fairly readable and it works in linear time which is a massive improvement over the exponential run time of the naive solution. My only concern is that the code which is responsible for memoization is all tangled up with the core code that computes the Fibonacci sequence. This isn’t ideal for readability or for reuse if there happens to be another function that we would like to memoize.
Fortunately for use though, we have a solution to that problem.
Behold the Memoize Function!
Now we have come to the main course:
The memoize
function allows us to cleanly separate our memoization logic from the core logic of our function. The result is as readable as the naive recursive solution while still being as performant as the memoized one.
There’s a lot going on in lines 1 - 3, so let’s break that down:
memoize
is a higher order function. Higher order functions are ones which take other functions as input or return other functions as output. In this case,memoize
does both, taking in a function as input and returning a modified version of that function as its output.memoize
is a generic function. This function takes in two type parameters,I
andO
, which allows us tomemoize
any one-parameter function. We can doInt => Int
functions likefibonacci
, but we can also doString => String
,Double => Boolean
,Foo => Bar
, or anything else we can imagine.The implementation of
memoize
doesn’t just use aHashMap
like the previous memoization example did. The implementation is aHashMap
. How does that work? Well in Scala,HashMap[Key, Value]
is a subtype of of the typeKey => Value
.As we said in the first bullet point, functions are first class citizens in Scala. We can assign them to values and give them to and return them from other functions. They even have types. The type of
fibonacci
isInt => Int
, but what doesInt => Int
mean? It’s actually a short hand for the typeFunction1[T1, R]
, which is the trait of all single-parameter functions.The type
HashMap[Key, Value]
extends theMapOps
trait, which extends thePartialFunction[Key, Value]
trait, which extendsFunction1[T1, R]
. In short,HashMap[Int, Int]
is a subtype ofInt => Int
and can therefore be used anywhere we need anInt => Int
, like in ourmemoize
function.This isn’t just a vanilla
HashMap
either. It’sapply
(a.k.a. getter) method has been overridden in a very special way. Specifically, it has been replaced withgetorElseUpdate(key, f(key))
. This method will either return the value associated withkey
if it’s in the HashMap, or it will calculate it using the supplied expression, add the key-value pair to the HashMap, and then return the value. What is thisf
function though? That’s the function we’re trying to memoize! In our example, that is thefibonacci
function.
Shallow vs Deep Memoization
Any particularly astute observers out there may have noticed that we partially re-wrote our fibonacci
function after we introduced memoize
, taking it from this:
def fibonacci(n: Int): Int = n match { case 0 | 1 => 1 case n => fibonacci(n - 1) + fibonacci(n - 2) }
to this:
val fibonacci: Int => Int = memoize { case 0 | 1 => 1 case n => fibonacci(n - 1) + fibonacci(n - 2) }
You may now be wondering, “If we already had a perfectly good fibonacci
function, couldn’t we have just passed that into memoize
like instead of re-writing the whole thing?” as in:
val memoizedFibonacci = memoize(fibonacci)
Well, sort of. We could have done that and it would have compiled and returned correct results, but it wouldn’t have been exactly what we were expecting. Let’s walk through it to see what I mean:
There are a couple of things going on here that are weird:
We’re not using memoization to calculate
fibonacci(5)
initially. We can see from the console output that the calculation is making repeated calls tofibonacci
for the inputs3
,2
,1
and0
. Fortunately though, when we callfibonacci(5)
a second time, it is correctly pulling that out of the cache.It didn’t have the value of
fibonacci(4)
cached. Even though this was the first time we calledfibonacci(4)
directly,fibonacci(4)
was definitely calculated while calculatingfibonacci(5)
, and so it’s strange that that result hadn’t already been memoized.
The cause of both of these problems becomes pretty clear if we step through the calculation ourselves:
We first calculate
memoizedFibonacci
as the result of memoizingunmemoizedFibonacci
. Even though the explicit type ofmemoizedFibonacci
isInt => Int
, the underlying implementation is aHashMap
with an overriddenapply
method.When we call
memoizedFibonacci(5)
the first time, we’re really invoking the overriddenapply
method of theHashMap
, which checks if the value5
is contained in theHashMap
.5
is not in theHashMap
though, so it goes on to calculate the value of5
using the expressionf(5)
. What isf
though? It’sunmemoizedFibonacci
, and this is the cause of our problems!unmemoizedFibonacci
does not check the cache before calculating its result. (It doesn’t even know about the cache.) Rather, it just goes on to calculateunmemoizedFibonacci(5)
asunmemoizedFibonacci(4) + unmemoizedFibonacci(3)
which is why we’re seeing all the repeated calls when we callmemoizedFibonacci(5)
the first time.Additionally
unmemoizedFibonacci
doesn’t store any of it’s intermediate results in the cache either (because it still doesn’t know about the cache). It’s only afterHashMap.apply
has calculatedf(5)
does it store the result off(5)
(and only the result off(5)
) in theHashMap
. This is why we have to calculatememoizedFibonacci(4)
instead of just being able to look it up in the cache.
In short, the root of our problem is that our function is only shallowly memoized. It’s only checking for and storing the results of it’s calculations at the outermost level. Once we enter the function f
, all efforts to memoize the calculation go out the window.
In order to realize all the benefits of the memoize
function, we need deeply
memoize our functions. Our default value function, f
, needs to use the memoized function that we are creating. This is why we had to rewrite our fibonacci
function.
Let’s Wrap Up
Before we get out of here, I need to offer a huge shoutout to Stack Overflow user pathikrit for introducing me to the memoize
function.
To summarize what we’ve learned today:
There’s a class of problems called Problems with Overlapping Subproblems which can be solved recursively, but the recursive solutions usually exhibit exponential runtime and need some help to become performant enough for practical use.
One method for optimizing this type of problems is memoization. This technique involves recording the solutions to subproblems in a cache so they can be quickly retrieved when they’re needed again later.
The
memoize
function is a helpful method for cleanly separating memoization logic from the core logic of your solution.You can use the
memoize
function to create shallow memoizations of any existing function. To realize all the benefits memoization though, you will need to write new, deeply-memoized implementations your functions.
I hope this post has been interesting or useful or both.
Happy Scala-ing!