-- |
-- Module      :  ELynx.Simulate.MarkovProcess
-- Description :  Markov process helpers
-- Copyright   :  (c) Dominik Schrempf 2021
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jan 24 09:02:25 2019.
module ELynx.Simulate.MarkovProcess
  ( ProbMatrix,
    State,
    probMatrix,
    jump,
  )
where

import Control.Monad.Primitive
import ELynx.Data.MarkovProcess.RateMatrix
import Numeric.LinearAlgebra
import System.Random.MWC
import System.Random.MWC.Distributions

-- | A probability matrix, P_ij(t) = Pr (X_t = j | X_0 = i).
type ProbMatrix = Matrix R

-- | Make type signatures a little clearer.
type State = Int

-- | The important matrix that gives the probabilities to move from one state to
-- another in a specific time (branch length).
probMatrix :: RateMatrix -> Double -> ProbMatrix
probMatrix :: RateMatrix -> Double -> RateMatrix
probMatrix RateMatrix
q Double
t
  | Double
t Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 =
    if RateMatrix -> Int
forall t. Matrix t -> Int
rows RateMatrix
q Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== RateMatrix -> Int
forall t. Matrix t -> Int
cols RateMatrix
q
      then Int -> RateMatrix
forall a. (Num a, Element a) => Int -> Matrix a
ident (RateMatrix -> Int
forall t. Matrix t -> Int
rows RateMatrix
q)
      else [Char] -> RateMatrix
forall a. HasCallStack => [Char] -> a
error [Char]
"probMatrix: Matrix is not square."
  | Double
t Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0 = [Char] -> RateMatrix
forall a. HasCallStack => [Char] -> a
error [Char]
"probMatrix: Time is negative."
  | Bool
otherwise = RateMatrix -> RateMatrix
forall t. Field t => Matrix t -> Matrix t
expm (RateMatrix -> RateMatrix) -> RateMatrix -> RateMatrix
forall a b. (a -> b) -> a -> b
$ Double -> RateMatrix -> RateMatrix
forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale Double
t RateMatrix
q

-- | Move from a given state to a new one according to a transition probability
-- matrix .
--
-- This function is the bottleneck of the simulator and takes up most of the
-- computation time.
jump :: (PrimMonad m) => State -> ProbMatrix -> Gen (PrimState m) -> m State
jump :: Int -> RateMatrix -> Gen (PrimState m) -> m Int
jump Int
i RateMatrix
p = Vector Double -> Gen (PrimState m) -> m Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v Double) =>
v Double -> g -> m Int
categorical (RateMatrix
p RateMatrix -> Int -> Vector Double
forall c t. Indexable c t => c -> Int -> t
! Int
i)

-- XXX: Maybe for later, use condensed tables.
--
-- Write storable instance, compilation is really slow otherwise. instance
-- Storable (Int, R) where sizeOf (x, y) = sizeOf x + sizeOf y
--
-- Do not generate table for each jump.
--
-- jump :: (PrimMonad m) => State -> ProbMatrix -> Gen (PrimState m) -> m State
-- jump i p = genFromTable table
--   where
--     ws = toList $ p ! i
--     vsAndWs = fromList [ (v, w) | (v, w) <- zip [(0 :: Int) ..] ws
--                                 , w > 0 ]
--     table = tableFromProbabilities vsAndWs

-- -- | Perform N jumps from a given state and according to a transition
-- -- probability matrix transformed to a list of generators. This implementation
-- -- uses 'foldM' and I am not sure how to access or store the actual chain. This
-- -- could be done by an equivalent of 'scanl' for general monads, which I was
-- -- unable to find. This function is neat, but will most likely not be needed.
-- -- However, it is instructive and is left in place.
-- jumpN :: (MonadRandom m) => State -> [Generator State] -> Int -> m State
-- jumpN s p n = foldM jump s (replicate n p)