module Data.Array.Accelerate.Trafo.Fusion (
DelayedAcc, DelayedOpenAcc(..),
DelayedAfun, DelayedOpenAfun,
DelayedExp, DelayedFun, DelayedOpenExp, DelayedOpenFun,
convertAcc, convertAfun,
) where
import Prelude hiding ( exp, until )
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Trafo.Base
import Data.Array.Accelerate.Trafo.Shrink
import Data.Array.Accelerate.Trafo.Simplify
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import Data.Array.Accelerate.Array.Sugar ( Array, Arrays(..), ArraysR(..), ArrRepr', Elt, EltRepr, Shape )
import Data.Array.Accelerate.Tuple
import qualified Data.Array.Accelerate.Debug as Stats
#ifdef ACCELERATE_DEBUG
import System.IO.Unsafe
#endif
#include "accelerate.h"
convertAcc :: Arrays arrs => Acc arrs -> DelayedAcc arrs
convertAcc = withSimplStats . quenchAcc . annealAcc
convertAfun :: Afun f -> DelayedAfun f
convertAfun = withSimplStats . quenchAfun . annealAfun
withSimplStats :: a -> a
#ifdef ACCELERATE_DEBUG
withSimplStats x = unsafePerformIO Stats.resetSimplCount `seq` x
#else
withSimplStats x = x
#endif
quenchAcc :: Arrays arrs => OpenAcc aenv arrs -> DelayedOpenAcc aenv arrs
quenchAcc = cvtA
where
embed :: (Shape sh, Elt e) => OpenAcc aenv (Array sh e) -> DelayedOpenAcc aenv (Array sh e)
embed (OpenAcc pacc) =
case pacc of
Avar v
-> Delayed (arrayShape v) (indexArray v) (linearIndex v)
Generate (cvtE -> sh) (cvtF -> f)
-> Delayed sh f (f `compose` fromIndex sh)
Map (cvtF -> f) (embed -> Delayed{..})
-> Delayed extentD (f `compose` indexD) (f `compose` linearIndexD)
Backpermute (cvtE -> sh) (cvtF -> p) (embed -> Delayed{..})
-> let p' = indexD `compose` p
in Delayed sh p'(p' `compose` fromIndex sh)
Transform (cvtE -> sh) (cvtF -> p) (cvtF -> f) (embed -> Delayed{..})
-> let f' = f `compose` indexD `compose` p
in Delayed sh f' (f' `compose` fromIndex sh)
_ -> INTERNAL_ERROR(error) "quench" "tried to consume a non-embeddable term"
fusionError = INTERNAL_ERROR(error) "quench" "unexpected fusible materials"
cvtA :: OpenAcc aenv a -> DelayedOpenAcc aenv a
cvtA (OpenAcc pacc) = Manifest $
case pacc of
Avar ix -> Avar ix
Use arr -> Use arr
Unit e -> Unit (cvtE e)
Alet bnd body -> Alet (cvtA bnd) (cvtA body)
Acond p t e -> Acond (cvtE p) (cvtA t) (cvtA e)
Atuple tup -> Atuple (cvtAT tup)
Aprj ix tup -> Aprj ix (cvtA tup)
Apply f a -> Apply (cvtAF f) (cvtA a)
Aforeign ff f a -> Aforeign ff (cvtAF f) (cvtA a)
Map f a -> Map (cvtF f) (embed a)
Generate sh f -> Generate (cvtE sh) (cvtF f)
Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (embed a)
Backpermute sh p a -> backpermute (cvtE sh) (cvtF p) (embed a) a
Reshape{} -> fusionError
Replicate{} -> fusionError
Slice{} -> fusionError
ZipWith{} -> fusionError
Fold f z a -> Fold (cvtF f) (cvtE z) (embed a)
Fold1 f a -> Fold1 (cvtF f) (embed a)
FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (embed a) (embed s)
Fold1Seg f a s -> Fold1Seg (cvtF f) (embed a) (embed s)
Scanl f z a -> Scanl (cvtF f) (cvtE z) (embed a)
Scanl1 f a -> Scanl1 (cvtF f) (embed a)
Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (embed a)
Scanr f z a -> Scanr (cvtF f) (cvtE z) (embed a)
Scanr1 f a -> Scanr1 (cvtF f) (embed a)
Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (embed a)
Permute f d p a -> Permute (cvtF f) (cvtA d) (cvtF p) (embed a)
Stencil f x a -> Stencil (cvtF f) x (cvtA a)
Stencil2 f x a y b -> Stencil2 (cvtF f) x (cvtA a) y (cvtA b)
backpermute sh p a x
| OpenAcc (Avar v) <- x
, Just REFL <- match p (simplify $ reindex (arrayShape v) sh)
= Reshape sh (Manifest (Avar v))
| otherwise
= Backpermute sh p a
cvtAT :: Atuple (OpenAcc aenv) a -> Atuple (DelayedOpenAcc aenv) a
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup t a) = cvtAT t `SnocAtup` cvtA a
cvtAF :: OpenAfun aenv f -> PreOpenAfun DelayedOpenAcc aenv f
cvtAF (Alam f) = Alam (cvtAF f)
cvtAF (Abody b) = Abody (cvtA b)
cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f
cvtF (Lam f) = Lam (cvtF f)
cvtF (Body b) = Body (cvtE b)
cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t
cvtE exp =
case exp of
Let bnd body -> Let (cvtE bnd) (cvtE body)
Var ix -> Var ix
Const c -> Const c
Tuple tup -> Tuple (cvtT tup)
Prj ix t -> Prj ix (cvtE t)
IndexNil -> IndexNil
IndexCons sh sz -> IndexCons (cvtE sh) (cvtE sz)
IndexHead sh -> IndexHead (cvtE sh)
IndexTail sh -> IndexTail (cvtE sh)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh)
IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl)
ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix)
FromIndex sh ix -> FromIndex (cvtE sh) (cvtE ix)
Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e)
Iterate n f x -> Iterate (cvtE n) (cvtE f) (cvtE x)
PrimConst c -> PrimConst c
PrimApp f x -> PrimApp f (cvtE x)
Index a sh -> Index (cvtA a) (cvtE sh)
LinearIndex a i -> LinearIndex (cvtA a) (cvtE i)
Shape a -> Shape (cvtA a)
ShapeSize sh -> ShapeSize (cvtE sh)
Intersect s t -> Intersect (cvtE s) (cvtE t)
Foreign ff f e -> Foreign ff (cvtF f) (cvtE e)
cvtT :: Tuple (OpenExp env aenv) t -> Tuple (DelayedOpenExp env aenv) t
cvtT NilTup = NilTup
cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e
quenchAfun :: OpenAfun aenv f -> DelayedOpenAfun aenv f
quenchAfun (Alam f) = Alam (quenchAfun f)
quenchAfun (Abody b) = Abody (quenchAcc b)
annealAcc :: Arrays arrs => OpenAcc aenv arrs -> OpenAcc aenv arrs
annealAcc = computeAcc . delayAcc
where
delayAcc :: Arrays a => OpenAcc aenv a -> Delayed OpenAcc aenv a
delayAcc (OpenAcc pacc) = delayPreAcc delayAcc elimAcc pacc
countAcc :: UsesOfAcc OpenAcc
countAcc ok idx (OpenAcc pacc) = usesOfPreAcc ok countAcc idx pacc
elimAcc :: Idx aenv s -> OpenAcc aenv t -> Bool
elimAcc v acc = countAcc False v acc <= lIMIT
where
lIMIT = 1
annealAfun :: OpenAfun aenv f -> OpenAfun aenv f
annealAfun (Alam f) = Alam (annealAfun f)
annealAfun (Abody b) = Abody (annealAcc b)
type DelayAcc acc = forall aenv arrs. Arrays arrs => acc aenv arrs -> Delayed acc aenv arrs
type ElimAcc acc = forall aenv s t. Idx aenv s -> acc aenv t -> Bool
delayPreAcc
:: forall acc aenv arrs. (Kit acc, Arrays arrs)
=> DelayAcc acc
-> ElimAcc acc
-> PreOpenAcc acc aenv arrs
-> Delayed acc aenv arrs
delayPreAcc delayAcc elimAcc pacc =
case pacc of
Alet bnd body -> aletD delayAcc elimAcc bnd body
Acond p at ae -> acondD delayAcc (cvtE p) at ae
Aprj ix tup -> aprjD delayAcc ix tup
Atuple tup -> done $ Atuple (cvtAT tup)
Apply f a -> done $ Apply (cvtAF f) (cvtA a)
Aforeign ff f a -> done $ Aforeign ff (cvtAF f) (cvtA a)
Avar v -> done $ Avar v
Use arrs -> done $ Use arrs
Unit e -> done $ Unit (cvtE e)
Generate sh f -> generateD (cvtE sh) (cvtF f)
Map f a -> fuse (into mapD (cvtF f)) a
ZipWith f a b -> fuse2 (into zipWithD (cvtF f)) a b
Transform sh p f a -> fuse (into3 transformD (cvtE sh) (cvtF p) (cvtF f)) a
Backpermute sl p a -> fuse (into2 backpermuteD (cvtE sl) (cvtF p)) a
Slice slix a sl -> fuse (into (sliceD slix) (cvtE sl)) a
Replicate slix sh a -> fuse (into (replicateD slix) (cvtE sh)) a
Reshape sl a -> fuse (into reshapeD (cvtE sl)) a
Fold f z a -> embed (into2 Fold (cvtF f) (cvtE z)) a
Fold1 f a -> embed (into Fold1 (cvtF f)) a
FoldSeg f z a s -> embed2 (into2 FoldSeg (cvtF f) (cvtE z)) a s
Fold1Seg f a s -> embed2 (into Fold1Seg (cvtF f)) a s
Scanl f z a -> embed (into2 Scanl (cvtF f) (cvtE z)) a
Scanl1 f a -> embed (into Scanl1 (cvtF f)) a
Scanl' f z a -> embed (into2 Scanl' (cvtF f) (cvtE z)) a
Scanr f z a -> embed (into2 Scanr (cvtF f) (cvtE z)) a
Scanr1 f a -> embed (into Scanr1 (cvtF f)) a
Scanr' f z a -> embed (into2 Scanr' (cvtF f) (cvtE z)) a
Permute f d p a -> embed2 (into2 permute (cvtF f) (cvtF p)) d a
Stencil f x a -> embed (into (stencil x) (cvtF f)) a
Stencil2 f x a y b -> embed2 (into (stencil2 x y) (cvtF f)) a b
where
cvtA :: Arrays a => acc aenv' a -> acc aenv' a
cvtA = computeAcc . delayAcc
cvtAT :: Atuple (acc aenv') a -> Atuple (acc aenv') a
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup tup a) = cvtAT tup `SnocAtup` cvtA a
cvtAF :: PreOpenAfun acc aenv' f -> PreOpenAfun acc aenv' f
cvtAF (Alam f) = Alam (cvtAF f)
cvtAF (Abody a) = Abody (cvtA a)
permute f p d a = Permute f d p a
stencil x f a = Stencil f x a
stencil2 x y f a b = Stencil2 f x a y b
cvtF :: PreFun acc aenv t -> PreFun acc aenv t
cvtF = cvtF' . simplify
cvtE :: PreExp acc aenv t -> PreExp acc aenv t
cvtE = cvtE' . simplify
cvtF' :: PreOpenFun acc env aenv' t -> PreOpenFun acc env aenv' t
cvtF' (Lam f) = Lam (cvtF' f)
cvtF' (Body b) = Body (cvtE' b)
cvtE' :: PreOpenExp acc env aenv' t -> PreOpenExp acc env aenv' t
cvtE' exp =
case exp of
Let bnd body -> Let (cvtE' bnd) (cvtE' body)
Var ix -> Var ix
Const c -> Const c
Tuple tup -> Tuple (cvtT tup)
Prj tup ix -> Prj tup (cvtE' ix)
IndexNil -> IndexNil
IndexCons sh sz -> IndexCons (cvtE' sh) (cvtE' sz)
IndexHead sh -> IndexHead (cvtE' sh)
IndexTail sh -> IndexTail (cvtE' sh)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE' ix) (cvtE' sh)
IndexFull x ix sl -> IndexFull x (cvtE' ix) (cvtE' sl)
ToIndex sh ix -> ToIndex (cvtE' sh) (cvtE' ix)
FromIndex sh ix -> FromIndex (cvtE' sh) (cvtE' ix)
Cond p t e -> Cond (cvtE' p) (cvtE' t) (cvtE' e)
Iterate n f x -> Iterate (cvtE' n) (cvtE' f) (cvtE' x)
PrimConst c -> PrimConst c
PrimApp f x -> PrimApp f (cvtE' x)
Index a sh -> Index a (cvtE' sh)
LinearIndex a i -> LinearIndex a (cvtE' i)
Shape a -> Shape a
ShapeSize sh -> ShapeSize (cvtE' sh)
Intersect s t -> Intersect (cvtE' s) (cvtE' t)
Foreign ff f e -> Foreign ff (cvtF' f) (cvtE' e)
cvtT :: Tuple (PreOpenExp acc env aenv') t -> Tuple (PreOpenExp acc env aenv') t
cvtT NilTup = NilTup
cvtT (SnocTup tup e) = cvtT tup `SnocTup` cvtE' e
into :: Sink f => (f env' a -> b) -> f env a -> Extend acc env env' -> b
into op a env = op (sink env a)
into2 :: (Sink f1, Sink f2)
=> (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend acc env env' -> c
into2 op a b env = op (sink env a) (sink env b)
into3 :: (Sink f1, Sink f2, Sink f3)
=> (f1 env' a -> f2 env' b -> f3 env' c -> d) -> f1 env a -> f2 env b -> f3 env c -> Extend acc env env' -> d
into3 op a b c env = op (sink env a) (sink env b) (sink env c)
fuse :: Arrays as
=> (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs)
-> acc aenv as
-> Delayed acc aenv bs
fuse op (delayAcc -> Term env cc) = Term env (op env cc)
fuse2 :: (Arrays as, Arrays bs)
=> (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs -> Cunctation acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Delayed acc aenv cs
fuse2 op a1 a0
| Term env1 cc1 <- delayAcc a1
, Term env0 cc0 <- delayAcc (sink env1 a0)
, env <- env1 `join` env0
= Term env (op env (sink env0 cc1) cc0)
embed :: (Arrays as, Arrays bs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs)
-> acc aenv as
-> Delayed acc aenv bs
embed op (delayAcc -> Term env cc) = case cc of
Done v -> Term (env `PushEnv` op env (avarIn v)) (Done ZeroIdx)
Step sh p f v -> Term (env `PushEnv` op env (computeAcc (Term BaseEnv (Step sh p f v)))) (Done ZeroIdx)
Yield sh f -> Term (env `PushEnv` op env (computeAcc (Term BaseEnv (Yield sh f)))) (Done ZeroIdx)
embed2 :: forall aenv as bs cs. (Arrays as, Arrays bs, Arrays cs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Delayed acc aenv cs
embed2 op (delayAcc -> Term env1 cc1) a0 = case cc1 of
Done v -> inner env1 v a0
Step sh p f v -> inner (env1 `PushEnv` compute (Term BaseEnv (Step sh p f v))) ZeroIdx a0
Yield sh f -> inner (env1 `PushEnv` compute (Term BaseEnv (Yield sh f))) ZeroIdx a0
where
inner :: Extend acc aenv aenv' -> Idx aenv' as -> acc aenv bs -> Delayed acc aenv cs
inner env1 v1 (delayAcc . sink env1 -> Term env0 cc0) = case cc0 of
Done v0 -> let env = env1 `join` env0 in Term (env `PushEnv` op env (avarIn (sink env0 v1)) (avarIn v0)) (Done ZeroIdx)
Step sh p f v -> let env = env1 `join` env0 in Term (env `PushEnv` op env (avarIn (sink env0 v1)) (computeAcc (Term BaseEnv (Step sh p f v)))) (Done ZeroIdx)
Yield sh f -> let env = env1 `join` env0 in Term (env `PushEnv` op env (avarIn (sink env0 v1)) (computeAcc (Term BaseEnv (Yield sh f)))) (Done ZeroIdx)
data Delayed acc aenv a where
Term :: Extend acc aenv aenv'
-> Cunctation acc aenv' a
-> Delayed acc aenv a
data Cunctation acc aenv a where
Done :: Arrays a
=> Idx aenv a
-> Cunctation acc aenv a
Yield :: (Shape sh, Elt e)
=> PreExp acc aenv sh
-> PreFun acc aenv (sh -> e)
-> Cunctation acc aenv (Array sh e)
Step :: (Shape sh, Shape sh', Elt a, Elt b)
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> PreFun acc aenv (a -> b)
-> Idx aenv (Array sh a)
-> Cunctation acc aenv (Array sh' b)
done :: Arrays a => PreOpenAcc acc aenv a -> Delayed acc aenv a
done pacc
| Avar v <- pacc = Term BaseEnv (Done v)
| otherwise = Term (BaseEnv `PushEnv` pacc) (Done ZeroIdx)
yield :: Kit acc
=> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sh e)
yield cc =
case cc of
Yield{} -> cc
Step sh p f v -> Yield sh (f `compose` indexArray v `compose` p)
Done v
| ArraysRarray <- accType' cc -> Yield (arrayShape v) (indexArray v)
| otherwise -> error "yield: impossible case"
step :: Kit acc
=> Cunctation acc aenv (Array sh e)
-> Maybe (Cunctation acc aenv (Array sh e))
step cc =
case cc of
Yield{} -> Nothing
Step{} -> Just cc
Done v
| ArraysRarray <- accType' cc -> Just $ Step (arrayShape v) identity identity v
| otherwise -> error "step: impossible case"
shape :: Kit acc => Cunctation acc aenv (Array sh e) -> PreExp acc aenv sh
shape cc
| Just (Step sh _ _ _) <- step cc = sh
| Yield sh _ <- yield cc = sh
accType' :: forall acc aenv a. Arrays a => Cunctation acc aenv a -> ArraysR (ArrRepr' a)
accType' _ = arrays' (undefined :: a)
data Extend acc aenv aenv' where
BaseEnv :: Extend acc aenv aenv
PushEnv :: Arrays a
=> Extend acc aenv aenv' -> PreOpenAcc acc aenv' a -> Extend acc aenv (aenv', a)
join :: Extend acc env env' -> Extend acc env' env'' -> Extend acc env env''
join x BaseEnv = x
join x (PushEnv as a) = x `join` as `PushEnv` a
bind :: (Kit acc, Arrays a)
=> Extend acc aenv aenv'
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc aenv a
bind BaseEnv = id
bind (PushEnv env a) = bind env . Alet (inject a) . inject
class Sink f where
sink :: Extend acc env env' -> f env t -> f env' t
instance Sink Idx where
sink BaseEnv = Stats.substitution "sink" id
sink (PushEnv e _) = SuccIdx . sink e
instance Kit acc => Sink (PreOpenExp acc env) where
sink env = weakenEA rebuildAcc (sink env)
instance Kit acc => Sink (PreOpenFun acc env) where
sink env = weakenFA rebuildAcc (sink env)
instance Kit acc => Sink (PreOpenAcc acc) where
sink env = weakenA rebuildAcc (sink env)
instance Kit acc => Sink acc where
sink env = rebuildAcc (Avar . sink env)
instance Kit acc => Sink (Cunctation acc) where
sink env cc = case cc of
Done v -> Done (sink env v)
Step sh p f v -> Step (sink env sh) (sink env p) (sink env f) (sink env v)
Yield sh f -> Yield (sink env sh) (sink env f)
class Sink1 f where
sink1 :: Extend acc env env' -> f (env,s) t -> f (env',s) t
instance Sink1 Idx where
sink1 BaseEnv = Stats.substitution "sink1" id
sink1 (PushEnv e _) = split . sink1 e
where
split :: Idx (env,s) t -> Idx ((env,u),s) t
split ZeroIdx = ZeroIdx
split (SuccIdx ix) = SuccIdx (SuccIdx ix)
instance Kit acc => Sink1 (PreOpenExp acc env) where
sink1 env = weakenEA rebuildAcc (sink1 env)
instance Kit acc => Sink1 (PreOpenFun acc env) where
sink1 env = weakenFA rebuildAcc (sink1 env)
instance Kit acc => Sink1 (PreOpenAcc acc) where
sink1 env = weakenA rebuildAcc (sink1 env)
instance Kit acc => Sink1 acc where
sink1 env = rebuildAcc (Avar . sink1 env)
compute :: (Kit acc, Arrays arrs) => Delayed acc aenv arrs -> PreOpenAcc acc aenv arrs
compute (Term env cc)
= bind env
$ case cc of
Done v -> Avar v
Yield (simplify -> sh) (simplify -> f) -> Generate sh f
Step (simplify -> sh) (simplify -> p) (simplify -> f) v
| Just REFL <- identShape
, Just REFL <- isIdentity p
, Just REFL <- isIdentity f -> Avar v
| Just REFL <- identShape
, Just REFL <- isIdentity p -> Map f acc
| Just REFL <- isIdentity f -> Backpermute sh p acc
| otherwise -> Transform sh p f acc
where
identShape = match sh (arrayShape v)
acc = avarIn v
computeAcc :: (Kit acc, Arrays arrs) => Delayed acc aenv arrs -> acc aenv arrs
computeAcc = inject . compute
generateD :: (Shape sh, Elt e)
=> PreExp acc aenv sh
-> PreFun acc aenv (sh -> e)
-> Delayed acc aenv (Array sh e)
generateD sh f
= Stats.ruleFired "generateD"
$ Term BaseEnv (Yield sh f)
mapD :: (Kit acc, Elt b)
=> PreFun acc aenv (a -> b)
-> Cunctation acc aenv (Array sh a)
-> Cunctation acc aenv (Array sh b)
mapD f = Stats.ruleFired "mapD" . go
where
go (step -> Just (Step sh ix g v)) = Step sh ix (f `compose` g) v
go (yield -> Yield sh g) = Yield sh (f `compose` g)
backpermuteD
:: (Kit acc, Shape sh')
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sh' e)
backpermuteD sh' p = Stats.ruleFired "backpermuteD" . go
where
go (step -> Just (Step _ q f v)) = Step sh' (q `compose` p) f v
go (yield -> Yield _ g) = Yield sh' (g `compose` p)
transformD
:: (Kit acc, Shape sh', Elt b)
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> PreFun acc aenv (a -> b)
-> Cunctation acc aenv (Array sh a)
-> Cunctation acc aenv (Array sh' b)
transformD sh' p f
= Stats.ruleFired "transformD"
. backpermuteD sh' p
. mapD f
replicateD
:: (Kit acc, Shape sh, Shape sl, Elt slix, Elt e)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> Cunctation acc aenv (Array sl e)
-> Cunctation acc aenv (Array sh e)
replicateD sliceIndex slix cc
= Stats.ruleFired "replicateD"
$ backpermuteD (IndexFull sliceIndex slix (shape cc)) (extend sliceIndex slix) cc
sliceD
:: (Kit acc, Shape sh, Shape sl, Elt slix, Elt e)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sl e)
sliceD sliceIndex slix cc
= Stats.ruleFired "sliceD"
$ backpermuteD (IndexSlice sliceIndex slix (shape cc)) (restrict sliceIndex slix) cc
reshapeD
:: (Kit acc, Shape sh, Shape sl)
=> PreExp acc aenv sl
-> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sl e)
reshapeD sl cc
= Stats.ruleFired "reshapeD"
$ backpermuteD sl (reindex (shape cc) sl) cc
zipWithD :: (Kit acc, Shape sh, Elt a, Elt b, Elt c)
=> PreFun acc aenv (a -> b -> c)
-> Cunctation acc aenv (Array sh a)
-> Cunctation acc aenv (Array sh b)
-> Cunctation acc aenv (Array sh c)
zipWithD f cc1 cc0
| Just (Step sh1 p1 f1 v1) <- step cc1
, Just (Step sh0 p0 f0 v0) <- step cc0
, Just REFL <- match v1 v0
, Just REFL <- match p1 p0
= Stats.ruleFired "zipWithD/step"
$ Step (sh1 `Intersect` sh0) p0 (combine f f1 f0) v0
| Yield sh1 f1 <- yield cc1
, Yield sh0 f0 <- yield cc0
= Stats.ruleFired "zipWithD"
$ Yield (sh1 `Intersect` sh0) (combine f f1 f0)
where
combine :: forall acc aenv a b c e. (Elt a, Elt b, Elt c)
=> PreFun acc aenv (a -> b -> c)
-> PreFun acc aenv (e -> a)
-> PreFun acc aenv (e -> b)
-> PreFun acc aenv (e -> c)
combine c ixa ixb
| Lam (Lam (Body c')) <- weakenFE SuccIdx c :: PreOpenFun acc ((),e) aenv (a -> b -> c)
, Lam (Body ixa') <- ixa
, Lam (Body ixb') <- ixb
= Lam $ Body $ Let ixa' $ Let (weakenE SuccIdx ixb') c'
aletD :: forall acc aenv arrs brrs. (Kit acc, Arrays arrs, Arrays brrs)
=> DelayAcc acc
-> ElimAcc acc
-> acc aenv arrs
-> acc (aenv,arrs) brrs
-> Delayed acc aenv brrs
aletD delayAcc elimAcc (delayAcc -> Term env1 cc1) acc0
| Done v1 <- cc1
, Term env0 cc0 <- delayAcc $ rebuildAcc (subTop (Avar v1) . sink1 env1) acc0
= Stats.ruleFired "aletD/float"
$ Term (env1 `join` env0) cc0
| otherwise
, Term env0 cc0 <- delayAcc $ sink1 env1 acc0
= case cc1 of
Step{} -> aletD' env1 cc1 env0 cc0
Yield{} -> aletD' env1 cc1 env0 cc0
where
subTop :: forall aenv s t. Arrays t => PreOpenAcc acc aenv s -> Idx (aenv,s) t -> PreOpenAcc acc aenv t
subTop t ZeroIdx = t
subTop _ (SuccIdx idx) = Avar idx
aletD' :: forall aenv aenv' aenv'' sh e brrs. (Kit acc, Shape sh, Elt e, Arrays brrs)
=> Extend acc aenv aenv'
-> Cunctation acc aenv' (Array sh e)
-> Extend acc (aenv', Array sh e) aenv''
-> Cunctation acc aenv'' brrs
-> Delayed acc aenv brrs
aletD' env1 cc1 env0 cc0
| not shouldInline = Term (env1 `PushEnv` bnd `join` env0) cc0
| Stats.ruleFired "aletD/eliminate" False
= undefined
| Done v1 <- cc1 = eliminate (arrayShape v1) (indexArray v1)
| Step sh1 p1 f1 v1 <- cc1 = eliminate sh1 (f1 `compose` indexArray v1 `compose` p1)
| Yield sh1 f1 <- cc1 = eliminate sh1 f1
where
shouldInline = elimAcc ZeroIdx body
body = computeAcc (Term env0 cc0)
bnd = compute (Term BaseEnv cc1)
eliminate :: PreExp acc aenv' sh
-> PreFun acc aenv' (sh -> e)
-> Delayed acc aenv brrs
eliminate sh1 f1
| sh1' <- weakenEA rebuildAcc SuccIdx sh1
, f1' <- weakenFA rebuildAcc SuccIdx f1
, Term env0' cc0' <- delayAcc $ rebuildAcc (subTop bnd) $ kmap (replaceA sh1' f1' ZeroIdx) body
= Term (env1 `join` env0') cc0'
replaceE :: forall env aenv sh e t. (Kit acc, Shape sh, Elt e)
=> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env aenv t
replaceE sh' f' avar exp =
case exp of
Let x y -> Let (cvtE x) (replaceE (weakenE SuccIdx sh') (weakenFE SuccIdx f') avar y)
Var i -> Var i
Foreign ff f e -> Foreign ff f (cvtE e)
Const c -> Const c
Tuple t -> Tuple (cvtT t)
Prj ix e -> Prj ix (cvtE e)
IndexNil -> IndexNil
IndexCons sl sz -> IndexCons (cvtE sl) (cvtE sz)
IndexHead sh -> IndexHead (cvtE sh)
IndexTail sz -> IndexTail (cvtE sz)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh)
IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl)
ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix)
FromIndex sh i -> FromIndex (cvtE sh) (cvtE i)
Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e)
Iterate n f x -> Iterate (cvtE n) (replaceE (weakenE SuccIdx sh') (weakenFE SuccIdx f') avar f) (cvtE x)
PrimConst c -> PrimConst c
PrimApp g x -> PrimApp g (cvtE x)
ShapeSize sh -> ShapeSize (cvtE sh)
Intersect sh sl -> Intersect (cvtE sh) (cvtE sl)
Shape a
| Just REFL <- match a a' -> Stats.substitution "replaceE/shape" sh'
| otherwise -> exp
Index a sh
| Just REFL <- match a a'
, Lam (Body b) <- f' -> Stats.substitution "replaceE/!" $ Let sh b
| otherwise -> Index a (cvtE sh)
LinearIndex a i
| Just REFL <- match a a'
, Lam (Body b) <- f' -> Stats.substitution "replaceE/!!" $ Let (Let i (FromIndex (weakenE SuccIdx sh') (Var ZeroIdx))) b
| otherwise -> LinearIndex a (cvtE i)
where
a' = avarIn avar
cvtE :: PreOpenExp acc env aenv s -> PreOpenExp acc env aenv s
cvtE = replaceE sh' f' avar
cvtT :: Tuple (PreOpenExp acc env aenv) s -> Tuple (PreOpenExp acc env aenv) s
cvtT NilTup = NilTup
cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e
replaceF :: forall env aenv sh e t. (Kit acc, Shape sh, Elt e)
=> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenFun acc env aenv t
-> PreOpenFun acc env aenv t
replaceF sh' f' avar fun =
case fun of
Body e -> Body (replaceE sh' f' avar e)
Lam f -> Lam (replaceF (weakenE SuccIdx sh') (weakenFE SuccIdx f') avar f)
replaceA :: forall aenv sh e a. (Kit acc, Shape sh, Elt e)
=> PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenAcc acc aenv a
-> PreOpenAcc acc aenv a
replaceA sh' f' avar pacc =
case pacc of
Avar v
| Just REFL <- match v avar -> Avar avar
| otherwise -> Avar v
Alet bnd body ->
let sh'' = weakenEA rebuildAcc SuccIdx sh'
f'' = weakenFA rebuildAcc SuccIdx f'
in
Alet (cvtA bnd) (kmap (replaceA sh'' f'' (SuccIdx avar)) body)
Use arrs -> Use arrs
Unit e -> Unit (cvtE e)
Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae)
Aprj ix tup -> Aprj ix (cvtA tup)
Atuple tup -> Atuple (cvtAT tup)
Apply f a -> Apply f (cvtA a)
Aforeign ff f a -> Aforeign ff f (cvtA a)
Generate sh f -> Generate (cvtE sh) (cvtF f)
Map f a -> Map (cvtF f) (cvtA a)
ZipWith f a b -> ZipWith (cvtF f) (cvtA a) (cvtA b)
Backpermute sh p a -> Backpermute (cvtE sh) (cvtF p) (cvtA a)
Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (cvtA a)
Slice slix a sl -> Slice slix (cvtA a) (cvtE sl)
Replicate slix sh a -> Replicate slix (cvtE sh) (cvtA a)
Reshape sl a -> Reshape (cvtE sl) (cvtA a)
Fold f z a -> Fold (cvtF f) (cvtE z) (cvtA a)
Fold1 f a -> Fold1 (cvtF f) (cvtA a)
FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (cvtA a) (cvtA s)
Fold1Seg f a s -> Fold1Seg (cvtF f) (cvtA a) (cvtA s)
Scanl f z a -> Scanl (cvtF f) (cvtE z) (cvtA a)
Scanl1 f a -> Scanl1 (cvtF f) (cvtA a)
Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (cvtA a)
Scanr f z a -> Scanr (cvtF f) (cvtE z) (cvtA a)
Scanr1 f a -> Scanr1 (cvtF f) (cvtA a)
Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (cvtA a)
Permute f d p a -> Permute (cvtF f) (cvtA d) (cvtF p) (cvtA a)
Stencil f x a -> Stencil (cvtF f) x (cvtA a)
Stencil2 f x a y b -> Stencil2 (cvtF f) x (cvtA a) y (cvtA b)
where
cvtA :: acc aenv s -> acc aenv s
cvtA = kmap (replaceA sh' f' avar)
cvtE :: PreExp acc aenv s -> PreExp acc aenv s
cvtE = replaceE sh' f' avar
cvtF :: PreFun acc aenv s -> PreFun acc aenv s
cvtF = replaceF sh' f' avar
cvtAT :: Atuple (acc aenv) s -> Atuple (acc aenv) s
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup tup a) = cvtAT tup `SnocAtup` cvtA a
acondD :: (Kit acc, Arrays arrs)
=> DelayAcc acc
-> PreExp acc aenv Bool
-> acc aenv arrs
-> acc aenv arrs
-> Delayed acc aenv arrs
acondD delayAcc p t e
| Const ((),True) <- p = Stats.knownBranch "True" $ delayAcc t
| Const ((),False) <- p = Stats.knownBranch "False" $ delayAcc e
| Just REFL <- match t e = Stats.knownBranch "redundant" $ delayAcc e
| otherwise = done $ Acond p (computeAcc (delayAcc t))
(computeAcc (delayAcc e))
aprjD :: forall acc aenv arrs a. (Kit acc, IsTuple arrs, Arrays arrs, Arrays a)
=> DelayAcc acc
-> TupleIdx (TupleRepr arrs) a
-> acc aenv arrs
-> Delayed acc aenv a
aprjD delayAcc ix a
| Atuple tup <- extract a = Stats.ruleFired "aprj/Atuple" . delayAcc $ aprjAT ix tup
| otherwise = done $ Aprj ix (cvtA a)
where
cvtA :: acc aenv arrs -> acc aenv arrs
cvtA = computeAcc . delayAcc
aprjAT :: TupleIdx atup a -> Atuple (acc aenv) atup -> acc aenv a
aprjAT ZeroTupIdx (SnocAtup _ a) = a
aprjAT (SuccTupIdx ix) (SnocAtup t _) = aprjAT ix t
isIdentity :: PreFun acc aenv (a -> b) -> Maybe (a :=: b)
isIdentity f
| Lam (Body (Var ZeroIdx)) <- f = Just REFL
| otherwise = Nothing
identity :: Elt a => PreOpenFun acc env aenv (a -> a)
identity = Lam (Body (Var ZeroIdx))
toIndex :: Shape sh => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> Int)
toIndex sh = Lam (Body (ToIndex (weakenE SuccIdx sh) (Var ZeroIdx)))
fromIndex :: Shape sh => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (Int -> sh)
fromIndex sh = Lam (Body (FromIndex (weakenE SuccIdx sh) (Var ZeroIdx)))
reindex :: (Kit acc, Shape sh, Shape sh')
=> PreOpenExp acc env aenv sh'
-> PreOpenExp acc env aenv sh
-> PreOpenFun acc env aenv (sh -> sh')
reindex sh' sh
| Just REFL <- match sh sh' = identity
| otherwise = fromIndex sh' `compose` toIndex sh
extend :: (Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> PreFun acc aenv (sh -> sl)
extend sliceIndex slix = Lam (Body (IndexSlice sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx)))
restrict :: (Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> PreFun acc aenv (sl -> sh)
restrict sliceIndex slix = Lam (Body (IndexFull sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx)))
arrayShape :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreExp acc aenv sh
arrayShape = Shape . avarIn
indexArray :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreFun acc aenv (sh -> e)
indexArray v = Lam (Body (Index (avarIn v) (Var ZeroIdx)))
linearIndex :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreFun acc aenv (Int -> e)
linearIndex v = Lam (Body (LinearIndex (avarIn v) (Var ZeroIdx)))