{-# LANGUAGE TypeFamilies #-}

-- | Do various kernel optimisations - mostly related to coalescing.
module Futhark.Pass.KernelBabysitting (babysitKernels) where

import Control.Arrow (first)
import Control.Monad.State.Strict
import Data.Foldable
import Data.List (elemIndex, isPrefixOf, sort)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.GPU hiding
  ( BasicOp,
    Body,
    Exp,
    FParam,
    FunDef,
    LParam,
    Lambda,
    Pat,
    PatElem,
    Prog,
    RetType,
    Stm,
  )
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Util

-- | The pass definition.
babysitKernels :: Pass GPU GPU
babysitKernels :: Pass GPU GPU
babysitKernels =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    String
"babysit kernels"
    String
"Transpose kernel input arrays for better performance."
    forall a b. (a -> b) -> a -> b
$ forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation forall {f :: * -> *}.
MonadFreshNames f =>
Scope GPU -> Stms GPU -> f (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> f (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms = do
      let m :: BuilderT GPU (State VNameSource) (Stms GPU)
m = forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope forall a b. (a -> b) -> a -> b
$ ExpMap -> Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
transformStms forall a. Monoid a => a
mempty Stms GPU
stms
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState (forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT GPU (State VNameSource) (Stms GPU)
m forall k a. Map k a
M.empty)

type BabysitM = Builder GPU

transformStms :: ExpMap -> Stms GPU -> BabysitM (Stms GPU)
transformStms :: ExpMap -> Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
transformStms ExpMap
expmap Stms GPU
stms = forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ExpMap -> Stm GPU -> BabysitM ExpMap
transformStm ExpMap
expmap Stms GPU
stms

transformBody :: ExpMap -> Body GPU -> BabysitM (Body GPU)
transformBody :: ExpMap -> Body GPU -> BabysitM (Body GPU)
transformBody ExpMap
expmap (Body () Stms GPU
stms Result
res) = do
  Stms GPU
stms' <- ExpMap -> Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
transformStms ExpMap
expmap Stms GPU
stms
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res

-- | Map from variable names to defining expression.  We use this to
-- hackily determine whether something is transposed or otherwise
-- funky in memory (and we'd prefer it not to be).  If we cannot find
-- it in the map, we just assume it's all good.  HACK and FIXME, I
-- suppose.  We really should do this at the memory level.
type ExpMap = M.Map VName (Stm GPU)

nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
name ExpMap
m =
  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name ExpMap
m of
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Opaque OpaqueOp
_ (Var VName
arr)))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Rearrange [Int]
perm VName
_))) -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Reshape ReshapeKind
_ Shape
_ VName
arr))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Manifest [Int]
perm VName
_))) -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just [Int]
perm
    Just (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ (Op (SegOp (SegMap SegLevel
_ SegSpace
_ [Type]
ts KernelBody GPU
_)))) ->
      forall {shape} {dec} {u}.
(ArrayShape shape, Typed dec) =>
(PatElem dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find
          ((forall a. Eq a => a -> a -> Bool
== VName
name) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
          (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPU)
pat) [Type]
ts)
    Maybe (Stm GPU)
_ -> forall a. Maybe a
Nothing
  where
    nonlinear :: (PatElem dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear (PatElem dec
pe, TypeBase shape u
t)
      | Int
inner_r <- forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t,
        Int
inner_r forall a. Ord a => a -> a -> Bool
> Int
0 = do
          let outer_r :: Int
outer_r = forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe) forall a. Num a => a -> a -> a
- Int
inner_r
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse forall a b. (a -> b) -> a -> b
$ [Int
inner_r .. Int
inner_r forall a. Num a => a -> a -> a
+ Int
outer_r forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
inner_r forall a. Num a => a -> a -> a
- Int
1]
      | Bool
otherwise = forall a. Maybe a
Nothing

transformStm :: ExpMap -> Stm GPU -> BabysitM ExpMap
transformStm :: ExpMap -> Stm GPU -> BabysitM ExpMap
transformStm ExpMap
expmap (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp SegOp SegLevel GPU
op)))
  -- FIXME: We only make coalescing optimisations for SegThread
  -- SegOps, because that's what the analysis assumes.  For SegGroup
  -- we should probably look at the component SegThreads, but it
  -- apparently hasn't come up in practice yet.
  | SegThread {} <- forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op = do
      let mapper :: SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
