-- |
-- Module      :  Mcmc.Proposal.Bactrian
-- Description :  Bactrian proposals
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jun 25 15:49:48 2020.
--
-- See https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3845170/.
module Mcmc.Proposal.Bactrian
  ( SpikeParameter,
    slideBactrian,
    scaleBactrian,
  )
where

import Mcmc.Proposal
import Mcmc.Statistics.Types
import Numeric.Log
import Statistics.Distribution
import Statistics.Distribution.Normal
import System.Random.MWC.Distributions
import System.Random.Stateful

-- | Type synonym indicating the spike parameter.
type SpikeParameter = Double

genBactrian ::
  SpikeParameter ->
  StandardDeviation Double ->
  IOGenM StdGen ->
  IO Double
genBactrian :: Double -> Double -> IOGenM StdGen -> IO Double
genBactrian Double
m Double
s IOGenM StdGen
g = do
  let mn :: Double
mn = Double
m forall a. Num a => a -> a -> a
* Double
s
      sd :: Double
sd = forall a. Floating a => a -> a
sqrt (Double
1 forall a. Num a => a -> a -> a
- Double
m forall a. Num a => a -> a -> a
* Double
m) forall a. Num a => a -> a -> a
* Double
s
      d :: NormalDistribution
d = Double -> Double -> NormalDistribution
normalDistr Double
mn Double
sd
  Double
x <- forall d g (m :: * -> *).
(ContGen d, StatefulGen g m) =>
d -> g -> m Double
genContVar NormalDistribution
d IOGenM StdGen
g
  Bool
b <- forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
bernoulli Double
0.5 IOGenM StdGen
g
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Bool
b then Double
x else -Double
x

logDensityBactrian :: SpikeParameter -> StandardDeviation Double -> Double -> Log Double
logDensityBactrian :: Double -> Double -> Double -> KernelRatio
logDensityBactrian Double
m Double
s Double
x = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log forall a b. (a -> b) -> a -> b
$ Double
kernel1 forall a. Num a => a -> a -> a
+ Double
kernel2
  where
    mn :: Double
mn = Double
m forall a. Num a => a -> a -> a
* Double
s
    sd :: Double
sd = forall a. Floating a => a -> a
sqrt (Double
1 forall a. Num a => a -> a -> a
- Double
m forall a. Num a => a -> a -> a
* Double
m) forall a. Num a => a -> a -> a
* Double
s
    dist1 :: NormalDistribution
dist1 = Double -> Double -> NormalDistribution
normalDistr (-Double
mn) Double
sd
    dist2 :: NormalDistribution
dist2 = Double -> Double -> NormalDistribution
normalDistr Double
mn Double
sd
    kernel1 :: Double
kernel1 = forall d. ContDistr d => d -> Double -> Double
density NormalDistribution
dist1 Double
x
    kernel2 :: Double
kernel2 = forall d. ContDistr d => d -> Double -> Double
density NormalDistribution
dist2 Double
x

bactrianAdditive ::
  SpikeParameter ->
  StandardDeviation Double ->
  PFunction Double
bactrianAdditive :: Double -> Double -> PFunction Double
bactrianAdditive Double
m Double
s Double
x IOGenM StdGen
g = do
  Double
dx <- Double -> Double -> IOGenM StdGen -> IO Double
genBactrian Double
m Double
s IOGenM StdGen
g
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> KernelRatio -> KernelRatio -> PResult a
Propose (Double
x forall a. Num a => a -> a -> a
+ Double
dx) KernelRatio
1.0 KernelRatio
1.0, forall a. Maybe a
Nothing)

bactrianAdditivePFunction ::
  SpikeParameter ->
  StandardDeviation Double ->
  TuningParameter ->
  PFunction Double
bactrianAdditivePFunction :: Double -> Double -> Double -> PFunction Double
bactrianAdditivePFunction Double
m Double
s Double
t
  | Double
m forall a. Ord a => a -> a -> Bool
< Double
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditivePFunction: Spike parameter negative."
  | Double
m forall a. Ord a => a -> a -> Bool
>= Double
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditivePFunction: Spike parameter 1.0 or larger."
  | Double
s forall a. Ord a => a -> a -> Bool
<= Double
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditivePFunction: Standard deviation 0.0 or smaller."
  | Bool
otherwise = Double -> Double -> PFunction Double
bactrianAdditive Double
m (Double
t forall a. Num a => a -> a -> a
* Double
s)

