module Learning.HMM.Internal (
HMM' (..)
, Likelihood
, Probability
, init'
, withEmission'
, viterbi'
, baumWelch'
) where
import Control.Applicative ((<$>))
import Control.Monad (forM_, replicateM)
import Control.Monad.ST (runST)
import qualified Data.Map.Strict as M (empty, insertWith, findWithDefault)
import Data.Number.LogFloat (LogFloat, logFloat)
import Data.Random.RVar (RVar)
import Data.Random.Distribution.Simplex (stdSimplex)
import Data.Random.Distribution.Uniform.Util ()
import Data.Vector (Vector, (!))
import qualified Data.Vector as V (
filter, foldl', foldl1', freeze, fromList, last, length, map, maximum
, maxIndex, replicate, sum, tail, zip, zipWith, zipWith3, zipWith4
)
import qualified Data.Vector.Mutable as MV (new, read, write)
import qualified Data.Vector.Util as V (unsafeElemIndex)
import Data.Vector.Util.LinearAlgebra (
(>+>), (>.>), (>/>), (#+#), (.>), (>/), (#/), (<.>), (#.>), (<.#)
)
import qualified Data.Vector.Util.LinearAlgebra as V (transpose)
type Likelihood = LogFloat
type Probability = LogFloat
data HMM' s o = HMM' { states' :: Vector s
, outputs' :: Vector o
, initialStateDist' :: Vector Probability
, transitionDist' :: Vector (Vector Probability)
, emissionDistT' :: Vector (Vector Probability)
}
init' :: Vector s -> Vector o -> RVar (HMM' s o)
init' ss os = do
let n = V.length ss
m = V.length os
pi0 <- V.fromList <$> stdSimplex (n1)
w <- V.fromList <$> replicateM n (V.fromList <$> stdSimplex (n1))
phi <- V.fromList <$> replicateM n (V.fromList <$> stdSimplex (m1))
return HMM' { states' = ss
, outputs' = os
, initialStateDist' = pi0
, transitionDist' = w
, emissionDistT' = V.transpose phi
}
withEmission' :: (Ord s, Ord o) => HMM' s o -> Vector o -> HMM' s o
withEmission' model xs = model { emissionDistT' = phi' }
where
ss = states' model
os = outputs' model
(path, _) = viterbi' model xs
mp = V.foldl' (\m k -> M.insertWith (+) k 1 m) M.empty $ V.zip path xs
hists = V.map (\s -> V.map (\o -> M.findWithDefault 0 (s, o) mp) os) ss
phi' = V.transpose $ V.map (\h -> h >/ V.sum h) hists
viterbi' :: Eq o => HMM' s o -> Vector o -> (Vector s, Likelihood)
viterbi' model xs = (path, likelihood)
where
path = V.map (ss !) $ runST $ do
ix <- MV.new n
ix `MV.write` (n1) $ V.maxIndex $ deltas ! (n1)
forM_ (reverse [0..(n2)]) $ \i -> do
j <- ix `MV.read` (i+1)
ix `MV.write` i $ psis ! (i+1) ! j
V.freeze ix
where
ss = states' model
likelihood = V.maximum $ deltas ! (n1)
deltas :: Vector (Vector Probability)
psis :: Vector (Vector Int)
(deltas, psis) = runST $ do
ds <- MV.new n
ps <- MV.new n
ds `MV.write` 0 $ (phi' ! x 0) >.> pi0
ps `MV.write` 0 $ V.replicate k (0 :: Int)
forM_ [1..(n1)] $ \i -> do
d <- ds `MV.read` (i1)
let dws = V.map (d >.>) w'
ds `MV.write` i $ phi' ! x i >.> V.map V.maximum dws
ps `MV.write` i $ V.map V.maxIndex dws
ds' <- V.freeze ds
ps' <- V.freeze ps
return (ds', ps')
where
k = V.length $ states' model
x i = let os = outputs' model
xs' = V.map (`V.unsafeElemIndex` os) xs
in xs' ! i
pi0 = initialStateDist' model
w' = V.transpose $ transitionDist' model
phi' = emissionDistT' model
n = V.length xs
baumWelch' :: (Eq s, Eq o) => HMM' s o -> Vector o -> [(HMM' s o, Likelihood)]
baumWelch' model xs = zip ms $ tail ells
where
(ms, ells) = unzip $ iterate ((`baumWelch1'` xs) . fst) (model, undefined)
baumWelch1' :: (Eq s, Eq o) => HMM' s o -> Vector o -> (HMM' s o, Likelihood)
baumWelch1' model xs = (model', likelihood)
where
model' = model { initialStateDist' = pi0
, transitionDist' = w
, emissionDistT' = phi'
}
likelihood = V.last ells
alphas = forward' model xs
betas = backward' model xs
ells = V.zipWith (<.>) alphas betas
gammas = V.zipWith3 (\a b l -> a >.> b >/ l) alphas betas ells
xis = V.zipWith4 (\a b l x -> let w1 = V.zipWith (.>) a w0
w2 = V.map (phi0 ! x >.> b >.>) w1
in w2 #/ l)
alphas (V.tail betas) (V.tail ells) (V.tail xs')
where
xs' = V.map (`V.unsafeElemIndex` os) xs
w0 = transitionDist' model
phi0 = emissionDistT' model
pi0 = let gs = gammas ! 0
in gs >/ V.sum gs
w = let ws = V.foldl1' (#+#) xis
zs = V.map V.sum ws
in V.zipWith (>/) ws zs
phi' = let gs' o = V.map snd $ V.filter ((== o) . fst) $ V.zip xs gammas
phis = V.foldl1' (>+>) . gs'
zs = V.foldl1' (>+>) gammas
in V.map (\o -> phis o >/> zs) os
os = outputs' model
forward' :: Eq o => HMM' s o -> Vector o -> Vector (Vector Probability)
forward' model xs = runST $ do
v <- MV.new n
v `MV.write` 0 $ (phi' ! x 0) >.> pi0
forM_ [1..(n1)] $ \i -> do
a <- v `MV.read` (i1)
v `MV.write` i $ (phi' ! x i) >.> (a <.# w)
V.freeze v
where
n = V.length xs
x i = let os = outputs' model
xs' = V.map (`V.unsafeElemIndex` os) xs
in xs' ! i
pi0 = initialStateDist' model
w = transitionDist' model
phi' = emissionDistT' model
backward' :: Eq o => HMM' s o -> Vector o -> Vector (Vector Probability)
backward' model xs = runST $ do
v <- MV.new n
v `MV.write` (n1) $ V.replicate k $ logFloat (1 :: Double)
forM_ (reverse [0..(n2)]) $ \i -> do
b <- v `MV.read` (i+1)
v `MV.write` i $ w #.> ((phi' ! x (i+1)) >.> b)
V.freeze v
where
n = V.length xs
k = V.length $ states' model
x i = let os = outputs' model
xs' = V.map (`V.unsafeElemIndex` os) xs
in xs' ! i
w = transitionDist' model
phi' = emissionDistT' model