mapper =
            forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
              { mapOnSegOpBody :: KernelBody GPU -> BuilderT GPU (State VNameSource) (KernelBody GPU)
mapOnSegOpBody =
                  ExpMap
-> SegSpace
-> KernelBody GPU
-> BuilderT GPU (State VNameSource) (KernelBody GPU)
transformKernelBody ExpMap
expmap (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPU
op)
              }
      SegOp SegLevel GPU
op' <- forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
mapper SegOp SegLevel GPU
op
      let stm' :: Stm GPU
stm' = forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp SegOp SegLevel GPU
op'
      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm GPU
stm'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm GPU
stm') | VName
name <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat] forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap
transformStm ExpMap
expmap (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Exp GPU
e' <- forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (ExpMap -> Mapper GPU GPU (BuilderT GPU (State VNameSource))
transform ExpMap
expmap) Exp GPU
e
  let stm' :: Stm GPU
stm' = forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e'
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm GPU
stm'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm GPU
stm') | VName
name <- forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPU)
pat] forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap

transform :: ExpMap -> Mapper GPU GPU BabysitM
transform :: ExpMap -> Mapper GPU GPU (BuilderT GPU (State VNameSource))
transform ExpMap
expmap =
  forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper {mapOnBody :: Scope GPU -> Body GPU -> BabysitM (Body GPU)
mapOnBody = \Scope GPU
scope -> forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpMap -> Body GPU -> BabysitM (Body GPU)
transformBody ExpMap
expmap}

transformKernelBody ::
  ExpMap ->
  SegSpace ->
  KernelBody GPU ->
  BabysitM (KernelBody GPU)
transformKernelBody :: ExpMap
-> SegSpace
-> KernelBody GPU
-> BuilderT GPU (State VNameSource) (KernelBody GPU)
transformKernelBody ExpMap
expmap SegSpace
space KernelBody GPU
kbody = do
  -- Go spelunking for accesses to arrays that are defined outside the
  -- kernel body and where the indices are kernel thread indices.
  Scope GPU
scope <- forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let thread_gids :: [VName]
thread_gids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      thread_local :: Names
thread_local = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space forall a. a -> [a] -> [a]
: [VName]
thread_gids
      free_ker_vars :: Names
