{-# LANGUAGE BangPatterns #-}

-- |
-- Module      :  Mcmc.Metropolis
-- Description :  Metropolis-Hastings at its best
-- Copyright   :  (c) Dominik Schrempf 2020
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Tue May  5 20:11:30 2020.
module Mcmc.Metropolis
  ( mh,
    mhContinue,
  )
where

import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.State.Strict
import Data.Aeson
import Data.Maybe
import Mcmc.Item
import Mcmc.Mcmc
import Mcmc.Move
import Mcmc.Status
import Mcmc.Tools.Shuffle
import Mcmc.Trace
import Numeric.Log
import System.Random.MWC
import Prelude hiding (cycle)

-- For non-symmetric moves.
mhRatio :: Log Double -> Log Double -> Log Double -> Log Double -> Log Double
mhRatio lX lY qXY qYX = lY * qYX / lX / qXY
{-# INLINE mhRatio #-}

-- For symmetric moves.
mhRatioSymmetric :: Log Double -> Log Double -> Log Double
mhRatioSymmetric lX lY = lY / lX
{-# INLINE mhRatioSymmetric #-}

mhMove :: Move a -> Mcmc a ()
mhMove m = do
  let p = mvSample $ mvSimple m
      mq = mvDensity $ mvSimple m
  s <- get
  let (Item x pX lX) = item s
      pF = priorF s
      lF = likelihoodF s
      a = acceptance s
      g = generator s
  -- 1. Sample new state.
  !y <- liftIO $ p x g
  -- 2. Calculate Metropolis-Hastings ratio.
  let !pY = pF y
      !lY = lF y
      !r = case mq of
        Nothing -> mhRatioSymmetric (pX * lX) (pY * lY)
        Just q -> mhRatio (pX * lX) (pY * lY) (q x y) (q y x)
  -- 3. Accept or reject.
  if ln r >= 0.0
    then put $ s {item = Item y pY lY, acceptance = pushA m True a}
    else do
      b <- uniform g
      if b < exp (ln r)
        then put $ s {item = Item y pY lY, acceptance = pushA m True a}
        else put $ s {acceptance = pushA m False a}

-- TODO: Split the generator here. See SaveSpec -> mhContinue.

-- Replicate 'Move's according to their weights and shuffle them.
getNCycles :: Cycle a -> Int -> GenIO -> IO [[Move a]]
getNCycles c = shuffleN mvs
  where
    !mvs = concat [replicate (mvWeight m) m | m <- fromCycle c]

-- Run one iterations; perform all moves in a Cycle.
mhIter :: ToJSON a => [Move a] -> Mcmc a ()
mhIter mvs = do
  mapM_ mhMove mvs
  s <- get
  let i = item s
      t = trace s
      n = iteration s
  put $ s {trace = pushT i t, iteration = succ n}
  mcmcMonitorExec

-- Run N iterations.
mhNIter :: ToJSON a => Int -> Mcmc a ()
mhNIter n = do
  c <- gets cycle
  g <- gets generator
  cycles <- liftIO $ getNCycles c n g
  forM_ cycles mhIter

-- Burn in and auto tune.
mhBurnInN :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnInN b (Just t)
  | t <= 0 = error "mhBurnInN: Auto tuning period smaller equal 0."
  | b > t = mhNIter t >> mcmcAutotune t >> mhBurnInN (b - t) (Just t)
  | otherwise = mhNIter b >> mcmcAutotune b
mhBurnInN b Nothing = mhNIter b

-- Initialize burn in for given number of iterations.
mhBurnIn :: ToJSON a => Int -> Maybe Int -> Mcmc a ()
mhBurnIn b t
  | b < 0 = error "mhBurnIn: Negative number of burn in iterations."
  | b == 0 = return ()
  | otherwise = do
    liftIO $ putStrLn $ "-- Burn in for " <> show b <> " cycles."
    mcmcMonitorHeader
    mhBurnInN b t
    liftIO $ putStrLn "-- Burn in finished."
    case t of
      Nothing -> return ()
      Just _ -> mcmcSummarizeCycle t
    s <- get
    let a = acceptance s
    put $ s {acceptance = resetA a}

-- Run for given number of iterations.
mhRun :: ToJSON a => Int -> Mcmc a ()
mhRun n = do
  liftIO $ putStrLn $ "-- Run chain for " <> show n <> " iterations."
  mcmcMonitorHeader
  mhNIter n

mhT :: ToJSON a => Mcmc a ()
mhT = do
  s <- get
  liftIO $ putStrLn "-- Start of Metropolis-Hastings sampler."
  mcmcInit
  mcmcReport
  mcmcSummarizeCycle Nothing
  let b = fromMaybe 0 (burnInIterations s)
  mhBurnIn b (autoTuningPeriod s)
  let n = iterations s
  mhRun n
  mcmcClose

mhContinueT :: ToJSON a => Int -> Mcmc a ()
mhContinueT dn = do
  liftIO $ putStrLn "-- Continue Metropolis-Hastings sampler."
  liftIO $ putStrLn $ "-- Run chain for " <> show dn <> " additional iterations."
  s <- get
  let n = iterations s
  put s {iterations = n + dn}
  mcmcInit
  mcmcSummarizeCycle Nothing
  mhRun dn
  mcmcClose

-- | Continue a Markov chain for a given number of Metropolis-Hastings steps.
mhContinue ::
  ToJSON a =>
  -- | Additional number of Metropolis-Hastings steps.
  Int ->
  -- | Loaded status of the Markov chain.
  Status a ->
  IO (Status a)
mhContinue dn
  | dn <= 0 = error "mhContinue: The number of iterations is zero or negative."
  | otherwise = execStateT $ mhContinueT dn

-- | Run a Markov chain for a given number of Metropolis-Hastings steps.
mh ::
  ToJSON a =>
  -- | Initial (or last) status of the Markov chain.
  Status a ->
  IO (Status a)
mh s = do
  let m = iteration s
  if m == 0
    then execStateT mhT s
    else do
      putStrLn "To continue a Markov chain run, please use 'mhContinue'."
      error $ "mh: Current iteration " ++ show m ++ " is non-zero."