{-# LANGUAGE RecordWildCards #-}

module Learning.IOHMM
  ( IOHMM (..)
  , LogLikelihood
  , init
  , withEmission
  , viterbi
  , baumWelch
  , baumWelch'
  , simulate
  ) where

import           Control.Applicative                         ( (<$>) )
import           Control.Arrow                               ( first )
import           Data.List                                   ( elemIndex )
import           Data.Maybe                                  ( fromJust )
import           Data.Random.Distribution                    ( rvar )
import qualified Data.Random.Distribution.Categorical as C   ( Categorical, fromList, normalizeCategoricalPs )
import           Data.Random.Distribution.Extra              ( pmf )
import           Data.Random.RVar                            ( RVar )
import qualified Data.Vector                          as V   ( elemIndex, fromList, map, toList, unsafeIndex )
import qualified Data.Vector.Generic                  as G   ( convert )
import qualified Data.Vector.Unboxed                  as U   ( fromList, zip )
import qualified Numeric.LinearAlgebra.Data           as H   ( (!), fromList, fromLists, toList )
import qualified Numeric.LinearAlgebra.HMatrix        as H   ( tr )
import           Learning.IOHMM.Internal                     ( LogLikelihood )
import qualified Learning.IOHMM.Internal              as I
import           Prelude                              hiding ( init )

-- | Parameter set of the input-output hidden Markov model with discrete emission.
--   This 'IOHMM' assumes that the inputs affect only the transition
--   probabilities. The model schema is as follows.
--
--   @
--       x_0    x_1           x_n
--               |             |
--               v             v
--       z_0 -> z_1 -> ... -> z_n
--        |      |             |
--        v      v             v
--       y_0    y_1           y_n
--   @
--
--   Here, @[x_0, x_1, ..., x_n]@ are given inputs, @[z_0, z_1, ..., z_n]@
--   are hidden states, and @[y_0, y_1, ..., y_n]@ are observed outputs.
--   @z_0@ is determined by the 'initialStateDist'.
--   For @i = 1, ..., n@, @z_i@ is determined by the 'transitionDist'
--   conditioned by @x_i@ and @z_{i-1}@.
--   For @i = 0, ..., n@, @y_i@ is determined by the 'emissionDist'
--   conditioned by @z_i@.
data IOHMM i s o = IOHMM { inputs  :: [i]
                         , states  :: [s]
                         , outputs :: [o]
                         , initialStateDist :: C.Categorical Double s
                           -- ^ Categorical distribution of initial state
                         , transitionDist :: i -> s -> C.Categorical Double s
                           -- ^ Categorical distribution of next state
                           --   conditioned by the input and previous state
                         , emissionDist :: s -> C.Categorical Double o
                           -- ^ Categorical distribution of output conditioned
                           --   by the hidden state
                         }

instance (Show i, Show s, Show o) => Show (IOHMM i s o) where
  show IOHMM {..} = "IOHMM {inputs = "           ++ show inputs
                      ++ ", states = "           ++ show states
                      ++ ", outputs = "          ++ show outputs
                      ++ ", initialStateDist = " ++ show initialStateDist
                      ++ ", transitionDist = "   ++ show [(transitionDist i s, (i, s)) | i <- inputs, s <- states]
                      ++ ", emissionDist = "     ++ show [(emissionDist s, s) | s <- states]
                      ++ "}"

-- | @init inputs states outputs@ returns a random variable of models with the
--   @inputs@, @states@, and @outputs@, wherein parameters are sampled from uniform
--   distributions.
init :: (Eq i, Eq s, Eq o) => [i] -> [s] -> [o] -> RVar (IOHMM i s o)
init is ss os = fromInternal is ss os <$> I.init (length is) (length ss) (length os)

-- | @withEmission model xs ys@ returns a model in which the
--   'emissionDist' is updated by re-estimations using the inputs @xs@ and
--   outputs @ys@. The 'emissionDist' is set to be normalized histograms
--   each of which is calculated from segumentations of @ys@ based on the
--   Viterbi state path.
--   If the lengths of @xs@ and @ys@ are different, the longer one is cut
--   by the length of the shorter one.
withEmission :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> IOHMM i s o
withEmission (model @ IOHMM {..}) xs ys = fromInternal inputs states outputs $ I.withEmission model' $ U.zip xs' ys'
  where
    inputs'  = V.fromList inputs
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
    ys'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys

-- | @viterbi model xs ys@ performs the Viterbi algorithm using the inputs
--   @xs@ and outputs @ys@, and returns the most likely state path and its
--   log likelihood.
--   If the lengths of @xs@ and @ys@ are different, the longer one is cut
--   by the length of the shorter one.
viterbi :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> ([s], LogLikelihood)
viterbi (model @ IOHMM {..}) xs ys =
  checkModelIn "viterbi" model `seq`
  checkDataIn "viterbi" model xs ys `seq`
  first toStates $ I.viterbi model' $ U.zip xs' ys'
  where
    inputs'  = V.fromList inputs
    states'  = V.fromList states
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
    ys'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys
    toStates = V.toList . V.map (V.unsafeIndex states') . G.convert

-- | @baumWelch model xs ys@ iteratively performs the Baum-Welch algorithm
--   using the inputs @xs@ and outputs @ys@, and returns a list of updated
--   models and their corresponding log likelihoods.
--   If the lengths of @xs@ and @ys@ are different, the longer one is cut
--   by the length of the shorter one.
baumWelch :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> [(IOHMM i s o, LogLikelihood)]
baumWelch (model @ IOHMM {..}) xs ys =
  checkModelIn "baumWelch" model `seq`
  checkDataIn "baumWelch" model xs ys `seq`
  map (first $ fromInternal inputs states outputs) $ I.baumWelch model' $ U.zip xs' ys'
  where
    inputs'  = V.fromList inputs
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
    ys'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys

-- | @baumWelch' model xs@ performs the Baum-Welch algorithm using the
--   inputs @xs@ and outputs @ys@, and returns a model locally maximizing
--   its log likelihood.
baumWelch' :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> (IOHMM i s o, LogLikelihood)
baumWelch' (model @ IOHMM {..}) xs ys =
  checkModelIn "baumWelch" model `seq`
  checkDataIn "baumWelch" model xs ys `seq`
  first (fromInternal inputs states outputs) $ I.baumWelch' model' $ U.zip xs' ys'
  where
    inputs'  = V.fromList inputs
    outputs' = V.fromList outputs
    model'   = toInternal model
    xs'      = U.fromList $ fromJust $ mapM (`V.elemIndex` inputs') xs
    ys'      = U.fromList $ fromJust $ mapM (`V.elemIndex` outputs') ys

