-- |
-- Module      :  Mcmc.Proposal.Bactrian
-- Description :  Bactrian proposals
-- Copyright   :  (c) Dominik Schrempf, 2021
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Thu Jun 25 15:49:48 2020.
--
-- See https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3845170/.
module Mcmc.Proposal.Bactrian
  ( SpikeParameter,
    slideBactrian,
    scaleBactrian,
  )
where

import Mcmc.Proposal
import Mcmc.Statistics.Types
import Numeric.Log
import Statistics.Distribution
import Statistics.Distribution.Normal
import System.Random.MWC
import System.Random.MWC.Distributions

-- | Type synonym indicating the spike parameter.
type SpikeParameter = Double

genBactrian ::
  SpikeParameter ->
  StandardDeviation Double ->
  GenIO ->
  IO Double
genBactrian :: SpikeParameter -> SpikeParameter -> GenIO -> IO SpikeParameter
genBactrian SpikeParameter
m SpikeParameter
s GenIO
g = do
  let mn :: SpikeParameter
mn = SpikeParameter
m SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
s
      sd :: SpikeParameter
sd = SpikeParameter -> SpikeParameter
forall a. Floating a => a -> a
sqrt (SpikeParameter
1 SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
- SpikeParameter
m SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
m) SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
s
      d :: NormalDistribution
d = SpikeParameter -> SpikeParameter -> NormalDistribution
normalDistr SpikeParameter
mn SpikeParameter
sd
  SpikeParameter
x <- NormalDistribution -> Gen RealWorld -> IO SpikeParameter
forall d g (m :: * -> *).
(ContGen d, StatefulGen g m) =>
d -> g -> m SpikeParameter
genContVar NormalDistribution
d Gen RealWorld
GenIO
g
  Bool
b <- SpikeParameter -> Gen RealWorld -> IO Bool
forall g (m :: * -> *).
StatefulGen g m =>
SpikeParameter -> g -> m Bool
bernoulli SpikeParameter
0.5 Gen RealWorld
GenIO
g
  SpikeParameter -> IO SpikeParameter
forall (m :: * -> *) a. Monad m => a -> m a
return (SpikeParameter -> IO SpikeParameter)
-> SpikeParameter -> IO SpikeParameter
forall a b. (a -> b) -> a -> b
$ if Bool
b then SpikeParameter
x else -SpikeParameter
x

logDensityBactrian :: SpikeParameter -> StandardDeviation Double -> Double -> Log Double
logDensityBactrian :: SpikeParameter
-> SpikeParameter -> SpikeParameter -> Log SpikeParameter
logDensityBactrian SpikeParameter
m SpikeParameter
s SpikeParameter
x = SpikeParameter -> Log SpikeParameter
forall a. a -> Log a
Exp (SpikeParameter -> Log SpikeParameter)
-> SpikeParameter -> Log SpikeParameter
forall a b. (a -> b) -> a -> b
$ SpikeParameter -> SpikeParameter
forall a. Floating a => a -> a
log (SpikeParameter -> SpikeParameter)
-> SpikeParameter -> SpikeParameter
forall a b. (a -> b) -> a -> b
$ SpikeParameter
kernel1 SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
+ SpikeParameter
kernel2
  where
    mn :: SpikeParameter
mn = SpikeParameter
m SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
s
    sd :: SpikeParameter
sd = SpikeParameter -> SpikeParameter
forall a. Floating a => a -> a
sqrt (SpikeParameter
1 SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
- SpikeParameter
m SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
m) SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
s
    dist1 :: NormalDistribution
dist1 = SpikeParameter -> SpikeParameter -> NormalDistribution
normalDistr (-SpikeParameter
mn) SpikeParameter
sd
    dist2 :: NormalDistribution
dist2 = SpikeParameter -> SpikeParameter -> NormalDistribution
normalDistr SpikeParameter
mn SpikeParameter
sd
    kernel1 :: SpikeParameter
kernel1 = NormalDistribution -> SpikeParameter -> SpikeParameter
forall d. ContDistr d => d -> SpikeParameter -> SpikeParameter
density NormalDistribution
dist1 SpikeParameter
x
    kernel2 :: SpikeParameter
kernel2 = NormalDistribution -> SpikeParameter -> SpikeParameter
forall d. ContDistr d => d -> SpikeParameter -> SpikeParameter
density NormalDistribution
dist2 SpikeParameter
x

bactrianAdditive ::
  SpikeParameter ->
  StandardDeviation Double ->
  ProposalSimple Double
bactrianAdditive :: SpikeParameter -> SpikeParameter -> ProposalSimple SpikeParameter
bactrianAdditive SpikeParameter
m SpikeParameter
s SpikeParameter
x GenIO
g = do
  SpikeParameter
