{-# LANGUAGE RecordWildCards #-}

module Learning.HMM
  ( HMM (..)
  , LogLikelihood
  , init
  , withEmission
  , euclideanDistance
  , viterbi
  , baumWelch
  , baumWelch'
  , simulate
  ) where

import           Control.Applicative                         ( (<$>) )
import           Control.Arrow                               ( first )
import           Data.List                                   ( elemIndex )
import           Data.Maybe                                  ( fromJust )
import           Data.Random.Distribution                    ( rvar )
import qualified Data.Random.Distribution.Categorical as C   ( Categorical, fromList, normalizeCategoricalPs )
import           Data.Random.Distribution.Extra              ( pmf )
import           Data.Random.RVar                            ( RVar )
import qualified Data.Vector                          as V   ( elemIndex, fromList, map, toList, unsafeIndex )
import qualified Data.Vector.Generic                  as G   ( convert )
import qualified Data.Vector.Unboxed                  as U   ( fromList )
import           Learning.HMM.Internal                       ( LogLikelihood )
import qualified Learning.HMM.Internal                as I
import qualified Numeric.LinearAlgebra.Data           as H   ( (!), fromList, fromLists, toList )
import qualified Numeric.LinearAlgebra.HMatrix        as H   ( tr )
import           Prelude                              hiding ( init )

-- | Parameter set of the hidden Markov model with discrete emission.
--   The model schema is as follows.
--
--   @
--       z_0 -> z_1 -> ... -> z_n
--        |      |             |
--        v      v             v
--       x_0    x_1           x_n
--   @
--
--   Here, @[z_0, z_1, ..., z_n]@ are hidden states and @[x_0, x_1, ..., x_n]@
--   are observed outputs. @z_0@ is determined by the 'initialStateDist'.
--   For @i = 1, ..., n@, @z_i@ is determined by the 'transitionDist'
--   conditioned by @z_{i-1}@.
--   For @i = 0, ..., n@, @x_i@ is determined by the 'emissionDist'
--   conditioned by @z_i@.
data HMM s o = HMM { states  :: [s]
                   , outputs :: [o]
                   , initialStateDist :: C.Categorical Double s
                     -- ^ Categorical distribution of initial state
                   , transitionDist :: s -> C.Categorical Double s
                     -- ^ Categorical distribution of next state
                     --   conditioned by the previous state
                   , emissionDist :: s -> C.Categorical Double o
                     -- ^ Categorical distribution of output conditioned
                     --   by the hidden state
                   }

instance (Show s, Show o) => Show (HMM s o) where
  show HMM {..} = "HMM {states = "           ++ show states
                  ++ ", outputs = "          ++ show outputs
                  ++ ", initialStateDist = " ++ show initialStateDist
                  ++ ", transitionDist = "   ++ show [(transitionDist s, s) | s <- states]
                  ++ ", emissionDist = "     ++ show [(emissionDist s, s) | s <- states]
                  ++ "}"

-- | @init states outputs@ returns a random variable of models with the
--   @states@ and @outputs@, wherein parameters are sampled from uniform
--   distributions.
init :: (Eq s, Eq o) => [s] -> [o] -> RVar (HMM s o)
init ss os = fromInternal ss os <$> I.init (length ss) (length os)

-- | @model \`withEmission\` xs@ returns a model in which the
--   'emissionDist' is updated by re-estimations using the observed outputs
--   @xs@. The 'emissionDist' is set to be normalized histograms each of
--   which is calculated from segumentations of @xs@ based on the Viterbi
--   state path.
withEmission :: (Eq s, Eq o) => HMM s o -> [o] -> HMM s o
withEmission (model @ HMM {..}) xs = fromInternal states outputs $ I.withEmission model' xs'
  where
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs

-- | Return the Euclidean distance between two models that have the same
--   states and outputs.
euclideanDistance :: (Eq s, Eq o) => HMM s o -> HMM s o -> Double
euclideanDistance model1 model2 =
  checkTwoModelsIn "euclideanDistance" model1 model2 `seq`
  I.euclideanDistance model1' model2'
  where
    model1' = toInternal model1
    model2' = toInternal model2

-- | @viterbi model xs@ performs the Viterbi algorithm using the observed
--   outputs @xs@, and returns the most likely state path and its log
--   likelihood.
viterbi :: (Eq s, Eq o) => HMM s o -> [o] -> ([s], LogLikelihood)
viterbi (model @ HMM {..}) xs =
  checkModelIn "viterbi" model `seq`
  checkDataIn "viterbi" model xs `seq`
  first toStates $ I.viterbi model' xs'
  where
    states'  = V.fromList states
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs
    toStates = V.toList . V.map (V.unsafeIndex states') . G.convert

-- | @baumWelch model xs@ iteratively performs the Baum-Welch algorithm
--   using the observed outputs @xs@, and returns a list of updated models
--   and their corresponding log likelihoods.
baumWelch :: (Eq s, Eq o) => HMM s o -> [o] -> [(HMM s o, LogLikelihood)]
baumWelch (model @ HMM {..}) xs =
  checkModelIn "baumWelch" model `seq`
  checkDataIn "baumWelch" model xs `seq`
  map (first $ fromInternal states outputs) $ I.baumWelch model' xs'
  where
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs

-- | @baumWelch' model xs@ performs the Baum-Welch algorithm using the
--   observed outputs @xs@, and returns a model locally maximizing its log
--   likelihood.
baumWelch' :: (Eq s, Eq o) => HMM s o -> [o] -> (HMM s o, LogLikelihood)
baumWelch' (model @ HMM {..}) xs =
  checkModelIn "baumWelch" model `seq`
  checkDataIn "baumWelch" model xs `seq`
  first (fromInternal states outputs) $ I.baumWelch' model' xs'
  where
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') xs

-- | @simulate model t@ generates a Markov process of length @t@ using the
--   @model@, and returns its state path and outputs.
simulate :: HMM s o -> Int -> RVar ([s], [o])
simulate HMM {..} step
  | step < 1  = return ([], [])
  | otherwise = do s0 <- rvar initialStateDist
                   x0 <- rvar $ emissionDist s0
                   unzip . ((s0, x0) :) <$> sim s0 (step - 1)
  where
    sim _ 0 = return []
    sim s t = do s' <- rvar $ transitionDist s
                 x' <- rvar $ emissionDist s'
                 ((s', x') :) <$> sim s' (t - 1)

-- | Check if the model is valid in the sense of whether the 'states' and
--   'outputs' are not empty.
checkModelIn :: String -> HMM s o -> ()
checkModelIn fun HMM {..}
  | null states  = errorIn fun "empty states"
  | null outputs = errorIn fun "empty outputs"
  | otherwise    = ()

-- | Check if the two models have the same states and outputs.
checkTwoModelsIn :: (Eq s, Eq o) => String -> HMM s o -> HMM s o -> ()
checkTwoModelsIn fun model model'
  | ss /= ss' = errorIn fun "states disagree"
  | os /= os' = errorIn fun "outputs disagree"
  | otherwise = ()
  where
    ss  = states model
    ss' = states model'
    os  = outputs model
    os' = outputs model'

-- | Check if all the elements of the observed outputs are contained in the
--   'outputs' of the model.
checkDataIn :: Eq o => String -> HMM s o -> [o] -> ()
checkDataIn fun HMM {..} xs
  | any (`notElem` outputs) xs = errorIn fun "illegal data"
  | otherwise                  = ()

-- | Convert internal 'HMM' to 'HMM'.
fromInternal :: (Eq s, Eq o) => [s] -> [o] -> I.HMM -> HMM s o
fromInternal ss os I.HMM {..} = HMM { states           = ss
                                    , outputs          = os
                                    , initialStateDist = C.fromList pi0'
                                    , transitionDist   = \s -> case elemIndex s ss of
                                                                 Nothing -> C.fromList []
                                                                 Just i  -> C.fromList $ w' i
                                    , emissionDist     = \s -> case elemIndex s ss of
                                                                 Nothing -> C.fromList []
                                                                 Just i  -> C.fromList $ phi' i
                                    }
  where
    pi0'   = zip (H.toList initialStateDist) ss
    w' i   = zip (H.toList $ transitionDist H.! i) ss
    phi' i = zip (H.toList $ H.tr emissionDistT H.! i) os

-- | Convert 'HMM' to internal 'HMM'. The 'initialStateDist'',
--   'transitionDist'', and 'emissionDistT'' are normalized.
toInternal :: (Eq s, Eq o) => HMM s o -> I.HMM
toInternal HMM {..} = I.HMM { I.nStates          = length states
                            , I.nOutputs         = length outputs
                            , I.initialStateDist = pi0
                            , I.transitionDist   = w
                            , I.emissionDistT    = phi'
                            }
  where
    pi0_ = C.normalizeCategoricalPs initialStateDist
    w_   = C.normalizeCategoricalPs . transitionDist
    phi_ = C.normalizeCategoricalPs . emissionDist
    pi0  = H.fromList [pmf pi0_ s | s <- states]
    w    = H.fromLists [[pmf (w_ s) s' | s' <- states] | s <- states]
    phi' = H.fromLists [[pmf (phi_ s) o | s <- states] | o <- outputs]

errorIn :: String -> String -> a
errorIn fun msg = error $ "Learning.HMM." ++ fun ++ ": " ++ msg