free_ker_vars = forall a. FreeIn a => a -> Names
freeIn KernelBody GPU
kbody Names -> Names -> Names
`namesSubtract` SegSpace -> Names
getKerVariantIds SegSpace
space
  forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    ( forall (f :: * -> *).
Monad f =>
Names
-> Names
-> Scope GPU
-> ArrayIndexTransform f
-> KernelBody GPU
-> f (KernelBody GPU)
traverseKernelBodyArrayIndexes
        Names
free_ker_vars
        Names
thread_local
        (Scope GPU
scope forall a. Semigroup a => a -> a -> a
<> forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
        (forall (m :: * -> *).
MonadBuilder m =>
ExpMap
-> [(VName, SubExp)] -> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess ExpMap
expmap (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space))
        KernelBody GPU
kbody
    )
    forall a. Monoid a => a
mempty
  where
    getKerVariantIds :: SegSpace -> Names
getKerVariantIds = [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SegSpace -> Scope rep
scopeOfSegSpace

type ArrayIndexTransform m =
  Names ->
  (VName -> Bool) -> -- thread local?
  (VName -> SubExp -> Bool) -> -- variant to a certain gid (given as first param)?
  Scope GPU -> -- type environment
  VName ->
  Slice SubExp ->
  m (Maybe (VName, Slice SubExp))

traverseKernelBodyArrayIndexes ::
  Monad f =>
  Names ->
  Names ->
  Scope GPU ->
  ArrayIndexTransform f ->
  KernelBody GPU ->
  f (KernelBody GPU)
traverseKernelBodyArrayIndexes :: forall (f :: * -> *).
Monad f =>
Names
-> Names
-> Scope GPU
-> ArrayIndexTransform f
-> KernelBody GPU
-> f (KernelBody GPU)
traverseKernelBodyArrayIndexes Names
free_ker_vars Names
thread_variant Scope GPU
outer_scope ArrayIndexTransform f
f (KernelBody () Stms GPU
kstms [KernelResult]
kres) =
  forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. [Stm rep] -> Stms rep
stmsFromList
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( (VarianceTable, Scope GPU) -> Stm GPU -> f (Stm GPU)
onStm
          ( VarianceTable -> Stms GPU -> VarianceTable
varianceInStms forall a. Monoid a => a
mempty Stms GPU
kstms,
            Scope GPU
outer_scope
          )
      )
      (forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
kres
  where
    onLambda :: (VarianceTable, Scope GPU) -> Lambda GPU -> f (Lambda GPU)
onLambda (VarianceTable
variance, Scope GPU
scope) Lambda GPU
lam =
      (\Body GPU
body' -> Lambda GPU
lam {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'})
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VarianceTable, Scope GPU) -> Body GPU -> f (Body GPU)
onBody (VarianceTable
variance, Scope GPU
scope') (forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam)
      where
        scope' :: Scope GPU
scope' = Scope GPU
scope forall a. Semigroup a => a -> a -> a
<> forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)

    onBody :: (VarianceTable, Scope GPU) -> Body GPU -> f (Body GPU)
onBody (VarianceTable
variance, Scope GPU
scope) (Body BodyDec GPU
bdec Stms GPU
stms Result
bres) = do
      Stms GPU
stms' <- forall rep. [Stm rep] -> Stms rep
stmsFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Scope GPU) -> Stm GPU -> f (Stm GPU)
onStm (VarianceTable
variance', Scope GPU
scope')) (forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPU
bdec Stms GPU
stms' Result
bres
      where
        variance' :: VarianceTable
variance' = VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
variance Stms GPU
stms
        scope' :: Scope GPU
scope' = Scope GPU
scope forall a. Semigroup a => a -> a -> a
<> forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms

    onStm :: (VarianceTable, Scope GPU) -> Stm GPU -> f (Stm GPU)
onStm (VarianceTable
variance, Scope GPU
_) (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec (BasicOp (Index VName
arr Slice SubExp
is))) =
      forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (VName, Slice SubExp) -> Exp GPU
oldOrNew forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArrayIndexTransform f
f Names
free_ker_vars VName -> Bool
isThreadLocal VName -> SubExp -> Bool
isGidVariant Scope GPU
outer_scope VName
arr Slice SubExp
is
      where
        oldOrNew :: Maybe (VName, Slice SubExp) -> Exp GPU
oldOrNew Maybe (VName, Slice SubExp)
Nothing =
          forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
is
        oldOrNew (Just (VName
arr', Slice SubExp
is')) =
          forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' Slice SubExp
is'

        isGidVariant :: VName -> SubExp -> Bool
isGidVariant VName
gid (Var VName
v) =
          VName
gid forall a. Eq a => a -> a -> Bool
== VName
v Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance)
        isGidVariant VName
_ SubExp
_ = Bool
False

        isThreadLocal :: VName -> Bool
isThreadLocal VName
v =
          Names
thread_variant
            Names -> Names -> Bool
`namesIntersect` forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance
    onStm (VarianceTable
variance, Scope GPU
scope) (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec Exp GPU
e) =
      forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((VarianceTable, Scope GPU) -> Mapper GPU GPU f
mapper (VarianceTable
variance, Scope GPU
scope)) Exp GPU
e

    onOp :: (VarianceTable, Scope GPU)
-> HostOp SOAC GPU -> f (HostOp SOAC GPU)
onOp (VarianceTable, Scope GPU)
ctx (OtherOp SOAC GPU
soac) =
      forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda GPU -> f (Lambda GPU)
mapOnSOACLambda = (VarianceTable, Scope GPU) -> Lambda GPU -> f (Lambda GPU)
onLambda (VarianceTable, Scope GPU)
ctx} SOAC GPU
soac
    onOp (VarianceTable, Scope GPU)
_ HostOp SOAC GPU
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp SOAC GPU
op

    mapper :: (VarianceTable, Scope GPU) -> Mapper GPU GPU f
mapper (VarianceTable, Scope GPU)
ctx =
      forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPU -> Body GPU -> f (Body GPU)
mapOnBody = forall a b. a -> b -> a
const ((VarianceTable, Scope GPU) -> Body GPU -> f (Body GPU)
onBody (VarianceTable, Scope GPU)
ctx),
          mapOnOp :: Op GPU -> f (Op GPU)
mapOnOp = (VarianceTable, Scope GPU)
-> HostOp SOAC GPU -> f (HostOp SOAC GPU)
onOp (VarianceTable, Scope GPU)
ctx
        }

type Replacements = M.Map (VName, Slice SubExp) VName

ensureCoalescedAccess ::
  MonadBuilder m =>
  ExpMap ->
  [(VName, SubExp)] ->
  ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess :: forall (m :: * -> *).
MonadBuilder m =>
ExpMap
-> [(VName, SubExp)] -> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess
  ExpMap
expmap
  [(VName, SubExp)]
thread_space
  Names
free_ker_vars
  VName -> Bool
isThreadLocal
  VName -> SubExp -> Bool
isGidVariant
  Scope GPU
outer_scope
  VName
arr
  Slice SubExp
slice = do
    Maybe VName
seen <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName
arr, Slice SubExp
slice)

    case (Maybe VName
seen, VName -> Bool
isThreadLocal VName
arr, forall t. Typed t => t -> Type
typeOf forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Scope GPU
outer_scope) of
      -- Already took care of this case elsewhere.
      (Just VName
arr', Bool
_, Maybe Type
_) ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
      (Maybe VName
Nothing, Bool
False, Just Type
t)
        -- We are fully indexing the array with thread IDs, but the
        -- indices are in a permuted order.
        | Just [SubExp]
is <- forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
          forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is forall a. Eq a => a -> a -> Bool
== forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t,
          Just [SubExp]
is' <- Names
-> (VName -> SubExp -> Bool)
-> [SubExp]
-> [SubExp]
-> Maybe [SubExp]
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) [SubExp]
is,
          Just [Int]
perm <- [SubExp]
is' forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` [SubExp]
is ->
            VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadBuilder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)
        -- Check whether the access is already coalesced because of a
        -- previous rearrange being applied to the current array:
        -- 1. get the permutation of the source-array rearrange
        -- 2. apply it to the slice
        -- 3. check that the innermost index is actually the gid
        --    of the innermost kernel dimension.
        -- If so, the access is already coalesced, nothing to do!
        -- (Cosmin's Heuristic.)
        | Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Rearrange [Int]
perm VName
_))) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm,
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
          VName
