{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

{- | For recording samples and log-probabilities during model execution.
-}

module Trace (
  -- * Sample trace
    STrace
  , FromSTrace(..)
  , updateSTrace
  -- * Log-probability trace
  , LPTrace
  , updateLPTrace) where

import Data.Map (Map)
import Data.Maybe ( fromJust )
import Data.Proxy ( Proxy(..) )
import Effects.Dist ( Addr )
import PrimDist ( ErasedPrimDist(..), PrimVal, PrimDist, logProb )
import Env ( UniqueKey, Assign((:=)), Env(ECons), ObsVar(..), varToStr, nil )
import GHC.TypeLits ( KnownSymbol )
import OpenSum (OpenSum)
import qualified Data.Map as Map
import qualified OpenSum

{- | The type of sample traces, mapping addresses of sample/observe operations
     to their primitive distributions and sampled values.
-}
type STrace = Map Addr (ErasedPrimDist, OpenSum PrimVal)

-- | For converting sample traces to model environments
class FromSTrace env where
  -- | Convert a sample trace to a model environment
  fromSTrace :: STrace -> Env env

instance FromSTrace '[] where
  fromSTrace :: STrace -> Env '[]
fromSTrace STrace
_ = Env '[]
nil

instance (UniqueKey x env ~ 'True, KnownSymbol x, Eq a, OpenSum.Member a PrimVal, FromSTrace env) => FromSTrace ((x := a) : env) where
  fromSTrace :: STrace -> Env ((x ':= a) : env)
fromSTrace STrace
sMap = [a] -> Env env -> Env ((x ':= a) : env)
forall a (env :: [Assign Symbol (*)]) (x :: Symbol).
[a] -> Env env -> Env ((x ':= a) : env)
ECons ((ObsVar x, Proxy a) -> STrace -> [a]
forall a (x :: Symbol).
(Eq a, Member a PrimVal) =>
(ObsVar x, Proxy a) -> STrace -> [a]
extractSamples (forall (x :: Symbol). KnownSymbol x => ObsVar x
ObsVar @x, forall {t}. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a) STrace
sMap) (STrace -> Env env
forall (env :: [Assign Symbol (*)]).
FromSTrace env =>
STrace -> Env env
fromSTrace STrace
sMap)

extractSamples ::  forall a x. (Eq a, OpenSum.Member a PrimVal) => (ObsVar x, Proxy a) -> STrace -> [a]
extractSamples :: forall a (x :: Symbol).
(Eq a, Member a PrimVal) =>
(ObsVar x, Proxy a) -> STrace -> [a]
extractSamples (ObsVar x
x, Proxy a
typ)  =
    (((String, Int), (ErasedPrimDist, OpenSum PrimVal)) -> a)
-> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a)
-> (((String, Int), (ErasedPrimDist, OpenSum PrimVal)) -> Maybe a)
-> ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (as :: [*]). Member a as => OpenSum as -> Maybe a
OpenSum.prj @a (OpenSum PrimVal -> Maybe a)
-> (((String, Int), (ErasedPrimDist, OpenSum PrimVal))
    -> OpenSum PrimVal)
-> ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ErasedPrimDist, OpenSum PrimVal) -> OpenSum PrimVal
forall a b. (a, b) -> b
snd ((ErasedPrimDist, OpenSum PrimVal) -> OpenSum PrimVal)
-> (((String, Int), (ErasedPrimDist, OpenSum PrimVal))
    -> (ErasedPrimDist, OpenSum PrimVal))
-> ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> OpenSum PrimVal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> (ErasedPrimDist, OpenSum PrimVal)
forall a b. (a, b) -> b
snd)
  ([((String, Int), (ErasedPrimDist, OpenSum PrimVal))] -> [a])
-> (STrace -> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))])
-> STrace
-> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STrace -> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))]
forall k a. Map k a -> [(k, a)]
Map.toList
  (STrace -> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))])
-> (STrace -> STrace)
-> STrace
-> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Int) -> (ErasedPrimDist, OpenSum PrimVal) -> Bool)
-> STrace -> STrace
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey (\(String
tag, Int
idx) (ErasedPrimDist, OpenSum PrimVal)
_ -> String
tag String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== ObsVar x -> String
forall (x :: Symbol). ObsVar x -> String
varToStr ObsVar x
x)

-- | Update a sample trace at an address
updateSTrace :: (Show x, OpenSum.Member x PrimVal) =>
  -- | address of sample site
     Addr
  -- | primitive distribution at address
  -> PrimDist x
  -- | sampled value
  -> x
  -- | previous sample trace
  -> STrace
  -- | updated sample trace
  -> STrace
updateSTrace :: forall x.
(Show x, Member x PrimVal) =>
(String, Int) -> PrimDist x -> x -> STrace -> STrace
updateSTrace (String, Int)
α PrimDist x
d x
x = (String, Int)
-> (ErasedPrimDist, OpenSum PrimVal) -> STrace -> STrace
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (String, Int)
α (PrimDist x -> ErasedPrimDist
forall a. Show a => PrimDist a -> ErasedPrimDist
ErasedPrimDist PrimDist x
d, x -> OpenSum PrimVal
forall a (as :: [*]). Member a as => a -> OpenSum as
OpenSum.inj x
x)

{- | The type of log-probability traces, mapping addresses of sample/observe operations
     to their log probabilities
-}
type LPTrace = Map Addr Double

-- | Compute and update a log-probability trace at an address
updateLPTrace ::
  -- | address of sample/observe site
     Addr
  -- | primitive distribution at address
  -> PrimDist x
  -- | sampled or observed value
  -> x
  -- | previous log-prob trace
  -> LPTrace
  -- | updated log-prob trace
  -> LPTrace
updateLPTrace :: forall x. (String, Int) -> PrimDist x -> x -> LPTrace -> LPTrace
updateLPTrace (String, Int)
α PrimDist x
d x
x = (String, Int) -> Double -> LPTrace -> LPTrace
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (String, Int)
α (PrimDist x -> x -> Double
forall a. PrimDist a -> a -> Double
logProb PrimDist x
d x
x)