module Data.Array.Accelerate.Trafo.Substitution (
inline, substitute, compose,
(:>),
weakenA, weakenEA, weakenFA,
weakenE, weakenFE,
RebuildAcc,
rebuildA, rebuildAfun, rebuildOpenAcc,
rebuildE, rebuildEA,
rebuildFA, rebuildFE,
) where
import Prelude hiding ( exp )
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Array.Sugar ( Elt, Arrays )
import qualified Data.Array.Accelerate.Debug as Stats
infixr `compose`
infixr `substitute`
inline :: Elt t
=> PreOpenExp acc (env, s) aenv t
-> PreOpenExp acc env aenv s
-> PreOpenExp acc env aenv t
inline f g = Stats.substitution "inline" $ rebuildE (subTop g) f
where
subTop :: Elt t => PreOpenExp acc env aenv s -> Idx (env, s) t -> PreOpenExp acc env aenv t
subTop s ZeroIdx = s
subTop _ (SuccIdx ix) = Var ix
substitute :: (Elt b, Elt c)
=> PreOpenExp acc (env, b) aenv c
-> PreOpenExp acc (env, a) aenv b
-> PreOpenExp acc (env, a) aenv c
substitute f g
| Stats.substitution "substitute" False = undefined
| Var ZeroIdx <- g = f
| otherwise = Let g $ rebuildE split f
where
split :: Elt c => Idx (env,b) c -> PreOpenExp acc ((env,a),b) aenv c
split ZeroIdx = Var ZeroIdx
split (SuccIdx ix) = Var (SuccIdx (SuccIdx ix))
compose :: Elt c
=> PreOpenFun acc env aenv (b -> c)
-> PreOpenFun acc env aenv (a -> b)
-> PreOpenFun acc env aenv (a -> c)
compose (Lam (Body f)) (Lam (Body g)) = Stats.substitution "compose" . Lam . Body $ substitute f g
compose _ _ = error "compose: impossible evaluation"
type env :> env' = forall t'. Idx env t' -> Idx env' t'
weakenA :: RebuildAcc acc -> aenv :> aenv' -> PreOpenAcc acc aenv a -> PreOpenAcc acc aenv' a
weakenA k v = Stats.substitution "weakenA" . rebuildA k (Avar . v)
weakenEA :: RebuildAcc acc -> aenv :> aenv' -> PreOpenExp acc env aenv t -> PreOpenExp acc env aenv' t
weakenEA k v = Stats.substitution "weakenEA" . rebuildEA k (Avar . v)
weakenFA :: RebuildAcc acc -> aenv :> aenv' -> PreOpenFun acc env aenv f -> PreOpenFun acc env aenv' f
weakenFA k v = Stats.substitution "weakenFA" . rebuildFA k (Avar . v)
weakenE :: env :> env' -> PreOpenExp acc env aenv t -> PreOpenExp acc env' aenv t
weakenE v = Stats.substitution "weakenE" . rebuildE (Var . v)
weakenFE :: env :> env' -> PreOpenFun acc env aenv f -> PreOpenFun acc env' aenv f
weakenFE v = Stats.substitution "weakenFE" . rebuildFE (Var . v)
class SyntacticExp f where
varIn :: Elt t => Idx env t -> f acc env aenv t
expOut :: Elt t => f acc env aenv t -> PreOpenExp acc env aenv t
weakenExp :: Elt t => f acc env aenv t -> f acc (env, s) aenv t
newtype IdxE (acc :: * -> * -> *) env aenv t = IE { unIE :: Idx env t }
instance SyntacticExp IdxE where
varIn = IE
expOut = Var . unIE
weakenExp = IE . SuccIdx . unIE
instance SyntacticExp PreOpenExp where
varIn = Var
expOut = id
weakenExp = rebuildE (weakenExp . IE)
shiftE
:: (SyntacticExp f, Elt t)
=> (forall t'. Elt t' => Idx env t' -> f acc env' aenv t')
-> Idx (env, s) t
-> f acc (env', s) aenv t
shiftE _ ZeroIdx = varIn ZeroIdx
shiftE v (SuccIdx ix) = weakenExp (v ix)
rebuildE
:: SyntacticExp f
=> (forall t'. Elt t' => Idx env t' -> f acc env' aenv t')
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env' aenv t
rebuildE v exp =
case exp of
Let a b -> Let (rebuildE v a) (rebuildE (shiftE v) b)
Var ix -> expOut (v ix)
Const c -> Const c
Tuple tup -> Tuple (rebuildTE v tup)
Prj tup e -> Prj tup (rebuildE v e)
IndexNil -> IndexNil
IndexCons sh sz -> IndexCons (rebuildE v sh) (rebuildE v sz)
IndexHead sh -> IndexHead (rebuildE v sh)
IndexTail sh -> IndexTail (rebuildE v sh)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (rebuildE v ix) (rebuildE v sh)
IndexFull x ix sl -> IndexFull x (rebuildE v ix) (rebuildE v sl)
ToIndex sh ix -> ToIndex (rebuildE v sh) (rebuildE v ix)
FromIndex sh ix -> FromIndex (rebuildE v sh) (rebuildE v ix)
Cond p t e -> Cond (rebuildE v p) (rebuildE v t) (rebuildE v e)
Iterate n f x -> Iterate (rebuildE v n) (rebuildE (shiftE v) f) (rebuildE v x)
PrimConst c -> PrimConst c
PrimApp f x -> PrimApp f (rebuildE v x)
Index a sh -> Index a (rebuildE v sh)
LinearIndex a i -> LinearIndex a (rebuildE v i)
Shape a -> Shape a
ShapeSize sh -> ShapeSize (rebuildE v sh)
Intersect s t -> Intersect (rebuildE v s) (rebuildE v t)
Foreign ff f e -> Foreign ff f (rebuildE v e)
rebuildTE
:: SyntacticExp f
=> (forall t'. Elt t' => Idx env t' -> f acc env' aenv t')
-> Tuple (PreOpenExp acc env aenv) t
-> Tuple (PreOpenExp acc env' aenv) t
rebuildTE v tup =
case tup of
NilTup -> NilTup
SnocTup t e -> rebuildTE v t `SnocTup` rebuildE v e
rebuildFE
:: SyntacticExp f
=> (forall t'. Elt t' => Idx env t' -> f acc env' aenv t')
-> PreOpenFun acc env aenv t
-> PreOpenFun acc env' aenv t
rebuildFE v fun =
case fun of
Body e -> Body (rebuildE v e)
Lam f -> Lam (rebuildFE (shiftE v) f)
type RebuildAcc acc =
forall aenv aenv' f a. SyntacticAcc f
=> (forall a'. Arrays a' => Idx aenv a' -> f acc aenv' a')
-> acc aenv a
-> acc aenv' a
class SyntacticAcc f where
avarIn :: Arrays t => Idx aenv t -> f acc aenv t
accOut :: Arrays t => f acc aenv t -> PreOpenAcc acc aenv t
weakenAcc :: Arrays t => RebuildAcc acc -> f acc aenv t -> f acc (aenv, s) t
newtype IdxA (acc :: * -> * -> *) aenv t = IA { unIA :: Idx aenv t }
instance SyntacticAcc IdxA where
avarIn = IA
accOut = Avar . unIA
weakenAcc _ = IA . SuccIdx . unIA
instance SyntacticAcc PreOpenAcc where
avarIn = Avar
accOut = id
weakenAcc k = rebuildA k (weakenAcc k . IA)
rebuildOpenAcc
:: SyntacticAcc f
=> (forall t'. Arrays t' => Idx aenv t' -> f OpenAcc aenv' t')
-> OpenAcc aenv t
-> OpenAcc aenv' t
rebuildOpenAcc v (OpenAcc acc) = OpenAcc (rebuildA rebuildOpenAcc v acc)
shiftA
:: (SyntacticAcc f, Arrays t)
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t')
-> Idx (aenv, s) t
-> f acc (aenv', s) t
shiftA _ _ ZeroIdx = avarIn ZeroIdx
shiftA k v (SuccIdx ix) = weakenAcc k (v ix)
rebuildA
:: SyntacticAcc f
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t')
-> PreOpenAcc acc aenv t
-> PreOpenAcc acc aenv' t
rebuildA rebuild v acc =
case acc of
Alet a b -> Alet (rebuild v a) (rebuild (shiftA rebuild v) b)
Avar ix -> accOut (v ix)
Atuple tup -> Atuple (rebuildATA rebuild v tup)
Aprj tup a -> Aprj tup (rebuild v a)
Apply f a -> Apply (rebuildAfun rebuild v f) (rebuild v a)
Aforeign ff afun as -> Aforeign ff afun (rebuild v as)
Acond p t e -> Acond (rebuildEA rebuild v p) (rebuild v t) (rebuild v e)
Use a -> Use a
Unit e -> Unit (rebuildEA rebuild v e)
Reshape e a -> Reshape (rebuildEA rebuild v e) (rebuild v a)
Generate e f -> Generate (rebuildEA rebuild v e) (rebuildFA rebuild v f)
Transform sh ix f a -> Transform (rebuildEA rebuild v sh) (rebuildFA rebuild v ix) (rebuildFA rebuild v f) (rebuild v a)
Replicate sl slix a -> Replicate sl (rebuildEA rebuild v slix) (rebuild v a)
Slice sl a slix -> Slice sl (rebuild v a) (rebuildEA rebuild v slix)
Map f a -> Map (rebuildFA rebuild v f) (rebuild v a)
ZipWith f a1 a2 -> ZipWith (rebuildFA rebuild v f) (rebuild v a1) (rebuild v a2)
Fold f z a -> Fold (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a)
Fold1 f a -> Fold1 (rebuildFA rebuild v f) (rebuild v a)
FoldSeg f z a s -> FoldSeg (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a) (rebuild v s)
Fold1Seg f a s -> Fold1Seg (rebuildFA rebuild v f) (rebuild v a) (rebuild v s)
Scanl f z a -> Scanl (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a)
Scanl' f z a -> Scanl' (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a)
Scanl1 f a -> Scanl1 (rebuildFA rebuild v f) (rebuild v a)
Scanr f z a -> Scanr (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a)
Scanr' f z a -> Scanr' (rebuildFA rebuild v f) (rebuildEA rebuild v z) (rebuild v a)
Scanr1 f a -> Scanr1 (rebuildFA rebuild v f) (rebuild v a)
Permute f1 a1 f2 a2 -> Permute (rebuildFA rebuild v f1) (rebuild v a1) (rebuildFA rebuild v f2) (rebuild v a2)
Backpermute sh f a -> Backpermute (rebuildEA rebuild v sh) (rebuildFA rebuild v f) (rebuild v a)
Stencil f b a -> Stencil (rebuildFA rebuild v f) b (rebuild v a)
Stencil2 f b1 a1 b2 a2
-> Stencil2 (rebuildFA rebuild v f) b1 (rebuild v a1) b2 (rebuild v a2)
rebuildAfun
:: SyntacticAcc f
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t')
-> PreOpenAfun acc aenv t
-> PreOpenAfun acc aenv' t
rebuildAfun k v afun =
case afun of
Abody b -> Abody (k v b)
Alam f -> Alam (rebuildAfun k (shiftA k v) f)
rebuildATA
:: SyntacticAcc f
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t')
-> Atuple (acc aenv) t
-> Atuple (acc aenv') t
rebuildATA k v atup =
case atup of
NilAtup -> NilAtup
SnocAtup t a -> rebuildATA k v t `SnocAtup` k v a
rebuildEA
:: SyntacticAcc f
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t')
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env aenv' t
rebuildEA k v exp =
case exp of
Let a b -> Let (rebuildEA k v a) (rebuildEA k v b)
Var ix -> Var ix
Const c -> Const c
Tuple tup -> Tuple (rebuildTA k v tup)
Prj tup e -> Prj tup (rebuildEA k v e)
IndexNil -> IndexNil
IndexCons sh sz -> IndexCons (rebuildEA k v sh) (rebuildEA k v sz)
IndexHead sh -> IndexHead (rebuildEA k v sh)
IndexTail sh -> IndexTail (rebuildEA k v sh)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (rebuildEA k v ix) (rebuildEA k v sh)
IndexFull x ix sl -> IndexFull x (rebuildEA k v ix) (rebuildEA k v sl)
ToIndex sh ix -> ToIndex (rebuildEA k v sh) (rebuildEA k v ix)
FromIndex sh ix -> FromIndex (rebuildEA k v sh) (rebuildEA k v ix)
Cond p t e -> Cond (rebuildEA k v p) (rebuildEA k v t) (rebuildEA k v e)
Iterate n f x -> Iterate (rebuildEA k v n) (rebuildEA k v f) (rebuildEA k v x)
PrimConst c -> PrimConst c
PrimApp f x -> PrimApp f (rebuildEA k v x)
Index a sh -> Index (k v a) (rebuildEA k v sh)
LinearIndex a i -> LinearIndex (k v a) (rebuildEA k v i)
Shape a -> Shape (k v a)
ShapeSize sh -> ShapeSize (rebuildEA k v sh)
Intersect s t -> Intersect (rebuildEA k v s) (rebuildEA k v t)
Foreign ff f e -> Foreign ff f (rebuildEA k v e)
rebuildTA
:: SyntacticAcc f
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t')
-> Tuple (PreOpenExp acc env aenv) t
-> Tuple (PreOpenExp acc env aenv') t
rebuildTA k v tup =
case tup of
NilTup -> NilTup
SnocTup t e -> rebuildTA k v t `SnocTup` rebuildEA k v e
rebuildFA
:: SyntacticAcc f
=> RebuildAcc acc
-> (forall t'. Arrays t' => Idx aenv t' -> f acc aenv' t')
-> PreOpenFun acc env aenv t
-> PreOpenFun acc env aenv' t
rebuildFA k v fun =
case fun of
Body e -> Body (rebuildEA k v e)
Lam f -> Lam (rebuildFA k v f)