inner_gid <- forall a. [a] -> a
last [VName]
thread_gids,
          forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice forall a. Ord a => a -> a -> Bool
>= forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm,
          [DimIndex SubExp]
slice' <- forall a b. (a -> b) -> [a] -> [b]
map (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice !!) [Int]
perm,
          DimFix SubExp
inner_ind <- forall a. [a] -> a
last [DimIndex SubExp]
slice',
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
          VName -> SubExp -> Bool
isGidVariant VName
inner_gid SubExp
inner_ind ->
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        -- We are not fully indexing an array, but the remaining slice
        -- is invariant to the innermost-kernel dimension. We assume
        -- the remaining slice will be sequentially streamed, hence
        -- tiling will be applied later and will solve coalescing.
        -- Hence nothing to do at this point. (Cosmin's Heuristic.)
        | ([SubExp]
is, Slice SubExp
rem_slice) <- Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice Slice SubExp
slice,
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
          Slice SubExp -> Bool
allDimAreSlice Slice SubExp
rem_slice,
          Maybe (Stm GPU)
Nothing <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
          PrimType
pt <- forall shape u. TypeBase shape u -> PrimType
elemType Type
t,
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (forall a. Num a => PrimType -> a
primByteSize PrimType
pt) Slice SubExp
rem_slice,
          [SubExp]
is forall a. Eq a => a -> a -> Bool
/= forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
          Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
is),
          forall a. [a] -> a
last [VName]
thread_gids VName -> Names -> Bool
`notNameIn` (forall a. FreeIn a => a -> Names
freeIn [SubExp]
is forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Names
freeIn Slice SubExp
rem_slice) ->
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        -- We are not fully indexing the array, and the indices are not
        -- a proper prefix of the thread indices, and some indices are
        -- thread local, so we assume (HEURISTIC!)  that the remaining
        -- dimensions will be traversed sequentially.
        | ([SubExp]
is, Slice SubExp
rem_slice) <- Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice Slice SubExp
slice,
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
          PrimType