-- | Additive symmetric proposal with kernel similar to the silhouette of a
-- Bactrian camel.
--
-- The [Bactrian
-- kernel](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3845170/figure/fig01) is
-- a mixture of two symmetrically arranged normal distributions. The spike
-- parameter \(m \in (0, 1)\) loosely determines the standard deviations of the
-- individual humps while the second parameter \(s > 0\) refers to the
-- standard deviation of the complete Bactrian kernel.
--
-- See https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3845170/.
slideBactrian ::
  SpikeParameter ->
  StandardDeviation Double ->
  PName ->
  PWeight ->
  Tune ->
  Proposal Double
slideBactrian :: Double -> Double -> PName -> PWeight -> Tune -> Proposal Double
slideBactrian Double
m Double
s = forall a.
PDescription
-> (Double -> PFunction a)
-> PSpeed
-> PDimension
-> PName
-> PWeight
-> Tune
-> Proposal a
createProposal PDescription
description (Double -> Double -> Double -> PFunction Double
bactrianAdditivePFunction Double
m Double
s) PSpeed
PFast (Int -> PDimension
PDimension Int
1)
  where
    description :: PDescription
description = [Char] -> PDescription
PDescription forall a b. (a -> b) -> a -> b
$ [Char]
"Slide Bactrian; spike: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Double
m forall a. [a] -> [a] -> [a]
++ [Char]
", sd: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Double
s

-- We have:
-- x  (1+dx ) = x'
-- x' (1+dx') = x.
--
-- Hence,
-- dx' = 1/(1-dx) - 1.
fInv :: Double -> Double
fInv :: Double -> Double
fInv Double
dx = forall a. Fractional a => a -> a
recip (Double
1 forall a. Num a => a -> a -> a
- Double
dx) forall a. Num a => a -> a -> a
- Double
1

bactrianMult ::
  SpikeParameter ->
  StandardDeviation Double ->
  PFunction Double
bactrianMult :: Double -> Double -> PFunction Double
bactrianMult Double
m Double
s Double
x IOGenM StdGen
g = do
  Double
du <- Double -> Double -> IOGenM StdGen -> IO Double
genBactrian Double
m Double
s IOGenM StdGen
g
  let qXY :: KernelRatio
qXY = Double -> Double -> Double -> KernelRatio
logDensityBactrian Double
m Double
s Double
du
      qYX :: KernelRatio
qYX = Double -> Double -> Double -> KernelRatio
logDensityBactrian Double
m Double
s (Double -> Double
fInv Double
du)
      u :: Double
u = Double
1.0 forall a. Num a => a -> a -> a
+ Double
du
      jac :: KernelRatio
jac = forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ forall a. Floating a => a -> a
log forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => a -> a
recip Double
u
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> KernelRatio -> KernelRatio -> PResult a
Propose (Double
x forall a. Num a => a -> a -> a
* Double
u) (KernelRatio
qYX forall a. Fractional a => a -> a -> a
/ KernelRatio
qXY) KernelRatio
jac, forall a. Maybe a
Nothing)

bactrianMultPFunction ::
  SpikeParameter ->
  StandardDeviation Double ->
  TuningParameter ->
  PFunction Double
bactrianMultPFunction :: Double -> Double -> Double -> PFunction Double
bactrianMultPFunction Double
m Double
s Double
t
  | Double
m forall a. Ord a => a -> a -> Bool
< Double
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultPFunction: Spike parameter negative."
  | Double
m forall a. Ord a => a -> a -> Bool
>= Double
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultPFunction: Spike parameter 1.0 or larger."
  | Double
s forall a. Ord a => a -> a -> Bool
<= Double
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultPFunction: Standard deviation 0.0 or smaller."
  | Bool
otherwise = Double -> Double -> PFunction Double
bactrianMult Double
m (Double
t forall a. Num a => a -> a -> a
* Double
s)

-- | Multiplicative proposal with kernel similar to the silhouette of a Bactrian
-- camel.
--
-- See 'Mcmc.Proposal.Scale.scale', and 'slideBactrian'.
scaleBactrian ::
  SpikeParameter ->
  StandardDeviation Double ->
  PName ->
  PWeight ->
  Tune ->
  Proposal Double
scaleBactrian :: Double -> Double -> PName -> PWeight -> Tune -> Proposal Double
scaleBactrian Double
m Double
s = forall a.
PDescription
-> (Double -> PFunction a)
-> PSpeed
-> PDimension
-> PName
-> PWeight
-> Tune
-> Proposal a
createProposal PDescription
description (Double -> Double -> Double -> PFunction Double
bactrianMultPFunction Double
m Double
s) PSpeed
PFast (Int -> PDimension
PDimension Int
1)
  where
    description :: PDescription
description = [Char] -> PDescription
PDescription forall a b. (a -> b) -> a -> b
$ [Char]
"Scale Bactrian; spike: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Double
m forall a. Semigroup a => a -> a -> a
<> [Char]
", sd: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Double
s