Topic: State Monads Date: Dec. 2, 2009 Number: 29 Examples: State.hs, MemoS.hs, EditDistanceM.hs, Fib.hs, StateT.hs, ParserST.hs, Fibonacci.hs, SimpleMemo.hs, EditDistanceSM.hs, Main.hs Reading: Chapter 18 Review session Friday, 10:00, 006 Kemeny Exam Saturday, 8:00 am, 006 Kemeny -- State Monads We saw in Parser how useful it was to have #, !, ?, and >-> worry about all the details of how the remaining string was passed from parser to parser. Last time we showed how to make Parser an instance of Monad, where the bind operation passed on state information. We saw that this simplified writing the various operators. We would like to do this in general - have state passed around "under the hood". A State Monad is a way to do this. It is a generalization of what we saw in Parser. (Note - I find this way of defining a State monad clearer than the one in the book, but they are equivalent.) The State data type is: data State state value = State { stateRunner :: state -> (value, state) } We have the same pattern - the State (Parser in parser) is a function from a current state (String in parser) to a (value,state) pair. No Maybe this time. The returned state is the updated state. So our monad is: instance Monad (State state) where return value = State (\ state -> (value, state)) action >>= function = State (\ firstState -> let (intermediateValue, secondState) = runState action firstState in runState (function intermediateValue) secondState) fail = error Again, this is similar to Parser. The return is very similar. The fail just returns an error. But bind is quite similar, but is simpler because there is no Maybe in the return value. Remember that bind's type signature is: m a -> (a -> m b) -> m b So the action is of type (m a), which is a function from a state to a (value, state) pair. The function is of type (a -> m b). We first arrange for the return to have constructor State and to be a function from a state to something. The something should be a (value,state) pair. To compute this pair, we run the action on the input state firstState. This yields an intermediateValue and a secondState. We get our (m b) to return by first applying (function intermediateValue) to get an action of type (m b) We then run this action on secondState to get our final (value,state) pair. We have threaded the state into the first action, out of the first action and into the second action (which we got by evaluating "function"), and finally out of the second action. The State module also provides some utility functions to make our life easier. runState we have already seen - it runs an action on a given state. We also have evalState (run action and return the value part of the pair) and execState (run action and return the new state). The last two functions are perhaps the strangest looking ones. Their idea is simple. "get" is used to fetch the current state and treat it as the value of the "get". "put" makes it so that the thing that it is bound to will get a specified new state. These would have been useful in the "char" parser definition - the reason for the anonymous function was to make the state available to the two parsers: char :: Parser Char char = Parser $ \cs -> case cs of "" -> Nothing (c:cs) -> Just (c,cs) The alternative would have been: char = do string <- get case string of "" -> fail "No more characters in string!" (c : cs) -> do put cs; return c Note that the "string <- get" will bind the "value" of get (which is the state!) to string. We then can pattern match on this, just as we pattern matched on cs in the original. But we are not inside a Parser constructor and anonymous function, so we use monad operations. "fail" returns an anonymous function returning "Nothing", just as our first code did. "return c" creates a function whose value is (value, s), where s is the state passed into the return. The value should clearly be c. But how do we make the state cs? "put cs" does exactly that. We now see what we want, but how do we get it? To return the current value of the state as the value of get is easy, if strange-looking: get :: State a a get = State (\ state -> (state, state)) Performing "s <- get" will result in the state being assigned to s (because normally the value would be, but the value IS the state). To "put" we create a State with an anonymous function that ignores the old state and uses put's argument as the state. But what should the value be? There is no value, so () is used. put :: state -> State state () put newState = State (\ _oldState -> ((), newState)) Because the Parser needs to go to Maybe (value,state) rather than (value,state) we cannot use State to implement it. Instead we use StateT, a "state monad tranformer" class. The difference between State and StateT is that in StateT a monad is involved. So the definition of State is: data State state value = State { stateRunner :: state -> (value, state) } while for StateT it is: data StateT state m value = StateT { stateRunner :: state -> m (value, state) } Note that in this case the thing returned by the function is not just a "(value, state)" pair, but an "m (value, state)". So the pair is returned in a monad. The StateT code is almost identical to the State code, except that everything returns a monad. The main things that change are type signatures and the fact that the "let" constructs in State are replaced by "do" constructs with "<-" to remove the (state, value) pair from the monad. Compare: evalState :: State state value -> state -> value evalState action initialState = let (value, _finalState) = runState action initialState in value to evalStateT action initialState = do (value, _finalState) <- runStateT action initialState return value For Parser this monad is Maybe. Look at ParserST.hs to see how this works. We define: type ReturnMonad = Maybe type Parser a = StateT String ReturnMonad a The functions are very similar to what we saw in ParserM. But what if we wanted our parser to return a list of all valid parses rather than just a Maybe of a single parse? The main difference is that we would then have to consider both sides of a "!", even if the first succeeded. We could get multiple valid parses if the grammar is ambiguous. So we will want to return a list of valid parses, with the empty list indicating failure. What change would be needed to make our parser return a list of valid parses? One word: replace "Maybe" by "[]" in the definition of ReturnMonad! Why? MonadPlus for [] is (++), so it takes all left choices followed by all right choices when doing (!). When binding two parsers together, every valid first parse will be combined with every valid second parse, which is what we want. Failure n [] is defined to be the empty list, so that works as expected. I modified the grammar to allow ambigous sentences. For example, when parsing "time flies" I get: [(S (UnmodNoun "time") (IVerb "flies"),""), (S (ImpVerb "time") (UnmodNoun "flies"),"")] The first choice says that "time" is a noun with no modifiers and "flies" is an intransitive verb. The second choice says that "time" is an imperative verb and "flies" is a noun with no modifiers. Run the parser on all 6 sentences in the test data. Actually, saying I had to change only one word is a slight lie. parseWord is supposed to parse letters until something else is found and parseSpaces is supposed to parse white space until something else is found. But: runParser (takeRepeatParse letter) "The cat" => [("The"," cat"),("Th","e cat"),("T","he cat")] The problem is that takeRepeatParse calls takeWhileParse, and that parser must consider both choices for (!) every time it is called. With Maybe the first parse is returned if it succeeds, so return [] could not be used until the letter parser failed. How do we get around this? We define: atEndOf :: (Char -> Bool) -> Parser String atEndOf test = do string <- get case string of "" -> return "" -- Succeed if string empty (c:cs) -> if test c then fail "More valid characters" else return "" This is a new type of parser - it does not parse any characters! It succeeds only if the next character (if any) fails the predicate test. It is a "look-ahead" function. So we define parseSpaces = takeWhileParse space #- atEndOf isSpace and parseWord = elimSpacesAround (takeRepeatParse letter #- atEndOf isAlpha) This way neither can quit early. I made these modifications, which are not needed for Maybe, to avoid lots of ambiguous ways of parsing spaces using elimSpacesAround. (Consider "The cat". The space can either be eliminated as part of parsing "The" or as part of parsing "cat", giving two different parses.) -- Memoizing For an example that uses the State module we look at MemoS.hs. We saw in EditDistance that passing the state (the Map) around was a pain in the neck. We would like the map to be the state part of a State monad. So we define a data type MemoState, consisting of a state (Map.Map a b) and a value of type b. We define a type memoFunction to be a function (a -> MemoState a b). The heart of our memoization is a function callMemo, which can be applied to a MemoFunction and an argument, and yields a MemoState. We will apply this to every call to the function that we want to memoize. It performs the basic memoization actions. First, it gets the current map (the state) via a call to get. It looks up the argument in the map. If it finds it, it returns the result. (Note return will keep the current map, so state is passed here). If nothing is found, we compute the function on the argument and call it result. We get the current map. We put the updated map (with the (argument, result) pair inserted), thus making it the state that gets incorporated into the "return result". Finally, runMemo on a function and arguement simply computes (function a) starting with an empty map. evalState returns the value part of the resulting state. -- Fib We first show how to memoize fibonacci numbers. It is easy: fibHelper :: MemoFunction Integer Integer fibHelper n = if n < 2 then return 1 else do fm1 <- callMemo fibHelper (n-1) fm2 <- callMemo fibHelper (n-2) return (fm1 + fm2) fib :: Integer -> Integer fib n = runMemo fibHelper n We write a version of fib called fibHelper where each recursive call is preceeded by "callMemo". We get the two results and return their sum. fib n then is defined in terms of runMemo and fibHelper. -- EditDistance The same trick is used for EditDistance. There was already a helper function eDist. We simply preceed each recusive call to eDist by callMemo. We then runMemo on eDist with the desired pair. Note that we no longer have to thread the map as we compute matchDist, delDist, insDist, and swapDist. -- Improved memoization This is not bad, but it requires adding callMemo calls before all recursive calls in the function. It would be nice if that were not necessary - if you could take any function and memoize it without re-writing it. This is possible. Taylor has written versions of the Memo module that can be used to memoize any function that is written in a certain style. It has to be written as a function from a function to a function. makeFibM :: (Integral a, Monad m) => (a -> m a) -> (a -> m a) makeFibM fib = \n -> if n < 2 then return 1 else do a <- fib (n - 1) b <- fib (n - 2) return (a + b) It then finds the "fixed point" of this function. A fixed point is a value x such that f(x) = x. But here we are dealing with functions being mapped to functions. What does this mean? What happens if we call: (makeFibM fib) n Then the function call returns the function \n -> if n < 2 then return 1 else do ... So if n < 2, we know the answer. But what if it is not? Then we cannot say,, because we must compute fib(n-1) and fib (n-2), and fib is whatever function was passed as a parameter. If fib were (^2) (in the right monad) the answer is different than if fib were (*5). If fib were the normal fibonacci function then we would get the right answer. If not, we can get almost anything. But what if we call (makeFibM (makeFibM fib)) 2? The n in the outer makeFibM is not <2, so we call fib 1 and fib 0. But this time we know what fib is: it is (makeFibM fib). And we said that we already know what it computes for 0 and 1. So for n <= 2 (makeFibM (makeFibM fib)) n computes the fibonacci number correctly. For bigger n it depends what fib is. We can repeat that again. In particular, we could compute: (makeFibM (makeFibM (makeFibM fib))) n for n = 3, because it will call (makeFibM (makeFibM fib)) for 2 and 1, and (makeFibM (makeFibM fib)) computes the right answers for these, independent of what fib is. There is a pattern here - every time I nest in an additional makeFibM call, I can compute the correct value for fibonacci n for n one larger. But is there a way to get the normal fibonacci function from makeFibM, one that works for all n >= 0?? The standard definition of "fix" in Haskell is: fix :: ((a -> b) -> (a -> b)) -> (a -> b) fix f = f (fix f) What does this say? Substituting: fix f => f (fix f) => f (f (fix f)) => f (f (f (fix f))) => ... We are repeatedly applying f. In fact, we have infinite recursion! We have an infinitely deep nesting of f. In a strict language we would have a problem, but Haskell evaluates lazily. So if we compute: (fix makeFibM) n for n < 2 we get the answer from: makeFibM (fix makeFibM) n But if n = 2, we need to go to (makeFibM (makeFibM (fix makeFibM)))) 2 before we have the right answer. In general, we have to "unroll" n nested makeFibM calls to compute: (fix makeFibM) n. The underlying idea is that if you give fix something like makeFibM it will keep applying it to itself, getting a function that is "more and more" defined until everything that can be defined is defined. (This is very hand-wavy!) For any specific argument it will "unroll" exactly as many nested calls as are needed to compute the value for that argument. Thus (fix makeFibM) n is exactly the normal Fibonacci function. The code in fixMemo basically inserts "callMemo" calls before each call to the function to be memoized in this chain. (Hand-wavy, but give the correct idea.) This is beyond the scope of this class. But I wanted to show you that there are lots of powerful, interesting things that can be done with functional programming that we have not seen. SimpleMemo.hs does the fixMemo that we talk about above. Fibonacci.hs uses it to compute Fibonacci numbers and EditDistanceSM.hs uses it to compute edit distance in a memoized way. Main.hs is a driver program for all of these functions. For the curious, see "Y combinator" in Wikipedia or other references.