F# Simple memoization


Memoization consists of caching function results to avoid computing the same result multiple times. This is useful when working with functions that perform costly computations.

We can use a simple factorial function as an example:

let factorial index =
    let rec innerLoop i acc =
        match i with
        | 0 | 1 -> acc
        | _ -> innerLoop (i - 1) (acc * i)

    innerLoop index 1

Calling this function multiple times with the same parameter results in the same computation again and again.

Memoization will help us to cache the factorial result and return it if the same parameter is passed again.

Here is a simple memoization implementation :

// memoization takes a function as a parameter
// This function will be called every time a value is not in the cache
let memoization f =
    // The dictionary is used to store values for every parameter that has been seen
    let cache = Dictionary<_,_>()
    fun c ->
        let exist, value = cache.TryGetValue (c)
        match exist with
        | true -> 
            // Return the cached result directly, no method call
            printfn "%O -> In cache" c
        | _ -> 
            // Function call is required first followed by caching the result for next call with the same parameters
            printfn "%O -> Not in cache, calling function..." c
            let value = f c
            cache.Add (c, value)

The memoization function simply takes a function as a parameter and returns a function with the same signature. You could see this in the method signature f:('a -> 'b) -> ('a -> 'b). This way you can use memoization the same way as if you were calling the factorial method.

The printfn calls are to show what happens when we call the function multiple times; they can be removed safely.

Using memoization is easy:

// Pass the function we want to cache as a parameter via partial application
let factorialMem = memoization factorial

// Then call the memoized function
printfn "%i" (factorialMem 10)
printfn "%i" (factorialMem 10)
printfn "%i" (factorialMem 10)
printfn "%i" (factorialMem 4)
printfn "%i" (factorialMem 4)

// Prints
// 10 -> Not in cache, calling function...
// 3628800
// 10 -> In cache
// 3628800
// 10 -> In cache
// 3628800
// 4 -> Not in cache, calling function...
// 24
// 4 -> In cache
// 24