-- | @simulate model xs@ generates a Markov process coinciding with the
--   inputs @xs@ using the @model@, and returns its state path and observed
--   outputs.
simulate :: IOHMM i s o -> [i] -> RVar ([s], [o])
simulate IOHMM {..} xs
  | null xs   = return ([], [])
  | otherwise = do s0 <- rvar initialStateDist
                   y0 <- rvar $ emissionDist s0
                   unzip . ((s0, y0) :) <$> sim s0 (tail xs)
  where
    sim _ []      = return []
    sim s (x:xs') = do s' <- rvar $ transitionDist x s
                       y' <- rvar $ emissionDist s'
                       ((s', y') :) <$> sim s' xs'

-- | Check if the model is valid in the sense of whether the 'states' and
--   'outputs' are not empty.
checkModelIn :: String -> IOHMM i s o -> ()
checkModelIn fun IOHMM {..}
  | null inputs  = errorIn fun "empty inputs"
  | null states  = errorIn fun "empty states"
  | null outputs = errorIn fun "empty outputs"
  | otherwise    = ()

-- | Check if all the elements of the given inputs (outputs) are contained
--   in the 'inputs' ('outputs') of the model.
checkDataIn :: (Eq i, Eq o) => String -> IOHMM i s o -> [i] -> [o] -> ()
checkDataIn fun IOHMM {..} xs ys
  | all (`elem` inputs) xs && all (`elem` outputs) ys = ()
  | otherwise                                         = errorIn fun "illegal data"

-- | Convert internal 'IOHMM' to 'IOHMM'.
fromInternal :: (Eq i, Eq s, Eq o) => [i] -> [s] -> [o] -> I.IOHMM -> IOHMM i s o
fromInternal is ss os I.IOHMM {..} = IOHMM { inputs           = is
                                           , states           = ss
                                           , outputs          = os
                                           , initialStateDist = C.fromList pi0'
                                           , transitionDist   = \i s -> case (elemIndex i is, elemIndex s ss) of
                                                                          (Nothing, _)     -> C.fromList []
                                                                          (_, Nothing)     -> C.fromList []
                                                                          (Just j, Just k) -> C.fromList $ w' j k
                                           , emissionDist     = \s -> case elemIndex s ss of
                                                                        Nothing -> C.fromList []
                                                                        Just i  -> C.fromList $ phi' i
                                           }
  where
    pi0'   = zip (H.toList initialStateDist) ss
    w' j k = zip (H.toList $ V.unsafeIndex transitionDist j H.! k) ss
    phi' i = zip (H.toList $ H.tr emissionDistT H.! i) os

-- | Convert 'IOHMM' to internal 'IOHMM'. The 'initialStateDist'',
--   'transitionDist'', and 'emissionDistT'' are normalized.
toInternal :: (Eq i, Eq s, Eq o) => IOHMM i s o -> I.IOHMM
toInternal IOHMM {..} = I.IOHMM { I.nInputs          = length inputs
                                , I.nStates          = length states
                                , I.nOutputs         = length outputs
                                , I.initialStateDist = pi0
                                , I.transitionDist   = w
                                , I.emissionDistT    = phi'
                                }
  where
    pi0_ = C.normalizeCategoricalPs initialStateDist
    w_ i = C.normalizeCategoricalPs . transitionDist i
    phi_ = C.normalizeCategoricalPs . emissionDist
    pi0  = H.fromList [pmf pi0_ s | s <- states]
    w    = V.fromList $ map (\i -> H.fromLists [[pmf (w_ i s) s' | s' <- states] | s <- states]) inputs
    phi' = H.fromLists [[pmf (phi_ s) o | s <- states] | o <- outputs]

errorIn :: String -> String -> a
errorIn fun msg = error $ "Learning.IOHMM." ++ fun ++ ": " ++ msg