module Learning.HMM
( HMM (..)
, LogLikelihood
, init
, withEmission
, euclideanDistance
, viterbi
, baumWelch
, baumWelch'
, simulate
) where
import Control.Applicative ( (<$>) )
import Control.Arrow ( first )
import Data.List ( elemIndex )
import Data.Maybe ( fromJust )
import Data.Random.Distribution ( rvar )
import qualified Data.Random.Distribution.Categorical as C ( Categorical, fromList, normalizeCategoricalPs )
import Data.Random.Distribution.Extra ( pmf )
import Data.Random.RVar ( RVar )
import qualified Data.Vector as V ( elemIndex, fromList, map, toList, unsafeIndex )
import qualified Data.Vector.Generic as G ( convert )
import qualified Data.Vector.Unboxed as U ( fromList )
import Learning.HMM.Internal ( LogLikelihood )
import qualified Learning.HMM.Internal as I
import qualified Numeric.LinearAlgebra.Data as H ( (!), fromList, fromLists, toList )
import qualified Numeric.LinearAlgebra.HMatrix as H ( tr )
import Prelude hiding ( init )
data HMM s o = HMM { states :: [s]
, outputs :: [o]
, initialStateDist :: C.Categorical Double s
, transitionDist :: s -> C.Categorical Double s
, emissionDist :: s -> C.Categorical Double o
}
instance (Show s, Show o) => Show (HMM s o) where
show HMM {..} = "HMM {states = " ++ show states
++ ", outputs = " ++ show outputs
++ ", initialStateDist = " ++ show initialStateDist
++ ", transitionDist = " ++ show [(transitionDist s, s) | s <- states]
++ ", emissionDist = " ++ show [(emissionDist s, s) | s <- states]
++ "}"
init :: (Eq s, Eq o) => [s] -> [o] -> RVar (HMM s o)
init ss os = fromInternal ss os <$> I.init (length ss) (length os)
withEmission :: (Eq s, Eq o) => HMM s o -> [o] -> HMM s o
withEmission (model @ HMM {..}) xs = fromInternal states outputs $ I.withEmission model' xs'
where
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs
euclideanDistance :: (Eq s, Eq o) => HMM s o -> HMM s o -> Double
euclideanDistance model1 model2 =
checkTwoModelsIn "euclideanDistance" model1 model2 `seq`
I.euclideanDistance model1' model2'
where
model1' = toInternal model1
model2' = toInternal model2
viterbi :: (Eq s, Eq o) => HMM s o -> [o] -> ([s], LogLikelihood)
viterbi (model @ HMM {..}) xs =
checkModelIn "viterbi" model `seq`
checkDataIn "viterbi" model xs `seq`
first toStates $ I.viterbi model' xs'
where
states' = V.fromList states
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs
toStates = V.toList . V.map (V.unsafeIndex states') . G.convert
baumWelch :: (Eq s, Eq o) => HMM s o -> [o] -> [(HMM s o, LogLikelihood)]
baumWelch (model @ HMM {..}) xs =
checkModelIn "baumWelch" model `seq`
checkDataIn "baumWelch" model xs `seq`
map (first $ fromInternal states outputs) $ I.baumWelch model' xs'
where
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs
baumWelch' :: (Eq s, Eq o) => HMM s o -> [o] -> (HMM s o, LogLikelihood)
baumWelch' (model @ HMM {..}) xs =
checkModelIn "baumWelch" model `seq`
checkDataIn "baumWelch" model xs `seq`
first (fromInternal states outputs) $ I.baumWelch' model' xs'
where
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs
simulate :: HMM s o -> Int -> RVar ([s], [o])
simulate HMM {..} step
| step < 1 = return ([], [])
| otherwise = do s0 <- rvar initialStateDist
x0 <- rvar $ emissionDist s0
unzip . ((s0, x0) :) <$> sim s0 (step 1)
where
sim _ 0 = return []
sim s t = do s' <- rvar $ transitionDist s
x' <- rvar $ emissionDist s'
((s', x') :) <$> sim s' (t 1)
checkModelIn :: String -> HMM s o -> ()
checkModelIn fun HMM {..}
| null states = errorIn fun "empty states"
| null outputs = errorIn fun "empty outputs"
| otherwise = ()
checkTwoModelsIn :: (Eq s, Eq o) => String -> HMM s o -> HMM s o -> ()
checkTwoModelsIn fun model model'
| ss /= ss' = errorIn fun "states disagree"
| os /= os' = errorIn fun "outputs disagree"
| otherwise = ()
where
ss = states model
ss' = states model'
os = outputs model
os' = outputs model'
checkDataIn :: Eq o => String -> HMM s o -> [o] -> ()
checkDataIn fun HMM {..} xs
| any (`notElem` outputs) xs = errorIn fun "illegal data"
| otherwise = ()
fromInternal :: (Eq s, Eq o) => [s] -> [o] -> I.HMM -> HMM s o
fromInternal ss os I.HMM {..} = HMM { states = ss
, outputs = os
, initialStateDist = C.fromList pi0'
, transitionDist = \s -> case elemIndex s ss of
Nothing -> C.fromList []
Just i -> C.fromList $ w' i
, emissionDist = \s -> case elemIndex s ss of
Nothing -> C.fromList []
Just i -> C.fromList $ phi' i
}
where
pi0' = zip (H.toList initialStateDist) ss
w' i = zip (H.toList $ transitionDist H.! i) ss
phi' i = zip (H.toList $ H.tr emissionDistT H.! i) os
toInternal :: (Eq s, Eq o) => HMM s o -> I.HMM
toInternal HMM {..} = I.HMM { I.nStates = length states
, I.nOutputs = length outputs
, I.initialStateDist = pi0
, I.transitionDist = w
, I.emissionDistT = phi'
}
where
pi0_ = C.normalizeCategoricalPs initialStateDist
w_ = C.normalizeCategoricalPs . transitionDist
phi_ = C.normalizeCategoricalPs . emissionDist
pi0 = H.fromList [pmf pi0_ s | s <- states]
w = H.fromLists [[pmf (w_ s) s' | s' <- states] | s <- states]
phi' = H.fromLists [[pmf (phi_ s) o | s <- states] | o <- outputs]
errorIn :: String -> String -> a
errorIn fun msg = error $ "Learning.HMM." ++ fun ++ ": " ++ msg