{-# LANGUAGE CPP
, GADTs
, EmptyCase
, KindSignatures
, DataKinds
, PolyKinds
, TypeOperators
, ScopedTypeVariables
, Rank2Types
, MultiParamTypeClasses
, TypeSynonymInstances
, FlexibleInstances
, FlexibleContexts
, UndecidableInstances
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Disintegrate
( lam_
, disintegrateWithVar
, disintegrate, disintegrateInCtx
, densityWithVar
, density, densityInCtx
, observe, observeInCtx
, determine
, perform
, atomize
, constrainValue
, constrainOutcome
) where
#if __GLASGOW_HASKELL__ < 710
import Data.Functor ((<$>))
import Data.Foldable (Foldable, foldMap)
import Data.Traversable (Traversable)
import Control.Applicative (Applicative(..))
#endif
import Control.Applicative (Alternative(..))
import Control.Monad ((<=<), guard)
import Data.Functor.Compose (Compose(..))
import qualified Data.Traversable as T
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as L
import qualified Data.Text as Text
import qualified Data.IntMap as IM
import Data.Sequence (Seq)
import qualified Data.Sequence as S
import Data.Proxy (KProxy(..))
import Data.Maybe (fromMaybe, fromJust)
import Language.Hakaru.Syntax.IClasses
import Data.Number.Natural
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import qualified Language.Hakaru.Types.Coercion as C
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumCase (DatumEvaluator, MatchResult(..), matchBranches)
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.Transform (TransformCtx(..), minimalCtx)
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.Lazy
import Language.Hakaru.Evaluation.DisintegrationMonad
import qualified Language.Hakaru.Syntax.Prelude as P
import qualified Language.Hakaru.Expect as E
#ifdef __TRACE_DISINTEGRATE__
import qualified Text.PrettyPrint as PP
import Language.Hakaru.Pretty.Haskell
import Debug.Trace (trace, traceM)
#endif
lam_ :: (ABT Term abt) => Variable a -> abt '[] b -> abt '[] (a ':-> b)
lam_ x e = syn (Lam_ :$ bind x e :* End)
disintegrateWithVar
:: (ABT Term abt)
=> TransformCtx
-> Text.Text
-> Sing a
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrateWithVar ctx hint typ m =
let x = Variable hint (nextFreeOrBind m) typ
in map (lam_ x) . flip (runDisInCtx ctx) [Some2 m, Some2 (var x)] $ do
ab <- perform m
#ifdef __TRACE_DISINTEGRATE__
ss <- getStatements
trace ("-- disintegrate: finished perform\n"
++ show (pretty_Statements ss PP.$+$ PP.sep(prettyPrec_ 11 ab))
++ "\n") $ return ()
#endif
(a,b) <- emitUnpair ab
#ifdef __TRACE_DISINTEGRATE__
trace ("-- disintegrate: finished emitUnpair: "
++ show (pretty a, pretty b)) $ return ()
#endif
constrainValue (var x) a
#ifdef __TRACE_DISINTEGRATE__
ss <- getStatements
extras <- getExtras
traceM ("-- disintegrate: finished constrainValue\n"
++ show (pretty_Statements ss) ++ "\n"
++ show (prettyExtras extras)
)
#endif
return b
disintegrateInCtx
:: (ABT Term abt)
=> TransformCtx
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrateInCtx ctx m =
disintegrateWithVar
ctx
Text.empty
(fst . sUnPair . sUnMeasure $ typeOf m)
m
disintegrate
:: (ABT Term abt)
=> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrate = disintegrateInCtx minimalCtx
densityWithVar
:: (ABT Term abt)
=> TransformCtx
-> Text.Text
-> Sing a
-> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
densityWithVar ctx hint typ m =
let x = Variable hint (nextFree m `max` nextBind m) typ
in (lam_ x . E.total) <$> observeInCtx ctx m (var x)
densityInCtx
:: (ABT Term abt)
=> TransformCtx
-> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
densityInCtx ctx m =
densityWithVar
ctx
Text.empty
(sUnMeasure $ typeOf m)
m
density
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
density = densityInCtx minimalCtx
observeInCtx
:: (ABT Term abt)
=> TransformCtx
-> abt '[] ('HMeasure a)
-> abt '[] a
-> [abt '[] ('HMeasure a)]
observeInCtx ctx m x =
runDisInCtx ctx (constrainOutcome x m >> return x) [Some2 m, Some2 x]
observe
:: (ABT Term abt)
=> abt '[] ('HMeasure a)
-> abt '[] a
-> [abt '[] ('HMeasure a)]
observe = observeInCtx minimalCtx
determine :: (ABT Term abt) => [abt '[] a] -> Maybe (abt '[] a)
determine [] = Nothing
determine (m:_) = Just m
evaluate_ :: (ABT Term abt) => TermEvaluator abt (Dis abt)
evaluate_ = evaluate perform
evaluateDatum :: (ABT Term abt) => DatumEvaluator (abt '[]) (Dis abt)
evaluateDatum e = viewWhnfDatum <$> evaluate_ e
perform :: forall abt. (ABT Term abt) => MeasureEvaluator abt (Dis abt)
perform = \e0 ->
#ifdef __TRACE_DISINTEGRATE__
getStatements >>= \ss ->
getExtras >>= \extras ->
getIndices >>= \inds ->
trace ("\n-- perform --\n"
++ "at " ++ show (ppInds inds) ++ "\n"
++ show (prettyExtras extras) ++ "\n"
++ show (pretty_Statements_withTerm ss e0)
++ "\n") $
#endif
caseVarSyn e0 performVar performTerm
where
performTerm :: forall a. Term abt ('HMeasure a) -> Dis abt (Whnf abt a)
performTerm (Dirac :$ e1 :* End) = evaluate_ e1
performTerm (MeasureOp_ o :$ es) = performMeasureOp o es
performTerm (MBind :$ e1 :* e2 :* End) =
caseBind e2 $ \x e2' -> do
inds <- getIndices
push (SBind x (Thunk e1) inds) e2' >>= perform
performTerm (Plate :$ e1 :* e2 :* End) = do
x1 <- pushPlate e1 e2
return $ fromJust (toWhnf x1)
performTerm (Superpose_ pes) = do
inds <- getIndices
if not (null inds) && L.length pes > 1 then bot else
emitFork_ (P.superpose . fmap ((,) P.one))
(fmap (\(p,e) -> push (SWeight (Thunk p) inds) e >>= perform)
pes)
performTerm (Let_ :$ e1 :* e2 :* End) =
caseBind e2 $ \x e2' -> do
inds <- getIndices
push (SLet x (Thunk e1) inds) e2' >>= perform
performTerm t0 = do
w <- evaluate_ (syn t0)
#ifdef __TRACE_DISINTEGRATE__
trace ("-- perform: finished evaluate, with:\n" ++ show (PP.sep(prettyPrec_ 11 w))) $ return ()
#endif
performWhnf w
performVar :: forall a. Variable ('HMeasure a) -> Dis abt (Whnf abt a)
performVar = performWhnf <=< evaluateVar perform evaluate_
performWhnf
:: forall a. Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
performWhnf (Head_ v) = perform $ fromHead v
performWhnf (Neutral e) = (Neutral . var) <$>
(emitMBind . fromWhnf =<< atomize e)
performMeasureOp
:: forall typs args a
. (typs ~ UnLCs args, args ~ LCs typs)
=> MeasureOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
performMeasureOp = \o es -> nice o es <|> complete o es
where
nice
:: MeasureOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
nice o es = do
es' <- traverse21 atomizeCore es
x <- emitMBind2 $ syn (MeasureOp_ o :$ es')
return (Neutral x)
complete
:: MeasureOp typs a
-> SArgs abt args
-> Dis abt (Whnf abt a)
complete Normal = \(mu :* sd :* End) -> do
x <- var <$> emitMBind P.lebesgue
pushWeight (P.densityNormal mu sd x)
return (Neutral x)
complete Uniform = \(lo :* hi :* End) -> do
x <- var <$> emitMBind P.lebesgue
pushGuard (lo P.< x P.&& x P.< hi)
pushWeight (P.densityUniform lo hi x)
return (Neutral x)
complete _ = \_ -> bot
atomize :: (ABT Term abt) => TermEvaluator abt (Dis abt)
atomize e =
#ifdef __TRACE_DISINTEGRATE__
trace ("\n-- atomize --\n" ++ show (pretty e)) $
#endif
do whnf <- evaluate_ e
case whnf of
Head_ v -> Head_ <$> traverse21 atomizeCore v
Neutral e -> Neutral . unviewABT <$>
traverse12 (traverse21 atomizeCore) (viewABT e)
atomizeCore :: (ABT Term abt) => abt xs a -> Dis abt (abt xs a)
atomizeCore e =
do xs <- getHeapVars
vs <- extFreeVars e
if disjointVarSet xs vs
then return e
else
let (ys, e') = caseBinds e
in
#ifdef __TRACE_DISINTEGRATE__
trace ("\n-- atomizeCore --\n" ++ show (pretty e')) $
#endif
(binds_ ys . fromWhnf) <$> atomize e'
where
disjointVarSet xs ys =
IM.null (IM.intersection (unVarSet xs) (unVarSet ys))
getHeapVars :: Dis abt (VarSet ('KProxy :: KProxy Hakaru))
getHeapVars =
Dis $ \_ c h -> c (foldMap statementVars (statements h)) h
constrainValue :: (ABT Term abt) => abt '[] a -> abt '[] a -> Dis abt ()
constrainValue v0 e0 =
#ifdef __TRACE_DISINTEGRATE__
getStatements >>= \ss ->
getExtras >>= \extras ->
getIndices >>= \inds ->
trace ("\n-- constrainValue: " ++ show (pretty v0) ++ "\n"
++ show (pretty_Statements_withTerm ss e0) ++ "\n"
++ "at " ++ show (ppInds inds) ++ "\n"
++ show (prettyExtras extras) ++ "\n"
) $
#endif
caseVarSyn e0 (constrainVariable v0) $ \t ->
case t of
Empty_ _ -> error "TODO: disintegrate empty arrays"
Array_ n e ->
caseBind e $ \x body -> do j <- freshInd n
let x' = indVar j
body' <- extSubst x (var x') body
inds <- getIndices
withIndices (extendIndices j inds) $
constrainValue (v0 P.! (var x')) body'
ArrayLiteral_ _ -> error "TODO: disintegrate literal arrays"
ArrayOp_ o :$ args -> constrainValueArrayOp v0 o args
Lam_ :$ _ :* End -> error "TODO: disintegrate lambdas"
App_ :$ _ :* _ :* End -> error "TODO: disintegrate lambdas"
Integrate :$ _ :* _ :* _ :* End ->
error "TODO: disintegrate integration"
Summate _ _ :$ _ :* _ :* _ :* End ->
error "TODO: disintegrate integration"
Literal_ v -> bot
Datum_ d -> constrainDatum v0 d
Dirac :$ _ :* End -> bot
MBind :$ _ :* _ :* End -> bot
MeasureOp_ o :$ es -> constrainValueMeasureOp v0 o es
Superpose_ pes -> bot
Reject_ _ -> bot
Let_ :$ e1 :* e2 :* End ->
caseBind e2 $ \x e2' ->
push (SLet x (Thunk e1) []) e2' >>= constrainValue v0
CoerceTo_ c :$ e1 :* End ->
constrainValue (P.unsafeFrom_ c v0) e1
UnsafeFrom_ c :$ e1 :* End ->
constrainValue (P.coerceTo_ c v0) e1
NaryOp_ o es -> constrainNaryOp v0 o es
PrimOp_ o :$ es -> constrainPrimOp v0 o es
Transform_ t :$ _ -> error $
concat["constrainValue{", show t, "}"
,": cannot yet disintegrate transforms; expand them first"]
Case_ e bs ->
do match <- matchBranches evaluateDatum e bs
case match of
Nothing ->
error "constrainValue{Case_}: nothing matched!"
Just GotStuck ->
constrainBranches v0 e bs
Just (Matched rho body) ->
pushes (toVarStatements rho) body >>= constrainValue v0
<|> constrainBranches v0 e bs
_ :$ _ -> error "constrainValue: the impossible happened"
constrainBranches
:: (ABT Term abt)
=> abt '[] a
-> abt '[] b
-> [Branch b abt a]
-> Dis abt ()
constrainBranches v0 e = choose . map constrainBranch
where
constrainBranch (Branch pat body) =
let (vars,body') = caseBinds body
in push (SGuard vars pat (Thunk e) []) body'
>>= constrainValue v0
constrainDatum
:: (ABT Term abt) => abt '[] a -> Datum (abt '[]) a -> Dis abt ()
constrainDatum v0 d =
case patternOfDatum d of
PatternOfDatum pat es -> do
xs <- freshVars $ fmap11 (Hint Text.empty . typeOf) es
emit_ $ \body ->
syn $ Case_ v0
[ Branch pat (binds_ xs body)
, Branch PWild (P.reject $ (typeOf body))
]
constrainValues xs es
constrainValues
:: (ABT Term abt)
=> List1 Variable xs
-> List1 (abt '[]) xs
-> Dis abt ()
constrainValues (Cons1 x xs) (Cons1 e es) =
constrainValue (var x) e >> constrainValues xs es
constrainValues Nil1 Nil1 = return ()
constrainValues _ _ = error "constrainValues: the impossible happened"
data PatternOfDatum (ast :: Hakaru -> *) (a :: Hakaru) =
forall xs. PatternOfDatum
!(Pattern xs a)
!(List1 ast xs)
patternOfDatum :: Datum ast a -> PatternOfDatum ast a
patternOfDatum =
\(Datum hint _typ d) ->
podCode d $ \p es ->
PatternOfDatum (PDatum hint p) es
where
podCode
:: DatumCode xss ast a
-> (forall bs. PDatumCode xss bs a -> List1 ast bs -> r)
-> r
podCode (Inr d) k = podCode d $ \ p es -> k (PInr p) es
podCode (Inl d) k = podStruct d $ \ p es -> k (PInl p) es
podStruct
:: DatumStruct xs ast a
-> (forall bs. PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
podStruct (Et d1 d2) k =
podFun d1 $ \p1 es1 ->
podStruct d2 $ \p2 es2 ->
k (PEt p1 p2) (es1 `append1` es2)
podStruct Done k = k PDone Nil1
podFun
:: DatumFun x ast a
-> (forall bs. PDatumFun x bs a -> List1 ast bs -> r)
-> r
podFun (Konst e) k = k (PKonst PVar) (Cons1 e Nil1)
podFun (Ident e) k = k (PIdent PVar) (Cons1 e Nil1)
constrainVariable
:: (ABT Term abt) => abt '[] a -> Variable a -> Dis abt ()
constrainVariable v0 x =
do extras <- getExtras
maybe bot lookForLoc (lookupAssoc x extras)
where lookForLoc (Loc l jxs) =
(maybe (freeLocError l) return =<<) . select l $ \s ->
case s of
SBind l' e ixs -> do
Refl <- locEq l l'
guard (length ixs == length jxs)
Just $ do
inds <- getIndices
guard (jxs `permutes` inds)
e' <- apply ixs inds (fromLazy e)
constrainOutcome v0 e'
unsafePush (SLet l (Whnf_ (Neutral v0)) inds)
SLet l' e ixs -> do
Refl <- locEq l l'
guard (length ixs == length jxs)
Just $ do
inds <- getIndices
guard (jxs `permutes` inds)
e' <- apply ixs inds (fromLazy e)
constrainValue v0 e'
unsafePush (SLet l (Whnf_ (Neutral v0)) inds)
SWeight _ _ -> Nothing
SGuard ls' pat scrutinee i -> error "TODO: constrainVariable{SGuard}"
constrainValueArrayOp
:: forall abt typs args a
. (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> abt '[] a
-> ArrayOp typs a
-> SArgs abt args
-> Dis abt ()
constrainValueArrayOp v0 = go
where
go :: ArrayOp typs a -> SArgs abt args -> Dis abt ()
go (Index _) (e1 :* e2 :* End) = do
w1 <- evaluate_ e1
case w1 of
Neutral e1' -> bot
Head_ (WArray _ b) -> caseBind b $ \x body ->
extSubst x e2 body >>= constrainValue v0
Head_ (WEmpty _) -> bot
Head_ a@(WArrayLiteral _) -> constrainValueIdxArrLit v0 e2 a
_ -> error "constrainValue {ArrayOp Index}: uknown whnf of array type"
go (Size _) _ = error "TODO: disintegrate {ArrayOp Size}"
go (Reduce _) _ = error "TODO: disintegrate {ArrayOp Reduce}"
go _ _ = error "constrainValueArrayOp: unknown arrayOp"
constrainValueIdxArrLit
:: forall abt a
. (ABT Term abt)
=> abt '[] a
-> abt '[] 'HNat
-> Head abt ('HArray a)
-> Dis abt ()
constrainValueIdxArrLit v0 e2 = go
where
go :: Head abt ('HArray a) -> Dis abt ()
go (WArrayLiteral [a1,a2]) =
case (jmEq1 (typeOf v0) sBool) of
Just Refl ->
let constrainInd = flip constrainValue e2
in case (isLitBool a1, isLitBool a2) of
(Just b1, Just b2)
| isLitTrue b1 && isLitFalse b2 ->
constrainInd $ P.if_ v0 (P.nat_ 0) (P.nat_ 1)
| isLitTrue b2 && isLitFalse b1 ->
constrainInd $ P.if_ v0 (P.nat_ 1) (P.nat_ 0)
| otherwise -> error "constrainValue: cannot invert (Index [b,b] i)"
_ -> error "TODO: constrainValue (Index [b1,b2] i)"
Nothing -> error "TODO: constrainValue (Index [a1,a2] i)"
go (WArrayLiteral _) = bot
go _ = error "constrainValueIdxArrLit: unknown ArrayLiteral form"
isLitBool :: (ABT Term abt) => abt '[] a -> Maybe (Datum (abt '[]) HBool)
isLitBool e = caseVarSyn e (const Nothing) $ \b ->
case b of
Datum_ d@(Datum _ typ _) -> case (jmEq1 typ sBool) of
Just Refl -> Just d
Nothing -> Nothing
_ -> Nothing
isLitTrue :: (ABT Term abt) => Datum (abt '[]) HBool -> Bool
isLitTrue (Datum tdTrue sBool (Inl Done)) = True
isLitTrue _ = False
isLitFalse :: (ABT Term abt) => Datum (abt '[]) HBool -> Bool
isLitFalse (Datum tdFalse sBool (Inr (Inl Done))) = True
isLitFalse _ = False
constrainValueMeasureOp
:: forall abt typs args a
. (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> abt '[] ('HMeasure a)
-> MeasureOp typs a
-> SArgs abt args
-> Dis abt ()
constrainValueMeasureOp v0 = go
where
go :: MeasureOp typs a -> SArgs abt args -> Dis abt ()
go Lebesgue = \(e1 :* e2 :* End) ->
constrainValue v0 (P.lebesgue' e1 e2)
go Counting = \End -> bot
go Categorical = \(e1 :* End) ->
constrainValue v0 (P.categorical e1)
go Uniform = \(e1 :* e2 :* End) ->
constrainValue v0 (P.uniform' e1 e2)
go Normal = \(e1 :* e2 :* End) ->
constrainValue v0 (P.normal' e1 e2)
go Poisson = \(e1 :* End) ->
constrainValue v0 (P.poisson' e1)
go Gamma = \(e1 :* e2 :* End) ->
constrainValue v0 (P.gamma' e1 e2)
go Beta = \(e1 :* e2 :* End) ->
constrainValue v0 (P.beta' e1 e2)
constrainNaryOp
:: (ABT Term abt)
=> abt '[] a
-> NaryOp a
-> Seq (abt '[] a)
-> Dis abt ()
constrainNaryOp v0 o =
case o of
Sum theSemi ->
lubSeq $ \es1 e es2 -> do
u <- atomize $ syn (NaryOp_ (Sum theSemi) (es1 S.>< es2))
v <- evaluate_ $ P.unsafeMinus_ theSemi v0 (fromWhnf u)
constrainValue (fromWhnf v) e
Prod theSemi ->
lubSeq $ \es1 e es2 -> do
u <- atomize $ syn (NaryOp_ (Prod theSemi) (es1 S.>< es2))
let u' = fromWhnf u
emitWeight $ P.recip (toProb_abs theSemi u')
v <- evaluate_ $ P.unsafeDiv_ theSemi v0 u'
constrainValue (fromWhnf v) e
Max theOrd ->
chooseSeq $ \es1 e es2 -> do
u <- atomize $ syn (NaryOp_ (Max theOrd) (es1 S.>< es2))
emitGuard $ P.primOp2_ (Less theOrd) (fromWhnf u) v0
constrainValue v0 e
_ -> error $ "TODO: constrainNaryOp{" ++ show o ++ "}"
toProb_abs :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] 'HProb
toProb_abs HSemiring_Nat = P.nat2prob
toProb_abs HSemiring_Int = P.nat2prob . P.abs_
toProb_abs HSemiring_Prob = id
toProb_abs HSemiring_Real = P.abs_
lubSeq :: (Alternative m) => (Seq a -> a -> Seq a -> m b) -> Seq a -> m b
lubSeq f = go S.empty
where
go xs ys =
case S.viewl ys of
S.EmptyL -> empty
y S.:< ys' -> f xs y ys' <|> go (xs S.|> y) ys'
chooseSeq :: (ABT Term abt)
=> (Seq a -> a -> Seq a -> Dis abt b)
-> Seq a
-> Dis abt b
chooseSeq f = choose . go S.empty
where
go xs ys =
case S.viewl ys of
S.EmptyL -> []
y S.:< ys' -> f xs y ys' : go (xs S.|> y) ys'
constrainPrimOp
:: forall abt typs args a
. (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> abt '[] a
-> PrimOp typs a
-> SArgs abt args
-> Dis abt ()
constrainPrimOp v0 = go
where
error_TODO op = error $ "TODO: constrainPrimOp{" ++ op ++"}"
go :: PrimOp typs a -> SArgs abt args -> Dis abt ()
go Not = \(e1 :* End) -> error_TODO "Not"
go Impl = \(e1 :* e2 :* End) -> error_TODO "Impl"
go Diff = \(e1 :* e2 :* End) -> error_TODO "Diff"
go Nand = \(e1 :* e2 :* End) -> error_TODO "Nand"
go Nor = \(e1 :* e2 :* End) -> error_TODO "Nor"
go Pi = \End -> bot
go Sin = \(e1 :* End) -> do
x0 <- emitLet' v0
n <- var <$> emitMBind P.counting
let tau_n = P.real_ 2 P.* P.fromInt n P.* P.pi
emitGuard (P.negate P.one P.< x0 P.&& x0 P.< P.one)
v <- var <$> emitSuperpose
[ P.dirac (tau_n P.+ P.asin x0)
, P.dirac (tau_n P.+ P.pi P.- P.asin x0)
]
emitWeight
. P.recip
. P.sqrt
. P.unsafeProb
$ (P.one P.- x0 P.^ P.nat_ 2)
constrainValue v e1
go Cos = \(e1 :* End) -> do
x0 <- emitLet' v0
n <- var <$> emitMBind P.counting
let tau_n = P.real_ 2 P.* P.fromInt n P.* P.pi
emitGuard (P.negate P.one P.< x0 P.&& x0 P.< P.one)
r <- emitLet' (tau_n P.+ P.acos x0)
v <- var <$> emitSuperpose [P.dirac r, P.dirac (r P.+ P.pi)]
emitWeight
. P.recip
. P.sqrt
. P.unsafeProb
$ (P.one P.- x0 P.^ P.nat_ 2)
constrainValue v e1
go Tan = \(e1 :* End) -> do
x0 <- emitLet' v0
n <- var <$> emitMBind P.counting
r <- emitLet' (P.fromInt n P.* P.pi P.+ P.atan x0)
emitWeight $ P.recip (P.one P.+ P.square x0)
constrainValue r e1
go Asin = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight $ P.unsafeProb (P.cos x0)
constrainValue (P.sin x0) e1
go Acos = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight $ P.unsafeProb (P.sin x0)
constrainValue (P.cos x0) e1
go Atan = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight $ P.recip (P.unsafeProb (P.cos x0 P.^ P.nat_ 2))
constrainValue (P.tan x0) e1
go Sinh = \(e1 :* End) -> error_TODO "Sinh"
go Cosh = \(e1 :* End) -> error_TODO "Cosh"
go Tanh = \(e1 :* End) -> error_TODO "Tanh"
go Asinh = \(e1 :* End) -> error_TODO "Asinh"
go Acosh = \(e1 :* End) -> error_TODO "Acosh"
go Atanh = \(e1 :* End) -> error_TODO "Atanh"
go Choose = \(e1 :* e2 :* End) -> error_TODO "Choose"
go Floor = \(e1 :* End) -> error_TODO "Floor"
go RealPow = \(e1 :* e2 :* End) ->
do
u <- emitLet' v0
let w = P.recip (u P.* P.unsafeProb (P.abs (P.log e1)))
emitWeight w
constrainValue (P.log u P./ P.log e1) e2
<|> do
u <- emitLet' v0
let ex = v0 P.** P.recip e2
let w = P.abs (P.fromProb ex P./ (e2 P.* P.fromProb u))
emitWeight $ P.unsafeProb w
constrainValue ex e1
go Exp = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight (P.recip x0)
constrainValue (P.log x0) e1
go Log = \(e1 :* End) -> do
exp_x0 <- emitLet' (P.exp v0)
emitWeight exp_x0
constrainValue exp_x0 e1
go (Infinity _) = \End -> error_TODO "Infinity"
go GammaFunc = \(e1 :* End) -> error_TODO "GammaFunc"
go BetaFunc = \(e1 :* e2 :* End) -> error_TODO "BetaFunc"
go (Equal theOrd) = \(e1 :* e2 :* End) -> error_TODO "Equal"
go (Less theOrd) = \(e1 :* e2 :* End) -> error_TODO "Less"
go (NatPow theSemi) = \(e1 :* e2 :* End) -> error_TODO "NatPow"
go (Negate theRing) = \(e1 :* End) ->
let negate_v0 = syn (PrimOp_ (Negate theRing) :$ v0 :* End)
in constrainValue negate_v0 e1
go (Abs theRing) = \(e1 :* End) -> do
let theSemi = hSemiring_HRing theRing
theOrd =
case theRing of
HRing_Int -> HOrd_Int
HRing_Real -> HOrd_Real
theEq = hEq_HOrd theOrd
signed = C.singletonCoercion (C.Signed theRing)
zero = P.zero_ theSemi
lt = P.primOp2_ $ Less theOrd
eq = P.primOp2_ $ Equal theEq
neg = P.primOp1_ $ Negate theRing
x0 <- emitLet' (P.coerceTo_ signed v0)
v <- var <$> emitMBind
(P.if_ (lt zero x0)
(P.dirac x0 P.<|> P.dirac (neg x0))
(P.if_ (eq zero x0)
(P.dirac zero)
(P.reject . SMeasure $ typeOf zero)))
constrainValue v e1
go (Signum theRing) = \(e1 :* End) ->
case theRing of
HRing_Real -> bot
HRing_Int -> do
x <- var <$> emitMBind P.counting
emitGuard $ P.signum x P.== v0
constrainValue x e1
go (Recip theFractional) = \(e1 :* End) -> do
x0 <- emitLet' v0
emitWeight
. P.recip
. P.unsafeProbFraction_ theFractional
$ square (hSemiring_HFractional theFractional) x0
constrainValue (P.primOp1_ (Recip theFractional) x0) e1
go (NatRoot theRadical) = \(e1 :* e2 :* End) ->
case theRadical of
HRadical_Prob -> do
x0 <- emitLet' v0
u2 <- fromWhnf <$> atomize e2
emitWeight (P.nat2prob u2 P.* x0)
constrainValue (x0 P.^ u2) e1
go (Erf theContinuous) = \(e1 :* End) ->
error "TODO: constrainPrimOp: need InvErf to disintegrate Erf"
square :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] a
square theSemiring e =
syn (PrimOp_ (NatPow theSemiring) :$ e :* P.nat_ 2 :* End)
constrainOutcome
:: forall abt a
. (ABT Term abt)
=> abt '[] a
-> abt '[] ('HMeasure a)
-> Dis abt ()
constrainOutcome v0 e0 =
#ifdef __TRACE_DISINTEGRATE__
getExtras >>= \extras ->
getIndices >>= \inds ->
trace (
let s = "-- constrainOutcome"
in "\n" ++ s ++ ": "
++ show (pretty v0)
++ "\n" ++ replicate (length s) ' ' ++ ": "
++ show (pretty e0) ++ "\n"
++ "at " ++ show (ppInds inds) ++ "\n"
++ show (prettyExtras extras)
) $
#endif
do w0 <- evaluate_ e0
case w0 of
Neutral _ -> bot
Head_ v -> go v
where
impossible = error "constrainOutcome: the impossible happened"
go :: Head abt ('HMeasure a) -> Dis abt ()
go (WLiteral _) = impossible
go (WCoerceTo _ _) = impossible
go (WUnsafeFrom _ _) = impossible
go (WMeasureOp o es) = constrainOutcomeMeasureOp v0 o es
go (WDirac e1) = constrainValue v0 e1
go (WMBind e1 e2) =
caseBind e2 $ \x e2' -> do
i <- getIndices
push (SBind x (Thunk e1) i) e2' >>= constrainOutcome v0
go (WPlate e1 e2) = do
x' <- pushPlate e1 e2
constrainValue v0 x'
go (WChain e1 e2 e3) = error "TODO: constrainOutcome{Chain}"
go (WReject typ) = emit_ $ \m -> P.reject (typeOf m)
go (WSuperpose pes) = do
i <- getIndices
if not (null i) && L.length pes > 1 then bot else
emitFork_ (P.superpose . fmap ((,) P.one))
(fmap (\(p,e) -> push (SWeight (Thunk p) i) e >>= constrainOutcome v0)
pes)
constrainOutcomeMeasureOp
:: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> abt '[] a
-> MeasureOp typs a
-> SArgs abt args
-> Dis abt ()
constrainOutcomeMeasureOp v0 = go
where
go Lebesgue = \(lo :* hi :* End) -> do
v0' <- emitLet' v0
pushGuard (lo P.<= v0' P.&& v0' P.<= hi)
go Counting = \End -> return ()
go Categorical = \(e1 :* End) -> do
pushWeight (P.densityCategorical e1 v0)
go Uniform = \(lo :* hi :* End) -> do
v0' <- emitLet' v0
pushGuard (lo P.<= v0' P.&& v0' P.<= hi)
pushWeight (P.densityUniform lo hi v0')
go Normal = \(mu :* sd :* End) -> do
pushWeight (P.densityNormal mu sd v0)
go Poisson = \(e1 :* End) -> do
v0' <- emitLet' v0
pushGuard (P.nat_ 0 P.<= v0' P.&& P.prob_ 0 P.< e1)
pushWeight (P.densityPoisson e1 v0')
go Gamma = \(e1 :* e2 :* End) -> do
v0' <- emitLet' v0
pushGuard (P.prob_ 0 P.< v0' P.&&
P.prob_ 0 P.< e1 P.&&
P.prob_ 0 P.< e2)
pushWeight (P.densityGamma e1 e2 v0')
go Beta = \(e1 :* e2 :* End) -> do
v0' <- emitLet' v0
pushGuard (P.prob_ 0 P.<= v0' P.&&
P.prob_ 1 P.>= v0' P.&&
P.prob_ 0 P.< e1 P.&&
P.prob_ 0 P.< e2)
pushWeight (P.densityBeta e1 e2 v0')