pt <- forall shape u. TypeBase shape u -> PrimType
elemType Type
t,
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (forall a. Num a => PrimType -> a
primByteSize PrimType
pt) Slice SubExp
rem_slice,
          [SubExp]
is forall a. Eq a => a -> a -> Bool
/= forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isThreadLocal (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [SubExp]
is) -> do
            let perm :: [Int]
perm = Int -> Int -> [Int]
coalescingPermutation (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t
            VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *).
MonadBuilder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)

        -- Everything is fine... assuming that the array is in row-major
        -- order!  Make sure that is the case.
        | Just {} <- VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap ->
            case forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice of
              Just [SubExp]
is
                | Just [SubExp]
_ <- Names
-> (VName -> SubExp -> Bool)
-> [SubExp]
-> [SubExp]
-> Maybe [SubExp]
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) [SubExp]
is ->
                    VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr)
                | Bool
otherwise ->
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
              Maybe [SubExp]
_ -> VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr)
      (Maybe VName, Bool, Maybe Type)
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    where
      ([VName]
thread_gids, [SubExp]
_thread_gdims) = forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
thread_space

      replace :: VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace VName
arr' = do
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (VName
arr, Slice SubExp
slice) VName
arr'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)

-- Heuristic for avoiding rearranging too small arrays.
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice Int32
bs = forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
bs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. Slice d -> [d]
sliceDims
  where
    comb :: (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
x) (Constant (IntValue (Int32Value Int32
d))) = (Int32
d forall a. Num a => a -> a -> a
* Int32
x forall a. Ord a => a -> a -> Bool
< Int32
4, Int32
d forall a. Num a => a -> a -> a
* Int32
x)
    comb (Bool
_, Int32
x) SubExp
_ = (Bool
False, Int32
x)

splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice (Slice []) = ([], forall d. [DimIndex d] -> Slice d
Slice [])
splitSlice (Slice (DimFix SubExp
i : [DimIndex SubExp]
is)) = forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (SubExp
i :) forall a b. (a -> b) -> a -> b
$ Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
is)
splitSlice Slice SubExp
is = ([], Slice SubExp
is)

allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice (Slice []) = Bool
True
allDimAreSlice (Slice (DimFix SubExp
_ : [DimIndex SubExp]
_)) = Bool
False
allDimAreSlice (Slice (DimIndex SubExp
_ : [DimIndex SubExp]
is)) = Slice SubExp -> Bool
allDimAreSlice (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
is)

-- Try to move thread indexes into their proper position.
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes :: Names
-> (VName -> SubExp -> Bool)
-> [SubExp]
-> [SubExp]
-> Maybe [SubExp]
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant [SubExp]
tgids [SubExp]
is
  -- Do Nothing if:
  -- 1. any of the indices is a constant or a kernel free variable
  --    (because it would transpose a bigger array then needed -- big overhead).
  -- 2. the innermost index is variant to the innermost-thread gid
  --    (because access is likely to be already coalesced)
  -- 3. the indexes are a prefix of the thread indexes, because that
  -- means multiple threads will be accessing the same element.
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
isCt [SubExp]
is =
      forall a. Maybe a
Nothing
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
free_ker_vars) ([SubExp] -> [VName]
subExpVars [SubExp]
is) =
      forall a. Maybe a
Nothing
  | [SubExp]
is forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` [SubExp]
tgids =
      forall a. Maybe a
Nothing
  | Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
tgids),
    Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
is),
    Var VName
innergid <- forall a. [a] -> a
last [SubExp]
tgids,
    Int
num_is forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& VName -> SubExp -> Bool
isGidVariant VName
innergid (forall a. [a] -> a
last [SubExp]
is) =
      forall a. a -> Maybe a
Just [SubExp]
is
  -- 3. Otherwise try fix coalescing
  | Bool
otherwise =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [SubExp] -> (Int, SubExp) -> [SubExp]
move (forall a. [a] -> [a]
reverse [SubExp]
is) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] (forall a. [a] -> [a]
reverse [SubExp]
tgids)
  where
    num_is :: Int
num_is = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is

    move :: [SubExp] -> (Int, SubExp) -> [SubExp]
move [SubExp]
is_rev (Int
i, SubExp
tgid)
      -- If tgid is in is_rev anywhere but at position i, and
      -- position i exists, we move it to position i instead.
      | Just Int
j <- forall a. Eq a => a -> [a] -> Maybe Int
elemIndex SubExp
tgid [SubExp]
is_rev,
        Int
i forall a. Eq a => a -> a -> Bool
/= Int
j,
        Int
i forall a. Ord a => a -> a -> Bool
< Int
num_is =
          forall {a} {b} {t}.
(Integral a, Integral b, Show a, Show b, Show t) =>
a -> b -> [t] -> [t]
swap Int
i Int
j [SubExp]
is_rev
      | Bool
otherwise =
          [SubExp]
is_rev

    swap :: a -> b -> [t] -> [t]
swap a
i b
j [t]
l
      | Just t
ix <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [t]
l,
        Just t
jx <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth b
j [t]
l =
          forall {t} {t}. (Eq t, Num t) => t -> t -> [t] -> [t]
update a
i t
jx forall a b. (a -> b) -> a -> b
$ forall {t} {t}. (Eq t, Num t) => t -> t -> [t] -> [t]
update b
j t
ix [t]
l
      | Bool
otherwise =
          forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"coalescedIndexes swap: invalid indices" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (a
i, b
j, [t]
l)

    update :: t -> t -> [t] -> [t]
update t
0 t
x (t
_ : [t]
ys) = t
x forall a. a -> [a] -> [a]
: [t]
ys
    update t
i t
x (t
y : [t]
ys) = t
y forall a. a -> [a] -> [a]
: t -> t -> [t] -> [t]
update (t
i forall a. Num a => a -> a -> a
- t
1) t
x [t]
ys
    update t
_ t
_ [] = forall a. HasCallStack => String -> a
error String
"coalescedIndexes: update"

    isCt :: SubExp -> Bool
    isCt :: SubExp -> Bool
isCt (Constant PrimValue
_) = Bool
True
    isCt (Var VName
_) = Bool
False

coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation Int
num_is Int
rank =
  [Int
num_is .. Int
rank forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_is forall a. Num a => a -> a -> a
- Int
1]

rearrangeInput ::
  MonadBuilder m =>
  Maybe (Maybe [Int]) ->
  [Int] ->
  VName ->
  m VName
rearrangeInput :: forall (m :: * -> *).
MonadBuilder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just [Int]
current_perm)) [Int]
perm VName
arr
  | [Int]
current_perm forall a. Eq a => a -> a -> Bool
== [Int]
perm = forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr -- Already has desired representation.
rearrangeInput Maybe (Maybe [Int])
Nothing [Int]
perm VName
arr
  | forall a. Ord a => [a] -> [a]
sort [Int]
perm forall a. Eq a => a -> a -> Bool
== [Int]
perm = forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr -- We don't know the current
  -- representation, but the indexing
  -- is linear, so let's hope the
  -- array is too.
rearrangeInput (Just Just {}) [Int]
perm VName
arr
  | forall a. Ord a => [a] -> [a]
sort [Int]
perm forall a. Eq a => a -> a -> Bool
== [Int]
perm = forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr -- We just want a row-major array, no tricks.
rearrangeInput Maybe (Maybe [Int])
manifest [Int]
perm VName
arr = do
  -- We may first manifest the array to ensure that it is flat in
  -- memory.  This is sometimes unnecessary, in which case the copy
  -- will hopefully be removed by the simplifier.
  VName
manifested <- if forall a. Maybe a -> Bool
isJust Maybe (Maybe [Int])
manifest then forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr else forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr forall a. [a] -> [a] -> [a]
++ String
"_coalesced") forall a b. (a -> b) -> a -> b
$
    forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
manifested

rowMajorArray ::
  MonadBuilder m =>
  VName ->
  m VName
rowMajorArray :: forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr = do
  Int
rank <- forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr forall a. [a] -> [a] -> [a]
++ String
"_rowmajor") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1] VName
arr

--- Computing variance.

type VarianceTable = M.Map VName Names

varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
t = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
t forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stms rep -> [Stm rep]
stmsToList

varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
variance Stm GPU
stm =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
variance forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
  where
    add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
    look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v forall a. Semigroup a => a -> a -> a
<> forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
    binding_variance :: Names
binding_variance = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm)