module Data.HMM (Prob, HMM, train, bestSequence, sequenceProb) where import qualified Data.Map as M import Data.List (sort, groupBy, maximumBy) import Data.Maybe (fromMaybe, fromJust) import System.IO.Unsafe (unsafeInterleaveIO) import System.Environment (getArgs) import Control.Monad import qualified Data.Foldable import Debug.Trace import Data.Lognum type Prob = Lognum Double -- | The type of Hidden Markov Models. data HMM state observation = HMM [state] [Prob] [[Prob]] (observation -> [Prob]) -- | Perform a single step in the Viterbi algorithm. -- -- Takes a list of path probabilities, and an observation, and returns the updated -- list of (surviving) paths with probabilities. viterbi :: Ord observation => HMM state observation -> [(Prob, [state])] -> observation -> [(Prob, [state])] viterbi (HMM states _ state_transitions observations) prev x = [maximumBy (compare `on` fst) [(transition_prob * prev_prob * observation_prob, new_state:path) | transition_prob <- transition_probs | (prev_prob, path) <- prev | observation_prob <- observation_probs] | transition_probs <- state_transitions | new_state <- states] where observation_probs = observations x -- | The initial value for the Viterbi algorithm viterbi_init :: HMM state observation -> [(Prob, [state])] viterbi_init (HMM states state_probs _ _) = zip state_probs (map (:[]) states) -- | Perform a single step of the forward algorithm -- -- Each item in the input and output list is the probability that the system -- ended in the respective state. forward :: Ord observation => HMM state observation -> [Prob] -> observation -> [Prob] forward (HMM _ _ state_transitions observations) prev x = [sum [transition_prob * prev_prob * observation_prob | transition_prob <- transition_probs | prev_prob <- prev | observation_prob <- observation_probs] | transition_probs <- state_transitions] where observation_probs = observations x -- | The initial value for the forward algorithm forward_init :: HMM state observation -> [Prob] forward_init (HMM _ state_probs _ _) = state_probs learn_states :: (Ord state, Fractional prob) => [(observation, state)] -> M.Map state prob learn_states xs = histogram $ map snd xs learn_transitions :: (Ord state, Fractional prob) => [(observation, state)] -> M.Map (state, state) prob learn_transitions xs = let xs' = map snd xs in histogram $ zip xs' (tail xs') learn_observations :: (Ord state, Ord observation, Fractional prob) => M.Map state prob -> [(observation, state)] -> M.Map (observation, state) prob learn_observations state_prob = M.mapWithKey (\ (observation, state) prob -> prob / (fromJust $ M.lookup state state_prob)) . histogram histogram :: (Ord a, Fractional prob) => [a] -> M.Map a prob histogram xs = let hist = foldr (flip (M.insertWith (+)) 1) M.empty xs in M.map (/ M.fold (+) 0 hist) hist readBrownFile :: FilePath -> IO [(String, String)] readBrownFile = (liftM (map split . words)) . readFile where split [] = ([], []) split ('/':xs) = ([], xs) split (x:xs) = let (first, snd) = split xs in (x:first, snd) -- | Calculate the parameters of an HMM from a list of observations -- and the corresponding states. train :: (Ord observation, Ord state) => [(observation, state)] -> HMM state observation train sample = model where states = learn_states sample state_list = M.keys states transitions = learn_transitions sample trans_prob_mtx = [[fromMaybe 1e-10 $ M.lookup (old_state, new_state) transitions | old_state <- state_list] | new_state <- state_list] observations = learn_observations states sample observation_probs = fromMaybe (fill state_list []) . (flip M.lookup $ M.fromList $ map (\ (e, xs) -> (e, fill state_list xs)) $ map (\ xs -> (fst $ head xs, map snd xs)) $ groupBy ((==) `on` fst) [(observation, (state, prob)) | ((observation, state), prob) <- M.toAscList observations]) initial = map (\ state -> (fromJust $ M.lookup state states, [state])) state_list model = HMM state_list (fill state_list $ M.toAscList states) trans_prob_mtx observation_probs fill :: Eq state => [state] -> [(state, Prob)] -> [Prob] fill states [] = map (const 1e-10) states fill (s:states) xs@((s', p):xs') = if s /= s' then 1e-10 : fill states xs else p : fill states xs' -- | Test Viterbi's algorithm on an HMM by comparing the predicted states -- against known states for the observations. testViterbi :: (Ord observation, Ord state) => HMM state observation -> [(observation, state)] -> Rational testViterbi hmm testData = (fromIntegral $ length $ filter id $ zipWith (==) (bestSequence hmm observations) states) / (fromIntegral $ length testData) where observations = map fst testData states = map snd testData train2 :: FilePath -> IO (HMM String String) train2 = (liftM train) . readBrownFile -- | Calculate the most likely sequence of states for a given sequence of observations -- using Viterbi's algorithm bestSequence :: (Ord observation) => HMM state observation -> [observation] -> [state] bestSequence hmm = (reverse . tail . snd . (maximumBy (compare `on` fst))) . (foldl (viterbi hmm) (viterbi_init hmm)) -- | Calculate the probability of a given sequence of observations -- using the forward algorithm. sequenceProb :: (Ord observation) => HMM state observation -> [observation] -> Prob sequenceProb hmm = sum . (foldl (forward hmm) (forward_init hmm)) on f g a b = f (g a) (g b)