State transformers

One useful monad is the state transformer, which allows us to avoid explicitly carrying state information from one part of a program to another.

Example 0. Stateful random-number generator

As an example of the clumsiness of carrying state, look at this random-number generator:
-- Random-number generator in state-passing style

rng0 :: State -> (Double, State)
rng0 (State seed) = (scale seed, State (next seed))
For completeness, here are some suitable auxiliary functions that implement a standard 31-bit congruential random number generator, from "Numerical Recipes"
  where
      scale x = fromIntegral x/fromIntegral m	-- scale to (0,1)
      next x = (x*7141+54773) `mod` m
      m = 259200 :: Int
While the definition or rng0 is succinct enough, its state obtrudes in all the code that uses it, witness this code to generate a path among randomly chosen points in the unit square.
extend adds another point to a path. The state is mentioned 6 times--three times as often as any other variable.
type Point = (Double,Double)
type Path = [Point]

extend :: (Path,State) -> (Path,State)
extend (ps, s) = let
                     (x, s') = rng0 s
                     (y, s'') = rng0 s'
                 } in (ps++[(x,y)], s'')
To obtain a random path using extend, we must initialize the state and the path, and extract the path from the result, as in this fragment that creates a 2-point path:
test = fst (apply (extend.extend) ([],State 12345) )
We can hide most explicit mention of the state by programming the random-number generator as a monadic state transformer.

The Monad class

Recall the definition of the Monad type class:
class Monad m where
	(>>=) :: m a -> (a->m b) -> m b
	(>>) :: = m a -> m b -> m b
	return :: a -> m a
	fail :: String -> m a

	p >> q = p >>= \_ -> q
	fail s = error s
In particular type IO is an instance of class Monad.

State transformer

A state transformer carries a function from a state to a value paired with a new state. Both value and state can have any type. A state transformer is equipped with a function that applies the carried function to a state.
data ST state value = ST (state -> (value,state))

apply :: ST state value -> state -> (value, state)
apply (ST f) s = f s
Since state transformers are monads, they should be equipped with (at least) bind (>>=) and return functions.
instance Monad ST where
	(ST f)  >>= q = ST f' where
		f' = \s -> let (v,s') = f s in apply (q v) s'
	return x = ST (\s -> (x,s))
The following summary of types may help in threading one's way through the definition of (>>=).
	f :: state -> (value,state)	-- definition of ST
	(v,s') :: (value,state)
	q :: value -> ST state value	-- definition of >>=
	q v :: ST state value
	apply (q v) s' :: (value,state) -- definition of apply	
	f' :: state -> (value,state)    -- lambda abstraction
	ST f' :: ST state value         -- definition of ST	

Example 1. Monadic random-number generator

It is easy to re-express the random number generator of Example 0 as a state transformer.
rng1 :: ST State Double
rng1 = ST (\(State seed) -> (scale seed, State (next seed)))
Even more succinctly, we could have written
rng1 = ST rng0
Having implemented rng1 as a monad, we now have to use it in monadic style. The state recedes from sight; only values ps,x,y appear explicitly in the program
extend :: Path -> ST State Path
extend ps = do 
               x<-rng1
               y<-rng1
               return (ps++[(x,y)])

test :: ST State Path
test = do
            ps<-return []	-- start with the empty path
            ps'<-extend ps	-- extend it twice
            ps''<-extend ps'
            return ps''         -- to give a result
We aren't done, though. With everything having been lifted into the monad world, test yields only a state transformer. To actually do the intended calculation, we have to initialize the state, apply the state transformer, and extract the first (value) part of the result.
runtest = fst (apply test (State 12345))
We see ugly intermediate values ps,ps',ps'' in the definition of test. To get rid of them, we can use (>>=), which is the composition function for state transformers, and also recall that return is a right identity for (>>=), to get
test = return [] >>= extend >>= extend
Function composition runs here from left to right, like a pipeline in a Unix shell, but opposite to the order of Haskell's (.) operator.

Example 2. Queue

We wish to maintain a queue of elements of type a. Queue primitives are
queuemake a new (empty) queue
enq xenqueue value x
deqdequeue a value
return xyield value; don't change queue
For example, this sequence of operations yields a value of 1 and a queue containing 2.
   apply (do { enq 1; enq 2; x<-deq; return x}) newQueue
If the data type of a queue of a's is Q a, and that of a state transformer on queues is Qprog, we define
newQueue :: Q a

type QProg a = ST (Q a) a	
 
enq :: a -> QProg a
enq x = ST (\q -> (x, put x q))

deq :: QProg a
deq = ST get
We can implement queues in various ways. A convenient approach, with O(1) amortized cost for enq or deq, is to use two lists: the input end of the queue in LIFO order and the output in FIFO order. enq and deq access the head of one list or the other. When the output end is empty it is refilled from the input end.
data Q a = Q [a] [a] deriving Show

newQueue = Q [] []

put :: a -> Q a -> Q a
put x (Q xs ys) = (Q (x:xs) ys)

get :: Q a -> (a, Q a)
get (Q xs (y:ys)) = (y, Q xs ys)
get (Q [] []) = error "get from empty q"
get (Q xs []) = get (Q [] (reverse xs))

Example 3. Expanding the state

Once we've got something written in monad style, we can build on top of it.

Suppose we want to count function invocations. We generalize state, using Haskell's optional field-naming convention

data State = State {seed::Int, count::Int}
Field names serve both to select and to replace fields. The random-number generator refers to the seed; a ticker refers to the count.
rng2 :: ST State Double
rng2 = ST (\s -> (seed s, s{seed = next(seed s)}))

tick :: ST ()
tick :: ST (\s-> ((), s{count = (count s)+1}))

extend ps = do 
               tick()
               x<-rng2
               y<-rng2
               return (ps++[(x,y)])

-- test is the same as in Example 2.

runtest = let (ps,s) = (apply test(State{seed=12345, count=0}))
	in (ps, count s)
runtest reveals the final Path and the final count. Obviously there has been judgment exercised about which is value and which is state.

If we add more fields to the state, the only code that needs to be changed is the data definition and the initialization in runtest.