{-# LANGUAGE CPP , GADTs , EmptyCase , KindSignatures , DataKinds , PolyKinds , TypeOperators , ScopedTypeVariables , Rank2Types , MultiParamTypeClasses , TypeSynonymInstances , FlexibleInstances , FlexibleContexts , UndecidableInstances #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} ---------------------------------------------------------------- -- 2016.06.29 -- | -- Module : Language.Hakaru.Disintegrate -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : wren@community.haskell.org -- Stability : experimental -- Portability : GHC-only -- -- Disintegration via lazy partial evaluation. -- -- N.B., the forward direction of disintegration is /not/ just -- partial evaluation! In the version discussed in the paper we -- must also ensure that no heap-bound variables occur in the result, -- which requires using HNFs rather than WHNFs. That condition is -- sound, but a bit too strict; we could generalize this to handle -- cases where there may be heap-bound variables remaining in neutral -- terms, provided we (a) don't end up trying to go both forward -- and backward on the same variable, and (more importantly) (b) -- end up with the proper Jacobian. The paper version is rigged to -- ensure that whenever we recurse into two subexpressions (e.g., -- the arguments to addition) one of them has a Jacobian of zero, -- thus when going from @x*y@ to @dx*y + x*dy@ one of the terms -- cancels out. -- -- /Developer's Note:/ To help keep the code clean, we use the -- worker\/wrapper transform. However, due to complexities in -- typechecking GADTs, this often confuses GHC if you don't give -- just the right type signature on definitions. This confusion -- shows up whenever you get error messages about an \"ambiguous\" -- choice of 'ABT' instance, or certain types of \"couldn't match -- @a@ with @a1@\" error messages. To eliminate these issues we use -- @-XScopedTypeVariables@. In particular, the @abt@ type variable -- must be bound by the wrapper (i.e., the top-level definition), -- and the workers should just refer to that same type variable -- rather than quantifying over abother @abt@ type. In addition, -- whatever other type variables there are (e.g., the @xs@ and @a@ -- of an @abt xs a@ argument) should be polymorphic in the workers -- and should /not/ reuse the other analogous type variables bound -- by the wrapper. -- -- /Developer's Note:/ In general, we'd like to emit weights and -- guards \"as early as possible\"; however, determining when that -- actually is can be tricky. If we emit them way-too-early then -- we'll get hygiene errors because bindings for the variables they -- use have not yet been emitted. We can fix these hygiene erors -- by calling 'atomize', to ensure that all the necessary bindings -- have already been emitted. But that may still emit things too -- early, because emitting th variable-binding statements now means -- that we can't go forward\/backward on them later on; which may -- cause us to bot unnecessarily. One way to avoid this bot issue -- is to emit guards\/weights later than necessary, by actually -- pushing them onto the context (and then emitting them whenever -- we residualize the context). This guarantees we don't emit too -- early; but the tradeoff is we may end up generating duplicate -- code by emitting too late. One possible (currently unimplemented) -- solution to that code duplication issue is to let these statements -- be emitted too late, but then have a post-processing step to -- lift guards\/weights up as high as they can go. To avoid problems -- about testing programs\/expressions for equality, we can use a -- hash-consing trick so we keep track of the identity of guard\/weight -- statements, then we can simply compare those identities and only -- after the lifting do we replace the identity hash with the actual -- statement. ---------------------------------------------------------------- module Language.Hakaru.Disintegrate ( -- * the Hakaru API disintegrateWithVar , disintegrate , densityWithVar , density , observe , determine -- * Implementation details , perform , atomize , constrainValue , constrainOutcome ) where #if __GLASGOW_HASKELL__ < 710 import Data.Functor ((<$>)) import Data.Foldable (Foldable, foldMap) import Data.Traversable (Traversable) import Control.Applicative (Applicative(..)) #endif import Control.Applicative (Alternative(..)) import Control.Monad ((<=<), guard) import Data.Functor.Compose (Compose(..)) import qualified Data.Traversable as T import Data.List.NonEmpty (NonEmpty(..)) import qualified Data.List.NonEmpty as L import qualified Data.Text as Text import qualified Data.IntMap as IM import Data.Sequence (Seq) import qualified Data.Sequence as S import Data.Proxy (KProxy(..)) import qualified Data.Set as Set (fromList) import Data.Maybe (fromMaybe) import Language.Hakaru.Syntax.IClasses import Data.Number.Natural import Language.Hakaru.Types.DataKind import Language.Hakaru.Types.Sing import qualified Language.Hakaru.Types.Coercion as C import Language.Hakaru.Types.HClasses import Language.Hakaru.Syntax.TypeOf import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.Datum import Language.Hakaru.Syntax.DatumCase (DatumEvaluator, MatchResult(..), matchBranches) import Language.Hakaru.Syntax.ABT import Language.Hakaru.Evaluation.Types import Language.Hakaru.Evaluation.Lazy hiding (evaluate,update) import Language.Hakaru.Evaluation.DisintegrationMonad import qualified Language.Hakaru.Syntax.Prelude as P import qualified Language.Hakaru.Expect as E #ifdef __TRACE_DISINTEGRATE__ import qualified Text.PrettyPrint as PP import Language.Hakaru.Pretty.Haskell import Debug.Trace (trace, traceM) #endif ---------------------------------------------------------------- lam_ :: (ABT Term abt) => Variable a -> abt '[] b -> abt '[] (a ':-> b) lam_ x e = syn (Lam_ :$ bind x e :* End) -- | Disintegrate a measure over pairs with respect to the lebesgue -- measure on the first component. That is, for each measure kernel -- @n <- disintegrate m@ we have that @m == bindx lebesgue n@. The -- first two arguments give the hint and type of the lambda-bound -- variable in the result. If you want to automatically fill those -- in, then see 'disintegrate'. -- -- N.B., the resulting functions from @a@ to @'HMeasure b@ are -- indeed measurable, thus it is safe\/appropriate to use Hakaru's -- @(':->)@ rather than Haskell's @(->)@. -- -- BUG: Actually, disintegration is with respect to the /Borel/ -- measure on the first component of the pair! Alas, we don't really -- have a clean way of describing this since we've no primitive -- 'MeasureOp' for Borel measures. -- -- /Developer's Note:/ This function fills the role that the old -- @runDisintegrate@ did (as opposed to the old function called -- @disintegrate@). [Once people are familiar enough with the new -- code base and no longer recall what the old code base was doing, -- this note should be deleted.] disintegrateWithVar :: (ABT Term abt) => Text.Text -> Sing a -> abt '[] ('HMeasure (HPair a b)) -> [abt '[] (a ':-> 'HMeasure b)] disintegrateWithVar hint typ m = let x = Variable hint (nextFreeOrBind m) typ in map (lam_ x) . flip runDis [Some2 m, Some2 (var x)] $ do ab <- perform m #ifdef __TRACE_DISINTEGRATE__ ss <- getStatements trace ("-- disintegrate: finished perform\n" ++ show (pretty_Statements ss PP.$+$ PP.sep(prettyPrec_ 11 ab)) ++ "\n") $ return () #endif -- BUG: Why does 'testDisintegrate1a' return no solutions? -- -- In older code (up to git#38889a5): It's because 'emitUnpair' -- isn't quite smart enough. When the @ab@ expression is a -- 'Neutral' case expression, we need to go underneath the -- case expression and call 'constrainValue' on each branch. -- Instead, what we currently do is emit an @unpair@ case -- statement with the scrutinee being the 'Neutral' case -- expression, and then just return the pair of variables -- bound by the emitted @unpair@; but, of course, -- 'constrainValue' can't do anything with those variables -- (since they appear to be free, given as they've already -- been emitted). Another way to think about what it is we -- need to do to correct this is that we need to perform -- the case-of-case transformation (where one of the cases -- is the 'Neutral' one, and the other is the @unpair@). -- -- In newer code (git#e8a0c66 and later): When we call 'perform' -- on an 'SBind' statement we emit some code and update the -- binding to become an 'SLet' of some local variable to -- the emitted variable. Later on when we call 'constrainVariable' -- on the local variable, we will look that 'SLet' statement -- up; and then when we call 'constrainVariable' on the -- emitted variable, things will @bot@ because we cannot -- constrain free variables in general. (a,b) <- emitUnpair ab #ifdef __TRACE_DISINTEGRATE__ trace ("-- disintegrate: finished emitUnpair: " ++ show (pretty a, pretty b)) $ return () #endif constrainValue (var x) a #ifdef __TRACE_DISINTEGRATE__ ss <- getStatements locs <- getLocs traceM ("-- disintegrate: finished constrainValue\n" ++ show (pretty_Statements ss) ++ "\n" ++ show (prettyLocs locs) ) #endif return b -- | A variant of 'disintegrateWithVar' which automatically computes -- the type via 'typeOf'. disintegrate :: (ABT Term abt) => abt '[] ('HMeasure (HPair a b)) -> [abt '[] (a ':-> 'HMeasure b)] disintegrate m = disintegrateWithVar Text.empty (fst . sUnPair . sUnMeasure $ typeOf m) -- TODO: change the exception thrown form 'typeOf' so that we know it comes from here m -- | Return the density function for a given measure. The first two -- arguments give the hint and type of the lambda-bound variable -- in the result. If you want to automatically fill those in, then -- see 'density'. -- -- TODO: is the resulting function guaranteed to be measurable? If -- so, update this documentation to reflect that fact; if not, then -- we should make it into a Haskell function instead. densityWithVar :: (ABT Term abt) => Text.Text -> Sing a -> abt '[] ('HMeasure a) -> [abt '[] (a ':-> 'HProb)] densityWithVar hint typ m = let x = Variable hint (nextFree m `max` nextBind m) typ in (lam_ x . E.total) <$> observe m (var x) -- | A variant of 'densityWithVar' which automatically computes the -- type via 'typeOf'. density :: (ABT Term abt) => abt '[] ('HMeasure a) -> [abt '[] (a ':-> 'HProb)] density m = densityWithVar Text.empty (sUnMeasure $ typeOf m) m -- | Constrain a measure such that it must return the observed -- value. In other words, the resulting measure returns the observed -- value with weight according to its density in the original -- measure, and gives all other values weight zero. observe :: (ABT Term abt) => abt '[] ('HMeasure a) -> abt '[] a -> [abt '[] ('HMeasure a)] observe m x = runDis (constrainOutcome x m >> return x) [Some2 m, Some2 x] -- | Arbitrarily choose one of the possible alternatives. In the -- future, this function should be replaced by a better one that -- takes some sort of strategy for deciding which alternative to -- choose. determine :: (ABT Term abt) => [abt '[] a] -> Maybe (abt '[] a) determine [] = Nothing determine (m:_) = Just m ---------------------------------------------------------------- ---------------------------------------------------------------- firstM :: Functor f => (a -> f b) -> (a,c) -> f (b,c) firstM f (x,y) = (\z -> (z, y)) <$> f x -- N.B., forward disintegration is not identical to partial evaluation, -- as noted at the top of the file. For correctness we need to -- ensure the result is emissible (i.e., has no heap-bound variables). -- More specifically, we need to ensure emissibility in the places -- where we call 'emitMBind' evaluate_ :: (ABT Term abt) => TermEvaluator abt (Dis abt) evaluate_ = evaluate perform evaluateCase -- Copying `evaluate` and `update` from LH.Evaluation.Lazy for now (2016-06-28) -- Beginning of copied code ------------------------------------------------- evaluate :: forall abt m p . (ABT Term abt) => MeasureEvaluator abt (Dis abt) -> (TermEvaluator abt (Dis abt) -> CaseEvaluator abt (Dis abt)) -> TermEvaluator abt (Dis abt) {-# INLINE evaluate #-} evaluate perform evaluateCase = goEvaluate where evaluateCase_ :: CaseEvaluator abt (Dis abt) evaluateCase_ = evaluateCase goEvaluate goEvaluate :: TermEvaluator abt (Dis abt) goEvaluate e0 = #ifdef __TRACE_DISINTEGRATE__ getIndices >>= \inds -> trace ("-- goEvaluate: " ++ show (pretty e0) ++ "at " ++ show (ppInds inds)) $ #endif caseVarSyn e0 (update perform goEvaluate) $ \t -> case t of -- Things which are already WHNFs Literal_ v -> return . Head_ $ WLiteral v Datum_ d -> return . Head_ $ WDatum d Empty_ typ -> return . Head_ $ WEmpty typ Array_ e1 e2 -> return . Head_ $ WArray e1 e2 Lam_ :$ e1 :* End -> return . Head_ $ WLam e1 Dirac :$ e1 :* End -> return . Head_ $ WDirac e1 MBind :$ e1 :* e2 :* End -> return . Head_ $ WMBind e1 e2 Plate :$ e1 :* e2 :* End -> return . Head_ $ WPlate e1 e2 MeasureOp_ o :$ es -> return . Head_ $ WMeasureOp o es Superpose_ pes -> return . Head_ $ WSuperpose pes Reject_ typ -> return . Head_ $ WReject typ -- We don't bother evaluating these, even though we could... Integrate :$ e1 :* e2 :* e3 :* End -> return . Head_ $ WIntegrate e1 e2 e3 Summate h1 h2 :$ e1 :* e2 :* e3 :* End -> return . Neutral $ syn t --return . Head_ $ WSummate e1 e2 e3 -- Everything else needs some evaluation App_ :$ e1 :* e2 :* End -> do w1 <- goEvaluate e1 case w1 of Neutral e1' -> return . Neutral $ P.app e1' e2 Head_ v1 -> evaluateApp v1 where evaluateApp (WLam f) = -- call-by-name: caseBind f $ \x f' -> do i <- getIndices push (SLet x (Thunk e2) i) f' goEvaluate evaluateApp _ = error "evaluate{App_}: the impossible happened" Let_ :$ e1 :* e2 :* End -> do i <- getIndices caseBind e2 $ \x e2' -> push (SLet x (Thunk e1) i) e2' goEvaluate CoerceTo_ c :$ e1 :* End -> C.coerceTo c <$> goEvaluate e1 UnsafeFrom_ c :$ e1 :* End -> C.coerceFrom c <$> goEvaluate e1 -- TODO: will maybe clean up the code to map 'evaluate' over @es@ before calling the evaluateFooOp helpers? NaryOp_ o es -> evaluateNaryOp goEvaluate o es ArrayOp_ o :$ es -> evaluateArrayOp goEvaluate o es PrimOp_ o :$ es -> evaluatePrimOp goEvaluate o es -- BUG: avoid the chance of looping in case 'E.expect' residualizes! -- TODO: use 'evaluate' in 'E.expect' for the evaluation of @e1@ Expect :$ e1 :* e2 :* End -> error "TODO: evaluate{Expect}: unclear how to handle this without cyclic dependencies" {- -- BUG: can't call E.expect because of cyclic dependency goEvaluate . E.expect e1 $ \e3 -> syn (Let_ :$ e3 :* e2 :* End) -} Case_ e bs -> evaluateCase_ e bs _ :$ _ -> error "evaluate: the impossible happened" evaluateNaryOp :: (ABT Term abt) => TermEvaluator abt (Dis abt) -> NaryOp a -> Seq (abt '[] a) -> Dis abt (Whnf abt a) evaluateNaryOp evaluate_ = \o es -> mainLoop o (evalOp o) S.empty es where -- TODO: there's got to be a more efficient way to do this... mainLoop o op ws es = case S.viewl es of S.EmptyL -> return $ case S.viewl ws of S.EmptyL -> identityElement o -- Avoid empty naryOps w S.:< ws' | S.null ws' -> w -- Avoid singleton naryOps | otherwise -> Neutral . syn . NaryOp_ o $ fmap fromWhnf ws e S.:< es' -> do w <- evaluate_ e case matchNaryOp o w of Nothing -> mainLoop o op (snocLoop op ws w) es' Just es2 -> mainLoop o op ws (es2 S.>< es') snocLoop :: (ABT syn abt) => (Head abt a -> Head abt a -> Head abt a) -> Seq (Whnf abt a) -> Whnf abt a -> Seq (Whnf abt a) snocLoop op ws w1 = -- TODO: immediately return @ws@ if @w1 == identityElement o@ (whenever identityElement is defined) case S.viewr ws of S.EmptyR -> S.singleton w1 ws' S.:> w2 -> case (w1,w2) of (Head_ v1, Head_ v2) -> snocLoop op ws' (Head_ (op v1 v2)) _ -> ws S.|> w1 matchNaryOp :: (ABT Term abt) => NaryOp a -> Whnf abt a -> Maybe (Seq (abt '[] a)) matchNaryOp o w = case w of Head_ _ -> Nothing Neutral e -> caseVarSyn e (const Nothing) $ \t -> case t of NaryOp_ o' es | o' == o -> Just es _ -> Nothing -- TODO: move this off to Prelude.hs or somewhere... identityElement :: (ABT Term abt) => NaryOp a -> Whnf abt a identityElement o = case o of And -> Head_ (WDatum dTrue) Or -> Head_ (WDatum dFalse) Xor -> Head_ (WDatum dFalse) Iff -> Head_ (WDatum dTrue) Min _ -> Neutral (syn (NaryOp_ o S.empty)) -- no identity in general (but we could do it by cases...) Max _ -> Neutral (syn (NaryOp_ o S.empty)) -- no identity in general (but we could do it by cases...) -- TODO: figure out how to reuse 'P.zero_' and 'P.one_' here; requires converting thr @(syn . Literal_)@ into @(Head_ . WLiteral)@. Maybe we should change 'P.zero_' and 'P.one_' so they just return the 'Literal' itself rather than the @abt@? Sum HSemiring_Nat -> Head_ (WLiteral (LNat 0)) Sum HSemiring_Int -> Head_ (WLiteral (LInt 0)) Sum HSemiring_Prob -> Head_ (WLiteral (LProb 0)) Sum HSemiring_Real -> Head_ (WLiteral (LReal 0)) Prod HSemiring_Nat -> Head_ (WLiteral (LNat 1)) Prod HSemiring_Int -> Head_ (WLiteral (LInt 1)) Prod HSemiring_Prob -> Head_ (WLiteral (LProb 1)) Prod HSemiring_Real -> Head_ (WLiteral (LReal 1)) -- | The evaluation interpretation of each NaryOp evalOp :: (ABT Term abt) => NaryOp a -> Head abt a -> Head abt a -> Head abt a -- TODO: something more efficient\/direct if we can... evalOp And = \v1 v2 -> reflect (reify v1 && reify v2) evalOp Or = \v1 v2 -> reflect (reify v1 || reify v2) evalOp Xor = \v1 v2 -> reflect (reify v1 /= reify v2) evalOp Iff = \v1 v2 -> reflect (reify v1 == reify v2) evalOp (Min _) = error "TODO: evalOp{Min}" evalOp (Max _) = error "TODO: evalOp{Max}" {- evalOp (Min _) = \v1 v2 -> reflect (reify v1 `min` reify v2) evalOp (Max _) = \v1 v2 -> reflect (reify v1 `max` reify v2) evalOp (Sum _) = \v1 v2 -> reflect (reify v1 + reify v2) evalOp (Prod _) = \v1 v2 -> reflect (reify v1 * reify v2) -} -- HACK: this is just to have something to test. We really should reduce\/remove all this boilerplate... evalOp (Sum theSemi) = \(WLiteral v1) (WLiteral v2) -> WLiteral $ evalSum theSemi v1 v2 evalOp (Prod theSemi) = \(WLiteral v1) (WLiteral v2) -> WLiteral $ evalProd theSemi v1 v2 -- TODO: even if only one of the arguments is a literal, if that literal is zero\/one, then we can still partially evaluate it. (As is done in the old finally-tagless code) evalSum, evalProd :: HSemiring a -> Literal a -> Literal a -> Literal a evalSum HSemiring_Nat = \(LNat n1) (LNat n2) -> LNat (n1 + n2) evalSum HSemiring_Int = \(LInt i1) (LInt i2) -> LInt (i1 + i2) evalSum HSemiring_Prob = \(LProb p1) (LProb p2) -> LProb (p1 + p2) evalSum HSemiring_Real = \(LReal r1) (LReal r2) -> LReal (r1 + r2) evalProd HSemiring_Nat = \(LNat n1) (LNat n2) -> LNat (n1 * n2) evalProd HSemiring_Int = \(LInt i1) (LInt i2) -> LInt (i1 * i2) evalProd HSemiring_Prob = \(LProb p1) (LProb p2) -> LProb (p1 * p2) evalProd HSemiring_Real = \(LReal r1) (LReal r2) -> LReal (r1 * r2) isIndex :: (ABT Term abt) => Variable 'HNat -> Dis abt Bool isIndex v = do inds <- getIndices return $ v `elem` map indVar inds -- | For {evaluate, constrainValue v0} ArrayOp_ (Index _) :$ e1 :* e2 :* End indexArrayOp :: forall abt typs args a r . ( ABT Term abt , typs ~ UnLCs args, args ~ LCs typs ) => ArrayOp typs a -> SArgs abt args -> TermEvaluator abt (Dis abt) -> (abt '[] a -> Dis abt r) -- evaluate or (constrainValue v0) -> (Head abt ('HArray a) -- e1 is in whnf, and -> Variable 'HNat -- e2 is a free under current indices -> Dis abt r) -> (Term abt ('HArray a) -- e1 is neutral syntax -> Dis abt r) -> (abt '[] ('HArray a) -- e1 is a free variable -> Dis abt r) -> (abt '[] ('HArray a) -- e1 is a multiloc, and -> Variable 'HNat -- e2 is a free under current indices -> Dis abt r) -> Dis abt r indexArrayOp o@(Index _) (e1 :* e2 :* End) evaluate_ kInd kArr kSyn kFree kMultiLoc = do w1 <- evaluate_ e1 case w1 of Head_ arr@(WArray _ b) -> caseBind b $ \x body -> evalIndex (kInd . flip (rename x) body) (kArr arr) Head_ (WEmpty _) -> error "TODO: indexArrayOp o (Empty_ :* _ :* End)" Head_ _ -> error "indexArrayOp: unknown whnf of array type" Neutral e1' -> flip (caseVarSyn e1') kSyn $ \x -> do locs <- getLocs case (lookupAssoc x locs) of Nothing -> kFree e1' Just (Loc _ _) -> error "indexArrayOp: impossible, we have a Neutral term" Just (MultiLoc l js) -> evalIndex ((kInd . var =<<) . mkLoc Text.empty l . flip extendLocInds js) (kMultiLoc e1') where evalIndex :: (ABT Term abt) => (Variable 'HNat -> Dis abt r) -> (Variable 'HNat -> Dis abt r) -> Dis abt r evalIndex thenCase elseCase = do w2 <- evaluate_ e2 caseWhnf w2 (const bot) $ \term -> -- bot if index is in whnf (eg. a literal num) flip (caseVarSyn term) (const bot) $ \v -> -- bot if index is neutral syntax do isInd <- isIndex v if isInd then thenCase v else elseCase v indexArrayOp _ _ _ _ _ _ _ _ = error "indexArrayOp called on incorrect ArrayOp" evaluateArrayOp :: ( ABT Term abt , typs ~ UnLCs args, args ~ LCs typs) => TermEvaluator abt (Dis abt) -> ArrayOp typs a -> SArgs abt args -> Dis abt (Whnf abt a) evaluateArrayOp evaluate_ = go where go o@(Index _) = \args@(_ :* e2 :* End) -> let returnIndex = return . Neutral . syn in indexArrayOp o args evaluate_ evaluate_ (\arr v -> returnIndex (ArrayOp_ o :$ fromHead arr :* var v :* End)) (\s -> returnIndex (ArrayOp_ o :$ syn s :* e2 :* End)) (\e1' -> returnIndex (ArrayOp_ o :$ e1' :* e2 :* End)) (\e1' v -> returnIndex (ArrayOp_ o :$ e1' :* var v :* End)) go o@(Size _) = \(e1 :* End) -> do w1 <- evaluate_ e1 case w1 of Neutral e1' -> return . Neutral $ syn (ArrayOp_ o :$ e1' :* End) Head_ v1 -> case head2array v1 of WAEmpty -> return . Head_ $ WLiteral (LNat 0) WAArray e3 _ -> evaluate_ e3 go (Reduce _) = \(e1 :* e2 :* e3 :* End) -> error "TODO: evaluateArrayOp{Reduce}" data ArrayHead :: ([Hakaru] -> Hakaru -> *) -> Hakaru -> * where WAEmpty :: ArrayHead abt a WAArray :: !(abt '[] 'HNat) -> !(abt '[ 'HNat] a) -> ArrayHead abt a head2array :: Head abt ('HArray a) -> ArrayHead abt a head2array (WEmpty _) = WAEmpty head2array (WArray e1 e2) = WAArray e1 e2 impl, diff, nand, nor :: Bool -> Bool -> Bool impl x y = not x || y diff x y = x && not y nand x y = not (x && y) nor x y = not (x || y) evaluatePrimOp :: forall abt p typs args a . ( ABT Term abt , typs ~ UnLCs args, args ~ LCs typs) => TermEvaluator abt (Dis abt) -> PrimOp typs a -> SArgs abt args -> Dis abt (Whnf abt a) evaluatePrimOp evaluate_ = go where -- HACK: we don't have any way of saying these functions haven't reduced even though it's not actually a neutral term. neu1 :: forall b c . (abt '[] b -> abt '[] c) -> abt '[] b -> Dis abt (Whnf abt c) neu1 f e = (Neutral . f . fromWhnf) <$> evaluate_ e neu2 :: forall b c d . (abt '[] b -> abt '[] c -> abt '[] d) -> abt '[] b -> abt '[] c -> Dis abt (Whnf abt d) neu2 f e1 e2 = do e1' <- fromWhnf <$> evaluate_ e1 e2' <- fromWhnf <$> evaluate_ e2 return . Neutral $ f e1' e2' rr1 :: forall b b' c c' . (Interp b b', Interp c c') => (b' -> c') -> (abt '[] b -> abt '[] c) -> abt '[] b -> Dis abt (Whnf abt c) rr1 f' f e = do w <- evaluate_ e return $ case w of Neutral e' -> Neutral $ f e' Head_ v -> Head_ . reflect $ f' (reify v) rr2 :: forall b b' c c' d d' . (Interp b b', Interp c c', Interp d d') => (b' -> c' -> d') -> (abt '[] b -> abt '[] c -> abt '[] d) -> abt '[] b -> abt '[] c -> Dis abt (Whnf abt d) rr2 f' f e1 e2 = do w1 <- evaluate_ e1 w2 <- evaluate_ e2 return $ case w1 of Neutral e1' -> Neutral $ f e1' (fromWhnf w2) Head_ v1 -> case w2 of Neutral e2' -> Neutral $ f (fromWhnf w1) e2' Head_ v2 -> Head_ . reflect $ f' (reify v1) (reify v2) primOp2_ :: forall b c d . PrimOp '[ b, c ] d -> abt '[] b -> abt '[] c -> abt '[] d primOp2_ o e1 e2 = syn (PrimOp_ o :$ e1 :* e2 :* End) -- TODO: something more efficient\/direct if we can... go Not (e1 :* End) = rr1 not P.not e1 go Impl (e1 :* e2 :* End) = rr2 impl (primOp2_ Impl) e1 e2 go Diff (e1 :* e2 :* End) = rr2 diff (primOp2_ Diff) e1 e2 go Nand (e1 :* e2 :* End) = rr2 nand P.nand e1 e2 go Nor (e1 :* e2 :* End) = rr2 nor P.nor e1 e2 -- HACK: we don't have a way of saying that 'Pi' (or 'Infinity',...) is in fact a head; so we're forced to call it neutral which is a lie. We should add constructor(s) to 'Head' to cover these magic constants; probably grouped together under a single constructor called something like @Constant@. Maybe should group them like that in the AST as well? go Pi End = return $ Neutral P.pi -- We treat trig functions as strict, thus forcing their -- arguments; however, to avoid fuzz issues we don't actually -- evaluate the trig functions. -- -- HACK: we might should have some other way to make these -- 'Whnf' rather than calling them neutral terms; since they -- aren't, in fact, neutral! go Sin (e1 :* End) = neu1 P.sin e1 go Cos (e1 :* End) = neu1 P.cos e1 go Tan (e1 :* End) = neu1 P.tan e1 go Asin (e1 :* End) = neu1 P.asin e1 go Acos (e1 :* End) = neu1 P.acos e1 go Atan (e1 :* End) = neu1 P.atan e1 go Sinh (e1 :* End) = neu1 P.sinh e1 go Cosh (e1 :* End) = neu1 P.cosh e1 go Tanh (e1 :* End) = neu1 P.tanh e1 go Asinh (e1 :* End) = neu1 P.asinh e1 go Acosh (e1 :* End) = neu1 P.acosh e1 go Atanh (e1 :* End) = neu1 P.atanh e1 -- TODO: deal with how we have better types for these three ops than Haskell does... -- go RealPow (e1 :* e2 :* End) = rr2 (**) (P.**) e1 e2 go RealPow (e1 :* e2 :* End) = neu2 (P.**) e1 e2 -- HACK: these aren't actually neutral! -- BUG: we should try to cancel out @(exp . log)@ and @(log . exp)@ go Exp (e1 :* End) = neu1 P.exp e1 go Log (e1 :* End) = neu1 P.log e1 -- HACK: these aren't actually neutral! go (Infinity h) End = case h of HIntegrable_Nat -> return . Neutral $ P.primOp0_ (Infinity h) HIntegrable_Prob -> return $ Neutral P.infinity go GammaFunc (e1 :* End) = neu1 P.gammaFunc e1 go BetaFunc (e1 :* e2 :* End) = neu2 P.betaFunc e1 e2 go (Equal theEq) (e1 :* e2 :* End) = rrEqual theEq e1 e2 go (Less theOrd) (e1 :* e2 :* End) = rrLess theOrd e1 e2 go (NatPow theSemi) (e1 :* e2 :* End) = case theSemi of HSemiring_Nat -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2 HSemiring_Int -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2 HSemiring_Prob -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2 HSemiring_Real -> rr2 (\v1 v2 -> v1 ^ fromNatural v2) (P.^) e1 e2 go (Negate theRing) (e1 :* End) = case theRing of HRing_Int -> rr1 negate P.negate e1 HRing_Real -> rr1 negate P.negate e1 go (Abs theRing) (e1 :* End) = case theRing of HRing_Int -> rr1 (unsafeNatural . abs) P.abs_ e1 HRing_Real -> rr1 (unsafeNonNegativeRational . abs) P.abs_ e1 go (Signum theRing) (e1 :* End) = case theRing of HRing_Int -> rr1 signum P.signum e1 HRing_Real -> rr1 signum P.signum e1 go (Recip theFractional) (e1 :* End) = case theFractional of HFractional_Prob -> rr1 recip P.recip e1 HFractional_Real -> rr1 recip P.recip e1 go (NatRoot theRadical) (e1 :* e2 :* End) = case theRadical of HRadical_Prob -> neu2 (flip P.thRootOf) e1 e2 {- go (NatRoot theRadical) (e1 :* e2 :* End) = case theRadical of HRadical_Prob -> rr2 natRoot (flip P.thRootOf) e1 e2 go (Erf theContinuous) (e1 :* End) = case theContinuous of HContinuous_Prob -> rr1 erf P.erf e1 HContinuous_Real -> rr1 erf P.erf e1 -} go op _ = error $ "TODO: evaluatePrimOp{" ++ show op ++ "}" rrEqual :: forall b. HEq b -> abt '[] b -> abt '[] b -> Dis abt (Whnf abt HBool) rrEqual theEq = case theEq of HEq_Nat -> rr2 (==) (P.==) HEq_Int -> rr2 (==) (P.==) HEq_Prob -> rr2 (==) (P.==) HEq_Real -> rr2 (==) (P.==) HEq_Array aEq -> error "TODO: rrEqual{HEq_Array}" HEq_Bool -> rr2 (==) (P.==) HEq_Unit -> rr2 (==) (P.==) HEq_Pair aEq bEq -> \e1 e2 -> do w1 <- evaluate_ e1 w2 <- evaluate_ e2 case w1 of Neutral e1' -> return . Neutral $ P.primOp2_ (Equal theEq) e1' (fromWhnf w2) Head_ v1 -> case w2 of Neutral e2' -> return . Neutral $ P.primOp2_ (Equal theEq) (fromHead v1) e2' Head_ v2 -> do let (v1a, v1b) = reifyPair v1 let (v2a, v2b) = reifyPair v2 wa <- rrEqual aEq v1a v2a wb <- rrEqual bEq v1b v2b return $ case wa of Neutral ea -> case wb of Neutral eb -> Neutral (ea P.&& eb) Head_ vb | reify vb -> wa | otherwise -> Head_ $ WDatum dFalse Head_ va | reify va -> wb | otherwise -> Head_ $ WDatum dFalse HEq_Either aEq bEq -> error "TODO: rrEqual{HEq_Either}" rrLess :: forall b. HOrd b -> abt '[] b -> abt '[] b -> Dis abt (Whnf abt HBool) rrLess theOrd = case theOrd of HOrd_Nat -> rr2 (<) (P.<) HOrd_Int -> rr2 (<) (P.<) HOrd_Prob -> rr2 (<) (P.<) HOrd_Real -> rr2 (<) (P.<) HOrd_Array aOrd -> error "TODO: rrLess{HOrd_Array}" HOrd_Bool -> rr2 (<) (P.<) HOrd_Unit -> rr2 (<) (P.<) HOrd_Pair aOrd bOrd -> \e1 e2 -> do w1 <- evaluate_ e1 w2 <- evaluate_ e2 case w1 of Neutral e1' -> return . Neutral $ P.primOp2_ (Less theOrd) e1' (fromWhnf w2) Head_ v1 -> case w2 of Neutral e2' -> return . Neutral $ P.primOp2_ (Less theOrd) (fromHead v1) e2' Head_ v2 -> do let (v1a, v1b) = reifyPair v1 let (v2a, v2b) = reifyPair v2 error "TODO: rrLess{HOrd_Pair}" -- BUG: The obvious recursion won't work because we need to know when the first components are equal before recursing (to implement lexicographic ordering). We really need a ternary comparison operator like 'compare'. HOrd_Either aOrd bOrd -> error "TODO: rrLess{HOrd_Either}" update :: forall abt . (ABT Term abt) => MeasureEvaluator abt (Dis abt) -> TermEvaluator abt (Dis abt) -> VariableEvaluator abt (Dis abt) update perform evaluate_ x = do locs <- getLocs -- If we get 'Nothing', then it turns out @x@ is a free variable maybe (return $ Neutral (var x)) lookForLoc (lookupAssoc x locs) where lookForLoc (Loc l jxs) = (maybe (freeLocError l) return =<<) . select l $ \s -> case s of SBind l' e ixs -> do Refl <- varEq l l' Just $ do w <- withIndices ixs $ perform (caseLazy e fromWhnf id) unsafePush (SLet l (Whnf_ w) ixs) #ifdef __TRACE_DISINTEGRATE__ trace ("-- updated " ++ show (ppStatement 11 s) ++ " to " ++ show (ppStatement 11 (SLet l (Whnf_ w) ixs)) ) $ return () #endif let as = toAssocs $ zipWith Assoc (map indVar ixs) jxs w' = renames as (fromWhnf w) inds <- getIndices withIndices inds $ return (fromMaybe (Neutral w') (toWhnf w')) SLet l' e ixs -> do Refl <- varEq l l' Just $ do w <- withIndices ixs $ caseLazy e return evaluate_ unsafePush (SLet l (Whnf_ w) ixs) let as = toAssocs $ zipWith Assoc (map indVar ixs) jxs w' = renames as (fromWhnf w) inds <- getIndices withIndices inds $ return (fromMaybe (Neutral w') (toWhnf w')) -- This does not bind any variables, so it definitely can't match. SWeight _ _ -> Nothing -- This does bind variables, -- but there's no expression we can return for it -- because the variables are untouchable\/abstract. SGuard ls pat scrutinee i -> Just . return . Neutral $ var x -- Case for MultiLocs lookForLoc (MultiLoc l jxs) = return (Neutral $ var x) ---------------------------------------------------------- End of copied code -- -- | The forward disintegrator's function for evaluating case -- expressions. First we try calling 'defaultCaseEvaluator' which -- will evaluate the scrutinee and select the matching branch (if -- any). But that doesn't work out in general, since the scrutinee -- may contain heap-bound variables. So our fallback definition -- will push a 'SGuard' onto the heap and then continue evaluating -- each branch (thereby duplicating the continuation, calling it -- once on each branch). evaluateCase :: forall abt . (ABT Term abt) => TermEvaluator abt (Dis abt) -> CaseEvaluator abt (Dis abt) {-# INLINE evaluateCase #-} evaluateCase evaluate_ = evaluateCase_ where evaluateCase_ :: CaseEvaluator abt (Dis abt) evaluateCase_ e bs = defaultCaseEvaluator evaluate_ e bs <|> evaluateBranches e bs evaluateBranches :: CaseEvaluator abt (Dis abt) evaluateBranches e = choose . map evaluateBranch where evaluateBranch (Branch pat body) = let (vars,body') = caseBinds body in getIndices >>= \i -> push (SGuard vars pat (Thunk e) i) body' evaluate_ evaluateDatum :: (ABT Term abt) => DatumEvaluator (abt '[]) (Dis abt) evaluateDatum e = viewWhnfDatum <$> evaluate_ e -- | Simulate performing 'HMeasure' actions by simply emiting code -- for those actions, returning the bound variable. -- -- This is the function called @(|>>)@ in the disintegration paper. perform :: forall abt. (ABT Term abt) => MeasureEvaluator abt (Dis abt) perform = \e0 -> #ifdef __TRACE_DISINTEGRATE__ getStatements >>= \ss -> getLocs >>= \locs -> getIndices >>= \inds -> trace ("\n-- perform --\n" ++ "at " ++ show (ppInds inds) ++ "\n" ++ show (prettyLocs locs) ++ "\n" ++ show (pretty_Statements_withTerm ss e0) ++ "\n") $ #endif caseVarSyn e0 performVar performTerm where performTerm :: forall a. Term abt ('HMeasure a) -> Dis abt (Whnf abt a) performTerm (Dirac :$ e1 :* End) = evaluate_ e1 performTerm (MeasureOp_ o :$ es) = performMeasureOp o es performTerm (MBind :$ e1 :* e2 :* End) = caseBind e2 $ \x e2' -> do inds <- getIndices push (SBind x (Thunk e1) inds) e2' perform performTerm (Plate :$ e1 :* e2 :* End) = do x1 <- pushPlate e1 e2 return (Neutral (var x1)) performTerm (Superpose_ pes) = do inds <- getIndices if not (null inds) && L.length pes > 1 then bot else emitFork_ (P.superpose . fmap ((,) P.one)) (fmap (\(p,e) -> push (SWeight (Thunk p) inds) e perform) pes) -- Avoid falling through to the @performWhnf <=< evaluate_@ case performTerm (Let_ :$ e1 :* e2 :* End) = caseBind e2 $ \x e2' -> do inds <- getIndices push (SLet x (Thunk e1) inds) e2' perform -- TODO: we could optimize this by calling some @evaluateTerm@ -- directly, rather than calling 'syn' to rebuild @e0@ from -- @t0@ and then calling 'evaluate_' (which will just use -- 'caseVarSyn' to get the @t0@ back out from the @e0@). -- -- BUG: when @t0@ is a 'Case_', this doesn't work right. This -- is the source of the hygiene bug in 'testPerform1b'. Alas, -- we cannot use 'emitCaseWith' here since that would require -- the scrutinee to be emissible; but we'd want something pretty -- similar... performTerm t0 = do w <- evaluate_ (syn t0) #ifdef __TRACE_DISINTEGRATE__ trace ("-- perform: finished evaluate, with:\n" ++ show (PP.sep(prettyPrec_ 11 w))) $ return () #endif performWhnf w performVar :: forall a. Variable ('HMeasure a) -> Dis abt (Whnf abt a) performVar = performWhnf <=< update perform evaluate_ -- BUG: it's not clear this is actually doing the right thing for its call-sites. In particular, we should handle 'Case_' specially, to deal with the hygiene bug in 'testPerform1b'... -- -- BUG: found the 'testPerform1b' hygiene bug! We can't simply call 'emitMBind' on @e@, because @e@ may not be emissible! performWhnf :: forall a. Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a) performWhnf (Head_ v) = perform $ fromHead v performWhnf (Neutral e) = (Neutral . var) <$> emitMBind e -- TODO: right now we do the simplest thing. However, for better -- coverage and cleaner produced code we'll need to handle each -- of the ops separately. (For example, see how 'Uniform' is -- handled in the old code; how it has two options for what to -- do.) performMeasureOp :: forall typs args a . (typs ~ UnLCs args, args ~ LCs typs) => MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a) performMeasureOp = \o es -> nice o es <|> complete o es where -- Try to generate nice pretty output. nice :: MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a) nice o es = do es' <- traverse21 atomizeCore es x <- emitMBind $ syn (MeasureOp_ o :$ es') return (Neutral $ var x) -- Try to be as complete as possible (i.e., 'bot' as little as possible), no matter how ugly the output code gets. complete :: MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a) complete Normal = \(mu :* sd :* End) -> do x <- var <$> emitMBind P.lebesgue pushWeight (P.densityNormal mu sd x) return (Neutral x) complete Uniform = \(lo :* hi :* End) -> do x <- var <$> emitMBind P.lebesgue pushGuard (lo P.< x P.&& x P.< hi) pushWeight (P.densityUniform lo hi x) return (Neutral x) complete _ = \_ -> bot -- Calls unsafePush pushWeight :: (ABT Term abt) => abt '[] 'HProb -> Dis abt () pushWeight w = do inds <- getIndices unsafePush $ SWeight (Thunk w) inds -- Calls unsafePush pushGuard :: (ABT Term abt) => abt '[] HBool -> Dis abt () pushGuard b = do inds <- getIndices unsafePush $ SGuard Nil1 pTrue (Thunk b) inds -- | The goal of this function is to ensure the correctness criterion -- that given any term to be emitted, the resulting term is -- semantically equivalent but contains no heap-bound variables. -- That correctness criterion is necessary to ensure hygiene\/scoping. -- -- This particular implementation calls 'evaluate' recursively, -- giving us something similar to full-beta reduction. However, -- that is considered an implementation detail rather than part of -- the specification of what the function should do. Also, it's a -- gross hack and prolly a big part of why we keep running into -- infinite looping issues. -- -- This name is taken from the old finally tagless code, where -- \"atomic\" terms are (among other things) emissible; i.e., contain -- no heap-bound variables. -- -- BUG: this function infinitely loops in certain circumstances -- (namely when dealing with neutral terms) atomize :: (ABT Term abt) => TermEvaluator abt (Dis abt) atomize e = #ifdef __TRACE_DISINTEGRATE__ trace ("\n-- atomize --\n" ++ show (pretty e)) $ #endif traverse21 atomizeCore =<< evaluate_ e -- | A variant of 'atomize' which is polymorphic in the locally -- bound variables @xs@ (whereas 'atomize' requires @xs ~ '[]@). -- We factored this out because we often want this more polymorphic -- variant when using our indexed @TraversableMN@ classes. atomizeCore :: (ABT Term abt) => abt xs a -> Dis abt (abt xs a) atomizeCore e = do -- HACK: this check for 'disjointVarSet' is sufficient to catch -- the particular infinite loops we were encountering, but it -- will not catch all of them. If the call to 'evaluate_' in -- 'atomize' returns a neutral term which contains heap-bound -- variables, then we'll still loop forever since we don't -- traverse\/fmap over the top-level term constructor of neutral -- terms. xs <- getHeapVars if disjointVarSet xs (freeVars e) then return e else let (ys, e') = caseBinds e in (binds_ ys . fromWhnf) <$> atomize e' where -- TODO: does @IM.null . IM.intersection@ fuse correctly? disjointVarSet xs ys = IM.null (IM.intersection (unVarSet xs) (unVarSet ys)) -- HACK: if we really want to go through with this approach, then -- we should memoize the set of heap-bound variables in the -- 'ListContext' itself rather than recomputing it every time! getHeapVars :: Dis abt (VarSet ('KProxy :: KProxy Hakaru)) getHeapVars = Dis $ \_ c h -> c (foldMap statementVars (statements h)) h ---------------------------------------------------------------- ---------------------------------------------------------------- -- | Given an emissible term @v0@ (the first argument) and another -- term @e0@ (the second argument), compute the constraints such -- that @e0@ must evaluate to @v0@. This is the function called -- @(<|)@ in the disintegration paper, though notably we swap the -- argument order so that the \"value\" is the first argument. -- -- N.B., this function assumes (and does not verify) that the first -- argument is emissible. So callers (including recursive calls) -- must guarantee this invariant, by calling 'atomize' as necessary. -- -- TODO: capture the emissibility requirement on the first argument -- in the types, to help avoid accidentally passing the arguments -- in the wrong order! constrainValue :: (ABT Term abt) => abt '[] a -> abt '[] a -> Dis abt () constrainValue v0 e0 = #ifdef __TRACE_DISINTEGRATE__ getStatements >>= \ss -> getLocs >>= \locs -> getIndices >>= \inds -> trace ("\n-- constrainValue: " ++ show (pretty v0) ++ "\n" ++ show (pretty_Statements_withTerm ss e0) ++ "\n" ++ "at " ++ show (ppInds inds) ++ "\n" ++ show (prettyLocs locs) ++ "\n" ) $ #endif caseVarSyn e0 (constrainVariable v0) $ \t -> case t of -- There's a bunch of stuff we don't even bother trying to handle Empty_ _ -> error "TODO: disintegrate arrays" Array_ n e -> caseBind e $ \x body -> do j <- freshInd n let x' = indVar j body' = rename x x' body inds <- getIndices withIndices (extendIndices j inds) $ constrainValue (v0 P.! (var x')) body' -- TODO use meta-index ArrayOp_ o@(Index _) :$ args -> indexArrayOp o args evaluate_ (constrainValue v0) (const $ const bot) (const bot) (const bot) (const $ const bot) ArrayOp_ _ :$ _ -> error "TODO: disintegrate arrays" Lam_ :$ _ :* End -> error "TODO: disintegrate lambdas" App_ :$ _ :* _ :* End -> error "TODO: disintegrate lambdas" Integrate :$ _ :* _ :* _ :* End -> error "TODO: disintegrate integration" Summate _ _ :$ _ :* _ :* _ :* End -> error "TODO: disintegrate integration" -- N.B., the semantically correct definition is: -- -- > Literal_ v -- > | "dirac v has a density wrt the ambient measure" -> ... -- > | otherwise -> bot -- -- For the case where the ambient measure is Lebesgue, dirac -- doesn't have a density, so we return 'bot'. However, we -- will need to generalize this when we start handling other -- ambient measures. Literal_ v -> bot -- unsolvable. (kinda; see note) Datum_ d -> constrainDatum v0 d Dirac :$ _ :* End -> bot -- giving up. MBind :$ _ :* _ :* End -> bot -- giving up. MeasureOp_ o :$ es -> constrainValueMeasureOp v0 o es Superpose_ pes -> bot -- giving up. Reject_ _ -> bot -- giving up. Let_ :$ e1 :* e2 :* End -> caseBind e2 $ \x e2' -> push (SLet x (Thunk e1) []) e2' (constrainValue v0) CoerceTo_ c :$ e1 :* End -> -- TODO: we need to insert some kind of guard that says -- @v0@ is in the range of @coerceTo c@, or equivalently -- that @unsafeFrom c v0@ will always succeed. We need -- to emit this guard (for correctness of the generated -- program) because if @v0@ isn't in the range of the -- coercion, then there's no possible way the program -- @e1@ could in fact be observed at @v0@. The only -- question is how to perform that check; for the -- 'Signed' coercions it's easy enough, but for the -- 'Continuous' coercions it's not really clear. constrainValue (P.unsafeFrom_ c v0) e1 UnsafeFrom_ c :$ e1 :* End -> -- TODO: to avoid returning garbage, we'd need to place -- some constraint on @e1@ so that if the original -- program would've crashed due to a bad unsafe-coercion, -- then we won't return a disintegrated program (since -- it too should always crash). Avoiding this check is -- sound (i.e., if the input program is well-formed -- then the output program is a well-formed disintegration), -- it just overgeneralizes. constrainValue (P.coerceTo_ c v0) e1 NaryOp_ o es -> constrainNaryOp v0 o es PrimOp_ o :$ es -> constrainPrimOp v0 o es Expect :$ e1 :* e2 :* End -> error "TODO: constrainValue{Expect}" Case_ e bs -> -- First we try going forward on the scrutinee, to make -- pretty resulting programs; but if that doesn't work -- out, we fall back to just constraining the branches. do match <- matchBranches evaluateDatum e bs case match of Nothing -> -- If desired, we could return the Hakaru program -- that always crashes, instead of throwing a -- Haskell error. error "constrainValue{Case_}: nothing matched!" Just GotStuck -> constrainBranches v0 e bs Just (Matched rho body) -> pushes (toStatements rho) body (constrainValue v0) <|> constrainBranches v0 e bs _ :$ _ -> error "constrainValue: the impossible happened" -- | The default way of doing 'constrainValue' on a 'Case_' expression: -- by constraining each branch. To do this we rely on the fact that -- we're in a 'HMeasure' context (i.e., the continuation produces -- programs of 'HMeasure' type). For each branch we first assert the -- branch's pattern holds (via 'SGuard') and then call 'constrainValue' -- on the body of the branch; and the final program is the superposition -- of all these branches. -- -- TODO: how can we avoid duplicating the scrutinee expression? -- Would pushing a 'SLet' statement before the superpose be sufficient -- to achieve maximal sharing? constrainBranches :: (ABT Term abt) => abt '[] a -> abt '[] b -> [Branch b abt a] -> Dis abt () constrainBranches v0 e = choose . map constrainBranch where constrainBranch (Branch pat body) = let (vars,body') = caseBinds body in push (SGuard vars pat (Thunk e) []) body' (constrainValue v0) constrainDatum :: (ABT Term abt) => abt '[] a -> Datum (abt '[]) a -> Dis abt () constrainDatum v0 d = case patternOfDatum d of PatternOfDatum pat es -> do xs <- freshVars $ fmap11 (Hint Text.empty . typeOf) es emit_ $ \body -> syn $ Case_ v0 [ Branch pat (binds_ xs body) , Branch PWild (P.reject $ (typeOf body)) ] constrainValues xs es constrainValues :: (ABT Term abt) => List1 Variable xs -> List1 (abt '[]) xs -> Dis abt () constrainValues (Cons1 x xs) (Cons1 e es) = constrainValue (var x) e >> constrainValues xs es constrainValues Nil1 Nil1 = return () constrainValues _ _ = error "constrainValues: the impossible happened" data PatternOfDatum (ast :: Hakaru -> *) (a :: Hakaru) = forall xs. PatternOfDatum !(Pattern xs a) !(List1 ast xs) -- | Given a datum, return the pattern which will match it along -- with the subexpressions which would be bound to patter-variables. patternOfDatum :: Datum ast a -> PatternOfDatum ast a patternOfDatum = \(Datum hint _typ d) -> podCode d $ \p es -> PatternOfDatum (PDatum hint p) es where podCode :: DatumCode xss ast a -> (forall bs. PDatumCode xss bs a -> List1 ast bs -> r) -> r podCode (Inr d) k = podCode d $ \ p es -> k (PInr p) es podCode (Inl d) k = podStruct d $ \ p es -> k (PInl p) es podStruct :: DatumStruct xs ast a -> (forall bs. PDatumStruct xs bs a -> List1 ast bs -> r) -> r podStruct (Et d1 d2) k = podFun d1 $ \p1 es1 -> podStruct d2 $ \p2 es2 -> k (PEt p1 p2) (es1 `append1` es2) podStruct Done k = k PDone Nil1 podFun :: DatumFun x ast a -> (forall bs. PDatumFun x bs a -> List1 ast bs -> r) -> r podFun (Konst e) k = k (PKonst PVar) (Cons1 e Nil1) podFun (Ident e) k = k (PIdent PVar) (Cons1 e Nil1) ---------------------------------------------------------------- -- | N.B., as with 'constrainValue', we assume that the first -- argument is emissible. So it is the caller's responsibility to -- ensure this (by calling 'atomize' as appropriate). -- -- TODO: capture the emissibility requirement on the first argument -- in the types. constrainVariable :: (ABT Term abt) => abt '[] a -> Variable a -> Dis abt () constrainVariable v0 x = do locs <- getLocs -- If we get 'Nothing', then it turns out @x@ is a free variable. -- If @x@ is a free variable, then it's a neutral term; and we -- return 'bot' for neutral terms maybe bot lookForLoc (lookupAssoc x locs) where lookForLoc (Loc l jxs) = let -- Assumption: js has no duplicates permutes is js = length is == length js && Set.fromList is == Set.fromList (map indVar js) -- If we get 'Nothing', then it turns out @l@ is a free location. -- This is an error because of the invariant: -- if there exists an 'Assoc x ({Multi}Loc l _)' inside @locs@ -- then there must be a statement on the 'ListContext' that binds @l@ in (maybe (freeLocError l) return =<<) . select l $ \s -> case s of SBind l' e ixs -> do Refl <- varEq l l' guard (length ixs == length jxs) -- will error otherwise Just $ do inds <- getIndices guard (jxs `permutes` inds) -- will bot otherwise e' <- apply (zip ixs inds) (fromLazy e) constrainOutcome v0 e' unsafePush (SLet l (Whnf_ (Neutral v0)) inds) SLet l' e ixs -> do Refl <- varEq l l' guard (length ixs == length jxs) -- will error otherwise Just $ do inds <- getIndices guard (jxs `permutes` inds) -- will bot otherwise e' <- apply (zip ixs inds) (fromLazy e) constrainValue v0 e' unsafePush (SLet l (Whnf_ (Neutral v0)) inds) SWeight _ _ -> Nothing SGuard ls' pat scrutinee i -> error "TODO: constrainVariable{SGuard}" -- Case for MultiLoc lookForLoc (MultiLoc l jxs) = do #ifdef __TRACE_DISINTEGRATE__ traceM $ "looking for MultiLoc: " ++ show (prettyLoc (MultiLoc l jxs)) #endif n <- sizeInnermostInd l j <- freshInd n x' <- mkLoc Text.empty l (extendLocInds (indVar j) jxs) inds <- getIndices withIndices (extendIndices j inds) $ constrainValue (v0 P.! (var $ indVar j)) (var x') -- TODO use meta-index ---------------------------------------------------------------- -- | N.B., as with 'constrainValue', we assume that the first -- argument is emissible. So it is the caller's responsibility to -- ensure this (by calling 'atomize' as appropriate). -- -- TODO: capture the emissibility requirement on the first argument -- in the types. constrainValueMeasureOp :: forall abt typs args a . (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) => abt '[] ('HMeasure a) -> MeasureOp typs a -> SArgs abt args -> Dis abt () constrainValueMeasureOp v0 = go where -- TODO: for Lebesgue and Counting we use @bot@ because that's -- what the old finally-tagless code seems to have been doing. -- But is that right, or should they really be @return ()@? go :: MeasureOp typs a -> SArgs abt args -> Dis abt () go Lebesgue = \End -> bot -- TODO: see note above go Counting = \End -> bot -- TODO: see note above go Categorical = \(e1 :* End) -> constrainValue v0 (P.categorical e1) go Uniform = \(e1 :* e2 :* End) -> constrainValue v0 (P.uniform' e1 e2) go Normal = \(e1 :* e2 :* End) -> constrainValue v0 (P.normal' e1 e2) go Poisson = \(e1 :* End) -> constrainValue v0 (P.poisson' e1) go Gamma = \(e1 :* e2 :* End) -> constrainValue v0 (P.gamma' e1 e2) go Beta = \(e1 :* e2 :* End) -> constrainValue v0 (P.beta' e1 e2) ---------------------------------------------------------------- -- | N.B., We assume that the first argument, @v0@, is already -- atomized. So, this must be ensured before recursing, but we can -- assume it's already been done by the IH. -- -- N.B., we also rely on the fact that our 'HSemiring' instances -- are actually all /commutative/ semirings. If that ever becomes -- not the case, then we'll need to fix things here. -- -- As written, this will do a lot of redundant work in atomizing -- the subterms other than the one we choose to go backward on. -- Because evaluation has side-effects on the heap and is heap -- dependent, it seems like there may not be a way around that -- issue. (I.e., we could use dynamic programming to efficiently -- build up the 'M' computations, but not to evaluate them.) Of -- course, really we shouldn't be relying on the structure of the -- program here; really we should be looking at the heap-bound -- variables in the term: choosing each @x@ to go backward on, treat -- the term as a function of @x@, atomize that function (hence going -- forward on the rest of the variables), and then invert it and -- get the Jacobian. -- -- TODO: find some way to capture in the type that the first argument -- must be emissible. constrainNaryOp :: (ABT Term abt) => abt '[] a -> NaryOp a -> Seq (abt '[] a) -> Dis abt () constrainNaryOp v0 o = case o of Sum theSemi -> lubSeq $ \es1 e es2 -> do u <- atomize $ syn (NaryOp_ (Sum theSemi) (es1 S.>< es2)) v <- evaluate_ $ P.unsafeMinus_ theSemi v0 (fromWhnf u) constrainValue (fromWhnf v) e Prod theSemi -> lubSeq $ \es1 e es2 -> do u <- atomize $ syn (NaryOp_ (Prod theSemi) (es1 S.>< es2)) let u' = fromWhnf u -- TODO: emitLet? emitWeight $ P.recip (toProb_abs theSemi u') v <- evaluate_ $ P.unsafeDiv_ theSemi v0 u' constrainValue (fromWhnf v) e Max theOrd -> chooseSeq $ \es1 e es2 -> do u <- atomize $ syn (NaryOp_ (Max theOrd) (es1 S.>< es2)) emitGuard $ P.primOp2_ (Less theOrd) (fromWhnf u) v0 constrainValue v0 e _ -> error $ "TODO: constrainNaryOp{" ++ show o ++ "}" -- TODO: if this function (or the component @toProb@ and @semiringAbs@ -- parts) turn out to be useful elsewhere, then we should move it -- to the Prelude. toProb_abs :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] 'HProb toProb_abs HSemiring_Nat = P.nat2prob toProb_abs HSemiring_Int = P.nat2prob . P.abs_ toProb_abs HSemiring_Prob = id toProb_abs HSemiring_Real = P.abs_ -- TODO: is there any way to optimise the zippering over the Seq, a la 'S.inits' or 'S.tails'? -- TODO: really we want a dynamic programming approach to avoid unnecessary repetition of calling @evaluateNaryOp evaluate_@ on the two subsequences... lubSeq :: (Alternative m) => (Seq a -> a -> Seq a -> m b) -> Seq a -> m b lubSeq f = go S.empty where go xs ys = case S.viewl ys of S.EmptyL -> empty y S.:< ys' -> f xs y ys' <|> go (xs S.|> y) ys' chooseSeq :: (ABT Term abt) => (Seq a -> a -> Seq a -> Dis abt b) -> Seq a -> Dis abt b chooseSeq f = choose . go S.empty where go xs ys = case S.viewl ys of S.EmptyL -> [] y S.:< ys' -> f xs y ys' : go (xs S.|> y) ys' ---------------------------------------------------------------- -- HACK: for a lot of these, we can't use the prelude functions -- because Haskell can't figure out our polymorphism, so we have -- to define our own versions for manually passing dictionaries -- around. -- -- | N.B., We assume that the first argument, @v0@, is already -- atomized. So, this must be ensured before recursing, but we can -- assume it's already been done by the IH. -- -- TODO: find some way to capture in the type that the first argument -- must be emissible. constrainPrimOp :: forall abt typs args a . (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) => abt '[] a -> PrimOp typs a -> SArgs abt args -> Dis abt () constrainPrimOp v0 = go where error_TODO op = error $ "TODO: constrainPrimOp{" ++ op ++"}" go :: PrimOp typs a -> SArgs abt args -> Dis abt () go Not = \(e1 :* End) -> error_TODO "Not" go Impl = \(e1 :* e2 :* End) -> error_TODO "Impl" go Diff = \(e1 :* e2 :* End) -> error_TODO "Diff" go Nand = \(e1 :* e2 :* End) -> error_TODO "Nand" go Nor = \(e1 :* e2 :* End) -> error_TODO "Nor" go Pi = \End -> bot -- because @dirac pi@ has no density wrt lebesgue go Sin = \(e1 :* End) -> do x0 <- emitLet' v0 n <- var <$> emitMBind P.counting let tau_n = P.real_ 2 P.* P.fromInt n P.* P.pi -- TODO: emitLet? emitGuard (P.negate P.one P.< x0 P.&& x0 P.< P.one) v <- var <$> emitSuperpose [ P.dirac (tau_n P.+ P.asin x0) , P.dirac (tau_n P.+ P.pi P.- P.asin x0) ] emitWeight . P.recip . P.sqrt . P.unsafeProb $ (P.one P.- x0 P.^ P.nat_ 2) constrainValue v e1 go Cos = \(e1 :* End) -> do x0 <- emitLet' v0 n <- var <$> emitMBind P.counting let tau_n = P.real_ 2 P.* P.fromInt n P.* P.pi emitGuard (P.negate P.one P.< x0 P.&& x0 P.< P.one) r <- emitLet' (tau_n P.+ P.acos x0) v <- var <$> emitSuperpose [P.dirac r, P.dirac (r P.+ P.pi)] emitWeight . P.recip . P.sqrt . P.unsafeProb $ (P.one P.- x0 P.^ P.nat_ 2) constrainValue v e1 go Tan = \(e1 :* End) -> do x0 <- emitLet' v0 n <- var <$> emitMBind P.counting r <- emitLet' (P.fromInt n P.* P.pi P.+ P.atan x0) emitWeight $ P.recip (P.one P.+ P.square x0) constrainValue r e1 go Asin = \(e1 :* End) -> do x0 <- emitLet' v0 emitWeight $ P.unsafeProb (P.cos x0) -- TODO: bounds check for -pi/2 <= v0 < pi/2 constrainValue (P.sin x0) e1 go Acos = \(e1 :* End) -> do x0 <- emitLet' v0 emitWeight $ P.unsafeProb (P.sin x0) constrainValue (P.cos x0) e1 go Atan = \(e1 :* End) -> do x0 <- emitLet' v0 emitWeight $ P.recip (P.unsafeProb (P.cos x0 P.^ P.nat_ 2)) constrainValue (P.tan x0) e1 go Sinh = \(e1 :* End) -> error_TODO "Sinh" go Cosh = \(e1 :* End) -> error_TODO "Cosh" go Tanh = \(e1 :* End) -> error_TODO "Tanh" go Asinh = \(e1 :* End) -> error_TODO "Asinh" go Acosh = \(e1 :* End) -> error_TODO "Acosh" go Atanh = \(e1 :* End) -> error_TODO "Atanh" go RealPow = \(e1 :* e2 :* End) -> -- TODO: There's a discrepancy between @(**)@ and @pow_@ in the old code... do -- TODO: if @v1@ is 0 or 1 then bot. Maybe the @log v1@ in @w@ takes care of the 0 case? u <- emitLet' v0 -- either this from @(**)@: -- emitGuard $ P.zero P.< u -- w <- atomize $ P.recip (P.abs (v0 P.* P.log v1)) -- emitWeight $ P.unsafeProb (fromWhnf w) -- constrainValue (P.logBase v1 v0) e2 -- or this from @pow_@: let w = P.recip (u P.* P.unsafeProb (P.abs (P.log e1))) emitWeight w constrainValue (P.log u P./ P.log e1) e2 -- end. <|> do -- TODO: if @v2@ is 0 then bot. Maybe the weight @w@ takes care of this case? u <- emitLet' v0 let ex = v0 P.** P.recip e2 -- either this from @(**)@: -- emitGuard $ P.zero P.< u -- w <- atomize $ abs (ex / (v2 * v0)) -- or this from @pow_@: let w = P.abs (P.fromProb ex P./ (e2 P.* P.fromProb u)) -- end. emitWeight $ P.unsafeProb w constrainValue ex e1 go Exp = \(e1 :* End) -> do x0 <- emitLet' v0 -- TODO: do we still want\/need the @emitGuard (0 < x0)@ which is now equivalent to @emitGuard (0 /= x0)@ thanks to the types? emitWeight (P.recip x0) constrainValue (P.log x0) e1 go Log = \(e1 :* End) -> do exp_x0 <- emitLet' (P.exp v0) emitWeight exp_x0 constrainValue exp_x0 e1 go (Infinity _) = \End -> error_TODO "Infinity" -- scalar0 go GammaFunc = \(e1 :* End) -> error_TODO "GammaFunc" -- scalar1 go BetaFunc = \(e1 :* e2 :* End) -> error_TODO "BetaFunc" -- scalar2 go (Equal theOrd) = \(e1 :* e2 :* End) -> error_TODO "Equal" go (Less theOrd) = \(e1 :* e2 :* End) -> error_TODO "Less" go (NatPow theSemi) = \(e1 :* e2 :* End) -> error_TODO "NatPow" go (Negate theRing) = \(e1 :* End) -> -- TODO: figure out how to merge this implementation of @rr1 negate@ with the one in 'evaluatePrimOp' to DRY -- TODO: just emitLet the @v0@ and pass the neutral term to the recursive call? let negate_v0 = syn (PrimOp_ (Negate theRing) :$ v0 :* End) -- case v0 of -- Neutral e -> -- Neutral $ syn (PrimOp_ (Negate theRing) :$ e :* End) -- Head_ v -> -- case theRing of -- HRing_Int -> Head_ . reflect . negate $ reify v -- HRing_Real -> Head_ . reflect . negate $ reify v in constrainValue negate_v0 e1 go (Abs theRing) = \(e1 :* End) -> do let theSemi = hSemiring_HRing theRing theOrd = case theRing of HRing_Int -> HOrd_Int HRing_Real -> HOrd_Real theEq = hEq_HOrd theOrd signed = C.singletonCoercion (C.Signed theRing) zero = P.zero_ theSemi lt = P.primOp2_ $ Less theOrd eq = P.primOp2_ $ Equal theEq neg = P.primOp1_ $ Negate theRing x0 <- emitLet' (P.coerceTo_ signed v0) v <- var <$> emitMBind (P.if_ (lt zero x0) (P.dirac x0 P.<|> P.dirac (neg x0)) (P.if_ (eq zero x0) (P.dirac zero) (P.reject . SMeasure $ typeOf zero))) constrainValue v e1 go (Signum theRing) = \(e1 :* End) -> case theRing of HRing_Real -> bot HRing_Int -> do x <- var <$> emitMBind P.counting emitGuard $ P.signum x P.== v0 constrainValue x e1 go (Recip theFractional) = \(e1 :* End) -> do x0 <- emitLet' v0 emitWeight . P.recip . P.unsafeProbFraction_ theFractional -- TODO: define a dictionary-passing variant of 'P.square' instead, to include the coercion in there explicitly... $ square (hSemiring_HFractional theFractional) x0 constrainValue (P.primOp1_ (Recip theFractional) x0) e1 go (NatRoot theRadical) = \(e1 :* e2 :* End) -> case theRadical of HRadical_Prob -> do x0 <- emitLet' v0 u2 <- fromWhnf <$> atomize e2 emitWeight (P.nat2prob u2 P.* x0) constrainValue (x0 P.^ u2) e1 go (Erf theContinuous) = \(e1 :* End) -> error "TODO: constrainPrimOp: need InvErf to disintegrate Erf" -- HACK: can't use @(P.^)@ because Haskell can't figure out our polymorphism square :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] a square theSemiring e = syn (PrimOp_ (NatPow theSemiring) :$ e :* P.nat_ 2 :* End) ---------------------------------------------------------------- ---------------------------------------------------------------- -- TODO: do we really want the first argument to be a term at all, -- or do we want something more general like patters for capturing -- measurable events? -- -- | This is a helper function for 'constrainValue' to handle 'SBind' -- statements (just as the 'perform' argument to 'evaluate' is a -- helper for handling 'SBind' statements). -- -- N.B., We assume that the first argument, @v0@, is already -- atomized. So, this must be ensured before recursing, but we can -- assume it's already been done by the IH. Technically, we con't -- care whether the first argument is in normal form or not, just -- so long as it doesn't contain any heap-bound variables. -- -- This is the function called @(<<|)@ in the paper, though notably -- we swap the argument order. -- -- TODO: find some way to capture in the type that the first argument -- must be emissible, to help avoid accidentally passing the arguments -- in the wrong order! -- -- TODO: under what circumstances is @constrainOutcome x m@ different -- from @constrainValue x =<< perform m@? If they're always the -- same, then we should just use that as the definition in order -- to avoid repeating ourselves constrainOutcome :: forall abt a . (ABT Term abt) => abt '[] a -> abt '[] ('HMeasure a) -> Dis abt () constrainOutcome v0 e0 = #ifdef __TRACE_DISINTEGRATE__ getLocs >>= \locs -> getIndices >>= \inds -> trace ( let s = "-- constrainOutcome" in "\n" ++ s ++ ": " ++ show (pretty v0) ++ "\n" ++ replicate (length s) ' ' ++ ": " ++ show (pretty e0) ++ "\n" ++ "at " ++ show (ppInds inds) ++ "\n" ++ show (prettyLocs locs) ) $ #endif do w0 <- evaluate_ e0 case w0 of Neutral _ -> bot Head_ v -> go v where impossible = error "constrainOutcome: the impossible happened" go :: Head abt ('HMeasure a) -> Dis abt () go (WLiteral _) = impossible -- go (WDatum _) = impossible -- go (WEmpty _) = impossible -- go (WArray _ _) = impossible -- go (WLam _) = impossible -- go (WIntegrate _ _ _) = impossible -- go (WSummate _ _ _) = impossible go (WCoerceTo _ _) = impossible go (WUnsafeFrom _ _) = impossible go (WMeasureOp o es) = constrainOutcomeMeasureOp v0 o es go (WDirac e1) = constrainValue v0 e1 go (WMBind e1 e2) = caseBind e2 $ \x e2' -> do i <- getIndices push (SBind x (Thunk e1) i) e2' (constrainOutcome v0) go (WPlate e1 e2) = do x' <- pushPlate e1 e2 constrainValue v0 (var x') go (WChain e1 e2 e3) = error "TODO: constrainOutcome{Chain}" go (WReject typ) = emit_ $ \m -> P.reject (typeOf m) go (WSuperpose pes) = do i <- getIndices if not (null i) && L.length pes > 1 then bot else emitFork_ (P.superpose . fmap ((,) P.one)) (fmap (\(p,e) -> push (SWeight (Thunk p) i) e (constrainOutcome v0)) pes) -- TODO: should this really be different from 'constrainValueMeasureOp'? -- -- TODO: find some way to capture in the type that the first argument -- must be emissible. constrainOutcomeMeasureOp :: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) => abt '[] a -> MeasureOp typs a -> SArgs abt args -> Dis abt () constrainOutcomeMeasureOp v0 = go where -- Per the paper go Lebesgue = \End -> return () -- TODO: I think, based on Hakaru v0.2.0 go Counting = \End -> return () go Categorical = \(e1 :* End) -> do -- TODO: check that v0' is < then length of e1 pushWeight (P.densityCategorical e1 v0) -- Per the paper go Uniform = \(lo :* hi :* End) -> do v0' <- emitLet' v0 pushGuard (lo P.<= v0' P.&& v0' P.<= hi) pushWeight (P.densityUniform lo hi v0') -- TODO: Add fallback handling of Normal that does not atomize mu,sd. -- This fallback is as if Normal were defined in terms of Lebesgue -- and a density Weight. This fallback is present in Hakaru v0.2.0 -- in order to disintegrate a program such as -- x <~ normal(0,1) -- y <~ normal(x,1) -- return ((x+(y+y),x)::pair(real,real)) go Normal = \(mu :* sd :* End) -> do -- N.B., if\/when extending this to higher dimensions, the real equation is @recip (sqrt (2*pi*sd^2) ^ n) * exp (negate (norm_n (v0 - mu) ^ 2) / (2*sd^2))@ for @Real^n@. pushWeight (P.densityNormal mu sd v0) go Poisson = \(e1 :* End) -> do v0' <- emitLet' v0 pushGuard (P.nat_ 0 P.<= v0' P.&& P.prob_ 0 P.< e1) pushWeight (P.densityPoisson e1 v0') go Gamma = \(e1 :* e2 :* End) -> do v0' <- emitLet' v0 pushGuard (P.prob_ 0 P.< v0' P.&& P.prob_ 0 P.< e1 P.&& P.prob_ 0 P.< e2) pushWeight (P.densityGamma e1 e2 v0') go Beta = \(e1 :* e2 :* End) -> do v0' <- emitLet' v0 pushGuard (P.prob_ 0 P.<= v0' P.&& P.prob_ 1 P.>= v0' P.&& P.prob_ 0 P.< e1 P.&& P.prob_ 0 P.< e2) pushWeight (P.densityBeta e1 e2 v0') ---------------------------------------------------------------- ----------------------------------------------------------- fin.