{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
module Inference.LW (
lw
, runLW
, handleObs) where
import qualified Data.Map as Map
import Env ( Env )
import Effects.ObsReader ( ObsReader )
import Control.Monad ( replicateM )
import Effects.Dist ( Dist, Observe(..), Sample )
import Prog ( discharge, Member, Prog(..) )
import PrimDist ( logProb )
import Model ( handleCore, Model )
import Sampler ( Sampler )
import Effects.State ( modify, handleState, State )
import Trace ( FromSTrace(..), STrace )
import Inference.SIM (traceSamples, handleSamp)
lw :: (FromSTrace env, es ~ '[ObsReader env, Dist, State STrace, Observe, Sample])
=> Int
-> (b -> Model env es a)
-> (b, Env env)
-> Sampler [(Env env, Double)]
lw :: forall (env :: [Assign Symbol (*)]) (es :: [* -> *]) b a.
(FromSTrace env,
es ~ '[ObsReader env, Dist, State STrace, Observe, Sample]) =>
Int
-> (b -> Model env es a)
-> (b, Env env)
-> Sampler [(Env env, Double)]
lw Int
n b -> Model env es a
model (b, Env env)
xs_envs = do
let runN :: (b, Env env) -> Sampler [((a, STrace), Double)]
runN (b
x, Env env
env) = Int
-> Sampler ((a, STrace), Double) -> Sampler [((a, STrace), Double)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Env env -> Model env es a -> Sampler ((a, STrace), Double)
forall (es :: [* -> *]) (env :: [Assign Symbol (*)]) a.
(es ~ '[ObsReader env, Dist, State STrace, Observe, Sample]) =>
Env env -> Model env es a -> Sampler ((a, STrace), Double)
runLW Env env
env (b -> Model env es a
model b
x))
[((a, STrace), Double)]
lwTrace <- (b, Env env) -> Sampler [((a, STrace), Double)]
runN (b, Env env)
xs_envs
[(Env env, Double)] -> Sampler [(Env env, Double)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Env env, Double)] -> Sampler [(Env env, Double)])
-> [(Env env, Double)] -> Sampler [(Env env, Double)]
forall a b. (a -> b) -> a -> b
$ (((a, STrace), Double) -> (Env env, Double))
-> [((a, STrace), Double)] -> [(Env env, Double)]
forall a b. (a -> b) -> [a] -> [b]
map (\((a
_, STrace
strace), Double
p) -> (STrace -> Env env
forall (env :: [Assign Symbol (*)]).
FromSTrace env =>
STrace -> Env env
fromSTrace STrace
strace, Double
p)) [((a, STrace), Double)]
lwTrace
runLW :: es ~ '[ObsReader env, Dist, State STrace, Observe, Sample]
=> Env env
-> Model env es a
-> Sampler ((a, STrace), Double)
runLW :: forall (es :: [* -> *]) (env :: [Assign Symbol (*)]) a.
(es ~ '[ObsReader env, Dist, State STrace, Observe, Sample]) =>
Env env -> Model env es a -> Sampler ((a, STrace), Double)
runLW Env env
env = Prog '[Sample] ((a, STrace), Double)
-> Sampler ((a, STrace), Double)
forall a. Prog '[Sample] a -> Sampler a
handleSamp (Prog '[Sample] ((a, STrace), Double)
-> Sampler ((a, STrace), Double))
-> (Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[Sample] ((a, STrace), Double))
-> Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Sampler ((a, STrace), Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double
-> Prog '[Observe, Sample] (a, STrace)
-> Prog '[Sample] ((a, STrace), Double)
forall (es :: [* -> *]) a.
Member Sample es =>
Double -> Prog (Observe : es) a -> Prog es (a, Double)
handleObs Double
0 (Prog '[Observe, Sample] (a, STrace)
-> Prog '[Sample] ((a, STrace), Double))
-> (Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[Observe, Sample] (a, STrace))
-> Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[Sample] ((a, STrace), Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STrace
-> Prog '[State STrace, Observe, Sample] a
-> Prog '[Observe, Sample] (a, STrace)
forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
handleState STrace
forall k a. Map k a
Map.empty (Prog '[State STrace, Observe, Sample] a
-> Prog '[Observe, Sample] (a, STrace))
-> (Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a)
-> Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[Observe, Sample] (a, STrace)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog '[State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a
forall (es :: [* -> *]) a.
(Member (State STrace) es, Member Sample es) =>
Prog es a -> Prog es a
traceSamples (Prog '[State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a)
-> (Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a)
-> Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env env
-> Model
env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a
forall (es :: [* -> *]) (env :: [Assign Symbol (*)]) a.
(Member Observe es, Member Sample es) =>
Env env -> Model env (ObsReader env : Dist : es) a -> Prog es a
handleCore Env env
env
handleObs :: Member Sample es
=> Double
-> Prog (Observe : es) a
-> Prog es (a, Double)
handleObs :: forall (es :: [* -> *]) a.
Member Sample es =>
Double -> Prog (Observe : es) a -> Prog es (a, Double)
handleObs Double
logp (Val a
x) = (a, Double) -> Prog es (a, Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Double -> Double
forall a. Floating a => a -> a
exp Double
logp)
handleObs Double
logp (Op EffectSum (Observe : es) x
u x -> Prog (Observe : es) a
k) = case EffectSum (Observe : es) x -> Either (EffectSum es x) (Observe x)
forall (e :: * -> *) (es :: [* -> *]) x.
EffectSum (e : es) x -> Either (EffectSum es x) (e x)
discharge EffectSum (Observe : es) x
u of
Right (Observe PrimDist x
d x
y Addr
α) -> do
let logp' :: Double
logp' = PrimDist x -> x -> Double
forall a. PrimDist a -> a -> Double
logProb PrimDist x
d x
y
Double -> Prog (Observe : es) a -> Prog es (a, Double)
forall (es :: [* -> *]) a.
Member Sample es =>
Double -> Prog (Observe : es) a -> Prog es (a, Double)
handleObs (Double
logp Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
logp') (x -> Prog (Observe : es) a
k x
y)
Left EffectSum es x
op' -> EffectSum es x -> (x -> Prog es (a, Double)) -> Prog es (a, Double)
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op EffectSum es x
op' (Double -> Prog (Observe : es) a -> Prog es (a, Double)
forall (es :: [* -> *]) a.
Member Sample es =>
Double -> Prog (Observe : es) a -> Prog es (a, Double)
handleObs Double
logp (Prog (Observe : es) a -> Prog es (a, Double))
-> (x -> Prog (Observe : es) a) -> x -> Prog es (a, Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog (Observe : es) a
k)