module Learning.IOHMM
( IOHMM (..)
, LogLikelihood
, init
, withEmission
, viterbi
, baumWelch
, baumWelch'
, simulate
) where
import Control.Applicative ( (<$>) )
import Control.Arrow ( first )
import Data.List ( elemIndex )
import Data.Maybe ( fromJust )
import Data.Random.Distribution ( rvar )
import qualified Data.Random.Distribution.Categorical as C ( Categorical, fromList, normalizeCategoricalPs )
import Data.Random.Distribution.Extra ( pmf )
import Data.Random.RVar ( RVar )
import qualified Data.Vector as V ( elemIndex, fromList, map, toList, unsafeIndex )
import qualified Data.Vector.Generic as G ( convert )
import qualified Data.Vector.Unboxed as U ( fromList, zip )
import qualified Numeric.LinearAlgebra.Data as H ( (!), fromList, fromLists, toList )
import qualified Numeric.LinearAlgebra.HMatrix as H ( tr )
import Learning.IOHMM.Internal ( LogLikelihood )
import qualified Learning.IOHMM.Internal as I
import Prelude hiding ( init )
data IOHMM i s o = IOHMM { inputs :: [i]
, states :: [s]
, outputs :: [o]
, initialStateDist :: C.Categorical Double s
, transitionDist :: i -> s -> C.Categorical Double s
, emissionDist :: s -> C.Categorical Double o
}
instance (Show i, Show s, Show o) => Show (IOHMM i s o) where
show IOHMM {..} = "IOHMM {inputs = " ++ show inputs
++ ", states = " ++ show states
++ ", outputs = " ++ show outputs
++ ", initialStateDist = " ++ show initialStateDist
++ ", transitionDist = " ++ show [(transitionDist i s, (i, s)) | i <- inputs, s <- states]
++ ", emissionDist = " ++ show [(emissionDist s, s) | s <- states]
++ "}"
init :: (Eq i, Eq s, Eq o) => [i] -> [s] -> [o] -> RVar (IOHMM i s o)
init is ss os = fromInternal is ss os <$> I.init (length is) (length ss) (length os)
withEmission :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> IOHMM i s o
withEmission (model @ IOHMM {..}) xs ys = fromInternal inputs states outputs $ I.withEmission model' $ U.zip xs' ys'
where
inputs' = V.fromList inputs
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
ys' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys
viterbi :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> ([s], LogLikelihood)
viterbi (model @ IOHMM {..}) xs ys =
checkModelIn "viterbi" model `seq`
checkDataIn "viterbi" model xs ys `seq`
first toStates $ I.viterbi model' $ U.zip xs' ys'
where
inputs' = V.fromList inputs
states' = V.fromList states
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
ys' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys
toStates = V.toList . V.map (V.unsafeIndex states') . G.convert
baumWelch :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> [(IOHMM i s o, LogLikelihood)]
baumWelch (model @ IOHMM {..}) xs ys =
checkModelIn "baumWelch" model `seq`
checkDataIn "baumWelch" model xs ys `seq`
map (first $ fromInternal inputs states outputs) $ I.baumWelch model' $ U.zip xs' ys'
where
inputs' = V.fromList inputs
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
ys' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys
baumWelch' :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> (IOHMM i s o, LogLikelihood)
baumWelch' (model @ IOHMM {..}) xs ys =
checkModelIn "baumWelch" model `seq`
checkDataIn "baumWelch" model xs ys `seq`
first (fromInternal inputs states outputs) $ I.baumWelch' model' $ U.zip xs' ys'
where
inputs' = V.fromList inputs
outputs' = V.fromList outputs
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
ys' = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys
simulate :: IOHMM i s o -> [i] -> RVar ([s], [o])
simulate IOHMM {..} xs
| null xs = return ([], [])
| otherwise = do s0 <- rvar initialStateDist
y0 <- rvar $ emissionDist s0
unzip . ((s0, y0) :) <$> sim s0 (tail xs)
where
sim _ [] = return []
sim s (x:xs') = do s' <- rvar $ transitionDist x s
y' <- rvar $ emissionDist s'
((s', y') :) <$> sim s' xs'
checkModelIn :: String -> IOHMM i s o -> ()
checkModelIn fun IOHMM {..}
| null inputs = errorIn fun "empty inputs"
| null states = errorIn fun "empty states"
| null outputs = errorIn fun "empty outputs"
| otherwise = ()
checkDataIn :: (Eq i, Eq o) => String -> IOHMM i s o -> [i] -> [o] -> ()
checkDataIn fun IOHMM {..} xs ys
| all (`elem` inputs) xs && all (`elem` outputs) ys = ()
| otherwise = errorIn fun "illegal data"
fromInternal :: (Eq i, Eq s, Eq o) => [i] -> [s] -> [o] -> I.IOHMM -> IOHMM i s o
fromInternal is ss os I.IOHMM {..} = IOHMM { inputs = is
, states = ss
, outputs = os
, initialStateDist = C.fromList pi0'
, transitionDist = \i s -> case (elemIndex i is, elemIndex s ss) of
(Nothing, _) -> C.fromList []
(_, Nothing) -> C.fromList []
(Just j, Just k) -> C.fromList $ w' j k
, emissionDist = \s -> case elemIndex s ss of
Nothing -> C.fromList []
Just i -> C.fromList $ phi' i
}
where
pi0' = zip (H.toList initialStateDist) ss
w' j k = zip (H.toList $ V.unsafeIndex transitionDist j H.! k) ss
phi' i = zip (H.toList $ H.tr emissionDistT H.! i) os
toInternal :: (Eq i, Eq s, Eq o) => IOHMM i s o -> I.IOHMM
toInternal IOHMM {..} = I.IOHMM { I.nInputs = length inputs
, I.nStates = length states
, I.nOutputs = length outputs
, I.initialStateDist = pi0
, I.transitionDist = w
, I.emissionDistT = phi'
}
where
pi0_ = C.normalizeCategoricalPs initialStateDist
w_ i = C.normalizeCategoricalPs . transitionDist i
phi_ = C.normalizeCategoricalPs . emissionDist
pi0 = H.fromList [pmf pi0_ s | s <- states]
w = V.fromList $ map (\i -> H.fromLists [[pmf (w_ i s) s' | s' <- states] | s <- states]) inputs
phi' = H.fromLists [[pmf (phi_ s) o | s <- states] | o <- outputs]
errorIn :: String -> String -> a
errorIn fun msg = error $ "Learning.IOHMM." ++ fun ++ ": " ++ msg