module Mcmc.Proposal.Bactrian
( slideBactrian,
scaleBactrian,
)
where
import Mcmc.Proposal
import Numeric.Log
import Statistics.Distribution
import Statistics.Distribution.Normal
import System.Random.MWC
import System.Random.MWC.Distributions
genBactrian ::
Double ->
Double ->
GenIO ->
IO Double
genBactrian :: Double -> Double -> GenIO -> IO Double
genBactrian Double
m Double
s GenIO
g = do
let mn :: Double
mn = Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s
sd :: Double
sd = Double -> Double
forall a. Floating a => a -> a
sqrt (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
m) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s
d :: NormalDistribution
d = Double -> Double -> NormalDistribution
normalDistr Double
mn Double
sd
Double
x <- NormalDistribution -> GenIO -> IO Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
genContVar NormalDistribution
d GenIO
g
Bool
b <- Double -> Gen RealWorld -> IO Bool
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Bool
bernoulli Double
0.5 Gen RealWorld
GenIO
g
Double -> IO Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double -> IO Double) -> Double -> IO Double
forall a b. (a -> b) -> a -> b
$ if Bool
b then Double
x else - Double
x
logDensityBactrian :: Double -> Double -> Double -> Log Double
logDensityBactrian :: Double -> Double -> Double -> Log Double
logDensityBactrian Double
m Double
s Double
x = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
kernel1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
kernel2
where
mn :: Double
mn = Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s
sd :: Double
sd = Double -> Double
forall a. Floating a => a -> a
sqrt (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
m) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s
dist1 :: NormalDistribution
dist1 = Double -> Double -> NormalDistribution
normalDistr (- Double
mn) Double
sd
dist2 :: NormalDistribution
dist2 = Double -> Double -> NormalDistribution
normalDistr Double
mn Double
sd
kernel1 :: Double
kernel1 = NormalDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density NormalDistribution
dist1 Double
x
kernel2 :: Double
kernel2 = NormalDistribution -> Double -> Double
forall d. ContDistr d => d -> Double -> Double
density NormalDistribution
dist2 Double
x
bactrianAdditive ::
Double ->
Double ->
Double ->
GenIO ->
IO (Double, Log Double, Log Double)
bactrianAdditive :: Double
-> Double -> Double -> GenIO -> IO (Double, Log Double, Log Double)
bactrianAdditive Double
m Double
s Double
x GenIO
g = do
Double
dx <- Double -> Double -> GenIO -> IO Double
genBactrian Double
m Double
s GenIO
g
(Double, Log Double, Log Double)
-> IO (Double, Log Double, Log Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
dx, Log Double
1.0, Log Double
1.0)
bactrianAdditiveSimple ::
Double ->
Double ->
Double ->
ProposalSimple Double
bactrianAdditiveSimple :: Double
-> Double
-> Double
-> Double
-> GenIO
-> IO (Double, Log Double, Log Double)
bactrianAdditiveSimple Double
m Double
s Double
t
| Double
m Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0 = [Char]
-> Double -> Gen RealWorld -> IO (Double, Log Double, Log Double)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditiveSimple: Spike parameter negative."
| Double
m Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
1 = [Char]
-> Double -> Gen RealWorld -> IO (Double, Log Double, Log Double)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditiveSimple: Spike parameter 1.0 or larger."
| Double
s Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0 = [Char]
-> Double -> Gen RealWorld -> IO (Double, Log Double, Log Double)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianAdditiveSimple: Standard deviation 0.0 or smaller."
| Bool
otherwise = Double
-> Double -> Double -> GenIO -> IO (Double, Log Double, Log Double)
bactrianAdditive Double
m (Double
t Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s)
slideBactrian ::
Double ->
Double ->
PName ->
PWeight ->
Tune ->
Proposal Double
slideBactrian :: Double -> Double -> PName -> PWeight -> Tune -> Proposal Double
slideBactrian Double
m Double
s = PDescription
-> (Double
-> Double -> GenIO -> IO (Double, Log Double, Log Double))
-> PName
-> PWeight
-> Tune
-> Proposal Double
forall a.
PDescription
-> (Double -> ProposalSimple a)
-> PName
-> PWeight
-> Tune
-> Proposal a
createProposal PDescription
description (Double
-> Double
-> Double
-> Double
-> GenIO
-> IO (Double, Log Double, Log Double)
bactrianAdditiveSimple Double
m Double
s)
where
description :: PDescription
description = [Char] -> PDescription
PDescription ([Char] -> PDescription) -> [Char] -> PDescription
forall a b. (a -> b) -> a -> b
$ [Char]
"Slide Bactrian; spike: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Double -> [Char]
forall a. Show a => a -> [Char]
show Double
m [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
", sd: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Double -> [Char]
forall a. Show a => a -> [Char]
show Double
s
fInv :: Double -> Double
fInv :: Double -> Double
fInv Double
dx = Double -> Double
forall a. Fractional a => a -> a
recip (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
dx) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1
bactrianMult ::
Double ->
Double ->
Double ->
GenIO ->
IO (Double, Log Double, Log Double)
bactrianMult :: Double
-> Double -> Double -> GenIO -> IO (Double, Log Double, Log Double)
bactrianMult Double
m Double
s Double
x GenIO
g = do
Double
du <- Double -> Double -> GenIO -> IO Double
genBactrian Double
m Double
s GenIO
g
let qXY :: Log Double
qXY = Double -> Double -> Double -> Log Double
logDensityBactrian Double
m Double
s Double
du
qYX :: Log Double
qYX = Double -> Double -> Double -> Log Double
logDensityBactrian Double
m Double
s (Double -> Double
fInv Double
du)
u :: Double
u = Double
1.0 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
du
jac :: Log Double
jac = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Fractional a => a -> a
recip Double
u
(Double, Log Double, Log Double)
-> IO (Double, Log Double, Log Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
u, Log Double
qYX Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
qXY, Log Double
jac)
bactrianMultSimple :: Double -> Double -> Double -> ProposalSimple Double
bactrianMultSimple :: Double
-> Double
-> Double
-> Double
-> GenIO
-> IO (Double, Log Double, Log Double)
bactrianMultSimple Double
m Double
s Double
t
| Double
m Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0 = [Char]
-> Double -> Gen RealWorld -> IO (Double, Log Double, Log Double)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultSimple: Spike parameter negative."
| Double
m Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
1 = [Char]
-> Double -> Gen RealWorld -> IO (Double, Log Double, Log Double)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultSimple: Spike parameter 1.0 or larger."
| Double
s Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0 = [Char]
-> Double -> Gen RealWorld -> IO (Double, Log Double, Log Double)
forall a. HasCallStack => [Char] -> a
error [Char]
"bactrianMultSimple: Standard deviation 0.0 or smaller."
| Bool
otherwise = Double
-> Double -> Double -> GenIO -> IO (Double, Log Double, Log Double)
bactrianMult Double
m (Double
t Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s)
scaleBactrian ::
Double ->
Double ->
PName ->
PWeight ->
Tune ->
Proposal Double
scaleBactrian :: Double -> Double -> PName -> PWeight -> Tune -> Proposal Double
scaleBactrian Double
m Double
s = PDescription
-> (Double
-> Double -> GenIO -> IO (Double, Log Double, Log Double))
-> PName
-> PWeight
-> Tune
-> Proposal Double
forall a.
PDescription
-> (Double -> ProposalSimple a)
-> PName
-> PWeight
-> Tune
-> Proposal a
createProposal PDescription
description (Double
-> Double
-> Double
-> Double
-> GenIO
-> IO (Double, Log Double, Log Double)
bactrianMultSimple Double
m Double
s)
where
description :: PDescription
description = [Char] -> PDescription
PDescription ([Char] -> PDescription) -> [Char] -> PDescription
forall a b. (a -> b) -> a -> b
$ [Char]
"Scale Bactrian; spike: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Double -> [Char]
forall a. Show a => a -> [Char]
show Double
m [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
", sd: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Double -> [Char]
forall a. Show a => a -> [Char]
show Double
s