{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Trace (
STrace
, FromSTrace(..)
, updateSTrace
, LPTrace
, updateLPTrace) where
import Data.Map (Map)
import Data.Maybe ( fromJust )
import Data.Proxy ( Proxy(..) )
import Effects.Dist ( Addr )
import PrimDist ( ErasedPrimDist(..), PrimVal, PrimDist, logProb )
import Env ( UniqueKey, Assign((:=)), Env(ECons), ObsVar(..), varToStr, nil )
import GHC.TypeLits ( KnownSymbol )
import OpenSum (OpenSum)
import qualified Data.Map as Map
import qualified OpenSum
type STrace = Map Addr (ErasedPrimDist, OpenSum PrimVal)
class FromSTrace env where
fromSTrace :: STrace -> Env env
instance FromSTrace '[] where
fromSTrace :: STrace -> Env '[]
fromSTrace STrace
_ = Env '[]
nil
instance (UniqueKey x env ~ 'True, KnownSymbol x, Eq a, OpenSum.Member a PrimVal, FromSTrace env) => FromSTrace ((x := a) : env) where
fromSTrace :: STrace -> Env ((x ':= a) : env)
fromSTrace STrace
sMap = [a] -> Env env -> Env ((x ':= a) : env)
forall a (env :: [Assign Symbol (*)]) (x :: Symbol).
[a] -> Env env -> Env ((x ':= a) : env)
ECons ((ObsVar x, Proxy a) -> STrace -> [a]
forall a (x :: Symbol).
(Eq a, Member a PrimVal) =>
(ObsVar x, Proxy a) -> STrace -> [a]
extractSamples (forall (x :: Symbol). KnownSymbol x => ObsVar x
ObsVar @x, forall {t}. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a) STrace
sMap) (STrace -> Env env
forall (env :: [Assign Symbol (*)]).
FromSTrace env =>
STrace -> Env env
fromSTrace STrace
sMap)
extractSamples :: forall a x. (Eq a, OpenSum.Member a PrimVal) => (ObsVar x, Proxy a) -> STrace -> [a]
(ObsVar x
x, Proxy a
typ) =
(((String, Int), (ErasedPrimDist, OpenSum PrimVal)) -> a)
-> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a)
-> (((String, Int), (ErasedPrimDist, OpenSum PrimVal)) -> Maybe a)
-> ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (as :: [*]). Member a as => OpenSum as -> Maybe a
OpenSum.prj @a (OpenSum PrimVal -> Maybe a)
-> (((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> OpenSum PrimVal)
-> ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ErasedPrimDist, OpenSum PrimVal) -> OpenSum PrimVal
forall a b. (a, b) -> b
snd ((ErasedPrimDist, OpenSum PrimVal) -> OpenSum PrimVal)
-> (((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> (ErasedPrimDist, OpenSum PrimVal))
-> ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> OpenSum PrimVal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Int), (ErasedPrimDist, OpenSum PrimVal))
-> (ErasedPrimDist, OpenSum PrimVal)
forall a b. (a, b) -> b
snd)
([((String, Int), (ErasedPrimDist, OpenSum PrimVal))] -> [a])
-> (STrace -> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))])
-> STrace
-> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STrace -> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))]
forall k a. Map k a -> [(k, a)]
Map.toList
(STrace -> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))])
-> (STrace -> STrace)
-> STrace
-> [((String, Int), (ErasedPrimDist, OpenSum PrimVal))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Int) -> (ErasedPrimDist, OpenSum PrimVal) -> Bool)
-> STrace -> STrace
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
Map.filterWithKey (\(String
tag, Int
idx) (ErasedPrimDist, OpenSum PrimVal)
_ -> String
tag String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== ObsVar x -> String
forall (x :: Symbol). ObsVar x -> String
varToStr ObsVar x
x)
updateSTrace :: (Show x, OpenSum.Member x PrimVal) =>
Addr
-> PrimDist x
-> x
-> STrace
-> STrace
updateSTrace :: forall x.
(Show x, Member x PrimVal) =>
(String, Int) -> PrimDist x -> x -> STrace -> STrace
updateSTrace (String, Int)
α PrimDist x
d x
x = (String, Int)
-> (ErasedPrimDist, OpenSum PrimVal) -> STrace -> STrace
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (String, Int)
α (PrimDist x -> ErasedPrimDist
forall a. Show a => PrimDist a -> ErasedPrimDist
ErasedPrimDist PrimDist x
d, x -> OpenSum PrimVal
forall a (as :: [*]). Member a as => a -> OpenSum as
OpenSum.inj x
x)
type LPTrace = Map Addr Double
updateLPTrace ::
Addr
-> PrimDist x
-> x
-> LPTrace
-> LPTrace
updateLPTrace :: forall x. (String, Int) -> PrimDist x -> x -> LPTrace -> LPTrace
updateLPTrace (String, Int)
α PrimDist x
d x
x = (String, Int) -> Double -> LPTrace -> LPTrace
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (String, Int)
α (PrimDist x -> x -> Double
forall a. PrimDist a -> a -> Double
logProb PrimDist x
d x
x)