{-# LANGUAGE CPP , DataKinds , FlexibleContexts , GADTs , GeneralizedNewtypeDeriving , MultiParamTypeClasses , RankNTypes , ScopedTypeVariables , TypeOperators #-} {-# OPTIONS_GHC -Wall -fwarn-tabs #-} ---------------------------------------------------------------- -- 2017.02.01 -- | -- Module : Language.Hakaru.Syntax.Unroll -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : -- Stability : experimental -- Portability : GHC-only -- -- Performs renaming of Hakaru expressions to ensure globally unique variable -- identifiers. -- ---------------------------------------------------------------- module Language.Hakaru.Syntax.Unroll (renameInEnv, unroll) where import Control.Monad.Reader import Data.Maybe (fromMaybe) import Language.Hakaru.Syntax.ABT import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.AST.Eq (Varmap) import Language.Hakaru.Syntax.Prelude hiding ((>>=)) import Language.Hakaru.Types.HClasses import Prelude hiding (product, (*), (+), (-), (==), (>=)) #if __GLASGOW_HASKELL__ < 710 import Control.Applicative #endif newtype Unroll a = Unroll { runUnroll :: Reader Varmap a } deriving (Functor, Applicative, Monad, MonadReader Varmap, MonadFix) rebind :: (ABT Term abt, MonadFix m) => Variable a -> (Variable a -> m (abt xs b)) -> m (abt (a ': xs) b) rebind source f = binderM (varHint source) (varType source) $ \var' -> let v = caseVarSyn var' id (const $ error "oops") in f v renameInEnv :: (ABT Term abt, MonadReader Varmap m, MonadFix m) => Variable a -> m (abt xs b) -> m (abt (a ': xs) b) renameInEnv source action = rebind source $ \v -> local (insertAssoc $ Assoc source v) action unroll :: forall abt xs a . (ABT Term abt) => abt xs a -> abt xs a unroll abt = runReader (runUnroll $ unroll' abt) emptyAssocs unroll' :: forall abt xs a . (ABT Term abt) => abt xs a -> Unroll (abt xs a) unroll' = cataABTM var_ renameInEnv (>>= unrollTerm) where var_ :: Variable b -> Unroll (abt '[] b) var_ v = fmap (var . fromMaybe v . lookupAssoc v) ask mklet :: ABT Term abt => abt '[] b -> abt '[b] a -> abt '[] a mklet rhs body = syn (Let_ :$ rhs :* body :* End) mksummate, mkproduct :: (ABT Term abt) => HDiscrete a -> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> abt '[] b mksummate a b lo hi body = syn (Summate a b :$ lo :* hi :* body :* End) mkproduct a b lo hi body = syn (Product a b :$ lo :* hi :* body :* End) unrollTerm :: (ABT Term abt) => Term abt a -> Unroll (abt '[] a) unrollTerm (Summate disc semi :$ lo :* hi :* body :* End) = case (disc, semi) of (HDiscrete_Nat, HSemiring_Nat) -> unrollSummate disc semi lo hi body (HDiscrete_Nat, HSemiring_Int) -> unrollSummate disc semi lo hi body (HDiscrete_Nat, HSemiring_Prob) -> unrollSummate disc semi lo hi body (HDiscrete_Nat, HSemiring_Real) -> unrollSummate disc semi lo hi body (HDiscrete_Int, HSemiring_Nat) -> unrollSummate disc semi lo hi body (HDiscrete_Int, HSemiring_Int) -> unrollSummate disc semi lo hi body (HDiscrete_Int, HSemiring_Prob) -> unrollSummate disc semi lo hi body (HDiscrete_Int, HSemiring_Real) -> unrollSummate disc semi lo hi body unrollTerm (Product disc semi :$ lo :* hi :* body :* End) = case (disc, semi) of (HDiscrete_Nat, HSemiring_Nat) -> unrollProduct disc semi lo hi body (HDiscrete_Nat, HSemiring_Int) -> unrollProduct disc semi lo hi body (HDiscrete_Nat, HSemiring_Prob) -> unrollProduct disc semi lo hi body (HDiscrete_Nat, HSemiring_Real) -> unrollProduct disc semi lo hi body (HDiscrete_Int, HSemiring_Nat) -> unrollProduct disc semi lo hi body (HDiscrete_Int, HSemiring_Int) -> unrollProduct disc semi lo hi body (HDiscrete_Int, HSemiring_Prob) -> unrollProduct disc semi lo hi body (HDiscrete_Int, HSemiring_Real) -> unrollProduct disc semi lo hi body unrollTerm term = return (syn term) -- Conditionally introduce a variable for the rhs if the rhs is not currently a -- variable already. Be careful that the provided variable has been remaped to -- its equivalent in the target term if altering the binding structure of the -- program. letM' :: (Functor m, MonadFix m, ABT Term abt) => abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b) letM' e f = case viewABT e of Var _ -> f e Syn (Literal_ _) -> f e _ -> letM e f unrollSummate :: (ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) => HDiscrete a -> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> Unroll (abt '[] b) unrollSummate disc semi lo hi body = letM' lo $ \loVar -> letM' hi $ \hiVar -> let preamble = mklet loVar body loop = mksummate disc semi (loVar + one) hiVar body in return $ if_ (loVar == hiVar) zero (preamble + loop) unrollProduct :: (ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) => HDiscrete a -> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> Unroll (abt '[] b) unrollProduct disc semi lo hi body = letM' lo $ \loVar -> letM' hi $ \hiVar -> let preamble = mklet loVar body loop = mkproduct disc semi (loVar + one) hiVar body in return $ if_ (loVar == hiVar) one (preamble * loop)