{-# LANGUAGE DataKinds , GADTs , FlexibleContexts , KindSignatures , PolyKinds #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} ---------------------------------------------------------------- -- 2016.04.21 -- | -- Module : Language.Hakaru.Observe -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : ppaml@indiana.edu -- Stability : experimental -- Portability : GHC-only -- -- A simpler version of the work done in 'Language.Hakaru.Disintegrate' -- -- In principle, this module's functionality is entirely subsumed -- by the work done in Language.Hakaru.Disintegrate, so we can hope -- to define observe in terms of disintegrate. This is still useful -- as a guide to those that want something more in line with what other -- probabilisitc programming systems support. ---------------------------------------------------------------- module Language.Hakaru.Observe where import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.ABT import Language.Hakaru.Types.DataKind import Language.Hakaru.Types.Sing import qualified Language.Hakaru.Syntax.Prelude as P import Language.Hakaru.Syntax.TypeOf observe :: (ABT Term abt) => abt '[] ('HMeasure a) -> abt '[] a -> abt '[] ('HMeasure a) observe m a = observeAST (LC_ m) (LC_ a) -- TODO: move this to ABT.hs freshenVarRe :: ABT syn abt => Variable (a :: k) -> abt '[] (b :: k) -> Variable a freshenVarRe x m = x {varID = nextFree m `max` nextBind m} observeAST :: (ABT Term abt) => LC_ abt ('HMeasure a) -> LC_ abt a -> abt '[] ('HMeasure a) observeAST (LC_ m) (LC_ a) = caseVarSyn m observeVar $ \ast -> case ast of -- TODO: Add a name supply Let_ :$ e1 :* e2 :* End -> caseBind e2 $ \x e2' -> let x' = freshenVarRe x m e2'' = rename x x' e2' in syn (Let_ :$ e1 :* bind x' (observe e2'' a) :* End) --Dirac :$ e :* End -> P.if_ (e P.== a) (P.dirac a) P.reject -- TODO: Add a name supply MBind :$ e1 :* e2 :* End -> caseBind e2 $ \x e2' -> let x' = freshenVarRe x m e2'' = rename x x' e2' in syn (MBind :$ e1 :* bind x' (observe e2'' a) :* End) Plate :$ e1 :* e2 :* End -> caseBind e2 $ \x e2' -> let a' = syn (ArrayOp_ (Index (sUnMeasure $ typeOf e2')) :$ a :* var x :* End) in syn (Plate :$ e1 :* bind x (observe e2' a') :* End) MeasureOp_ op :$ es -> observeMeasureOp op es a _ -> error "observe can only be applied to measure primitives" -- This function can't inspect a variable due to -- calls to subst that happens in Let_ and Bind_ observeVar = error "observe can only be applied measure primitives" observeMeasureOp :: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) => MeasureOp typs a -> SArgs abt args -> abt '[] a -> abt '[] ('HMeasure a) observeMeasureOp Normal = \(mu :* sd :* End) a -> P.withWeight (P.densityNormal mu sd a) (P.dirac a) observeMeasureOp Uniform = \(lo :* hi :* End) a -> P.if_ (lo P.<= a P.&& a P.<= hi) (P.withWeight (P.unsafeProb $ P.recip $ hi P.- lo) (P.dirac a)) (P.reject (SMeasure SReal)) observeMeasureOp _ = error "TODO{Observe:observeMeasureOp}"