{-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Trafo.Substitution -- Copyright : [2012] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Manuel M T Chakravarty -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Trafo.Substitution ( -- ** Renaming & Substitution inline, substitute, compose, -- ** Weakening (:>), weakenA, weakenEA, weakenFA, weakenE, weakenFE, -- ** Rebuilding terms RebuildAcc, rebuildA, rebuildAfun, rebuildOpenAcc, rebuildE, rebuildEA, rebuildFA, rebuildFE, ) where import Prelude hiding ( exp ) import Data.Array.Accelerate.AST import Data.Array.Accelerate.Tuple import Data.Array.Accelerate.Array.Sugar ( Elt, Arrays ) import qualified Data.Array.Accelerate.Debug as Stats -- NOTE: [Renaming and Substitution] -- -- To do things like renaming and substitution, we need some operation on -- variables that we push structurally through terms, applying to each variable. -- We have a type preserving but environment changing operation: -- -- v :: forall t. Idx env t -> f env' aenv t -- -- The crafty bit is that 'f' might represent variables (for renaming) or terms -- (for substitutions). The demonic forall, --- which is to say that the -- quantifier is in a position which gives us obligation, not opportunity --- -- forces us to respect type: when pattern matching detects the variable we care -- about, happily we discover that it has the type we must respect. The demon is -- not so free to mess with us as one might fear at first. -- -- We then lift this to an operation which traverses terms and rebuild them -- after applying 'v' to the variables: -- -- rebuild v :: OpenExp env aenv t -> OpenExp env' aenv t -- -- The Syntactic class tells us what we need to know about 'f' if we want to be -- able to rebuild terms. In essence, the crucial functionality is to propagate -- a class of operations on variables that is closed under shifting. -- infixr `compose` infixr `substitute` -- | Replace the first variable with the given expression. The environment -- shrinks. -- inline :: Elt t => PreOpenExp acc (env, s) aenv t -> PreOpenExp acc env aenv s -> PreOpenExp acc env aenv t inline f g = Stats.substitution "inline" $ rebuildE (subTop g) f where subTop :: Elt t => PreOpenExp acc env aenv s -> Idx (env, s) t -> PreOpenExp acc env aenv t subTop s ZeroIdx = s subTop _ (SuccIdx ix) = Var ix -- | Replace an expression that uses the top environment variable with another. -- The result of the first is let bound into the second. -- substitute :: (Elt b, Elt c) => PreOpenExp acc (env, b) aenv c -> PreOpenExp acc (env, a) aenv b -> PreOpenExp acc (env, a) aenv c substitute f g | Stats.substitution "substitute" False = undefined | Var ZeroIdx <- g = f -- don't rebind an identity function | otherwise = Let g $ rebuildE split f where split :: Elt c => Idx (env,b) c -> PreOpenExp acc ((env,a),b) aenv c split ZeroIdx = Var ZeroIdx split (SuccIdx ix) = Var (SuccIdx (SuccIdx ix)) -- | Composition of unary functions. -- compose :: Elt c => PreOpenFun acc env aenv (b -> c) -> PreOpenFun acc env aenv (a -> b) -> PreOpenFun acc env aenv (a -> c) compose (Lam (Body f)) (Lam (Body g)) = Stats.substitution "compose" . Lam . Body $ substitute f g compose _ _ = error "compose: impossible evaluation" -- NOTE: [Weakening] -- -- Weakening is something we usually take for granted: every time you learn a -- new word, old sentences still make sense. If a conclusion is justified by a -- hypothesis, it is still justified if you add more hypotheses. Similarly, a -- term remains in scope if you bind more (fresh) variables. Weakening is the -- operation of shifting things from one scope to a larger scope in which new -- things have become meaningful, but no old things have vanished. -- -- When we use a named representation (or HOAS) we get weakening for free. But -- in the de Bruijn representation weakening takes work: you have to shift all -- variable references to make room for the new bindings. -- -- The type of shifting terms from one context into another -- type env :> env' = forall t'. Idx env t' -> Idx env' t' weakenA :: RebuildAcc acc -> aenv :> aenv' -> PreOpenAcc acc aenv a -> PreOpenAcc acc aenv' a weakenA k v = Stats.substitution "weakenA" . rebuildA k (Avar . v) weakenEA :: RebuildAcc acc -> aenv :> aenv' -> PreOpenExp acc env aenv t -> PreOpenExp acc env aenv' t weakenEA k v = Stats.substitution "weakenEA" . rebuildEA k (Avar . v) weakenFA :: RebuildAcc acc -> aenv :> aenv' -> PreOpenFun acc env aenv f -> PreOpenFun acc env aenv' f weakenFA k v = Stats.substitution "weakenFA" . rebuildFA k (Avar . v) weakenE :: env :> env' -> PreOpenExp acc env aenv t -> PreOpenExp acc env' aenv t weakenE v = Stats.substitution "weakenE" . rebuildE (Var . v) weakenFE :: env :> env' -> PreOpenFun acc env aenv f -> PreOpenFun acc env' aenv f weakenFE v = Stats.substitution "weakenFE" . rebuildFE (Var . v) {-# RULES "weakenA/weakenA" forall a (k :: RebuildAcc acc) (v1 :: env' :> env'') (v2 :: env :> env'). weakenA k v1 (weakenA k v2 a) = weakenA k (v1 . v2) a "weakenEA/weakenEA" forall a (k :: RebuildAcc acc) (v1 :: env' :> env'') (v2 :: env :> env'). weakenEA k v1 (weakenEA k v2 a) = weakenEA k (v1 . v2) a "weakenFA/weakenFA" forall a (k :: RebuildAcc acc) (v1 :: env' :> env'') (v2 :: env :> env'). weakenFA k v1 (weakenFA k v2 a) = weakenFA k (v1 . v2) a "weakenE/weakenE" forall e (v1 :: env' :> env'') (v2 :: env :> env'). weakenE v1 (weakenE v2 e) = weakenE (v1 . v2) e "weakenFE/weakenFE" forall e (v1 :: env' :> env'') (v2 :: env :> env'). weakenFE v1 (weakenFE v2 e) = weakenFE (v1 . v2) e #-} -- Simultaneous Substitution =================================================== -- -- Scalar expressions -- ------------------ -- SEE: [Renaming and Substitution] -- SEE: [Weakening] -- class SyntacticExp f where varIn :: Elt t => Idx env t -> f acc env aenv t expOut :: Elt t => f acc env aenv t -> PreOpenExp acc env aenv t weakenExp :: Elt t => f acc env aenv t -> f acc (env, s) aenv t newtype IdxE (acc :: * -> * -> *) env aenv t = IE { unIE :: Idx env t } instance SyntacticExp IdxE where varIn = IE expOut = Var . unIE weakenExp = IE . SuccIdx . unIE instance SyntacticExp PreOpenExp where varIn = Var expOut = id weakenExp = rebuildE (weakenExp . IE) shiftE :: (SyntacticExp f, Elt t) => (forall t'. Elt t' => Idx env t' -> f acc env' aenv t') -> Idx (env, s) t -> f acc (env', s) aenv t shiftE _ ZeroIdx = varIn ZeroIdx shiftE v (SuccIdx ix) = weakenExp (v ix) rebuildE :: SyntacticExp f => (forall t'. Elt t' => Idx env t' -> f acc env' aenv t') -> PreOpenExp acc env aenv t -> PreOpenExp acc env' aenv t rebuildE v exp = case exp of Let a b -> Let (rebuildE v a) (rebuildE (shiftE v) b) Var ix -> expOut (v ix) Const c -> Const c Tuple tup -> Tuple (rebuildTE v tup) Prj tup e -> Prj tup (rebuildE v e) IndexNil -> IndexNil IndexCons sh sz -> IndexCons (rebuildE v sh) (rebuildE v sz) IndexHead sh -> IndexHead (rebuildE v sh) IndexTail sh -> IndexTail (rebuildE v sh) IndexAny -> IndexAny IndexSlice x ix sh -> IndexSlice x (rebuildE v ix) (rebuildE v sh) IndexFull x ix sl -> IndexFull x (rebuildE v ix) (rebuildE v sl) ToIndex sh ix -> ToIndex (rebuildE v sh) (rebuildE v ix) FromIndex sh ix -> FromIndex (rebuildE v sh) (rebuildE v ix) Cond p t e -> Cond (rebuildE v p) (rebuildE v t) (rebuildE v e) Iterate n f x -> Iterate (rebuildE v n) (rebuildE (shiftE v) f) (rebuildE v x) PrimConst c -> PrimConst c PrimApp f x -> PrimApp f (rebuildE v x) Index a sh -> Index a (rebuildE v sh) LinearIndex a i -> LinearIndex a (rebuildE v i) Shape a -> Shape a ShapeSize sh -> ShapeSize (rebuildE v sh) Intersect s t -> Intersect (rebuildE v s) (rebuildE v t) Foreign ff f e -> Foreign ff f (rebuildE v e) rebuildTE :: SyntacticExp f => (forall t'. Elt t' => Idx env t' -> f acc env' aenv t') -> Tuple (PreOpenExp acc env aenv) t -> Tuple (PreOpenExp acc env' aenv) t rebuildTE v tup = case tup of NilTup -> NilTup SnocTup t e -> rebuildTE v t `SnocTup` rebuildE v e rebuildFE :: SyntacticExp f => (forall t'. Elt t' => Idx env t' -> f acc env' aenv t') -> PreOpenFun acc env aenv t -> PreOpenFun acc env' aenv t rebuildFE v fun = case fun of Body e -> Body (rebuildE v e) Lam f -> Lam (rebuildFE (shiftE v) f) -- Array expressions -- ----------------- type RebuildAcc acc = forall aenv aenv' f a. SyntacticAcc f => (forall a'. Arrays a' => Idx aenv a' -> f acc aenv' a') -> acc aenv a -> acc aenv' a class SyntacticAcc f where avarIn :: Arrays t => Idx aenv t -> f acc aenv t accOut :: Arrays t => f acc aenv t -> PreOpenAcc acc aenv t weakenAcc :: Arrays t => RebuildAcc acc -> f acc aenv t -> f acc (aenv, s) t newtype IdxA (acc :: * -> * -> *) aenv t = IA { unIA :: Idx aenv t } instance SyntacticAcc IdxA where avarIn = IA accOut = Avar . unIA weakenAcc _ = IA . SuccIdx . unIA instance SyntacticAcc PreOpenAcc where avarIn = Avar accOut = id weakenAcc k = rebuildA k (weakenAcc k . IA) rebuildOpenAcc :: SyntacticAcc f => (forall t'. Arrays t' => Idx aenv t' -> f OpenAcc aenv' t') -> OpenAcc aenv t -> OpenAcc aenv' t rebuildOpenAcc v (OpenAcc acc) = OpenAcc (rebuildA rebuildOpenAcc v acc) shiftA :: (SyntacticAcc f, Arrays t) => RebuildAcc acc -> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t') -> Idx (aenv, s) t -> f acc (aenv', s) t shiftA _ _ ZeroIdx = avarIn ZeroIdx shiftA k v (SuccIdx ix) = weakenAcc k (v ix) rebuildA :: SyntacticAcc f => RebuildAcc acc -> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t') -> PreOpenAcc acc aenv t -> PreOpenAcc acc aenv' t rebuildA rebuild v acc = case acc of Alet a b -> Alet (rebuild v a) (rebuild (shiftA rebuild v) b) Avar ix -> accOut (v ix) Atuple tup -> Atuple (rebuildATA rebuild v tup) Aprj tup a -> Aprj tup (rebuild v a) Apply f a -> Apply (rebuildAfun rebuild v f) (rebuild v a) Aforeign ff afun as -> Aforeign ff afun (rebuild v as) Acond p t e -> Acond (rebuildEA rebuild v p) (rebuild v t) (rebuild v e) Use a -> Use a Unit e -> Unit (rebuildEA rebuild v e) Reshape e a -> Reshape (rebuildEA rebuild v e) (rebuild v a) Generate e f -> Generate (rebuildEA rebuild v e) (rebuildFA rebuild v f) Transform sh ix f a -> Transform (rebuildEA rebuild v sh) (rebuildFA rebuild v ix) (rebuildFA rebuild v f) (rebuild v a) Replicate sl slix a -> Replicate sl (rebuildEA rebuild v slix) (rebuild v a) Slice sl a slix -> Slice sl (rebuild v a) (rebuildEA rebuild v slix) Map f a -> Map (rebuildFA rebuild v f) (rebuild v a) ZipWith f a1 a2 -> ZipWith (rebuildFA rebuild v f) (rebuild v a1) (rebuild v a2) Fold f z a -> Fold (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a) Fold1 f a -> Fold1 (rebuildFA rebuild v f) (rebuild v a) FoldSeg f z a s -> FoldSeg (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a) (rebuild v s) Fold1Seg f a s -> Fold1Seg (rebuildFA rebuild v f) (rebuild v a) (rebuild v s) Scanl f z a -> Scanl (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a) Scanl' f z a -> Scanl' (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a) Scanl1 f a -> Scanl1 (rebuildFA rebuild v f) (rebuild v a) Scanr f z a -> Scanr (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a) Scanr' f z a -> Scanr' (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a) Scanr1 f a -> Scanr1 (rebuildFA rebuild v f) (rebuild v a) Permute f1 a1 f2 a2 -> Permute (rebuildFA rebuild v f1) (rebuild v a1) (rebuildFA rebuild v f2) (rebuild v a2) Backpermute sh f a -> Backpermute (rebuildEA rebuild v sh) (rebuildFA rebuild v f) (rebuild v a) Stencil f b a -> Stencil (rebuildFA rebuild v f) b (rebuild v a) Stencil2 f b1 a1 b2 a2 -> Stencil2 (rebuildFA rebuild v f) b1 (rebuild v a1) b2 (rebuild v a2) -- Rebuilding array computations -- rebuildAfun :: SyntacticAcc f => RebuildAcc acc -> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t') -> PreOpenAfun acc aenv t -> PreOpenAfun acc aenv' t rebuildAfun k v afun = case afun of Abody b -> Abody (k v b) Alam f -> Alam (rebuildAfun k (shiftA k v) f) rebuildATA :: SyntacticAcc f => RebuildAcc acc -> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t') -> Atuple (acc aenv) t -> Atuple (acc aenv') t rebuildATA k v atup = case atup of NilAtup -> NilAtup SnocAtup t a -> rebuildATA k v t `SnocAtup` k v a -- Rebuilding scalar expressions -- rebuildEA :: SyntacticAcc f => RebuildAcc acc -> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t') -> PreOpenExp acc env aenv t -> PreOpenExp acc env aenv' t rebuildEA k v exp = case exp of Let a b -> Let (rebuildEA k v a) (rebuildEA k v b) Var ix -> Var ix Const c -> Const c Tuple tup -> Tuple (rebuildTA k v tup) Prj tup e -> Prj tup (rebuildEA k v e) IndexNil -> IndexNil IndexCons sh sz -> IndexCons (rebuildEA k v sh) (rebuildEA k v sz) IndexHead sh -> IndexHead (rebuildEA k v sh) IndexTail sh -> IndexTail (rebuildEA k v sh) IndexAny -> IndexAny IndexSlice x ix sh -> IndexSlice x (rebuildEA k v ix) (rebuildEA k v sh) IndexFull x ix sl -> IndexFull x (rebuildEA k v ix) (rebuildEA k v sl) ToIndex sh ix -> ToIndex (rebuildEA k v sh) (rebuildEA k v ix) FromIndex sh ix -> FromIndex (rebuildEA k v sh) (rebuildEA k v ix) Cond p t e -> Cond (rebuildEA k v p) (rebuildEA k v t) (rebuildEA k v e) Iterate n f x -> Iterate (rebuildEA k v n) (rebuildEA k v f) (rebuildEA k v x) PrimConst c -> PrimConst c PrimApp f x -> PrimApp f (rebuildEA k v x) Index a sh -> Index (k v a) (rebuildEA k v sh) LinearIndex a i -> LinearIndex (k v a) (rebuildEA k v i) Shape a -> Shape (k v a) ShapeSize sh -> ShapeSize (rebuildEA k v sh) Intersect s t -> Intersect (rebuildEA k v s) (rebuildEA k v t) Foreign ff f e -> Foreign ff f (rebuildEA k v e) rebuildTA :: SyntacticAcc f => RebuildAcc acc -> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t') -> Tuple (PreOpenExp acc env aenv) t -> Tuple (PreOpenExp acc env aenv') t rebuildTA k v tup = case tup of NilTup -> NilTup SnocTup t e -> rebuildTA k v t `SnocTup` rebuildEA k v e rebuildFA :: SyntacticAcc f => RebuildAcc acc -> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t') -> PreOpenFun acc env aenv t -> PreOpenFun acc env aenv' t rebuildFA k v fun = case fun of Body e -> Body (rebuildEA k v e) Lam f -> Lam (rebuildFA k v f)