{-# LANGUAGE RankNTypes #-}

-- Technically, only a Getter is needed when calculating the density of the move
-- ('densCont', and similar functions). I tried splitting the lens into a getter
-- and a setter. However, speed improvements were marginal, and some times not
-- even measurable. Using a 'Lens'' is just easier, and has no real drawbacks.

-- |
-- Module      :  Mcmc.Move.Generic
-- Description :  Generic interface to create moves
-- Copyright   :  (c) Dominik Schrempf 2020
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu May 14 20:26:27 2020.
module Mcmc.Move.Generic
  ( moveGenericContinuous,
    moveSymmetricGenericContinuous,
    moveGenericDiscrete,
    moveSymmetricGenericDiscrete,
  )
where

import Lens.Micro
import Mcmc.Move
import Numeric.Log
import Statistics.Distribution
import System.Random.MWC

jumpCont ::
  (ContDistr d, ContGen d) =>
  Lens' a Double ->
  d ->
  (Double -> Double -> Double) ->
  a ->
  GenIO ->
  IO a
jumpCont l d f x g = do
  dx <- genContVar d g
  return $ set l ((x ^. l) `f` dx) x
{-# INLINEABLE jumpCont #-}

densCont ::
  (ContDistr d, ContGen d) =>
  Lens' a Double ->
  d ->
  (Double -> Double -> Double) ->
  a ->
  a ->
  Log Double
densCont l d fInv x y = Exp $ logDensity d ((y ^. l) `fInv` (x ^. l))
{-# INLINEABLE densCont #-}

-- | Generic function to create moves for continuous parameters ('Double').
moveGenericContinuous ::
  (ContDistr d, ContGen d) =>
  -- | Instruction about which parameter to change.
  Lens' a Double ->
  -- | Probability distribution
  d ->
  -- | Forward operator, e.g. (+), so that x + dx = y.
  (Double -> Double -> Double) ->
  -- | Inverse operator, e.g.,(-), so that y - dx = x.
  (Double -> Double -> Double) ->
  MoveSimple a
moveGenericContinuous l d f fInv =
  MoveSimple (jumpCont l d f) (Just $ densCont l d fInv)

-- | Generic function to create symmetric moves for continuous parameters ('Double').
moveSymmetricGenericContinuous ::
  (ContDistr d, ContGen d) =>
  -- | Instruction about which parameter to change.
  Lens' a Double ->
  -- | Probability distribution
  d ->
  -- | Forward operator, e.g. (+), so that x + dx = y.
  (Double -> Double -> Double) ->
  MoveSimple a
moveSymmetricGenericContinuous l d f =
  MoveSimple (jumpCont l d f) Nothing

jumpDiscrete ::
  (DiscreteDistr d, DiscreteGen d) =>
  Lens' a Int ->
  d ->
  (Int -> Int -> Int) ->
  a ->
  GenIO ->
  IO a
jumpDiscrete l d f x g = do
  dx <- genDiscreteVar d g
  return $ set l ((x ^. l) `f` dx) x
{-# INLINEABLE jumpDiscrete #-}

densDiscrete ::
  (DiscreteDistr d, DiscreteGen d) =>
  Lens' a Int ->
  d ->
  (Int -> Int -> Int) ->
  a ->
  a ->
  Log Double
densDiscrete l d fInv x y =
  Exp $ logProbability d ((y ^. l) `fInv` (x ^. l))
{-# INLINEABLE densDiscrete #-}

-- | Generic function to create moves for discrete parameters ('Int').
moveGenericDiscrete ::
  (DiscreteDistr d, DiscreteGen d) =>
  -- | Instruction about which parameter to change.
  Lens' a Int ->
  -- | Probability distribution.
  d ->
  -- | Forward operator, e.g. (+), so that x + dx = y.
  (Int -> Int -> Int) ->
  -- | Inverse operator, e.g.,(-), so that y - dx = x.
  (Int -> Int -> Int) ->
  MoveSimple a
moveGenericDiscrete l fd f fInv =
  MoveSimple (jumpDiscrete l fd f) (Just $ densDiscrete l fd fInv)

-- | Generic function to create symmetric moves for discrete parameters ('Int').
moveSymmetricGenericDiscrete ::
  (DiscreteDistr d, DiscreteGen d) =>
  -- | Instruction about which parameter to change.
  Lens' a Int ->
  -- | Probability distribution.
  d ->
  -- | Forward operator, e.g. (+), so that x + dx = y.
  (Int -> Int -> Int) ->
  MoveSimple a
moveSymmetricGenericDiscrete l fd f =
  MoveSimple (jumpDiscrete l fd f) Nothing