{-# LANGUAGE CPP , GADTs , KindSignatures , DataKinds , PolyKinds , TypeOperators , Rank2Types , FlexibleContexts , MultiParamTypeClasses , FlexibleInstances , UndecidableInstances , EmptyCase , ScopedTypeVariables #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} ---------------------------------------------------------------- -- 2016.05.24 -- | -- Module : Language.Hakaru.Evaluation.DisintegrationMonad -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : wren@community.haskell.org -- Stability : experimental -- Portability : GHC-only -- -- The 'EvaluationMonad' for "Language.Hakaru.Disintegrate" ---------------------------------------------------------------- module Language.Hakaru.Evaluation.DisintegrationMonad ( -- * The disintegration monad -- ** List-based version getStatements, putStatements , ListContext(..), Ans, Dis(..), runDis -- ** TODO: IntMap-based version -- * Operators on the disintegration monad -- ** The \"zero\" and \"one\" , bot --, reject -- ** Emitting code , emit , emitMBind , emitLet , emitLet' , emitUnpair -- TODO: emitUneither -- emitCaseWith , emit_ , emitMBind_ , emitGuard , emitWeight , emitFork_ , emitSuperpose , choose -- * Overrides for original in Evaluation.Types , pushPlate -- * For Arrays/Plate , getIndices , withIndices , extendIndices , extendLocInds , statementInds , sizeInnermostInd -- * Locs , Loc(..) , getLocs , putLocs , insertLoc , adjustLoc , mkLoc , freeLocError , apply #ifdef __TRACE_DISINTEGRATE__ , prettyLoc , prettyLocs #endif ) where import Prelude hiding (id, (.)) import Control.Category (Category(..)) #if __GLASGOW_HASKELL__ < 710 import Data.Monoid (Monoid(..)) import Data.Functor ((<$>)) import Control.Applicative (Applicative(..)) #endif import Data.Maybe import qualified Data.Foldable as F import qualified Data.Traversable as T import Data.List.NonEmpty (NonEmpty(..)) import qualified Data.List.NonEmpty as NE import Control.Applicative (Alternative(..)) import Control.Monad (MonadPlus(..),foldM,guard) import Data.Text (Text) import qualified Data.Text as Text import Data.Number.Nat import Language.Hakaru.Syntax.IClasses import Language.Hakaru.Types.DataKind import Language.Hakaru.Types.Sing (Sing(..), sUnMeasure, sUnPair, sUnit) import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.Datum import Language.Hakaru.Syntax.DatumABT import Language.Hakaru.Syntax.TypeOf import Language.Hakaru.Syntax.ABT import qualified Language.Hakaru.Syntax.Prelude as P import Language.Hakaru.Evaluation.Types import Language.Hakaru.Evaluation.PEvalMonad (ListContext(..)) import Language.Hakaru.Evaluation.Lazy (reifyPair) #ifdef __TRACE_DISINTEGRATE__ import Debug.Trace (trace, traceM) import qualified Text.PrettyPrint as PP import Language.Hakaru.Pretty.Haskell (ppVariable, pretty) #endif getStatements :: Dis abt [Statement abt 'Impure] getStatements = Dis $ \_ c h -> c (statements h) h putStatements :: [Statement abt 'Impure] -> Dis abt () putStatements ss = Dis $ \_ c (ListContext i _) loc -> c () (ListContext i ss) loc ---------------------------------------------------------------- ---------------------------------------------------------------- -- | Capturing substitution: plug :: forall abt a xs b . (ABT Term abt) => Variable a -> abt '[] a -> abt xs b -> abt xs b plug x e = start where start :: forall xs' b' . abt xs' b' -> abt xs' b' start f = loop f (viewABT f) loop :: forall xs' b'. abt xs' b' -> View (Term abt) xs' b' -> abt xs' b' loop _ (Syn t) = syn $! fmap21 start t loop f (Var z) = case varEq x z of Just Refl -> e Nothing -> f loop f (Bind _ _) = caseBind f $ \z f' -> bind z (loop f' (viewABT f')) -- | Perform multiple capturing substitutions plugs :: forall abt xs a . (ABT Term abt) => Assocs (abt '[]) -> abt xs a -> abt xs a plugs rho0 e0 = F.foldl (\e (Assoc x v) -> plug x v e) e0 (unAssocs rho0) -- | Plug a term into a context. That is, the 'statements' of the -- context specifies a program with a hole in it; so we plug the -- given term into that hole, returning the complete program. residualizeListContext :: forall abt a . (ABT Term abt) => ListContext abt 'Impure -> Assocs (abt '[]) -> abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) residualizeListContext ss rho e0 = -- N.B., we use a left fold because the head of the list of -- statements is the one closest to the hole. #ifdef __TRACE_DISINTEGRATE__ trace ("e0: " ++ show (pretty e0) ++ "\n" ++ show (pretty_Statements (statements ss))) $ #endif foldl step (plugs rho e0) (statements ss) where step :: abt '[] ('HMeasure a) -> Statement abt 'Impure -> abt '[] ('HMeasure a) step e s = #ifdef __TRACE_DISINTEGRATE__ trace ("wrapping " ++ show (ppStatement 0 s) ++ "\n" ++ "around term " ++ show (pretty e)) $ #endif case s of SBind x body _ -> -- TODO: if @body@ is dirac, then treat as 'SLet' syn (MBind :$ plugs rho (fromLazy body) :* bind x e :* End) SLet x body _ | not (x `memberVarSet` freeVars e) -> #ifdef __TRACE_DISINTEGRATE__ trace ("could not find location" ++ show x ++ "\n" ++ "in term " ++ show (pretty e) ++ "\n" ++ "given rho " ++ show (prettyAssocs rho)) $ #endif e -- TODO: if used exactly once in @e@, then inline. | otherwise -> case getLazyVariable body of Just y -> plug x (plugs rho (var y)) e Nothing -> case getLazyLiteral body of Just v -> plug x (syn $ Literal_ v) e Nothing -> syn (Let_ :$ plugs rho (fromLazy body) :* bind x e :* End) SGuard xs pat scrutinee _ -> syn $ Case_ (plugs rho $ fromLazy scrutinee) [ Branch pat (binds_ xs e) , Branch PWild (P.reject $ typeOf e) ] SWeight body _ -> syn $ Superpose_ ((plugs rho $ fromLazy body, e) :| []) ---------------------------------------------------------------- -- A location is a variable *use* instantiated at some list of indices. data Loc :: (Hakaru -> *) -> Hakaru -> * where Loc :: Variable a -> [Variable 'HNat] -> Loc ast a MultiLoc :: Variable a -> [Variable 'HNat] -> Loc ast ('HArray a) locIndices :: Loc ast a -> [Variable 'HNat] locIndices (Loc _ inds) = inds locIndices (MultiLoc _ inds) = inds extendLocInds :: Variable 'HNat -> [Variable 'HNat] -> [Variable 'HNat] extendLocInds = (:) #ifdef __TRACE_DISINTEGRATE__ prettyLoc :: Loc ast (a :: Hakaru) -> PP.Doc prettyLoc (Loc l inds) = PP.text "Loc" PP.<+> ppVariable l PP.<+> ppList (map ppVariable inds) prettyLoc (MultiLoc l inds) = PP.text "MultiLoc" PP.<+> ppVariable l PP.<+> ppList (map ppVariable inds) prettyLocs :: (ABT Term abt) => Assocs (Loc (abt '[])) -> PP.Doc prettyLocs a = PP.vcat $ map go (fromAssocs a) where go (Assoc x l) = ppVariable x PP.<+> PP.text "->" PP.<+> prettyLoc l #endif -- In the paper we say that result must be a 'Whnf'; however, in -- the paper it's also always @HMeasure a@ and everything of that -- type is a WHNF (via 'WMeasure') so that's a trivial statement -- to make. If we turn it back into some sort of normal form, then -- it must be one preserved by 'residualizeContext'. -- -- Also, we add the list in order to support "lub" without it living in the AST. -- TODO: really we should use LogicT... type Ans abt a = ListContext abt 'Impure -> Assocs (Loc (abt '[])) -> [abt '[] ('HMeasure a)] ---------------------------------------------------------------- -- TODO: defunctionalize the continuation. In particular, the only -- heap modifications we need are 'push' and a variant of 'update' -- for finding\/replacing a binding once we have the value in hand; -- and the only 'freshNat' modifications are to allocate new 'Nat'. -- We could defunctionalize the second arrow too by relying on the -- @Codensity (ReaderT e m) ~= StateT e (Codensity m)@ isomorphism, -- which makes explicit that the only thing other than 'ListContext' -- updates is emitting something like @[Statement]@ to serve as the -- beginning of the final result. -- -- TODO: give this a better, more informative name! -- -- N.B., This monad is used not only for both 'perform' and 'constrainOutcome', but also for 'constrainValue'. newtype Dis abt x = Dis { unDis :: forall a. [Index (abt '[])] -> (x -> Ans abt a) -> Ans abt a } -- == @Codensity (Ans abt)@, assuming 'Codensity' is poly-kinded like it should be -- If we don't want to allow continuations that can make nondeterministic choices, then we should use the right Kan extension itself, rather than the Codensity specialization of it. -- | Run a computation in the 'Dis' monad, residualizing out all the -- statements in the final evaluation context. The second argument -- should include all the terms altered by the 'Dis' expression; this -- is necessary to ensure proper hygiene; for example(s): -- -- > runDis (perform e) [Some2 e] -- > runDis (constrainOutcome e v) [Some2 e, Some2 v] -- -- We use 'Some2' on the inputs because it doesn't matter what their -- type or locally-bound variables are, so we want to allow @f@ to -- contain terms with different indices. runDis :: (ABT Term abt, F.Foldable f) => Dis abt (abt '[] a) -> f (Some2 abt) -> [abt '[] ('HMeasure a)] runDis d es = m0 [] c0 (ListContext i0 []) emptyAssocs where (Dis m0) = d >>= residualizeLocs -- TODO: we only use dirac because 'residualizeListContext' requires it to already be a measure; unfortunately this can result in an extraneous @(>>= \x -> dirac x)@ redex at the end of the program. In principle, we should be able to eliminate that redex by changing the type of 'residualizeListContext'... c0 (e,rho) ss _ = [residualizeListContext ss rho (syn(Dirac :$ e :* End))] i0 = maxNextFree es {---------------------------------------------------------------------------------- residualizeLocs does the following: 1. update the heap by constructing plate/array around statements with nonempty indices 2. use locations to construct terms out of var and "!" (for indexing into arrays) For example, consider the state: list context (aka heap) = l1 <- lebesgue [] l2 <- plate (normal 0 1) [] l3 <- lebesgue [i] l4 <- dirac x3 [] l5 <- normal 0 1 [j] assocs (aka locs) = x1 -> Loc l1 [] x2 -> Loc l2 [] x3 -> MultiLoc l3 [] x4 -> Loc l4 [] x5 -> Loc l5 [k] Here the types of the above variables are: l1, x1 :: Real l2, x2 :: Array Real l3 :: Real x3 :: Array Real l4, x4 :: Array Real l5, x5 :: Real Then residualizeLoc does two things: 1.Change the heap list context = l1' <- lebesgue [] l2' <- plate (normal 0 1) [] l3' <- plate i (lebesgue) l4' <- dirac x3 l5' <- plate j (normal 0 1) 2.Create new association table rho = x1 -> var l1' x2 -> var l2' x3 -> array i' (l3' ! i') x4 -> var l4' x5 -> var l5' ! k ----------------------------------------------------------------------------------} residualizeLocs :: forall a abt. (ABT Term abt) => abt '[] a -> Dis abt (abt '[] a, Assocs (abt '[])) residualizeLocs e = do ss <- getStatements (ss', newlocs) <- foldM step ([], emptyAssocs) ss rho <- convertLocs newlocs putStatements (reverse ss') #ifdef __TRACE_DISINTEGRATE__ trace ("residualizeLocs: old heap:\n" ++ show (pretty_Statements ss )) $ return () trace ("residualizeLocs: new heap:\n" ++ show (pretty_Statements ss')) $ return () locs <- getLocs traceM ("oldlocs:\n" ++ show (prettyLocs locs) ++ "\n") traceM ("new assoc for renaming:\n" ++ show (prettyAssocs rho)) #endif return (e, rho) where step (ss',newlocs) s = do (s',newlocs') <- residualizeLoc s return (s':ss', insertAssocs newlocs' newlocs) data Name (a :: Hakaru) = Name {nameHint :: Text, nameID :: Nat} varName :: Variable a -> Name b varName x = Name (varHint x) (varID x) residualizeLoc :: (ABT Term abt) => Statement abt 'Impure -> Dis abt (Statement abt 'Impure, Assocs Name) residualizeLoc s = case s of SBind l _ _ -> do (s', newname) <- reifyStatement s return (s', singletonAssocs l newname) SLet l _ _ -> do (s', newname) <- reifyStatement s return (s', singletonAssocs l newname) SWeight w inds -> do l <- freshVar Text.empty sUnit let bodyW = Thunk $ P.weight (fromLazy w) (s', newname) <- reifyStatement (SBind l bodyW inds) return (s', singletonAssocs l newname) SGuard ls _ _ ixs | null ixs -> return (s, toAssocs1 ls (fmap11 varName ls)) | otherwise -> error "undefined: case statement under an array" reifyStatement :: (ABT Term abt) => Statement abt 'Impure -> Dis abt (Statement abt 'Impure, Name a) reifyStatement s = case s of SBind l _ [] -> return (s, varName l) SBind l body (i:is) -> do let plate = Thunk . P.plateWithVar (indSize i) (indVar i) l' <- freshVar (varHint l) (SArray (varType l)) reifyStatement (SBind l' (plate $ fromLazy body) is) SLet l _ [] -> return (s, varName l) SLet l body (i:is) -> do let array = Thunk . P.arrayWithVar (indSize i) (indVar i) l' <- freshVar (varHint l) (SArray (varType l)) reifyStatement (SLet l' (array $ fromLazy body) is) SWeight _ _ -> error "reifyStatement called on SWeight" SGuard _ _ _ _ -> error "reifyStatement called on SGuard" sizeInnermostInd :: (ABT Term abt) => Variable (a :: Hakaru) -> Dis abt (abt '[] 'HNat) sizeInnermostInd l = (maybe (freeLocError l) return =<<) . select l $ \s -> do guard (length (statementInds s) >= 1) case s of SBind l' _ ixs -> do Refl <- varEq l l' Just $ unsafePush s >> return (indSize (head ixs)) SLet l' _ ixs -> do Refl <- varEq l l' Just $ unsafePush s >> return (indSize (head ixs)) SWeight _ _ -> Nothing SGuard _ _ _ _ -> error "TODO: sizeInnermostInd{SGuard}" fromLoc :: (ABT Term abt) => Name b -> Sing a -> [Variable 'HNat] -> abt '[] a fromLoc name typ [] = var $ Variable { varHint = nameHint name , varID = nameID name , varType = typ } fromLoc name typ (i:is) = fromLoc name (SArray typ) is P.! var i convertLocs :: (ABT Term abt) => Assocs Name -> Dis abt (Assocs (abt '[])) convertLocs newlocs = F.foldr step emptyAssocs . fromAssocs <$> getLocs where build :: (ABT Term abt) => Assoc (Loc (abt '[])) -> Name a -> Assoc (abt '[]) build (Assoc x loc) name = Assoc x (fromLoc name (varType x) (case loc of Loc _ js -> js; MultiLoc _ js -> js)) step assoc@(Assoc _ loc) = insertAssoc $ case loc of Loc l _ -> maybe (freeLocError l) (build assoc) (lookupAssoc l newlocs) MultiLoc l _ -> maybe (freeLocError l) (build assoc) (lookupAssoc l newlocs) freeLocError :: Variable (a :: Hakaru) -> b freeLocError l = error $ "Found a free location " ++ show l apply :: (ABT Term abt) => [(Index (abt '[]), Index (abt '[]))] -> abt '[] a -> Dis abt (abt '[] a) apply ijs e = do locs <- fromAssocs <$> getLocs rho' <- foldM step rho locs return (renames rho' e) where rho = toAssocs $ map (\(i,j) -> Assoc (indVar i) (indVar j)) ijs step r (Assoc x loc) = let inds = locIndices loc check i = lookupAssoc i rho inds' = map (\i -> fromMaybe i (check i)) inds in if (any isJust (map check inds)) then do x' <- case loc of Loc l _ -> mkLoc Text.empty l inds' MultiLoc l _ -> mkMultiLoc Text.empty l inds' return (insertAssoc (Assoc x x') r) else return r extendIndices :: (ABT Term abt) => Index (abt '[]) -> [Index (abt '[])] -> [Index (abt '[])] -- TODO: check all Indices are unique extendIndices j js | j `elem` js = error ("Duplicate index between " ) -- TODO finish this error message by -- defining Show for Index | otherwise = j : js -- give better name statementInds :: Statement abt p -> [Index (abt '[])] statementInds (SBind _ _ i) = i statementInds (SLet _ _ i) = i statementInds (SWeight _ i) = i statementInds (SGuard _ _ _ i) = i statementInds (SStuff0 _ i) = i statementInds (SStuff1 _ _ i) = i getLocs :: (ABT Term abt) => Dis abt (Assocs (Loc (abt '[]))) getLocs = Dis $ \_ c h l -> c l h l putLocs :: (ABT Term abt) => Assocs (Loc (abt '[])) -> Dis abt () putLocs l = Dis $ \_ c h _ -> c () h l insertLoc :: (ABT Term abt) => Variable a -> Loc (abt '[]) a -> Dis abt () insertLoc v loc = Dis $ \_ c h l -> c () h $ insertAssoc (Assoc v loc) l adjustLoc :: (ABT Term abt) => Variable (a :: Hakaru) -> (Assoc (Loc (abt '[])) -> Assoc (Loc (abt '[]))) -> Dis abt () adjustLoc x f = do locs <- getLocs putLocs $ adjustAssoc x f locs mkLoc :: (ABT Term abt) => Text -> Variable (a :: Hakaru) -> [Variable 'HNat] -> Dis abt (Variable a) mkLoc hint s inds = do x <- freshVar hint (varType s) insertLoc x (Loc s inds) return x mkLocs :: (ABT Term abt) => List1 Variable (xs :: [Hakaru]) -> [Variable 'HNat] -> Dis abt (List1 Variable xs) mkLocs Nil1 _ = return Nil1 mkLocs (Cons1 x xs) inds = Cons1 <$> mkLoc Text.empty x inds <*> mkLocs xs inds mkMultiLoc :: (ABT Term abt) => Text -> Variable a -> [Variable 'HNat] -> Dis abt (Variable ('HArray a)) mkMultiLoc hint s inds = do x' <- freshVar hint (SArray $ varType s) insertLoc x' (MultiLoc s inds) return x' instance Functor (Dis abt) where fmap f (Dis m) = Dis $ \i c -> m i (c . f) instance Applicative (Dis abt) where pure x = Dis $ \_ c -> c x Dis mf <*> Dis mx = Dis $ \i c -> mf i $ \f -> mx i $ \x -> c (f x) instance Monad (Dis abt) where return = pure Dis m >>= k = Dis $ \i c -> m i $ \x -> unDis (k x) i c instance Alternative (Dis abt) where empty = Dis $ \_ _ _ _ -> [] Dis m <|> Dis n = Dis $ \i c h l -> m i c h l ++ n i c h l instance MonadPlus (Dis abt) where mzero = empty -- aka "bot" mplus = (<|>) -- aka "lub" instance (ABT Term abt) => EvaluationMonad abt (Dis abt) 'Impure where freshNat = Dis $ \_ c (ListContext n ss) -> c n (ListContext (n+1) ss) freshenStatement s = case s of SWeight _ _ -> return (s, mempty) SBind x body i -> do l <- freshenVar x x' <- mkLoc (varHint x) l (map indVar i) return (SBind l body i, singletonAssocs x x') SLet x body i -> do l <- freshenVar x x' <- mkLoc (varHint x) l (map indVar i) return (SLet l body i, singletonAssocs x x') SGuard xs pat scrutinee i -> do ls <- freshenVars xs xs' <- mkLocs ls (map indVar i) return (SGuard ls pat scrutinee i, toAssocs1 xs xs') getIndices = Dis $ \i c -> c i unsafePush s = Dis $ \_ c (ListContext i ss) -> c () (ListContext i (s:ss)) -- N.B., the use of 'reverse' is necessary so that the order -- of pushing matches that of 'pushes' unsafePushes ss = Dis $ \_ c (ListContext i ss') -> c () (ListContext i (reverse ss ++ ss')) select l p = loop [] where -- TODO: use a DList to avoid reversing inside 'unsafePushes' loop ss = do ms <- unsafePop case ms of Nothing -> do unsafePushes ss return Nothing Just s -> -- Alas, @p@ will have to recheck 'isBoundBy' -- in order to grab the 'Refl' proof we erased; -- but there's nothing to be done for it. case l `isBoundBy` s >> p s of Nothing -> loop (s:ss) Just mr -> do r <- mr unsafePushes ss return (Just r) withIndices :: [Index (abt '[])] -> Dis abt a -> Dis abt a withIndices inds (Dis m) = Dis $ \_ c -> m inds c -- | Not exported because we only need it for defining 'select' on 'Dis'. unsafePop :: Dis abt (Maybe (Statement abt 'Impure)) unsafePop = Dis $ \_ c h@(ListContext i ss) loc -> case ss of [] -> c Nothing h loc s:ss' -> c (Just s) (ListContext i ss') loc pushPlate :: (ABT Term abt) => abt '[] 'HNat -> abt '[ 'HNat ] ('HMeasure a) -> Dis abt (Variable ('HArray a)) pushPlate n e = caseBind e $ \x body -> do inds <- getIndices i <- freshInd n p <- freshVar Text.empty (sUnMeasure $ typeOf body) unsafePush (SBind p (Thunk $ rename x (indVar i) body) (extendIndices i inds)) mkMultiLoc Text.empty p (map indVar inds) ---------------------------------------------------------------- ---------------------------------------------------------------- -- | It is impossible to satisfy the constraints, or at least we -- give up on trying to do so. This function is identical to 'empty' -- and 'mzero' for 'Dis'; we just give it its own name since this is -- the name used in our papers. -- -- TODO: add some sort of trace information so we can get a better -- idea what caused a disintegration to fail. bot :: (ABT Term abt) => Dis abt a bot = Dis $ \_ _ _ _ -> [] -- | The empty measure is a solution to the constraints. -- reject :: (ABT Term abt) => Dis abt a -- reject = Dis $ \_ _ -> [syn (Superpose_ [])] -- Something essentially like this function was called @insert_@ -- in the finally-tagless code. -- -- | Emit some code that binds a variable, and return the variable -- thus bound. The function says what to wrap the result of the -- continuation with; i.e., what we're actually emitting. emit :: (ABT Term abt) => Text -> Sing a -> (forall r. abt '[a] ('HMeasure r) -> abt '[] ('HMeasure r)) -> Dis abt (Variable a) emit hint typ f = do x <- freshVar hint typ Dis $ \_ c h l -> (f . bind x) <$> c x h l -- This function was called @lift@ in the finally-tagless code. -- | Emit an 'MBind' (i.e., \"@m >>= \x ->@\") and return the -- variable thus bound (i.e., @x@). emitMBind :: (ABT Term abt) => abt '[] ('HMeasure a) -> Dis abt (Variable a) emitMBind m = emit Text.empty (sUnMeasure $ typeOf m) $ \e -> syn (MBind :$ m :* e :* End) -- | A smart constructor for emitting let-bindings. If the input -- is already a variable then we just return it; otherwise we emit -- the let-binding. N.B., this function provides the invariant that -- the result is in fact a variable; whereas 'emitLet'' does not. emitLet :: (ABT Term abt) => abt '[] a -> Dis abt (Variable a) emitLet e = caseVarSyn e return $ \_ -> emit Text.empty (typeOf e) $ \m -> syn (Let_ :$ e :* m :* End) -- | A smart constructor for emitting let-bindings. If the input -- is already a variable or a literal constant, then we just return -- it; otherwise we emit the let-binding. N.B., this function -- provides weaker guarantees on the type of the result; if you -- require the result to always be a variable, then see 'emitLet' -- instead. emitLet' :: (ABT Term abt) => abt '[] a -> Dis abt (abt '[] a) emitLet' e = caseVarSyn e (const $ return e) $ \t -> case t of Literal_ _ -> return e _ -> do x <- emit Text.empty (typeOf e) $ \m -> syn (Let_ :$ e :* m :* End) return (var x) -- | A smart constructor for emitting \"unpair\". If the input -- argument is actually a constructor then we project out the two -- components; otherwise we emit the case-binding and return the -- two variables. emitUnpair :: (ABT Term abt) => Whnf abt (HPair a b) -> Dis abt (abt '[] a, abt '[] b) emitUnpair (Head_ w) = return $ reifyPair w emitUnpair (Neutral e) = do let (a,b) = sUnPair (typeOf e) x <- freshVar Text.empty a y <- freshVar Text.empty b emitUnpair_ x y e emitUnpair_ :: forall abt a b . (ABT Term abt) => Variable a -> Variable b -> abt '[] (HPair a b) -> Dis abt (abt '[] a, abt '[] b) emitUnpair_ x y = loop where done :: abt '[] (HPair a b) -> Dis abt (abt '[] a, abt '[] b) done e = #ifdef __TRACE_DISINTEGRATE__ trace "-- emitUnpair: done (term is not Datum_ nor Case_)" $ #endif Dis $ \_ c h l -> ( syn . Case_ e . (:[]) . Branch (pPair PVar PVar) . bind x . bind y ) <$> c (var x, var y) h l loop :: abt '[] (HPair a b) -> Dis abt (abt '[] a, abt '[] b) loop e0 = caseVarSyn e0 (done . var) $ \t -> case t of Datum_ d -> do #ifdef __TRACE_DISINTEGRATE__ trace "-- emitUnpair: found Datum_" $ return () #endif return $ reifyPair (WDatum d) Case_ e bs -> do #ifdef __TRACE_DISINTEGRATE__ trace "-- emitUnpair: going under Case_" $ return () #endif -- TODO: we want this to duplicate the current -- continuation for (the evaluation of @loop@ in) -- all branches. So far our traces all end up -- returning @bot@ on the first branch, and hence -- @bot@ for the whole case-expression, so we can't -- quite tell whether it does what is intended. -- -- N.B., the only 'Dis'-effects in 'applyBranch' -- are to freshen variables; thus this use of -- 'traverse' is perfectly sound. emitCaseWith loop e bs _ -> done e0 -- TODO: emitUneither -- This function was called @insert_@ in the old finally-tagless code. -- | Emit some code that doesn't bind any variables. This function -- provides an optimisation over using 'emit' and then discarding -- the generated variable. emit_ :: (ABT Term abt) => (forall r. abt '[] ('HMeasure r) -> abt '[] ('HMeasure r)) -> Dis abt () emit_ f = Dis $ \_ c h l -> f <$> c () h l -- | Emit an 'MBind' that discards its result (i.e., \"@m >>@\"). -- We restrict the type of the argument to be 'HUnit' so as to avoid -- accidentally dropping things. emitMBind_ :: (ABT Term abt) => abt '[] ('HMeasure HUnit) -> Dis abt () emitMBind_ m = emit_ (m P.>>) -- TODO: if the argument is a value, then we can evaluate the 'P.if_' immediately rather than emitting it. -- | Emit an assertion that the condition is true. emitGuard :: (ABT Term abt) => abt '[] HBool -> Dis abt () emitGuard b = emit_ (P.withGuard b) -- == emit_ $ \m -> P.if_ b m P.reject -- TODO: if the argument is the literal 1, then we can avoid emitting anything. emitWeight :: (ABT Term abt) => abt '[] 'HProb -> Dis abt () emitWeight w = emit_ (P.withWeight w) -- N.B., this use of 'T.traverse' is definitely correct. It's -- sequentializing @t [abt '[] ('HMeasure a)]@ into @[t (abt '[] -- ('HMeasure a))]@ by chosing one of the possibilities at each -- position in @t@. No heap\/context effects can escape to mess -- things up. In contrast, using 'T.traverse' to sequentialize @t -- (Dis abt a)@ as @Dis abt (t a)@ is /wrong/! Doing that would give -- the conjunctive semantics where we have effects from one position -- in @t@ escape to affect the other positions. This has to do with -- the general issue in partial evaluation where we need to duplicate -- downstream work (as we do by passing the same heap to everyone) -- because there's no general way to combing the resulting heaps -- for each branch. -- -- | Run each of the elements of the traversable using the same -- heap and continuation for each one, then pass the results to a -- function for emitting code. emitFork_ :: (ABT Term abt, T.Traversable t) => (forall r. t (abt '[] ('HMeasure r)) -> abt '[] ('HMeasure r)) -> t (Dis abt a) -> Dis abt a emitFork_ f ms = Dis $ \i c h l -> f <$> T.traverse (\m -> unDis m i c h l) ms -- | Emit a 'Superpose_' of the alternatives, each with unit weight. emitSuperpose :: (ABT Term abt) => [abt '[] ('HMeasure a)] -> Dis abt (Variable a) emitSuperpose [] = error "TODO: emitSuperpose[]" emitSuperpose [e] = emitMBind e emitSuperpose es = emitMBind . P.superpose . NE.map ((,) P.one) $ NE.fromList es -- | Emit a 'Superpose_' of the alternatives, each with unit weight. choose :: (ABT Term abt) => [Dis abt a] -> Dis abt a choose [] = error "TODO: choose[]" choose [m] = m choose ms = emitFork_ (P.superpose . NE.map ((,) P.one) . NE.fromList) ms -- | Given some function we can call on the bodies of the branches, -- freshen all the pattern-bound variables and then run the function -- on all the branches in parallel (i.e., with the same continuation -- and heap) and then emit a case-analysis expression with the -- results of the continuations as the bodies of the branches. This -- function is useful for when we really do want to emit a 'Case_' -- expression, rather than doing the superpose of guard patterns -- thing that 'constrainValue' does. -- -- N.B., this function assumes (and does not verify) that the second -- argument is emissible. So callers must guarantee this invariant, -- by calling 'atomize' as necessary. -- -- TODO: capture the emissibility requirement on the second argument -- in the types. emitCaseWith :: (ABT Term abt) => (abt '[] b -> Dis abt r) -> abt '[] a -> [Branch a abt b] -> Dis abt r emitCaseWith f e bs = do gms <- T.for bs $ \(Branch pat body) -> let (vars, body') = caseBinds body in (\vars' -> let rho = toAssocs1 vars vars' in GBranch pat vars' (f $ renames rho body') ) <$> freshenVars vars Dis $ \i c h l -> (syn . Case_ e) <$> T.for gms (\gm -> fromGBranch <$> T.for gm (\m -> unDis m i c h l)) {-# INLINE emitCaseWith #-} ---------------------------------------------------------------- ----------------------------------------------------------- fin.