dx <- SpikeParameter -> SpikeParameter -> GenIO -> IO SpikeParameter
genBactrian SpikeParameter
m SpikeParameter
s GenIO
g
  (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall (m :: * -> *) a. Monad m => a -> m a
return (SpikeParameter
x SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
+ SpikeParameter
dx, Log SpikeParameter
1.0, Log SpikeParameter
1.0)

-- bactrianSimple lens spike stdDev tune forwardOp backwardOp
bactrianAdditiveSimple ::
  SpikeParameter ->
  StandardDeviation Double ->
  TuningParameter ->
  ProposalSimple Double
bactrianAdditiveSimple :: SpikeParameter
-> SpikeParameter
-> SpikeParameter
-> ProposalSimple SpikeParameter
bactrianAdditiveSimple SpikeParameter
m SpikeParameter
s SpikeParameter
t
  | SpikeParameter
m SpikeParameter -> SpikeParameter -> Bool
forall a. Ord a => a -> a -> Bool
< SpikeParameter
0 = [Char]
-> SpikeParameter
-> Gen RealWorld
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditiveSimple: Spike parameter negative."
  | SpikeParameter
m SpikeParameter -> SpikeParameter -> Bool
forall a. Ord a => a -> a -> Bool
>= SpikeParameter
1 = [Char]
-> SpikeParameter
-> Gen RealWorld
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditiveSimple: Spike parameter 1.0 or larger."
  | SpikeParameter
s SpikeParameter -> SpikeParameter -> Bool
forall a. Ord a => a -> a -> Bool
<= SpikeParameter
0 = [Char]
-> SpikeParameter
-> Gen RealWorld
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditiveSimple: Standard deviation 0.0 or smaller."
  | Bool
otherwise = SpikeParameter -> SpikeParameter -> ProposalSimple SpikeParameter
bactrianAdditive SpikeParameter
m (SpikeParameter
t SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
s)

-- | Additive symmetric proposal with kernel similar to the silhouette of a
-- Bactrian camel.
--
-- The [Bactrian
-- kernel](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3845170/figure/fig01) is
-- a mixture of two symmetrically arranged normal distributions. The spike
-- parameter \(m \in (0, 1)\) loosely determines the standard deviations of the
-- individual humps while the second parameter \(s > 0\) refers to the
-- standard deviation of the complete Bactrian kernel.
--
-- See https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3845170/.
slideBactrian ::
  SpikeParameter ->
  StandardDeviation Double ->
  PName ->
  PWeight ->
  Tune ->
  Proposal Double
slideBactrian :: SpikeParameter
-> SpikeParameter
-> PName
-> PWeight
-> Tune
-> Proposal SpikeParameter
slideBactrian SpikeParameter
m SpikeParameter
s = PDescription
-> (SpikeParameter -> ProposalSimple SpikeParameter)
-> PDimension
-> PName
-> PWeight
-> Tune
-> Proposal SpikeParameter
forall a.
PDescription
-> (SpikeParameter -> ProposalSimple a)
-> PDimension
-> PName
-> PWeight
-> Tune
-> Proposal a
createProposal PDescription
description (SpikeParameter
-> SpikeParameter
-> SpikeParameter
-> ProposalSimple SpikeParameter
bactrianAdditiveSimple SpikeParameter
m SpikeParameter
s) (Int -> PDimension
PDimension Int
1)
  where
    description :: PDescription
description = [Char] -> PDescription
PDescription ([Char] -> PDescription) -> [Char] -> PDescription
forall a b. (a -> b) -> a -> b
$ [Char]
"Slide Bactrian; spike: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SpikeParameter -> [Char]
forall a. Show a => a -> [Char]
show SpikeParameter
m [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", sd: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SpikeParameter -> [Char]
forall a. Show a => a -> [Char]
show SpikeParameter
s

-- We have:
-- x  (1+dx ) = x'
-- x' (1+dx') = x.
--
-- Hence,
-- dx' = 1/(1-dx) - 1.
fInv :: Double -> Double
fInv :: SpikeParameter -> SpikeParameter
fInv SpikeParameter
dx = SpikeParameter -> SpikeParameter
forall a. Fractional a => a -> a
recip (SpikeParameter
1 SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
- SpikeParameter
dx) SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
- SpikeParameter
1

bactrianMult ::
  SpikeParameter ->
  StandardDeviation Double ->
  ProposalSimple Double
bactrianMult :: SpikeParameter -> SpikeParameter -> ProposalSimple SpikeParameter
bactrianMult SpikeParameter
m SpikeParameter
s SpikeParameter
x GenIO
g = do
  SpikeParameter
du <- SpikeParameter -> SpikeParameter -> GenIO -> IO SpikeParameter
genBactrian SpikeParameter
m SpikeParameter
s GenIO
g
  let qXY :: Log SpikeParameter
qXY = SpikeParameter
-> SpikeParameter -> SpikeParameter -> Log SpikeParameter
logDensityBactrian SpikeParameter
m SpikeParameter
s SpikeParameter
du
      qYX :: Log SpikeParameter
qYX = SpikeParameter
-> SpikeParameter -> SpikeParameter -> Log SpikeParameter
logDensityBactrian SpikeParameter
m SpikeParameter
s (SpikeParameter -> SpikeParameter
fInv SpikeParameter
du)
      u :: SpikeParameter
u = SpikeParameter
1.0 SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
+ SpikeParameter
du
      jac :: Log SpikeParameter
jac = SpikeParameter -> Log SpikeParameter
forall a. a -> Log a
Exp (SpikeParameter -> Log SpikeParameter)
-> SpikeParameter -> Log SpikeParameter
forall a b. (a -> b) -> a -> b
$ SpikeParameter -> SpikeParameter
forall a. Floating a => a -> a
log (SpikeParameter -> SpikeParameter)
-> SpikeParameter -> SpikeParameter
forall a b. (a -> b) -> a -> b
$ SpikeParameter -> SpikeParameter
forall a. Fractional a => a -> a
recip SpikeParameter
u
  (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall (m :: * -> *) a. Monad m => a -> m a
return (SpikeParameter
x SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
u, Log SpikeParameter
qYX Log SpikeParameter -> Log SpikeParameter -> Log SpikeParameter
forall a. Fractional a => a -> a -> a
/ Log SpikeParameter
qXY, Log SpikeParameter
jac)

bactrianMultSimple ::
  SpikeParameter ->
  StandardDeviation Double ->
  TuningParameter ->
  ProposalSimple Double
bactrianMultSimple :: SpikeParameter
-> SpikeParameter
-> SpikeParameter
-> ProposalSimple SpikeParameter
bactrianMultSimple SpikeParameter
m SpikeParameter
s SpikeParameter
t
  | SpikeParameter
m SpikeParameter -> SpikeParameter -> Bool
forall a. Ord a => a -> a -> Bool
< SpikeParameter
0 = [Char]
-> SpikeParameter
-> Gen RealWorld
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultSimple: Spike parameter negative."
  | SpikeParameter
m SpikeParameter -> SpikeParameter -> Bool
forall a. Ord a => a -> a -> Bool
>= SpikeParameter
1 = [Char]
-> SpikeParameter
-> Gen RealWorld
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultSimple: Spike parameter 1.0 or larger."
  | SpikeParameter
s SpikeParameter -> SpikeParameter -> Bool
forall a. Ord a => a -> a -> Bool
<= SpikeParameter
0 = [Char]
-> SpikeParameter
-> Gen RealWorld
-> IO (SpikeParameter, Log SpikeParameter, Log SpikeParameter)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultSimple: Standard deviation 0.0 or smaller."
  | Bool
otherwise = SpikeParameter -> SpikeParameter -> ProposalSimple SpikeParameter
bactrianMult SpikeParameter
m (SpikeParameter
t SpikeParameter -> SpikeParameter -> SpikeParameter
forall a. Num a => a -> a -> a
* SpikeParameter
s)

-- | Multiplicative proposal with kernel similar to the silhouette of a Bactrian
-- camel.
--
-- See 'Mcmc.Proposal.Scale.scale', and 'slideBactrian'.
scaleBactrian ::
  SpikeParameter ->
  StandardDeviation Double ->
  PName ->
  PWeight ->
  Tune ->
  Proposal Double
scaleBactrian :: SpikeParameter
-> SpikeParameter
-> PName
-> PWeight
-> Tune
-> Proposal SpikeParameter
scaleBactrian SpikeParameter
m SpikeParameter
s = PDescription
-> (SpikeParameter -> ProposalSimple SpikeParameter)
-> PDimension
-> PName
-> PWeight
-> Tune
-> Proposal SpikeParameter
forall a.
PDescription
-> (SpikeParameter -> ProposalSimple a)
-> PDimension
-> PName
-> PWeight
-> Tune
-> Proposal a
createProposal PDescription
description (SpikeParameter
-> SpikeParameter
-> SpikeParameter
-> ProposalSimple SpikeParameter
bactrianMultSimple SpikeParameter
m SpikeParameter
s) (Int -> PDimension
PDimension Int
1)
  where
    description :: PDescription
description = [Char] -> PDescription
PDescription ([Char] -> PDescription) -> [Char] -> PDescription
forall a b. (a -> b) -> a -> b
$ [Char]
"Scale Bactrian; spike: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SpikeParameter -> [Char]
forall a. Show a => a -> [Char]
show SpikeParameter
m [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
", sd: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> SpikeParameter -> [Char]
forall a. Show a => a -> [Char]
show SpikeParameter
s