{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}

{- | Simulation 
-}

module Inference.SIM (
    simulate
  , runSimulate
  , traceSamples
  , handleSamp
  , handleObs) where

import Data.Map (Map)
import Effects.Dist ( Observe(..), Sample(..), Dist )
import Effects.ObsReader ( ObsReader )
import Effects.State ( State, modify, handleState )
import Env ( Env )
import Model ( Model, handleCore )
import OpenSum (OpenSum)
import PrimDist
import Prog ( Member(prj), Prog(..), discharge )
import qualified Data.Map as Map
import qualified OpenSum
import Sampler ( Sampler )
import Trace ( FromSTrace(..), STrace, updateSTrace )
import Unsafe.Coerce (unsafeCoerce)

-- | Top-level wrapper for simulating from a model
simulate :: (FromSTrace env, es ~ '[ObsReader env, Dist,State STrace, Observe, Sample])
  -- | A model awaiting an input
  => (b -> Model env es a)  
  -- | A model environment
  -> Env env               
  -- | Model input 
  -> b                    
  -- | Sampler generating: (model output, output environment)  
  -> Sampler (a, Env env)   
simulate :: forall (env :: [Assign Symbol (*)]) (es :: [* -> *]) b a.
(FromSTrace env,
 es ~ '[ObsReader env, Dist, State STrace, Observe, Sample]) =>
(b -> Model env es a) -> Env env -> b -> Sampler (a, Env env)
simulate b -> Model env es a
model Env env
env b
x  = do
  (a, STrace)
outputs_strace <- Env env -> Model env es a -> Sampler (a, STrace)
forall (es :: [* -> *]) (env :: [Assign Symbol (*)]) a.
(es ~ '[ObsReader env, Dist, State STrace, Observe, Sample]) =>
Env env -> Model env es a -> Sampler (a, STrace)
runSimulate Env env
env (b -> Model env es a
model b
x)
  (a, Env env) -> Sampler (a, Env env)
forall (m :: * -> *) a. Monad m => a -> m a
return ((STrace -> Env env) -> (a, STrace) -> (a, Env env)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap STrace -> Env env
forall (env :: [Assign Symbol (*)]).
FromSTrace env =>
STrace -> Env env
fromSTrace (a, STrace)
outputs_strace)

-- | Handler for simulating once from a probabilistic program
runSimulate :: (es ~ '[ObsReader env, Dist, State STrace, Observe, Sample])
 -- | Model environment
 => Env env 
 -- | Model
 -> Model env es a 
 -- | Sampler generating: (model output, sample trace)  
 -> Sampler (a, STrace)
runSimulate :: forall (es :: [* -> *]) (env :: [Assign Symbol (*)]) a.
(es ~ '[ObsReader env, Dist, State STrace, Observe, Sample]) =>
Env env -> Model env es a -> Sampler (a, STrace)
runSimulate Env env
env
  = Prog '[Sample] (a, STrace) -> Sampler (a, STrace)
forall a. Prog '[Sample] a -> Sampler a
handleSamp (Prog '[Sample] (a, STrace) -> Sampler (a, STrace))
-> (Model
      env '[ObsReader env, Dist, State STrace, Observe, Sample] a
    -> Prog '[Sample] (a, STrace))
-> Model
     env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Sampler (a, STrace)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog '[Observe, Sample] (a, STrace) -> Prog '[Sample] (a, STrace)
forall (es :: [* -> *]) a. Prog (Observe : es) a -> Prog es a
handleObs (Prog '[Observe, Sample] (a, STrace) -> Prog '[Sample] (a, STrace))
-> (Model
      env '[ObsReader env, Dist, State STrace, Observe, Sample] a
    -> Prog '[Observe, Sample] (a, STrace))
-> Model
     env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[Sample] (a, STrace)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STrace
-> Prog '[State STrace, Observe, Sample] a
-> Prog '[Observe, Sample] (a, STrace)
forall s (es :: [* -> *]) a.
s -> Prog (State s : es) a -> Prog es (a, s)
handleState STrace
forall k a. Map k a
Map.empty (Prog '[State STrace, Observe, Sample] a
 -> Prog '[Observe, Sample] (a, STrace))
-> (Model
      env '[ObsReader env, Dist, State STrace, Observe, Sample] a
    -> Prog '[State STrace, Observe, Sample] a)
-> Model
     env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[Observe, Sample] (a, STrace)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog '[State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a
forall (es :: [* -> *]) a.
(Member (State STrace) es, Member Sample es) =>
Prog es a -> Prog es a
traceSamples (Prog '[State STrace, Observe, Sample] a
 -> Prog '[State STrace, Observe, Sample] a)
-> (Model
      env '[ObsReader env, Dist, State STrace, Observe, Sample] a
    -> Prog '[State STrace, Observe, Sample] a)
-> Model
     env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env env
-> Model
     env '[ObsReader env, Dist, State STrace, Observe, Sample] a
-> Prog '[State STrace, Observe, Sample] a
forall (es :: [* -> *]) (env :: [Assign Symbol (*)]) a.
(Member Observe es, Member Sample es) =>
Env env -> Model env (ObsReader env : Dist : es) a -> Prog es a
handleCore Env env
env

-- | Trace sampled values for each @Sample@ operation
traceSamples :: (Member (State STrace) es, Member Sample es) => Prog es a -> Prog es a
traceSamples :: forall (es :: [* -> *]) a.
(Member (State STrace) es, Member Sample es) =>
Prog es a -> Prog es a
traceSamples (Val a
x) = a -> Prog es a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
traceSamples (Op EffectSum es x
op x -> Prog es a
k) = case EffectSum es x -> Maybe (Sample x)
forall (e :: * -> *) (es :: [* -> *]) x.
Member e es =>
EffectSum es x -> Maybe (e x)
prj EffectSum es x
op of 
  Just (Sample (PrimDistPrf PrimDist x
d) Addr
α) ->
       EffectSum es x -> (x -> Prog es a) -> Prog es a
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op EffectSum es x
op (\x
x -> do (STrace -> STrace) -> Prog es ()
forall s (es :: [* -> *]).
Member (State s) es =>
(s -> s) -> Prog es ()
modify (Addr -> PrimDist x -> x -> STrace -> STrace
forall x.
(Show x, Member x PrimVal) =>
Addr -> PrimDist x -> x -> STrace -> STrace
updateSTrace Addr
α PrimDist x
d x
x);
                       Prog es a -> Prog es a
forall (es :: [* -> *]) a.
(Member (State STrace) es, Member Sample es) =>
Prog es a -> Prog es a
traceSamples (x -> Prog es a
k x
x))
  Maybe (Sample x)
Nothing -> EffectSum es x -> (x -> Prog es a) -> Prog es a
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op EffectSum es x
op (Prog es a -> Prog es a
forall (es :: [* -> *]) a.
(Member (State STrace) es, Member Sample es) =>
Prog es a -> Prog es a
traceSamples (Prog es a -> Prog es a) -> (x -> Prog es a) -> x -> Prog es a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog es a
k)

-- | Handler @Observe@ operations by simply passing forward their observed value, performing no side-effects
handleObs :: Prog (Observe : es) a -> Prog es  a
handleObs :: forall (es :: [* -> *]) a. Prog (Observe : es) a -> Prog es a
handleObs (Val a
x) = a -> Prog es a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
handleObs (Op EffectSum (Observe : es) x
op x -> Prog (Observe : es) a
k) = case EffectSum (Observe : es) x -> Either (EffectSum es x) (Observe x)
forall (e :: * -> *) (es :: [* -> *]) x.
EffectSum (e : es) x -> Either (EffectSum es x) (e x)
discharge EffectSum (Observe : es) x
op of
  Right (Observe PrimDist x
d x
y Addr
α) -> Prog (Observe : es) a -> Prog es a
forall (es :: [* -> *]) a. Prog (Observe : es) a -> Prog es a
handleObs (x -> Prog (Observe : es) a
k x
y)
  Left EffectSum es x
op' -> EffectSum es x -> (x -> Prog es a) -> Prog es a
forall (es :: [* -> *]) x a.
EffectSum es x -> (x -> Prog es a) -> Prog es a
Op EffectSum es x
op' (Prog (Observe : es) a -> Prog es a
forall (es :: [* -> *]) a. Prog (Observe : es) a -> Prog es a
handleObs (Prog (Observe : es) a -> Prog es a)
-> (x -> Prog (Observe : es) a) -> x -> Prog es a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> Prog (Observe : es) a
k)

-- | Handle @Sample@ operations by using the @Sampler@ monad to draw from primitive distributions
handleSamp :: Prog '[Sample] a -> Sampler a
handleSamp :: forall a. Prog '[Sample] a -> Sampler a
handleSamp  (Val a
x)  = a -> Sampler a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
handleSamp  (Op EffectSum '[Sample] x
op x -> Prog '[Sample] a
k) = case EffectSum '[Sample] x -> Either (EffectSum '[] x) (Sample x)
forall (e :: * -> *) (es :: [* -> *]) x.
EffectSum (e : es) x -> Either (EffectSum es x) (e x)
discharge EffectSum '[Sample] x
op of
  Right (Sample (PrimDistPrf PrimDist x
d) Addr
α) ->
    do  x
x <- PrimDist x -> Sampler x
forall a. PrimDist a -> Sampler a
sample PrimDist x
d
        Prog '[Sample] a -> Sampler a
forall a. Prog '[Sample] a -> Sampler a
handleSamp (x -> Prog '[Sample] a
k x
x)
  Either (EffectSum '[] x) (Sample x)
_        -> [Char] -> Sampler a
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible: Nothing cannot occur"