-- | Data.HMM is a library for using Hidden Markov Models (HMMs) with Haskell. HMMs are a common method of machine learning. All of the most frequently used algorithms---the forward and backwards algorithms, Viterbi, and Baum-Welch---are implemented in this library. -- The best way to learn to use it is to visit the tutorial at http://izbicki.me/blog/using-hmms-in-haskell-for-bioinformatics. The tutorial also includes performance benchmarks and caveats that you should be aware of. module Data.HMM ( HMM(..), Prob , forward , backward , viterbi , baumWelch, baumWelchItr , simpleMM, simpleHMM, hmmJoin , verifyhmm , loadHMM , saveHMM ) where import Debug.Trace import Data.Array import Data.List import Data.List.Extras import Data.Number.LogFloat import qualified Data.MemoCombinators as Memo import Control.Parallel import System.IO import Text.ParserCombinators.Parsec type Prob = LogFloat -- | The data types for our HMM. FIXME: This should probably be changed to be HMMArray data HMM stateType eventType = HMM { states :: [stateType] , events :: [eventType] , initProbs :: (stateType -> Prob) , transMatrix :: (stateType -> stateType -> Prob) , outMatrix :: (stateType -> eventType -> Prob) } -- deriving (Show, Read) instance (Show stateType, Show eventType) => Show (HMM stateType eventType) where show hmm = hmm2str hmm hmm2str hmm = "HMM" ++ "{ states=" ++ (show $ states hmm) ++ ", events=" ++ (show $ events hmm) ++ ", initProbs=" ++ (show [(s,initProbs hmm s) | s <- states hmm]) ++ ", transMatrix=" ++ (show [(s1,s2,transMatrix hmm s1 s2) | s1 <- states hmm, s2 <- states hmm]) ++ ", outMatrix=" ++ (show [(s,e,outMatrix hmm s e) | s <- states hmm, e <- events hmm]) ++ "}" elemIndex2 :: (Show a, Eq a) => a -> [a] -> Int elemIndex2 e list = case elemIndex e list of Nothing -> seq (error ("elemIndex2: Index "++show e++" not in HMM "++show list)) 0 Just x -> x stateIndex :: (Show stateType, Show eventType, Eq stateType) => HMM stateType eventType -> stateType -> Int stateIndex hmm state = case elemIndex state $ states hmm of Nothing -> seq (error ("stateIndex: Index "++show state++" not in HMM "++show hmm)) 0 Just x -> x eventIndex :: (Show stateType, Show eventType, Eq eventType) => HMM stateType eventType -> eventType -> Int eventIndex hmm event = case elemIndex event $ events hmm of Nothing -> seq (error ("eventIndex: Index "++show event++" not in HMM "++show hmm)) 0 Just x -> x -- | Use simpleMM to create an untrained standard Markov model simpleMM eL order = HMM { states = sL , events = eL , initProbs = \s -> evenDist--skewedDist s , transMatrix = \s1 -> \s2 -> if (length s1==0) || (isPrefixOf (tail s1) s2) then skewedDist s2 --1.0 / (logFloat $ length sL) else 0.0 , outMatrix = \s -> \e -> 1.0/(logFloat $ length eL) } where evenDist = 1.0 / sLlen skewedDist s = (logFloat $ 1+elemIndex2 s sL) / ( (sLlen * (sLlen+ (logFloat (1.0 :: Double))))/2.0) sLlen = logFloat $ length sL sL = enumerateStates (order-1) [[]] enumerateStates order' list | order' == 0 = list | otherwise = enumerateStates (order'-1) [symbol:l | l <- list, symbol <- eL] -- | Use simpleHMM to create an untrained hidden Markov model simpleHMM :: (Eq stateType, Show eventType, Show stateType) => [stateType] -> [eventType] -> HMM stateType eventType simpleHMM sL eL = HMM { states = sL , events = eL , initProbs = \s -> evenDist--skewedDist s , transMatrix = \s1 -> \s2 -> skewedDist s2 , outMatrix = \s -> \e -> 1.0/(logFloat $ length eL) } where evenDist = 1.0 / sLlen skewedDist s = (logFloat $ 1+elemIndex2 s sL) / ( (sLlen * (sLlen+ (logFloat (1.0 :: Double))))/2.0) sLlen = logFloat $ length sL -- | forward algorithm determines the probability that a given event array would be emitted by our HMM forward :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> [eventType] -> Prob forward hmm obs = forwardArray hmm (listArray (1,bT) obs) where bT = length obs forwardArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Prob forwardArray hmm obs = sum [alpha hmm obs bT state | state <- states hmm] where bT = snd $ bounds obs alpha :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Int -> stateType -> Prob alpha hmm obs = memo_alpha where memo_alpha t state = memo_alpha2 t (stateIndex hmm state) memo_alpha2 = (Memo.memo2 Memo.integral Memo.integral memo_alpha3) memo_alpha3 t' state' | t' == 1 = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $ (outMatrix hmm (states hmm !! state') $ obs!t')*(initProbs hmm $ states hmm !! state') | otherwise = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $ (outMatrix hmm (states hmm !! state') $ obs!t')*(sum [(memo_alpha (t'-1) state2)*(transMatrix hmm state2 (states hmm !! state')) | state2 <- states hmm]) -- | backwards algorithm does the same thing as the forward algorithm, just a different implementation backward :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> [eventType] -> Prob backward hmm obs = backwardArray hmm $ listArray (1,length obs) obs backwardArray :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Prob backwardArray hmm obs = backwardArray' hmm obs where backwardArray' hmm obs = sum [(initProbs hmm state) *(outMatrix hmm state $ obs!1) *(beta hmm obs 1 state) | state <- states hmm ] beta :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Int -> stateType -> Prob beta hmm obs = memo_beta where bT = snd $ bounds obs memo_beta t state = memo_beta2 t (stateIndex hmm state) memo_beta2 = (Memo.memo2 Memo.integral Memo.integral memo_beta3) memo_beta3 t' state' | t' == bT = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $ 1 | otherwise = -- trace ("memo_alpha' t'="++show t'++", state'="++show state') $ sum [(transMatrix hmm (states hmm !! state') state2) *(outMatrix hmm state2 $ obs!(t'+1)) *(memo_beta (t'+1) state2) | state2 <- states hmm ] -- | Viterbi's algorithm calculates the most probable path through our states given an event array viterbi :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> [stateType] viterbi hmm obs = [memo_x' t | t <- [1..bT]] where bT = snd $ bounds obs memo_x' = Memo.integral x' x' t | t == bT = argmax (\i -> memo_delta bT i) (states hmm) | otherwise = memo_psi (t+1) (memo_x' (t+1)) -- delta :: Int -> stateType -> Prob memo_delta t state = memo_delta2 t (stateIndex hmm state) memo_delta2 = (Memo.memo2 Memo.integral Memo.integral memo_delta3) memo_delta3 t state = delta t (states hmm !! state) delta t state | t == 1 = (outMatrix hmm state $ obs!t)*(initProbs hmm state) | otherwise = maximum [(memo_delta (t-1) i)*(transMatrix hmm i state)*(outMatrix hmm (state) $ obs!t) | i <- states hmm ] -- psi :: Int -> stateType -> stateType memo_psi t state = memo_psi2 t (stateIndex hmm state) memo_psi2 = (Memo.memo2 Memo.integral Memo.integral memo_psi3) memo_psi3 t state = psi t (states hmm !! state) psi t state | t == 1 = (states hmm) !! 0 | otherwise = argmax (\i -> (memo_delta (t-1) i) * (transMatrix hmm i state)) (states hmm) -- | Baum-Welch is used to train an HMM baumWelch :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> Int -> HMM stateType eventType baumWelch hmm obs count | count == 0 = hmm | otherwise = -- trace ("baumWelch iterations left: "++(show count)) $ trace (show itr) $ baumWelch itr obs (count-1) where itr = baumWelchItr hmm obs baumWelchItr :: (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMM stateType eventType -> Array Int eventType -> HMM stateType eventType baumWelchItr hmm obs = --par newInitProbs $ par newTransMatrix $ par newOutMatrix --trace "baumWelchItr " $ HMM { states = states hmm , events = events hmm , initProbs = memo_newInitProbs , transMatrix = {-newTransMatrix---} memo_newTransMatrix , outMatrix = {-outMatrix hmm ---} memo_newOutMatrix } where bT = snd $ bounds obs memo_newInitProbs state = memo_newInitProbs2 (stateIndex hmm state) memo_newInitProbs2 = Memo.integral memo_newInitProbs3 memo_newInitProbs3 state = newInitProbs (states hmm !! state) newInitProbs state = gamma 1 state memo_newTransMatrix state1 state2 = memo_newTransMatrix2 (stateIndex hmm state1) (stateIndex hmm state2) memo_newTransMatrix2 = (Memo.memo2 Memo.integral Memo.integral memo_newTransMatrix3) memo_newTransMatrix3 state1 state2 = newTransMatrix (states hmm !! state1) (states hmm !! state2) newTransMatrix state1 state2 = --trace ("newTransMatrix"++(hmmid hmm)) $ sum [xi t state2 state1 | t <- [2..bT]] /sum [gamma t state1 | t <- [2..bT]] memo_newOutMatrix state event = memo_newOutMatrix2 (stateIndex hmm state) (eventIndex hmm event) memo_newOutMatrix2 = (Memo.memo2 Memo.integral Memo.integral memo_newOutMatrix3) memo_newOutMatrix3 state event = newOutMatrix (states hmm !! state) (events hmm !! event) newOutMatrix state event = sum [if (obs!t == event) then gamma t state else 0 | t <- [2..bT] ] /sum [gamma t state | t <- [2..bT]] -- Greek functions, included here for memoization xi t state1 state2 = (memo_alpha (t-1) state1) *(transMatrix hmm state1 state2) *(outMatrix hmm state2 $ obs!t) *(memo_beta t state2) /backwardArrayVar -- (backwardArray hmm obs) gamma t state = (memo_alpha t state) *(memo_beta t state) /backwardArrayVar backwardArrayVar = (backwardArray hmm obs) memo_beta t state = memo_beta2 t (stateIndex hmm state) memo_beta2 = (Memo.memo2 Memo.integral Memo.integral memo_beta3) memo_beta3 t' state' | t' == bT = 1 | otherwise = sum [(transMatrix hmm (states hmm !! state') state2) *(outMatrix hmm state2 $ obs!(t'+1)) *(memo_beta (t'+1) state2) | state2 <- states hmm ] memo_alpha t state = memo_alpha2 t (stateIndex hmm state) memo_alpha2 = (Memo.memo2 Memo.integral Memo.integral memo_alpha3) memo_alpha3 t' state' | t' == 1 = (outMatrix hmm (states hmm !! state') $ obs!t')*(initProbs hmm $ states hmm !! state') | otherwise = (outMatrix hmm (states hmm !! state') $ obs!t')*(sum [(memo_alpha (t'-1) state2)*(transMatrix hmm state2 (states hmm !! state')) | state2 <- states hmm]) -- | Joins 2 HMMs by connecting every state in the first HMM to every state in the second, and vice versa, with probabilities based on the join ratio hmmJoin :: (Eq stateType, Eq eventType, Read stateType, Show stateType) => HMM stateType eventType -> HMM stateType eventType -> Prob -> HMM (Int,stateType) eventType hmmJoin hmm1 hmm2 ratio = HMM { states = states1 ++ states2 , events = if (events hmm1) == (events hmm2) then events hmm1 else error "hmmJoin: event sets not equal" , initProbs = \s -> if (s `elem` states1) then (initProbs hmm1 $ lift s)*r1 else (initProbs hmm2 $ lift s)*r2 , transMatrix = \s1 -> \s2 -> if (s1 `elem` states1 && s2 `elem` states1) then (transMatrix hmm1 (lift s1) (lift s2))*r1 else if (s2 `elem` states2 && s2 `elem` states2) then (transMatrix hmm2 (lift s1) (lift s2))*r2 else if (s1 `elem` states1) then (r2)/(logFloat $ length $ states2) else (r1)/(logFloat $ length $ states1) , outMatrix = \s -> if (s `elem` states1) then (outMatrix hmm1 $ lift s) else (outMatrix hmm2 $ lift s) } where r1=ratio r2=1-ratio states1 = map (\x -> (1,x)) $ states hmm1 states2 = map (\x -> (2,x)) $ states hmm2 -- lift :: (Int,String) -> a lift x =snd x -- lift x =read $ (snd x ) -- debug utils hmmid hmm = show $ initProbs hmm $ (states hmm) !! 1 -- | tests listCPExp :: [a] -> Int -> [[a]] listCPExp language order = listCPExp' order [[]] where listCPExp' order list | order == 0 = list | otherwise = listCPExp' (order-1) [symbol:l | l <- list, symbol <- language] -- | should always equal 1 forwardtest hmm x = sum [forward hmm e | e <- listCPExp (events hmm) x] -- | should always equal 1 backwardtest hmm x = sum [backward hmm e | e <- listCPExp (events hmm) x] -- | should always equal each other fbtest hmm events = "fwd: " ++ show (forward hmm events) ++ " bkwd:" ++ show (backward hmm events) -- | initProbs should always equal 1; the others should equal the number of states verifyhmm hmm = do seq ip $ check "initProbs" ip check "transMatrix" tm check "outMatrix" om where check str var = do putStrLn $ str++" tollerance check: "++show var {- if abs(var-1)<0.0001 then putStrLn "True" else putStrLn "False"-} ip = sum $ [initProbs hmm s | s <- states hmm] tm = (sum $ [transMatrix hmm s1 s2 | s1 <- states hmm, s2 <- states hmm]) -- (length $ states hmm) om = sum $ [outMatrix hmm s e | s <- states hmm, e <- events hmm] -- / length $ states hmm ----- -- File processing functions below here data -- (Eq eventType, Eq stateType, Show eventType, Show stateType) => HMMArray stateType eventType = HMMArray { statesA :: [stateType] , eventsA :: [eventType] , initProbsA :: Array Int Prob , transMatrixA :: Array Int (Array Int Prob) -- (stateType -> stateType -> Prob) , outMatrixA :: Array Int (Array Int Prob) -- (stateType -> eventType -> Prob) } deriving (Show,Read) instance Read LogFloat where readsPrec a str = do dbl <- readsPrec a (drop 8 str) :: [(Double,String)] -- trace ("LogFloat -> "++show str) $ [(logFloat ((read (drop 8 str)) :: Double), "")] return (logFloat $ fst dbl, snd dbl) hmm2Array :: (Show stateType, Show eventType) => (HMM stateType eventType) -> (HMMArray stateType eventType) hmm2Array hmm = HMMArray { statesA = states hmm , eventsA = events hmm , initProbsA = listArray (1,length $ states hmm) [initProbs hmm state | state <- states hmm] , transMatrixA = listArray (1,length $ states hmm) [ listArray (1,length $ states hmm) [transMatrix hmm s1 s2 | s1 <- states hmm] | s2 <- states hmm] , outMatrixA = listArray (1,length $ states hmm) [ listArray (1,length $ events hmm) [outMatrix hmm s e | e <- events hmm] | s <- states hmm] } array2hmm :: (Show stateType, Show eventType, Eq stateType, Eq eventType) => (HMMArray stateType eventType) -> (HMM stateType eventType) array2hmm hmmA = HMM { states = statesA hmmA , events = eventsA hmmA , initProbs = \s -> (initProbsA hmmA) ! (stateAIndex hmmA s) , transMatrix = \s1 -> \s2 -> transMatrixA hmmA ! (stateAIndex hmmA s1) ! (stateAIndex hmmA s2) , outMatrix = \s -> \e -> outMatrixA hmmA ! (stateAIndex hmmA s) ! (eventAIndex hmmA e) } -- | saves the HMM to a file for later retrieval. HMMs can take a long time to calculate, so this is very useful saveHMM :: (Show stateType, Show eventType) => String -> HMM stateType eventType -> IO () saveHMM file hmm = do outh <- openFile file WriteMode hPutStrLn outh $ show $ hmm2Array hmm hClose outh -- loadHMM :: (Read stateType, Read eventType) => String -> IO (HMM stateType eventType) loadHMM file = do inh <- openFile file ReadMode hmmstr <- hGetLine inh let hmm = read hmmstr -- :: HMMArray stateType eventType return (array2hmm hmm) stateAIndex :: (Show stateType, Show eventType, Eq stateType) => HMMArray stateType eventType -> stateType -> Int stateAIndex hmm state = case elemIndex state $ statesA hmm of Nothing -> seq (error "stateIndex: Index "++show state++" not in HMM "++show hmm) 0 Just x -> x+1 eventAIndex :: (Show stateType, Show eventType, Eq eventType) => HMMArray stateType eventType -> eventType -> Int eventAIndex hmm event = case elemIndex event $ eventsA hmm of Nothing -> seq (error ("eventIndex: Index "++show event++" not in HMM "++show hmm)) 0 Just x -> x+1 ------------ hmmParse :: {-(Read stateType, Read eventType) =>-} String -> Either ParseError (HMM String Char) hmmParse str = do parse hmmParser str str hmmParser :: (Read stateType, Read eventType) => GenParser Char st (HMM stateType eventType) hmmParser = do let hmm = HMM { states = [] , events = [] , initProbs = (\x -> 0) , transMatrix = (\x -> \y -> 0) , outMatrix = (\x -> \y -> 0) } return hmm