{-# LANGUAGE CPP
, GADTs
, KindSignatures
, DataKinds
, PolyKinds
, TypeOperators
, Rank2Types
, FlexibleContexts
, MultiParamTypeClasses
, FlexibleInstances
, UndecidableInstances
, EmptyCase
, ScopedTypeVariables
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Evaluation.DisintegrationMonad
(
getStatements, putStatements
, ListContext(..), Ans, Dis(..), runDis, runDisInCtx
, bot
, emit
, emitMBind , emitMBind2
, emitLet
, emitLet'
, emitUnpair
, emit_
, emitMBind_
, emitGuard
, emitWeight
, emitFork_
, emitSuperpose
, choose
, pushWeight
, pushGuard
, pushPlate
, getIndices
, withIndices
, extendIndices
, selectMore
, permutes
, statementInds
, sizeInnermostInd
, Extra(..)
, getExtras
, putExtras
, insertExtra
, adjustExtra
, mkLoc
, freeLocError
, zipInds
, apply
#ifdef __TRACE_DISINTEGRATE__
, prettyExtra
, prettyExtras
#endif
) where
import Prelude hiding (id, (.))
import Control.Category (Category(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid(..))
import Data.Functor ((<$>))
import Control.Applicative (Applicative(..))
#endif
import qualified Data.Set as Set (fromList)
import Data.Maybe
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import Control.Applicative (Alternative(..))
import Control.Monad (MonadPlus(..),foldM,guard)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Number.Nat
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing (Sing(..), sUnMeasure, sUnPair, sUnit)
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumABT
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.Transform (TransformCtx(..), minimalCtx)
import Language.Hakaru.Syntax.ABT
import qualified Language.Hakaru.Syntax.Prelude as P
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.PEvalMonad (ListContext(..))
import Language.Hakaru.Evaluation.Lazy (reifyPair)
#ifdef __TRACE_DISINTEGRATE__
import Debug.Trace (trace, traceM)
import qualified Text.PrettyPrint as PP
import Language.Hakaru.Pretty.Haskell (ppVariable, pretty)
#endif
getStatements :: Dis abt [Statement abt Location 'Impure]
getStatements = Dis $ \_ c h -> c (statements h) h
putStatements :: [Statement abt Location 'Impure] -> Dis abt ()
putStatements ss =
Dis $ \_ c (ListContext i _) extras ->
c () (ListContext i ss) extras
plug :: forall abt a xs b
. (ABT Term abt)
=> Variable a
-> abt '[] a
-> abt xs b
-> abt xs b
plug x e = start
where
start :: forall xs' b' . abt xs' b' -> abt xs' b'
start f = loop f (viewABT f)
loop :: forall xs' b'. abt xs' b' -> View (Term abt) xs' b' -> abt xs' b'
loop _ (Syn t) = syn $! fmap21 start t
loop f (Var z) = case varEq x z of
Just Refl -> e
Nothing -> f
loop f (Bind _ _) = caseBind f $ \z f' ->
bind z (loop f' (viewABT f'))
plugs :: forall abt xs a
. (ABT Term abt)
=> Assocs (abt '[])
-> abt xs a
-> abt xs a
plugs rho0 e0 = F.foldl (\e (Assoc x v) -> plug x v e) e0 (unAssocs rho0)
residualizeListContext
:: forall abt a
. (ABT Term abt)
=> ListContext abt 'Impure
-> Assocs (abt '[])
-> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a)
residualizeListContext ss rho e0 =
#ifdef __TRACE_DISINTEGRATE__
trace ("e0: " ++ show (pretty e0) ++ "\n"
++ show (pretty_Statements (statements ss))) $
#endif
foldl step (plugs rho e0) (statements ss)
where
step
:: abt '[] ('HMeasure a)
-> Statement abt Location 'Impure
-> abt '[] ('HMeasure a)
step e s =
#ifdef __TRACE_DISINTEGRATE__
trace ("wrapping " ++ show (ppStatement 0 s) ++ "\n"
++ "around term " ++ show (pretty e)) $
#endif
case s of
SBind (Location x) body _ ->
syn (MBind :$ plugs rho (fromLazy body) :* bind x e :* End)
SLet (Location x) body _
| not (x `memberVarSet` freeVars e) ->
#ifdef __TRACE_DISINTEGRATE__
trace ("could not find location " ++ show x ++ "\n"
++ "in term " ++ show (pretty e) ++ "\n"
++ "given rho " ++ show (prettyAssocs rho)) $
#endif
e
| otherwise ->
case getLazyVariable body of
Just y -> plug x (plugs rho (var y)) e
Nothing ->
case getLazyLiteral body of
Just v -> plug x (syn $ Literal_ v) e
Nothing ->
syn (Let_ :$ plugs rho (fromLazy body) :* bind x e :* End)
SGuard xs pat scrutinee _ ->
syn $ Case_ (plugs rho $ fromLazy scrutinee)
[ Branch pat (binds_ (fromLocations1 xs) e)
, Branch PWild (P.reject $ typeOf e)
]
SWeight body _ -> syn $ Superpose_ ((plugs rho $ fromLazy body, e) :| [])
data Extra :: (Hakaru -> *) -> Hakaru -> * where
Loc :: Location a -> [ast 'HNat] -> Extra ast a
extrasInds :: Extra ast a -> [ast 'HNat]
extrasInds (Loc _ inds) = inds
selectMore :: [ast 'HNat] -> ast 'HNat -> [ast 'HNat]
selectMore = flip (:)
permutes :: (ABT Term abt)
=> [abt '[] 'HNat]
-> [Index (abt '[])]
-> Bool
permutes ts inds =
all isJust ts' &&
length ts' == length inds &&
Set.fromList (map fromJust ts') == Set.fromList (map indVar inds)
where ts' = map (\t -> caseVarSyn t Just (const Nothing)) ts
#ifdef __TRACE_DISINTEGRATE__
prettyExtra :: (ABT Term abt) => Extra (abt '[]) a -> PP.Doc
prettyExtra (Loc (Location x) inds) = PP.text "Loc" PP.<+> ppVariable x
PP.<+> ppList (map pretty inds)
prettyExtras :: (ABT Term abt)
=> Assocs (Extra (abt '[]))
-> PP.Doc
prettyExtras a = PP.vcat $ map go (fromAssocs a)
where go (Assoc x l) = ppVariable x PP.<+>
PP.text "->" PP.<+>
prettyExtra l
#endif
type Ans abt a
= ListContext abt 'Impure
-> Assocs (Extra (abt '[]))
-> [abt '[] ('HMeasure a)]
newtype Dis abt x =
Dis { unDis :: forall a. [Index (abt '[])] -> (x -> Ans abt a) -> Ans abt a }
runDisInCtx
:: (ABT Term abt, F.Foldable f)
=> TransformCtx
-> Dis abt (abt '[] a)
-> f (Some2 abt)
-> [abt '[] ('HMeasure a)]
runDisInCtx ctx d es =
m0 [] c0 (ListContext i0 []) emptyAssocs
where
(Dis m0) = d >>= residualizeLocs
c0 (e,rho) ss _ = [residualizeListContext ss rho (syn(Dirac :$ e :* End))]
i0 = maxNextFree es `max` nextFreeVar ctx
runDis
:: (ABT Term abt, F.Foldable f)
=> Dis abt (abt '[] a)
-> f (Some2 abt)
-> [abt '[] ('HMeasure a)]
runDis = runDisInCtx minimalCtx
residualizeLocs :: forall a abt. (ABT Term abt)
=> abt '[] a
-> Dis abt (abt '[] a, Assocs (abt '[]))
residualizeLocs e = do
ss <- getStatements
(ss', newlocs) <- foldM step ([], emptyLAssocs) ss
rho <- convertLocs newlocs
putStatements (reverse ss')
#ifdef __TRACE_DISINTEGRATE__
trace ("residualizeLocs: old heap:\n" ++ show (pretty_Statements ss )) $ return ()
trace ("residualizeLocs: new heap:\n" ++ show (pretty_Statements ss')) $ return ()
extras <- getExtras
traceM ("oldlocs:\n" ++ show (prettyExtras extras) ++ "\n")
traceM ("new assoc for renaming:\n" ++ show (prettyAssocs rho))
#endif
return (e, rho)
where step (ss',newlocs) s = do (s',newlocs') <- residualizeLoc s
return (s':ss', insertLAssocs newlocs' newlocs)
data Name (a :: Hakaru) = Name {nameHint :: Text, nameID :: Nat}
locName :: Location a -> Name b
locName (Location x) = Name (varHint x) (varID x)
residualizeLoc :: (ABT Term abt)
=> Statement abt Location 'Impure
-> Dis abt (Statement abt Location 'Impure, LAssocs Name)
residualizeLoc s =
case s of
SBind l _ _ -> do
(s', newname) <- reifyStatement s
return (s', singletonLAssocs l newname)
SLet l _ _ -> do
(s', newname) <- reifyStatement s
return (s', singletonLAssocs l newname)
SWeight w inds -> do
x <- freshVar Text.empty sUnit
let bodyW = Thunk $ P.weight (fromLazy w)
(s', newname) <- reifyStatement (SBind (Location x) bodyW inds)
return (s', singletonLAssocs (Location x) newname)
SGuard ls _ _ ixs
| null ixs -> return (s, toLAssocs1 ls (fmap11 locName ls))
| otherwise -> error "undefined: case statement under an array"
reifyStatement :: (ABT Term abt)
=> Statement abt Location 'Impure
-> Dis abt (Statement abt Location 'Impure, Name a)
reifyStatement s =
case s of
SBind l _ [] -> return (s, locName l)
SBind l body (i:is) -> do
let plate = Thunk . P.plateWithVar (indSize i) (indVar i)
x' <- freshVar (locHint l) (SArray (locType l))
reifyStatement (SBind (Location x') (plate $ fromLazy body) is)
SLet l _ [] -> return (s, locName l)
SLet l body (i:is) -> do
let array = Thunk . P.arrayWithVar (indSize i) (indVar i)
x' <- freshVar (locHint l) (SArray (locType l))
reifyStatement (SLet (Location x') (array $ fromLazy body) is)
SWeight _ _ -> error "reifyStatement called on SWeight"
SGuard _ _ _ _ -> error "reifyStatement called on SGuard"
sizeInnermostInd :: (ABT Term abt)
=> Location (a :: Hakaru)
-> Dis abt (abt '[] 'HNat)
sizeInnermostInd l =
(maybe (freeLocError l) return =<<) . select l $ \s ->
do guard (length (statementInds s) >= 1)
case s of
SBind l' _ ixs -> do Refl <- locEq l l'
Just $ unsafePush s >>
return (indSize (head ixs))
SLet l' _ ixs -> do Refl <- locEq l l'
Just $ unsafePush s >>
return (indSize (head ixs))
SWeight _ _ -> Nothing
SGuard _ _ _ _ -> error "TODO: sizeInnermostInd{SGuard}"
fromName :: (ABT Term abt)
=> Name b
-> Sing a
-> [abt '[] 'HNat]
-> abt '[] a
fromName name typ [] = var $ Variable { varHint = nameHint name
, varID = nameID name
, varType = typ }
fromName name typ (i:is) = fromName name (SArray typ) is P.! i
convertLocs :: (ABT Term abt)
=> LAssocs Name
-> Dis abt (Assocs (abt '[]))
convertLocs newlocs = F.foldr step emptyAssocs . fromAssocs <$> getExtras
where
build :: (ABT Term abt)
=> Assoc (Extra (abt '[]))
-> Name a
-> Assoc (abt '[])
build (Assoc x extra) name =
Assoc x (fromName name (varType x) (extrasInds extra))
step assoc@(Assoc _ extra) = insertAssoc $
case extra of
Loc l _ -> maybe (freeLocError l)
(build assoc)
(lookupLAssoc l newlocs)
freeLocError :: Location (a :: Hakaru) -> b
freeLocError l = error $ "Found a free location " ++ show l
zipInds :: (ABT Term abt)
=> [Index (abt '[])] -> [abt '[] 'HNat] -> Assocs (abt '[])
zipInds inds ts
| length inds /= length ts
= error "zipInds: argument lists must have the same length"
| otherwise = toAssocs $ zipWith Assoc (map indVar inds) ts
apply :: (ABT Term abt)
=> [Index (abt '[])]
-> [Index (abt '[])]
-> abt '[] a
-> Dis abt (abt '[] a)
apply is js e = extSubsts (zipInds is (map fromIndex js)) e
extendIndices
:: (ABT Term abt)
=> Index (abt '[])
-> [Index (abt '[])]
-> [Index (abt '[])]
extendIndices j js | j `elem` js
= error ("Duplicate index between " )
| otherwise
= j : js
statementInds :: Statement abt Location p -> [Index (abt '[])]
statementInds (SBind _ _ i) = i
statementInds (SLet _ _ i) = i
statementInds (SWeight _ i) = i
statementInds (SGuard _ _ _ i) = i
statementInds (SStuff0 _ i) = i
statementInds (SStuff1 _ _ i) = i
getExtras :: (ABT Term abt)
=> Dis abt (Assocs (Extra (abt '[])))
getExtras = Dis $ \_ c h l -> c l h l
putExtras :: (ABT Term abt)
=> Assocs (Extra (abt '[]))
-> Dis abt ()
putExtras l = Dis $ \_ c h _ -> c () h l
insertExtra :: (ABT Term abt)
=> Variable a
-> Extra (abt '[]) a
-> Dis abt ()
insertExtra v extra =
Dis $ \_ c h l -> c () h $
insertAssoc (Assoc v extra) l
adjustExtra :: (ABT Term abt)
=> Variable (a :: Hakaru)
-> (Assoc (Extra (abt '[])) -> Assoc (Extra (abt '[])))
-> Dis abt ()
adjustExtra x f = do
extras <- getExtras
putExtras $ adjustAssoc x f extras
mkLoc
:: (ABT Term abt)
=> Text
-> Location (a :: Hakaru)
-> [abt '[] 'HNat]
-> Dis abt (Variable a)
mkLoc hint l inds = do
x <- freshVar hint (locType l)
insertExtra x (Loc l inds)
return x
mkLocs
:: (ABT Term abt)
=> List1 Location (xs :: [Hakaru])
-> [abt '[] 'HNat]
-> Dis abt (List1 Variable xs)
mkLocs Nil1 _ = return Nil1
mkLocs (Cons1 l ls) inds = Cons1
<$> mkLoc Text.empty l inds
<*> mkLocs ls inds
instance Functor (Dis abt) where
fmap f (Dis m) = Dis $ \i c -> m i (c . f)
instance Applicative (Dis abt) where
pure x = Dis $ \_ c -> c x
Dis mf <*> Dis mx = Dis $ \i c -> mf i $ \f -> mx i $ \x -> c (f x)
instance Monad (Dis abt) where
return = pure
Dis m >>= k = Dis $ \i c -> m i $ \x -> unDis (k x) i c
instance Alternative (Dis abt) where
empty = Dis $ \_ _ _ _ -> []
Dis m <|> Dis n = Dis $ \i c h l -> m i c h l ++ n i c h l
instance MonadPlus (Dis abt) where
mzero = empty
mplus = (<|>)
instance (ABT Term abt) => EvaluationMonad abt (Dis abt) 'Impure where
freshNat =
Dis $ \_ c (ListContext n ss) ->
c n (ListContext (n+1) ss)
freshLocStatement s =
case s of
SWeight w e -> return (SWeight w e, mempty)
SBind x body i -> do
x' <- freshenVar x
let l = Location x'
v <- mkLoc (locHint l) l (map fromIndex i)
return (SBind l body i, singletonAssocs x v)
SLet x body i -> do
x' <- freshenVar x
let l = Location x'
v <- mkLoc (locHint l) l (map fromIndex i)
return (SLet l body i, singletonAssocs x v)
SGuard xs pat scrutinee i -> do
xs' <- freshenVars xs
let ls = locations1 xs'
vs <- mkLocs ls (map fromIndex i)
return (SGuard ls pat scrutinee i, toAssocs1 xs vs)
getIndices = Dis $ \i c -> c i
unsafePush s =
Dis $ \_ c (ListContext i ss) ->
c () (ListContext i (s:ss))
unsafePushes ss =
Dis $ \_ c (ListContext i ss') ->
c () (ListContext i (reverse ss ++ ss'))
select l p = loop []
where
loop ss = do
ms <- unsafePop
case ms of
Nothing -> do
unsafePushes ss
return Nothing
Just s ->
case l `isBoundBy` s >> p s of
Nothing -> loop (s:ss)
Just mr -> do
r <- mr
unsafePushes ss
return (Just r)
substVar x e z = do
extras <- getExtras
let defaultResult = return (var z)
case (lookupAssoc z extras) of
Nothing -> defaultResult
Just (Loc l inds) ->
if any (memberVarSet x . freeVars) inds
then do inds' <- mapM (extSubst x e) inds
var <$> mkLoc Text.empty l inds'
else defaultResult
extFreeVars e = do
extras <- getExtras
let fvs1 = freeVars e
indFVs (SomeVariable v) =
case (lookupAssoc v extras) of
Nothing -> emptyVarSet
Just (Loc _ is) -> foldr (unionVarSet.freeVars) emptyVarSet is
locVars (SomeVariable v) b =
case (lookupAssoc v extras) of
Nothing -> b
Just (Loc l _) -> insertVarSet (fromLocation l) b
fvs2 = foldr (unionVarSet.indFVs) fvs1 (fromVarSet fvs1)
return $ foldr locVars emptyVarSet (fromVarSet fvs2)
evaluateCase evaluate_ = evaluateCase_
where
evaluateCase_ :: CaseEvaluator abt (Dis abt)
evaluateCase_ e bs =
defaultCaseEvaluator evaluate_ e bs
<|> evaluateBranches e bs
evaluateBranches :: CaseEvaluator abt (Dis abt)
evaluateBranches e = choose . map evaluateBranch
where
evaluateBranch (Branch pat body) =
let (vars,body') = caseBinds body
in getIndices >>= \i ->
push (SGuard vars pat (Thunk e) i) body'
>>= evaluate_
evaluateVar perform evaluate_ x =
do extras <- getExtras
maybe (return $ Neutral (var x)) lookForLoc (lookupAssoc x extras)
where lookForLoc (Loc l jxs) =
(maybe (freeLocError l) return =<<) . select l $ \s ->
case s of
SBind l' e ixs -> do
Refl <- locEq l l'
Just $ do
w <- withIndices ixs $ perform (caseLazy e fromWhnf id)
unsafePush (SLet l (Whnf_ w) ixs)
#ifdef __TRACE_DISINTEGRATE__
trace ("-- updated "
++ show (ppStatement 11 s)
++ " to "
++ show (ppStatement 11 (SLet l (Whnf_ w) ixs))
) $ return ()
#endif
extSubsts (zipInds ixs jxs) (fromWhnf w) >>= evaluate_
SLet l' e ixs -> do
Refl <- locEq l l'
Just $ do
w <- withIndices ixs $ caseLazy e return evaluate_
unsafePush (SLet l (Whnf_ w) ixs)
extSubsts (zipInds ixs jxs) (fromWhnf w) >>= evaluate_
SWeight _ _ -> Nothing
SGuard ls pat scrutinee i -> Just . return . Neutral $ var x
withIndices :: [Index (abt '[])] -> Dis abt a -> Dis abt a
withIndices inds (Dis m) = Dis $ \_ c -> m inds c
unsafePop :: Dis abt (Maybe (Statement abt Location 'Impure))
unsafePop =
Dis $ \_ c h@(ListContext i ss) extras ->
case ss of
[] -> c Nothing h extras
s:ss' -> c (Just s) (ListContext i ss') extras
pushPlate
:: (ABT Term abt)
=> abt '[] 'HNat
-> abt '[ 'HNat ] ('HMeasure a)
-> Dis abt (abt '[] ('HArray a))
pushPlate n e =
caseBind e $ \x body -> do
inds <- getIndices
i <- freshInd n
p <- Location <$> freshVar Text.empty (sUnMeasure $ typeOf body)
let inds' = extendIndices i inds
unsafePush (SBind p (Thunk $ rename x (indVar i) body) inds')
v <- mkLoc Text.empty p $ map fromIndex inds'
return $ P.arrayWithVar n (indVar i) (var v)
pushWeight :: (ABT Term abt) => abt '[] 'HProb -> Dis abt ()
pushWeight w = do
inds <- getIndices
unsafePush $ SWeight (Thunk w) inds
pushGuard :: (ABT Term abt) => abt '[] HBool -> Dis abt ()
pushGuard b = do
inds <- getIndices
unsafePush $ SGuard Nil1 pTrue (Thunk b) inds
bot :: (ABT Term abt) => Dis abt a
bot = Dis $ \_ _ _ _ -> []
emit
:: (ABT Term abt)
=> Text
-> Sing a
-> (forall r. abt '[a] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt (Variable a)
emit hint typ f = do
x <- freshVar hint typ
Dis $ \_ c h l -> (f . bind x) <$> c x h l
emitMBind :: (ABT Term abt) => abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind m =
emit Text.empty (sUnMeasure $ typeOf m) $ \e ->
syn (MBind :$ m :* e :* End)
emitMBind2 :: (ABT Term abt) => abt '[] ('HMeasure a) -> Dis abt (abt '[] a)
emitMBind2 m = do
inds <- getIndices
let b = Whnf_ $ fromMaybe (error "emitMBind2: non-hnf term") (toWhnf m)
typ = sUnMeasure $ typeOf m
l <- Location <$> freshVar Text.empty typ
(SBind l' b' _, name) <- reifyStatement (SBind l b inds)
let (idx, p) = (fromName name typ (map fromIndex inds), fromLazy b')
Dis $ \_ c h ex ->
c idx h ex >>= \e ->
return $ syn (MBind :$ p :* bind (fromLocation l') e :* End)
emitLet :: (ABT Term abt) => abt '[] a -> Dis abt (Variable a)
emitLet e =
caseVarSyn e return $ \_ ->
emit Text.empty (typeOf e) $ \m ->
syn (Let_ :$ e :* m :* End)
emitLet' :: (ABT Term abt) => abt '[] a -> Dis abt (abt '[] a)
emitLet' e =
caseVarSyn e (const $ return e) $ \t ->
case t of
Literal_ _ -> return e
_ -> do
x <- emit Text.empty (typeOf e) $ \m ->
syn (Let_ :$ e :* m :* End)
return (var x)
emitUnpair
:: (ABT Term abt)
=> Whnf abt (HPair a b)
-> Dis abt (abt '[] a, abt '[] b)
emitUnpair (Head_ w) = return $ reifyPair w
emitUnpair (Neutral e) = do
let (a,b) = sUnPair (typeOf e)
x <- freshVar Text.empty a
y <- freshVar Text.empty b
emitUnpair_ x y e
emitUnpair_
:: forall abt a b
. (ABT Term abt)
=> Variable a
-> Variable b
-> abt '[] (HPair a b)
-> Dis abt (abt '[] a, abt '[] b)
emitUnpair_ x y = loop
where
done :: abt '[] (HPair a b) -> Dis abt (abt '[] a, abt '[] b)
done e =
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: done (term is not Datum_ nor Case_)" $
#endif
Dis $ \_ c h l ->
( syn
. Case_ e
. (:[])
. Branch (pPair PVar PVar)
. bind x
. bind y
) <$> c (var x, var y) h l
loop :: abt '[] (HPair a b) -> Dis abt (abt '[] a, abt '[] b)
loop e0 =
caseVarSyn e0 (done . var) $ \t ->
case t of
Datum_ d -> do
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: found Datum_" $ return ()
#endif
return $ reifyPair (WDatum d)
Case_ e bs -> do
#ifdef __TRACE_DISINTEGRATE__
trace "-- emitUnpair: going under Case_" $ return ()
#endif
emitCaseWith loop e bs
_ -> done e0
emit_
:: (ABT Term abt)
=> (forall r. abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
emit_ f = Dis $ \_ c h l -> f <$> c () h l
emitMBind_ :: (ABT Term abt) => abt '[] ('HMeasure HUnit) -> Dis abt ()
emitMBind_ m = emit_ (m P.>>)
emitGuard :: (ABT Term abt) => abt '[] HBool -> Dis abt ()
emitGuard b = emit_ (P.withGuard b)
emitWeight :: (ABT Term abt) => abt '[] 'HProb -> Dis abt ()
emitWeight w = emit_ (P.withWeight w)
emitFork_
:: (ABT Term abt, T.Traversable t)
=> (forall r. t (abt '[] ('HMeasure r)) -> abt '[] ('HMeasure r))
-> t (Dis abt a)
-> Dis abt a
emitFork_ f ms = Dis $ \i c h l -> f <$> T.traverse (\m -> unDis m i c h l) ms
emitSuperpose
:: (ABT Term abt)
=> [abt '[] ('HMeasure a)]
-> Dis abt (Variable a)
emitSuperpose [] = error "TODO: emitSuperpose[]"
emitSuperpose [e] = emitMBind e
emitSuperpose es =
emitMBind . P.superpose . NE.map ((,) P.one) $ NE.fromList es
choose :: (ABT Term abt) => [Dis abt a] -> Dis abt a
choose [] = error "TODO: choose[]"
choose [m] = m
choose ms = emitFork_ (P.superpose . NE.map ((,) P.one) . NE.fromList) ms
emitCaseWith
:: (ABT Term abt)
=> (abt '[] b -> Dis abt r)
-> abt '[] a
-> [Branch a abt b]
-> Dis abt r
emitCaseWith f e bs = do
gms <- T.for bs $ \(Branch pat body) ->
let (vars, body') = caseBinds body
in (\vars' ->
let rho = toAssocs1 vars vars'
in GBranch pat vars' (f $ renames rho body')
) <$> freshenVars vars
Dis $ \i c h l ->
(syn . Case_ e) <$> T.for gms (\gm ->
fromGBranch <$> T.for gm (\m ->
unDis m i c h l))
{-# INLINE emitCaseWith #-}