module Data.Array.Accelerate.Trafo.Fusion (
DelayedAcc, DelayedOpenAcc(..),
DelayedAfun, DelayedOpenAfun,
DelayedExp, DelayedFun, DelayedOpenExp, DelayedOpenFun,
convertAcc, convertAfun,
) where
import Prelude hiding ( exp, until )
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Trafo.Base
import Data.Array.Accelerate.Trafo.Shrink
import Data.Array.Accelerate.Trafo.Simplify
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import Data.Array.Accelerate.Array.Sugar ( Array, Arrays(..), ArraysR(..), ArrRepr', Elt, EltRepr, Shape )
import Data.Array.Accelerate.Tuple
import qualified Data.Array.Accelerate.Debug as Stats
#ifdef ACCELERATE_DEBUG
import System.IO.Unsafe
#endif
#include "accelerate.h"
convertAcc :: Arrays arrs => Bool -> Acc arrs -> DelayedAcc arrs
convertAcc fuseAcc = withSimplStats . convertOpenAcc fuseAcc
convertAfun :: Bool -> Afun f -> DelayedAfun f
convertAfun fuseAcc = withSimplStats . convertOpenAfun fuseAcc
withSimplStats :: a -> a
#ifdef ACCELERATE_DEBUG
withSimplStats x = unsafePerformIO Stats.resetSimplCount `seq` x
#else
withSimplStats x = x
#endif
convertOpenAcc :: Arrays arrs => Bool -> OpenAcc aenv arrs -> DelayedOpenAcc aenv arrs
convertOpenAcc fuseAcc = manifest . computeAcc . embedOpenAcc fuseAcc
where
delayed :: (Shape sh, Elt e) => OpenAcc aenv (Array sh e) -> DelayedOpenAcc aenv (Array sh e)
delayed (embedOpenAcc fuseAcc -> Embed BaseEnv cc) =
case cc of
Done v -> Delayed (arrayShape v) (indexArray v) (linearIndex v)
Yield (cvtE -> sh) (cvtF -> f) -> Delayed sh f (f `compose` fromIndex sh)
Step (cvtE -> sh) (cvtF -> p) (cvtF -> f) v
| Just REFL <- match sh (arrayShape v)
, Just REFL <- isIdentity p
-> Delayed sh (f `compose` indexArray v) (f `compose` linearIndex v)
| f' <- f `compose` indexArray v `compose` p
-> Delayed sh f' (f' `compose` fromIndex sh)
manifest :: OpenAcc aenv a -> DelayedOpenAcc aenv a
manifest (OpenAcc pacc) =
let fusionError = INTERNAL_ERROR(error) "manifest" "unexpected fusible materials"
in
Manifest $ case pacc of
Avar ix -> Avar ix
Use arr -> Use arr
Unit e -> Unit (cvtE e)
Alet bnd body -> alet (manifest bnd) (manifest body)
Acond p t e -> Acond (cvtE p) (manifest t) (manifest e)
Awhile p f a -> Awhile (cvtAF p) (cvtAF f) (manifest a)
Atuple tup -> Atuple (cvtAT tup)
Aprj ix tup -> Aprj ix (manifest tup)
Apply f a -> Apply (cvtAF f) (manifest a)
Aforeign ff f a -> Aforeign ff (cvtAF f) (manifest a)
Map f a -> Map (cvtF f) (delayed a)
Generate sh f -> Generate (cvtE sh) (cvtF f)
Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (delayed a)
Backpermute sh p a -> Backpermute (cvtE sh) (cvtF p) (delayed a)
Reshape sl a -> Reshape (cvtE sl) (manifest a)
Replicate{} -> fusionError
Slice{} -> fusionError
ZipWith{} -> fusionError
Fold f z a -> Fold (cvtF f) (cvtE z) (delayed a)
Fold1 f a -> Fold1 (cvtF f) (delayed a)
FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (delayed a) (delayed s)
Fold1Seg f a s -> Fold1Seg (cvtF f) (delayed a) (delayed s)
Scanl f z a -> Scanl (cvtF f) (cvtE z) (delayed a)
Scanl1 f a -> Scanl1 (cvtF f) (delayed a)
Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (delayed a)
Scanr f z a -> Scanr (cvtF f) (cvtE z) (delayed a)
Scanr1 f a -> Scanr1 (cvtF f) (delayed a)
Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (delayed a)
Permute f d p a -> Permute (cvtF f) (manifest d) (cvtF p) (delayed a)
Stencil f x a -> Stencil (cvtF f) x (manifest a)
Stencil2 f x a y b -> Stencil2 (cvtF f) x (manifest a) y (manifest b)
alet bnd body
| Manifest (Avar ZeroIdx) <- body
, Manifest x <- bnd
= x
| otherwise
= Alet bnd body
cvtAT :: Atuple (OpenAcc aenv) a -> Atuple (DelayedOpenAcc aenv) a
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup t a) = cvtAT t `SnocAtup` manifest a
cvtAF :: OpenAfun aenv f -> PreOpenAfun DelayedOpenAcc aenv f
cvtAF (Alam f) = Alam (cvtAF f)
cvtAF (Abody b) = Abody (manifest b)
cvtF :: OpenFun env aenv f -> DelayedOpenFun env aenv f
cvtF (Lam f) = Lam (cvtF f)
cvtF (Body b) = Body (cvtE b)
cvtE :: OpenExp env aenv t -> DelayedOpenExp env aenv t
cvtE exp =
case exp of
Let bnd body -> Let (cvtE bnd) (cvtE body)
Var ix -> Var ix
Const c -> Const c
Tuple tup -> Tuple (cvtT tup)
Prj ix t -> Prj ix (cvtE t)
IndexNil -> IndexNil
IndexCons sh sz -> IndexCons (cvtE sh) (cvtE sz)
IndexHead sh -> IndexHead (cvtE sh)
IndexTail sh -> IndexTail (cvtE sh)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh)
IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl)
ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix)
FromIndex sh ix -> FromIndex (cvtE sh) (cvtE ix)
Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e)
While p f x -> While (cvtF p) (cvtF f) (cvtE x)
PrimConst c -> PrimConst c
PrimApp f x -> PrimApp f (cvtE x)
Index a sh -> Index (manifest a) (cvtE sh)
LinearIndex a i -> LinearIndex (manifest a) (cvtE i)
Shape a -> Shape (manifest a)
ShapeSize sh -> ShapeSize (cvtE sh)
Intersect s t -> Intersect (cvtE s) (cvtE t)
Foreign ff f e -> Foreign ff (cvtF f) (cvtE e)
cvtT :: Tuple (OpenExp env aenv) t -> Tuple (DelayedOpenExp env aenv) t
cvtT NilTup = NilTup
cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e
convertOpenAfun :: Bool -> OpenAfun aenv f -> DelayedOpenAfun aenv f
convertOpenAfun c (Alam f) = Alam (convertOpenAfun c f)
convertOpenAfun c (Abody b) = Abody (convertOpenAcc c b)
type EmbedAcc acc = forall aenv arrs. Arrays arrs => acc aenv arrs -> Embed acc aenv arrs
type ElimAcc acc = forall aenv s t. acc aenv s -> acc (aenv,s) t -> Bool
embedOpenAcc :: Arrays arrs => Bool -> OpenAcc aenv arrs -> Embed OpenAcc aenv arrs
embedOpenAcc fuseAcc (OpenAcc pacc) =
embedPreAcc fuseAcc (embedOpenAcc fuseAcc) elimOpenAcc pacc
where
elimOpenAcc :: ElimAcc OpenAcc
elimOpenAcc bnd body
| Map f a <- extract bnd
, Avar _ <- extract a
, Lam (Body (Prj _ _)) <- f
= Stats.ruleFired "unzipD" True
| count False ZeroIdx body <= lIMIT
= True
| otherwise
= False
where
lIMIT = 1
count :: UsesOfAcc OpenAcc
count ok idx (OpenAcc pacc) = usesOfPreAcc ok count idx pacc
embedPreAcc
:: forall acc aenv arrs. (Kit acc, Arrays arrs)
=> Bool
-> EmbedAcc acc
-> ElimAcc acc
-> PreOpenAcc acc aenv arrs
-> Embed acc aenv arrs
embedPreAcc fuseAcc embedAcc elimAcc pacc
= unembed
$ case pacc of
Alet bnd body -> aletD embedAcc elimAcc bnd body
Acond p at ae -> acondD embedAcc (cvtE p) at ae
Aprj ix tup -> aprjD embedAcc ix tup
Awhile p f a -> done $ Awhile (cvtAF p) (cvtAF f) (cvtA a)
Atuple tup -> done $ Atuple (cvtAT tup)
Apply f a -> done $ Apply (cvtAF f) (cvtA a)
Aforeign ff f a -> done $ Aforeign ff (cvtAF f) (cvtA a)
Avar v -> done $ Avar v
Use arrs -> done $ Use arrs
Unit e -> done $ Unit (cvtE e)
Generate sh f -> generateD (cvtE sh) (cvtF f)
Map f a -> fuse (into mapD (cvtF f)) a
ZipWith f a b -> fuse2 (into zipWithD (cvtF f)) a b
Transform sh p f a -> fuse (into3 transformD (cvtE sh) (cvtF p) (cvtF f)) a
Backpermute sl p a -> fuse (into2 backpermuteD (cvtE sl) (cvtF p)) a
Slice slix a sl -> fuse (into (sliceD slix) (cvtE sl)) a
Replicate slix sh a -> fuse (into (replicateD slix) (cvtE sh)) a
Reshape sl a -> reshapeD (embedAcc a) (cvtE sl)
Fold f z a -> embed (into2 Fold (cvtF f) (cvtE z)) a
Fold1 f a -> embed (into Fold1 (cvtF f)) a
FoldSeg f z a s -> embed2 (into2 FoldSeg (cvtF f) (cvtE z)) a s
Fold1Seg f a s -> embed2 (into Fold1Seg (cvtF f)) a s
Scanl f z a -> embed (into2 Scanl (cvtF f) (cvtE z)) a
Scanl1 f a -> embed (into Scanl1 (cvtF f)) a
Scanl' f z a -> embed (into2 Scanl' (cvtF f) (cvtE z)) a
Scanr f z a -> embed (into2 Scanr (cvtF f) (cvtE z)) a
Scanr1 f a -> embed (into Scanr1 (cvtF f)) a
Scanr' f z a -> embed (into2 Scanr' (cvtF f) (cvtE z)) a
Permute f d p a -> embed2 (into2 permute (cvtF f) (cvtF p)) d a
Stencil f x a -> embed (into (stencil x) (cvtF f)) a
Stencil2 f x a y b -> embed2 (into (stencil2 x y) (cvtF f)) a b
where
unembed :: Embed acc aenv arrs -> Embed acc aenv arrs
unembed x
| fuseAcc = x
| otherwise = done (compute x)
cvtA :: Arrays a => acc aenv' a -> acc aenv' a
cvtA = computeAcc . embedAcc
cvtAT :: Atuple (acc aenv') a -> Atuple (acc aenv') a
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup tup a) = cvtAT tup `SnocAtup` cvtA a
cvtAF :: PreOpenAfun acc aenv' f -> PreOpenAfun acc aenv' f
cvtAF (Alam f) = Alam (cvtAF f)
cvtAF (Abody a) = Abody (cvtA a)
permute f p d a = Permute f d p a
stencil x f a = Stencil f x a
stencil2 x y f a b = Stencil2 f x a y b
cvtF :: PreFun acc aenv t -> PreFun acc aenv t
cvtF = cvtF' . simplify
cvtE :: PreExp acc aenv' t -> PreExp acc aenv' t
cvtE = cvtE' . simplify
cvtF' :: PreOpenFun acc env aenv' t -> PreOpenFun acc env aenv' t
cvtF' (Lam f) = Lam (cvtF' f)
cvtF' (Body b) = Body (cvtE' b)
cvtE' :: PreOpenExp acc env aenv' t -> PreOpenExp acc env aenv' t
cvtE' exp =
case exp of
Let bnd body -> Let (cvtE' bnd) (cvtE' body)
Var ix -> Var ix
Const c -> Const c
Tuple tup -> Tuple (cvtT tup)
Prj tup ix -> Prj tup (cvtE' ix)
IndexNil -> IndexNil
IndexCons sh sz -> IndexCons (cvtE' sh) (cvtE' sz)
IndexHead sh -> IndexHead (cvtE' sh)
IndexTail sh -> IndexTail (cvtE' sh)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE' ix) (cvtE' sh)
IndexFull x ix sl -> IndexFull x (cvtE' ix) (cvtE' sl)
ToIndex sh ix -> ToIndex (cvtE' sh) (cvtE' ix)
FromIndex sh ix -> FromIndex (cvtE' sh) (cvtE' ix)
Cond p t e -> Cond (cvtE' p) (cvtE' t) (cvtE' e)
While p f x -> While (cvtF' p) (cvtF' f) (cvtE' x)
PrimConst c -> PrimConst c
PrimApp f x -> PrimApp f (cvtE' x)
Index a sh -> Index a (cvtE' sh)
LinearIndex a i -> LinearIndex a (cvtE' i)
Shape a -> Shape a
ShapeSize sh -> ShapeSize (cvtE' sh)
Intersect s t -> Intersect (cvtE' s) (cvtE' t)
Foreign ff f e -> Foreign ff (cvtF' f) (cvtE' e)
cvtT :: Tuple (PreOpenExp acc env aenv') t -> Tuple (PreOpenExp acc env aenv') t
cvtT NilTup = NilTup
cvtT (SnocTup tup e) = cvtT tup `SnocTup` cvtE' e
into :: Sink f => (f env' a -> b) -> f env a -> Extend acc env env' -> b
into op a env = op (sink env a)
into2 :: (Sink f1, Sink f2)
=> (f1 env' a -> f2 env' b -> c) -> f1 env a -> f2 env b -> Extend acc env env' -> c
into2 op a b env = op (sink env a) (sink env b)
into3 :: (Sink f1, Sink f2, Sink f3)
=> (f1 env' a -> f2 env' b -> f3 env' c -> d) -> f1 env a -> f2 env b -> f3 env c -> Extend acc env env' -> d
into3 op a b c env = op (sink env a) (sink env b) (sink env c)
fuse :: Arrays as
=> (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs)
-> acc aenv as
-> Embed acc aenv bs
fuse op (embedAcc -> Embed env cc) = Embed env (op env cc)
fuse2 :: (Arrays as, Arrays bs)
=> (forall aenv'. Extend acc aenv aenv' -> Cunctation acc aenv' as -> Cunctation acc aenv' bs -> Cunctation acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Embed acc aenv cs
fuse2 op a1 a0
| Embed env1 cc1 <- embedAcc a1
, Embed env0 cc0 <- embedAcc (sink env1 a0)
, env <- env1 `join` env0
= Embed env (op env (sink env0 cc1) cc0)
embed :: (Arrays as, Arrays bs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> PreOpenAcc acc aenv' bs)
-> acc aenv as
-> Embed acc aenv bs
embed op (embedAcc -> Embed env cc)
= Embed (env `PushEnv` op env (inject (compute' cc))) (Done ZeroIdx)
embed2 :: forall aenv as bs cs. (Arrays as, Arrays bs, Arrays cs)
=> (forall aenv'. Extend acc aenv aenv' -> acc aenv' as -> acc aenv' bs -> PreOpenAcc acc aenv' cs)
-> acc aenv as
-> acc aenv bs
-> Embed acc aenv cs
embed2 op (embedAcc -> Embed env1 cc1) (embedAcc . sink env1 -> Embed env0 cc0)
| env <- env1 `join` env0
, acc1 <- inject . compute' $ sink env0 cc1
, acc0 <- inject . compute' $ cc0
= Embed (env `PushEnv` op env acc1 acc0) (Done ZeroIdx)
data Embed acc aenv a where
Embed :: Extend acc aenv aenv'
-> Cunctation acc aenv' a
-> Embed acc aenv a
data Cunctation acc aenv a where
Done :: Arrays a
=> Idx aenv a
-> Cunctation acc aenv a
Yield :: (Shape sh, Elt e)
=> PreExp acc aenv sh
-> PreFun acc aenv (sh -> e)
-> Cunctation acc aenv (Array sh e)
Step :: (Shape sh, Shape sh', Elt a, Elt b)
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> PreFun acc aenv (a -> b)
-> Idx aenv (Array sh a)
-> Cunctation acc aenv (Array sh' b)
instance Kit acc => Simplify (Cunctation acc aenv a) where
simplify (Done v) = Done v
simplify (Yield sh f) = Yield (simplify sh) (simplify f)
simplify (Step sh p f v) = Step (simplify sh) (simplify p) (simplify f) v
done :: Arrays a => PreOpenAcc acc aenv a -> Embed acc aenv a
done pacc
| Avar v <- pacc = Embed BaseEnv (Done v)
| otherwise = Embed (BaseEnv `PushEnv` pacc) (Done ZeroIdx)
yield :: Kit acc
=> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sh e)
yield cc =
case cc of
Yield{} -> cc
Step sh p f v -> Yield sh (f `compose` indexArray v `compose` p)
Done v
| ArraysRarray <- accType' cc -> Yield (arrayShape v) (indexArray v)
| otherwise -> error "yield: impossible case"
step :: Kit acc
=> Cunctation acc aenv (Array sh e)
-> Maybe (Cunctation acc aenv (Array sh e))
step cc =
case cc of
Yield{} -> Nothing
Step{} -> Just cc
Done v
| ArraysRarray <- accType' cc -> Just $ Step (arrayShape v) identity identity v
| otherwise -> error "step: impossible case"
shape :: Kit acc => Cunctation acc aenv (Array sh e) -> PreExp acc aenv sh
shape cc
| Just (Step sh _ _ _) <- step cc = sh
| Yield sh _ <- yield cc = sh
accType' :: forall acc aenv a. Arrays a => Cunctation acc aenv a -> ArraysR (ArrRepr' a)
accType' _ = arrays' (undefined :: a)
data Extend acc aenv aenv' where
BaseEnv :: Extend acc aenv aenv
PushEnv :: Arrays a
=> Extend acc aenv aenv' -> PreOpenAcc acc aenv' a -> Extend acc aenv (aenv', a)
join :: Extend acc env env' -> Extend acc env' env'' -> Extend acc env env''
join x BaseEnv = x
join x (PushEnv as a) = x `join` as `PushEnv` a
bind :: (Kit acc, Arrays a)
=> Extend acc aenv aenv'
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc aenv a
bind BaseEnv = id
bind (PushEnv env a) = bind env . Alet (inject a) . inject
sink :: Sink f => Extend acc env env' -> f env t -> f env' t
sink env = weaken (k env)
where
k :: Extend acc env env' -> Idx env t -> Idx env' t
k BaseEnv = Stats.substitution "sink" id
k (PushEnv e _) = SuccIdx . k e
sink1 :: Sink f => Extend acc env env' -> f (env,s) t -> f (env',s) t
sink1 env = weaken (k env)
where
k :: Extend acc env env' -> Idx (env,s) t -> Idx (env',s) t
k BaseEnv = Stats.substitution "sink1" id
k (PushEnv e _) = split . k e
split :: Idx (env,s) t -> Idx ((env,u),s) t
split ZeroIdx = ZeroIdx
split (SuccIdx ix) = SuccIdx (SuccIdx ix)
class Sink f where
weaken :: env :> env' -> f env t -> f env' t
instance Sink Idx where
weaken k = k
instance Kit acc => Sink (PreOpenExp acc env) where
weaken k = weakenEA rebuildAcc k
instance Kit acc => Sink (PreOpenFun acc env) where
weaken k = weakenFA rebuildAcc k
instance Kit acc => Sink (PreOpenAcc acc) where
weaken k = weakenA rebuildAcc k
instance Kit acc => Sink acc where
weaken k = rebuildAcc (Avar . k)
instance Kit acc => Sink (Cunctation acc) where
weaken k cc = case cc of
Done v -> Done (weaken k v)
Step sh p f v -> Step (weaken k sh) (weaken k p) (weaken k f) (weaken k v)
Yield sh f -> Yield (weaken k sh) (weaken k f)
compute :: (Kit acc, Arrays arrs) => Embed acc aenv arrs -> PreOpenAcc acc aenv arrs
compute (Embed env cc) = bind env (compute' cc)
compute' :: (Kit acc, Arrays arrs) => Cunctation acc aenv arrs -> PreOpenAcc acc aenv arrs
compute' cc = case simplify cc of
Done v -> Avar v
Yield sh f -> Generate sh f
Step sh p f v
| Just REFL <- match sh (arrayShape v)
, Just REFL <- isIdentity p
, Just REFL <- isIdentity f -> Avar v
| Just REFL <- match sh (arrayShape v)
, Just REFL <- isIdentity p -> Map f (avarIn v)
| Just REFL <- isIdentity f -> Backpermute sh p (avarIn v)
| otherwise -> Transform sh p f (avarIn v)
computeAcc :: (Kit acc, Arrays arrs) => Embed acc aenv arrs -> acc aenv arrs
computeAcc = inject . compute
generateD :: (Shape sh, Elt e)
=> PreExp acc aenv sh
-> PreFun acc aenv (sh -> e)
-> Embed acc aenv (Array sh e)
generateD sh f
= Stats.ruleFired "generateD"
$ Embed BaseEnv (Yield sh f)
mapD :: (Kit acc, Elt b)
=> PreFun acc aenv (a -> b)
-> Cunctation acc aenv (Array sh a)
-> Cunctation acc aenv (Array sh b)
mapD f = Stats.ruleFired "mapD" . go
where
go (step -> Just (Step sh ix g v)) = Step sh ix (f `compose` g) v
go (yield -> Yield sh g) = Yield sh (f `compose` g)
backpermuteD
:: (Kit acc, Shape sh')
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sh' e)
backpermuteD sh' p = Stats.ruleFired "backpermuteD" . go
where
go (step -> Just (Step _ q f v)) = Step sh' (q `compose` p) f v
go (yield -> Yield _ g) = Yield sh' (g `compose` p)
transformD
:: (Kit acc, Shape sh', Elt b)
=> PreExp acc aenv sh'
-> PreFun acc aenv (sh' -> sh)
-> PreFun acc aenv (a -> b)
-> Cunctation acc aenv (Array sh a)
-> Cunctation acc aenv (Array sh' b)
transformD sh' p f
= Stats.ruleFired "transformD"
. backpermuteD sh' p
. mapD f
replicateD
:: (Kit acc, Shape sh, Shape sl, Elt slix, Elt e)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> Cunctation acc aenv (Array sl e)
-> Cunctation acc aenv (Array sh e)
replicateD sliceIndex slix cc
= Stats.ruleFired "replicateD"
$ backpermuteD (IndexFull sliceIndex slix (shape cc)) (extend sliceIndex slix) cc
sliceD
:: (Kit acc, Shape sh, Shape sl, Elt slix, Elt e)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> Cunctation acc aenv (Array sh e)
-> Cunctation acc aenv (Array sl e)
sliceD sliceIndex slix cc
= Stats.ruleFired "sliceD"
$ backpermuteD (IndexSlice sliceIndex slix (shape cc)) (restrict sliceIndex slix) cc
reshapeD
:: (Kit acc, Shape sh, Shape sl, Elt e)
=> Embed acc aenv (Array sh e)
-> PreExp acc aenv sl
-> Embed acc aenv (Array sl e)
reshapeD (Embed env cc) (sink env -> sl)
| Done v <- cc
= Embed (env `PushEnv` Reshape sl (avarIn v)) (Done ZeroIdx)
| otherwise
= Stats.ruleFired "reshapeD"
$ Embed env (backpermuteD sl (reindex (shape cc) sl) cc)
zipWithD :: (Kit acc, Shape sh, Elt a, Elt b, Elt c)
=> PreFun acc aenv (a -> b -> c)
-> Cunctation acc aenv (Array sh a)
-> Cunctation acc aenv (Array sh b)
-> Cunctation acc aenv (Array sh c)
zipWithD f cc1 cc0
| Just (Step sh1 p1 f1 v1) <- step cc1
, Just (Step sh0 p0 f0 v0) <- step cc0
, Just REFL <- match v1 v0
, Just REFL <- match p1 p0
= Stats.ruleFired "zipWithD/step"
$ Step (sh1 `Intersect` sh0) p0 (combine f f1 f0) v0
| Yield sh1 f1 <- yield cc1
, Yield sh0 f0 <- yield cc0
= Stats.ruleFired "zipWithD"
$ Yield (sh1 `Intersect` sh0) (combine f f1 f0)
where
combine :: forall acc aenv a b c e. (Elt a, Elt b, Elt c)
=> PreFun acc aenv (a -> b -> c)
-> PreFun acc aenv (e -> a)
-> PreFun acc aenv (e -> b)
-> PreFun acc aenv (e -> c)
combine c ixa ixb
| Lam (Lam (Body c')) <- weakenFE SuccIdx c :: PreOpenFun acc ((),e) aenv (a -> b -> c)
, Lam (Body ixa') <- ixa
, Lam (Body ixb') <- ixb
= Lam $ Body $ Let ixa' $ Let (weakenE SuccIdx ixb') c'
aletD :: (Kit acc, Arrays arrs, Arrays brrs)
=> EmbedAcc acc
-> ElimAcc acc
-> acc aenv arrs
-> acc (aenv,arrs) brrs
-> Embed acc aenv brrs
aletD embedAcc elimAcc (embedAcc -> Embed env1 cc1) acc0
| Done v1 <- cc1
, Embed env0 cc0 <- embedAcc $ rebuildAcc (subAtop (Avar v1) . sink1 env1) acc0
= Stats.ruleFired "aletD/float"
$ Embed (env1 `join` env0) cc0
| otherwise
= aletD' embedAcc elimAcc (Embed env1 cc1) (embedAcc acc0)
aletD' :: forall acc aenv arrs brrs. (Kit acc, Arrays arrs, Arrays brrs)
=> EmbedAcc acc
-> ElimAcc acc
-> Embed acc aenv arrs
-> Embed acc (aenv, arrs) brrs
-> Embed acc aenv brrs
aletD' embedAcc elimAcc (Embed env1 cc1) (Embed env0 cc0)
| acc1 <- compute (Embed env1 cc1)
, False <- elimAcc (inject acc1) acc0
= Stats.ruleFired "aletD/bind"
$ Embed (BaseEnv `PushEnv` acc1 `join` env0) cc0
| acc0' <- sink1 env1 acc0
= Stats.ruleFired "aletD/eliminate"
$ case cc1 of
Step{} -> eliminate env1 cc1 acc0'
Yield{} -> eliminate env1 cc1 acc0'
where
acc0 = computeAcc (Embed env0 cc0)
eliminate :: forall aenv aenv' sh e brrs. (Kit acc, Shape sh, Elt e, Arrays brrs)
=> Extend acc aenv aenv'
-> Cunctation acc aenv' (Array sh e)
-> acc (aenv', Array sh e) brrs
-> Embed acc aenv brrs
eliminate env1 cc1 body
| Done v1 <- cc1 = elim (arrayShape v1) (indexArray v1)
| Step sh1 p1 f1 v1 <- cc1 = elim sh1 (f1 `compose` indexArray v1 `compose` p1)
| Yield sh1 f1 <- cc1 = elim sh1 f1
where
bnd :: PreOpenAcc acc aenv' (Array sh e)
bnd = compute' cc1
elim :: PreExp acc aenv' sh -> PreFun acc aenv' (sh -> e) -> Embed acc aenv brrs
elim sh1 f1
| sh1' <- weakenEA rebuildAcc SuccIdx sh1
, f1' <- weakenFA rebuildAcc SuccIdx f1
, Embed env0' cc0' <- embedAcc $ rebuildAcc (subAtop bnd) $ kmap (replaceA sh1' f1' ZeroIdx) body
= Embed (env1 `join` env0') cc0'
replaceE :: forall env aenv sh e t. (Kit acc, Shape sh, Elt e)
=> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env aenv t
replaceE sh' f' avar exp =
case exp of
Let x y -> Let (cvtE x) (replaceE (weakenE SuccIdx sh') (weakenFE SuccIdx f') avar y)
Var i -> Var i
Foreign ff f e -> Foreign ff f (cvtE e)
Const c -> Const c
Tuple t -> Tuple (cvtT t)
Prj ix e -> Prj ix (cvtE e)
IndexNil -> IndexNil
IndexCons sl sz -> IndexCons (cvtE sl) (cvtE sz)
IndexHead sh -> IndexHead (cvtE sh)
IndexTail sz -> IndexTail (cvtE sz)
IndexAny -> IndexAny
IndexSlice x ix sh -> IndexSlice x (cvtE ix) (cvtE sh)
IndexFull x ix sl -> IndexFull x (cvtE ix) (cvtE sl)
ToIndex sh ix -> ToIndex (cvtE sh) (cvtE ix)
FromIndex sh i -> FromIndex (cvtE sh) (cvtE i)
Cond p t e -> Cond (cvtE p) (cvtE t) (cvtE e)
PrimConst c -> PrimConst c
PrimApp g x -> PrimApp g (cvtE x)
ShapeSize sh -> ShapeSize (cvtE sh)
Intersect sh sl -> Intersect (cvtE sh) (cvtE sl)
While p f x -> While (replaceF sh' f' avar p) (replaceF sh' f' avar f) (cvtE x)
Shape a
| Just REFL <- match a a' -> Stats.substitution "replaceE/shape" sh'
| otherwise -> exp
Index a sh
| Just REFL <- match a a'
, Lam (Body b) <- f' -> Stats.substitution "replaceE/!" . cvtE $ Let sh b
| otherwise -> Index a (cvtE sh)
LinearIndex a i
| Just REFL <- match a a'
, Lam (Body b) <- f' -> Stats.substitution "replaceE/!!" . cvtE $ Let (Let i (FromIndex (weakenE SuccIdx sh') (Var ZeroIdx))) b
| otherwise -> LinearIndex a (cvtE i)
where
a' = avarIn avar
cvtE :: PreOpenExp acc env aenv s -> PreOpenExp acc env aenv s
cvtE = replaceE sh' f' avar
cvtT :: Tuple (PreOpenExp acc env aenv) s -> Tuple (PreOpenExp acc env aenv) s
cvtT NilTup = NilTup
cvtT (SnocTup t e) = cvtT t `SnocTup` cvtE e
replaceF :: forall env aenv sh e t. (Kit acc, Shape sh, Elt e)
=> PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenFun acc env aenv t
-> PreOpenFun acc env aenv t
replaceF sh' f' avar fun =
case fun of
Body e -> Body (replaceE sh' f' avar e)
Lam f -> Lam (replaceF (weakenE SuccIdx sh') (weakenFE SuccIdx f') avar f)
replaceA :: forall aenv sh e a. (Kit acc, Shape sh, Elt e)
=> PreExp acc aenv sh -> PreFun acc aenv (sh -> e) -> Idx aenv (Array sh e)
-> PreOpenAcc acc aenv a
-> PreOpenAcc acc aenv a
replaceA sh' f' avar pacc =
case pacc of
Avar v
| Just REFL <- match v avar -> Avar avar
| otherwise -> Avar v
Alet bnd body ->
let sh'' = weakenEA rebuildAcc SuccIdx sh'
f'' = weakenFA rebuildAcc SuccIdx f'
in
Alet (cvtA bnd) (kmap (replaceA sh'' f'' (SuccIdx avar)) body)
Use arrs -> Use arrs
Unit e -> Unit (cvtE e)
Acond p at ae -> Acond (cvtE p) (cvtA at) (cvtA ae)
Aprj ix tup -> Aprj ix (cvtA tup)
Atuple tup -> Atuple (cvtAT tup)
Awhile p f a -> Awhile p f (cvtA a)
Apply f a -> Apply f (cvtA a)
Aforeign ff f a -> Aforeign ff f (cvtA a)
Generate sh f -> Generate (cvtE sh) (cvtF f)
Map f a -> Map (cvtF f) (cvtA a)
ZipWith f a b -> ZipWith (cvtF f) (cvtA a) (cvtA b)
Backpermute sh p a -> Backpermute (cvtE sh) (cvtF p) (cvtA a)
Transform sh p f a -> Transform (cvtE sh) (cvtF p) (cvtF f) (cvtA a)
Slice slix a sl -> Slice slix (cvtA a) (cvtE sl)
Replicate slix sh a -> Replicate slix (cvtE sh) (cvtA a)
Reshape sl a -> Reshape (cvtE sl) (cvtA a)
Fold f z a -> Fold (cvtF f) (cvtE z) (cvtA a)
Fold1 f a -> Fold1 (cvtF f) (cvtA a)
FoldSeg f z a s -> FoldSeg (cvtF f) (cvtE z) (cvtA a) (cvtA s)
Fold1Seg f a s -> Fold1Seg (cvtF f) (cvtA a) (cvtA s)
Scanl f z a -> Scanl (cvtF f) (cvtE z) (cvtA a)
Scanl1 f a -> Scanl1 (cvtF f) (cvtA a)
Scanl' f z a -> Scanl' (cvtF f) (cvtE z) (cvtA a)
Scanr f z a -> Scanr (cvtF f) (cvtE z) (cvtA a)
Scanr1 f a -> Scanr1 (cvtF f) (cvtA a)
Scanr' f z a -> Scanr' (cvtF f) (cvtE z) (cvtA a)
Permute f d p a -> Permute (cvtF f) (cvtA d) (cvtF p) (cvtA a)
Stencil f x a -> Stencil (cvtF f) x (cvtA a)
Stencil2 f x a y b -> Stencil2 (cvtF f) x (cvtA a) y (cvtA b)
where
cvtA :: acc aenv s -> acc aenv s
cvtA = kmap (replaceA sh' f' avar)
cvtE :: PreExp acc aenv s -> PreExp acc aenv s
cvtE = replaceE sh' f' avar
cvtF :: PreFun acc aenv s -> PreFun acc aenv s
cvtF = replaceF sh' f' avar
cvtAT :: Atuple (acc aenv) s -> Atuple (acc aenv) s
cvtAT NilAtup = NilAtup
cvtAT (SnocAtup tup a) = cvtAT tup `SnocAtup` cvtA a
acondD :: (Kit acc, Arrays arrs)
=> EmbedAcc acc
-> PreExp acc aenv Bool
-> acc aenv arrs
-> acc aenv arrs
-> Embed acc aenv arrs
acondD embedAcc p t e
| Const ((),True) <- p = Stats.knownBranch "True" $ embedAcc t
| Const ((),False) <- p = Stats.knownBranch "False" $ embedAcc e
| Just REFL <- match t e = Stats.knownBranch "redundant" $ embedAcc e
| otherwise = done $ Acond p (computeAcc (embedAcc t))
(computeAcc (embedAcc e))
aprjD :: forall acc aenv arrs a. (Kit acc, IsTuple arrs, Arrays arrs, Arrays a)
=> EmbedAcc acc
-> TupleIdx (TupleRepr arrs) a
-> acc aenv arrs
-> Embed acc aenv a
aprjD embedAcc ix a
| Atuple tup <- extract a = Stats.ruleFired "aprj/Atuple" . embedAcc $ aprjAT ix tup
| otherwise = done $ Aprj ix (cvtA a)
where
cvtA :: acc aenv arrs -> acc aenv arrs
cvtA = computeAcc . embedAcc
aprjAT :: TupleIdx atup a -> Atuple (acc aenv) atup -> acc aenv a
aprjAT ZeroTupIdx (SnocAtup _ a) = a
aprjAT (SuccTupIdx ix) (SnocAtup t _) = aprjAT ix t
isIdentity :: PreFun acc aenv (a -> b) -> Maybe (a :=: b)
isIdentity f
| Lam (Body (Var ZeroIdx)) <- f = Just REFL
| otherwise = Nothing
identity :: Elt a => PreOpenFun acc env aenv (a -> a)
identity = Lam (Body (Var ZeroIdx))
toIndex :: Shape sh => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (sh -> Int)
toIndex sh = Lam (Body (ToIndex (weakenE SuccIdx sh) (Var ZeroIdx)))
fromIndex :: Shape sh => PreOpenExp acc env aenv sh -> PreOpenFun acc env aenv (Int -> sh)
fromIndex sh = Lam (Body (FromIndex (weakenE SuccIdx sh) (Var ZeroIdx)))
reindex :: (Kit acc, Shape sh, Shape sh')
=> PreOpenExp acc env aenv sh'
-> PreOpenExp acc env aenv sh
-> PreOpenFun acc env aenv (sh -> sh')
reindex sh' sh
| Just REFL <- match sh sh' = identity
| otherwise = fromIndex sh' `compose` toIndex sh
extend :: (Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> PreFun acc aenv (sh -> sl)
extend sliceIndex slix = Lam (Body (IndexSlice sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx)))
restrict :: (Shape sh, Shape sl, Elt slix)
=> SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
-> PreExp acc aenv slix
-> PreFun acc aenv (sl -> sh)
restrict sliceIndex slix = Lam (Body (IndexFull sliceIndex (weakenE SuccIdx slix) (Var ZeroIdx)))
arrayShape :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreExp acc aenv sh
arrayShape = Shape . avarIn
indexArray :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreFun acc aenv (sh -> e)
indexArray v = Lam (Body (Index (avarIn v) (Var ZeroIdx)))
linearIndex :: (Kit acc, Shape sh, Elt e) => Idx aenv (Array sh e) -> PreFun acc aenv (Int -> e)
linearIndex v = Lam (Body (LinearIndex (avarIn v) (Var ZeroIdx)))