{-# language OverloadedStrings #-}
{-# language DeriveFunctor, GeneralizedNewtypeDeriving #-}
{-# language FlexibleContexts #-}
module System.Random.MWC.Probability.Transition (
  -- * Transition
    Transition
  , mkTransition
  , runTransition
  -- ** Specialized combinators
  , evalTransition
  , execTransition
  -- ** Conditional execution
  , stepConditional
  -- * Helper functions
  , withSeverity
  -- * Re-exported from @logging-effect@
  -- ** Log message severity
  , L.WithSeverity(..), L.Severity(..)
  -- ** Handlers
  , L.Handler, L.withFDHandler
    -- ** Batched logging
  , L.BatchingOptions(..), L.defaultBatchingOptions, L.withBatchedHandler
  -- * Re-exported from @GHC.IO.Handle.FD@
  , stdout, stderr
  ) where

import Control.Monad
import Control.Monad.Primitive

-- import Control.Monad.Catch (MonadMask(..))
import GHC.IO.Handle (Handle(..))
import GHC.IO.Handle.FD (stdout, stderr)

import qualified Control.Monad.State as S
import Control.Monad.Trans.Class (MonadTrans(..), lift)
import Control.Monad.Trans.State.Strict (StateT(..), runStateT)

import qualified Control.Monad.Log as L
import Data.Char

import System.Random.MWC.Probability



-- | A Markov transition kernel.
newtype Transition message s m a = Transition (
  Gen (PrimState m) -> StateT s (L.LoggingT message m) a
  ) deriving (Functor)

instance Show (Transition msg s m a) where
  show _ = "<Transition>"

-- | Construct a 'Transition' from sampling, state transformation and logging functions.
--
-- NB: The three function arguments are used in the order in which they appear here:
--
-- 1. a random sample @w :: t@ is produced, using the current state @x :: s@ as input
--
-- 2. output @z :: a@ and next state @x' :: s@ are computed using @w@ and @x@
--
-- 3. a logging message is constructed, using @z@ and @x'@ as arguments.
mkTransition :: Monad m =>
        (s -> Prob m t)     -- ^ Generation of random data
     -> (s -> t -> (a, s))  -- ^ (Output, Next state)
     -> (s -> t -> a -> message) -- ^ Log message construction using (Next state, current random data, Output)
     -> Transition message s m a
mkTransition fm fs flog = Transition $ \gen -> do
  s <- S.get
  w <- lift . lift $ sample (fm s) gen
  let (a, s') = fs s w
  lift $ L.logMessage $ flog s' w a 
  S.put s'
  return a

-- | Run a 'Transition' for a number of steps, while logging each iteration.
--
-- Returns both the list of outputs and the final state.
runTransition :: Monad m =>
         L.Handler m message        -- ^ Logging handler
      -> Transition message s m a
      -> Int                      -- ^ Number of iterations 
      -> s                        -- ^ Initial state
      -> Gen (PrimState m)        -- ^ PRNG
      -> m ([a], s)               -- ^ (Outputs, Final state)
runTransition logf (Transition fm) n s0 g =
  L.runLoggingT (runStateT (replicateM n (fm g)) s0) logf


-- | Run a 'Transition' for a number of steps, while logging each iteration.
--
-- Returns the list of outputs.
evalTransition :: Monad m =>
                  L.Handler m message
               -> Transition message s m a
               -> Int
               -> s
               -> Gen (PrimState m)
               -> m [a]              -- ^ Outputs
evalTransition logf tfm n s0 g = fst <$> runTransition logf tfm n s0 g
  

-- | Run a 'Transition' for a number of steps, while logging each iteration.
--
-- Returns the final state.
execTransition :: Monad m =>
                  L.Handler m message
               -> Transition message s m a
               -> Int
               -> s
               -> Gen (PrimState m)
               -> m s              -- ^ Final state
execTransition logf tfm n s0 g = snd <$> runTransition logf tfm n s0 g


-- | Perform one 'Transition' and check output and updated state against the current state, producing an Either with the result of the comparison.
--
-- Can be useful for detecting early divergence or lack of convergence etc.
stepConditional :: Monad m =>
                   (a -> s -> s -> Bool)  -- ^ Inputs: Model output, Current state, New state
                -> (a -> s -> s -> l)     -- ^ "
                -> (a -> s -> s -> r)     -- ^ "
                -> L.Handler m message      
                -> Transition message s m a
                -> s                      -- ^ Current state
                -> Gen (PrimState m)
                -> m (Either l r)
stepConditional q fleft fright logf (Transition fm) s g = do 
  (a, s') <- L.runLoggingT (runStateT (fm g) s) logf
  if q a s s' then pure (Left $ fleft a s s') else pure (Right $ fright a s s')




-- * Helpers
  
bracketsUpp :: Show a => a -> String
bracketsUpp p = unwords ["[", map toUpper (show p), "]"]

-- | Render a logging message along with an annotation of its severity.
withSeverity :: (t -> String) -> L.WithSeverity t -> String
withSeverity k (L.WithSeverity u a ) = unwords [bracketsUpp u, k a]