We have a tree data type like this:
data Tree a = Tree a [Tree a] deriving Show
And we wish to write a function that assigns a number to each node of the tree, from an incrementing counter:
tag :: Tree a -> Tree (a, Int)
First we'll do it the long way around, since it illustrates the State
monad's low-level mechanics quite nicely.
import Control.Monad.State
-- Function that numbers the nodes of a `Tree`.
tag :: Tree a -> Tree (a, Int)
tag tree =
-- tagStep is where the action happens. This just gets the ball
-- rolling, with `0` as the initial counter value.
evalState (tagStep tree) 0
-- This is one monadic "step" of the calculation. It assumes that
-- it has access to the current counter value implicitly.
tagStep :: Tree a -> State Int (Tree (a, Int))
tagStep (Tree a subtrees) = do
-- The `get :: State s s` action accesses the implicit state
-- parameter of the State monad. Here we bind that value to
-- the variable `counter`.
counter <- get
-- The `put :: s -> State s ()` sets the implicit state parameter
-- of the `State` monad. The next `get` that we execute will see
-- the value of `counter + 1` (assuming no other puts in between).
put (counter + 1)
-- Recurse into the subtrees. `mapM` is a utility function
-- for executing a monadic actions (like `tagStep`) on a list of
-- elements, and producing the list of results. Each execution of
-- `tagStep` will be executed with the counter value that resulted
-- from the previous list element's execution.
subtrees' <- mapM tagStep subtrees
return $ Tree (a, counter) subtrees'
The bit where we are get
ting the current counter and then put
ting counter + 1 can be split off into a postIncrement
action, similar to what many C-style languages provide:
postIncrement :: Enum s => State s s
postIncrement = do
result <- get
modify succ
return result
The tree walk logic can be split out into its own function, like this:
mapTreeM :: Monad m => (a -> m b) -> Tree a -> m (Tree b)
mapTreeM action (Tree a subtrees) = do
a' <- action a
subtrees' <- mapM (mapTreeM action) subtrees
return $ Tree a' subtrees'
With this and the postIncrement
function we can rewrite tagStep
:
tagStep :: Tree a -> State Int (Tree (a, Int))
tagStep = mapTreeM step
where step :: a -> State Int (a, Int)
step a = do
counter <- postIncrement
return (a, counter)
Traversable
classThe mapTreeM
solution above can be easily rewritten into an instance of the Traversable
class:
instance Traversable Tree where
traverse action (Tree a subtrees) =
Tree <$> action a <*> traverse action subtrees
Note that this required us to use Applicative
(the <*>
operator) instead of Monad
.
With that, now we can write tag
like a pro:
tag :: Traversable t => t a -> t (a, Int)
tag init t = evalState (traverse step t) 0
where step a = do tag <- postIncrement
return (a, tag)
Note that this works for any Traversable
type, not just our Tree
type!
Traversable
boilerplateGHC has a DeriveTraversable
extension that eliminates the need for writing the instance above:
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
data Tree a = Tree a [Tree a]
deriving (Show, Functor, Foldable, Traversable)