{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Do various kernel optimisations - mostly related to coalescing.
module Futhark.Pass.KernelBabysitting (babysitKernels) where

import Control.Arrow (first)
import Control.Monad.State.Strict
import Data.Foldable
import Data.List (elemIndex, isPrefixOf, sort)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.GPU hiding
  ( BasicOp,
    Body,
    Exp,
    FParam,
    FunDef,
    LParam,
    Lambda,
    Pat,
    PatElem,
    Prog,
    RetType,
    Stm,
  )
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Util

-- | The pass definition.
babysitKernels :: Pass GPU GPU
babysitKernels :: Pass GPU GPU
babysitKernels =
  String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    String
"babysit kernels"
    String
"Transpose kernel input arrays for better performance."
    ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$ (Scope GPU -> Stms GPU -> PassM (Stms GPU))
-> Prog GPU -> PassM (Prog GPU)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope GPU -> Stms GPU -> PassM (Stms GPU)
forall (f :: * -> *).
MonadFreshNames f =>
Scope GPU -> Stms GPU -> f (Stms GPU)
onStms
  where
    onStms :: Scope GPU -> Stms GPU -> f (Stms GPU)
onStms Scope GPU
scope Stms GPU
stms = do
      let m :: BuilderT GPU (State VNameSource) (Stms GPU)
m = Scope GPU
-> BuilderT GPU (State VNameSource) (Stms GPU)
-> BuilderT GPU (State VNameSource) (Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (BuilderT GPU (State VNameSource) (Stms GPU)
 -> BuilderT GPU (State VNameSource) (Stms GPU))
-> BuilderT GPU (State VNameSource) (Stms GPU)
-> BuilderT GPU (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ ExpMap -> Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
transformStms ExpMap
forall a. Monoid a => a
mempty Stms GPU
stms
      ((Stms GPU, Stms GPU) -> Stms GPU)
-> f (Stms GPU, Stms GPU) -> f (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Stms GPU, Stms GPU) -> Stms GPU
forall a b. (a, b) -> a
fst (f (Stms GPU, Stms GPU) -> f (Stms GPU))
-> f (Stms GPU, Stms GPU) -> f (Stms GPU)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((Stms GPU, Stms GPU), VNameSource))
-> f (Stms GPU, Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Stms GPU, Stms GPU), VNameSource))
 -> f (Stms GPU, Stms GPU))
-> (VNameSource -> ((Stms GPU, Stms GPU), VNameSource))
-> f (Stms GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms GPU, Stms GPU)
-> VNameSource -> ((Stms GPU, Stms GPU), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BuilderT GPU (State VNameSource) (Stms GPU)
-> Scope GPU -> State VNameSource (Stms GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT GPU (State VNameSource) (Stms GPU)
m Scope GPU
forall k a. Map k a
M.empty)

type BabysitM = Builder GPU

transformStms :: ExpMap -> Stms GPU -> BabysitM (Stms GPU)
transformStms :: ExpMap -> Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
transformStms ExpMap
expmap Stms GPU
stms = BuilderT GPU (State VNameSource) ()
-> BuilderT
     GPU
     (State VNameSource)
     (Stms (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (BuilderT GPU (State VNameSource) ()
 -> BuilderT
      GPU
      (State VNameSource)
      (Stms (Rep (BuilderT GPU (State VNameSource)))))
-> BuilderT GPU (State VNameSource) ()
-> BuilderT
     GPU
     (State VNameSource)
     (Stms (Rep (BuilderT GPU (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (ExpMap -> Stm GPU -> BuilderT GPU (State VNameSource) ExpMap)
-> ExpMap -> Stms GPU -> BuilderT GPU (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ExpMap -> Stm GPU -> BuilderT GPU (State VNameSource) ExpMap
transformStm ExpMap
expmap Stms GPU
stms

transformBody :: ExpMap -> Body GPU -> BabysitM (Body GPU)
transformBody :: ExpMap -> Body GPU -> BabysitM (Body GPU)
transformBody ExpMap
expmap (Body () Stms GPU
stms Result
res) = do
  Stms GPU
stms' <- ExpMap -> Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
transformStms ExpMap
expmap Stms GPU
stms
  Body GPU -> BabysitM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU -> BabysitM (Body GPU))
-> Body GPU -> BabysitM (Body GPU)
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res

-- | Map from variable names to defining expression.  We use this to
-- hackily determine whether something is transposed or otherwise
-- funky in memory (and we'd prefer it not to be).  If we cannot find
-- it in the map, we just assume it's all good.  HACK and FIXME, I
-- suppose.  We really should do this at the memory level.
type ExpMap = M.Map VName (Stm GPU)

nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
name ExpMap
m =
  case VName -> ExpMap -> Maybe (Stm GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name ExpMap
m of
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Opaque OpaqueOp
_ (Var VName
arr)))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Rearrange [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Reshape ShapeChange SubExp
_ VName
arr))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Manifest [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
    Just (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ (Op (SegOp (SegMap _ _ ts _)))) ->
      (PatElem Type, Type) -> Maybe (Maybe [Int])
forall shape dec u.
(ArrayShape shape, Typed dec) =>
(PatElem dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear
        ((PatElem Type, Type) -> Maybe (Maybe [Int]))
-> Maybe (PatElem Type, Type) -> Maybe (Maybe [Int])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((PatElem Type, Type) -> Bool)
-> [(PatElem Type, Type)] -> Maybe (PatElem Type, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find
          ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
name) (VName -> Bool)
-> ((PatElem Type, Type) -> VName) -> (PatElem Type, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem Type -> VName)
-> ((PatElem Type, Type) -> PatElem Type)
-> (PatElem Type, Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem Type, Type) -> PatElem Type
forall a b. (a, b) -> a
fst)
          ([PatElem Type] -> [Type] -> [(PatElem Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
Pat (LetDec GPU)
pat) [Type]
ts)
    Maybe (Stm GPU)
_ -> Maybe (Maybe [Int])
forall a. Maybe a
Nothing
  where
    nonlinear :: (PatElem dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear (PatElem dec
pe, TypeBase shape u
t)
      | Int
inner_r <- TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t,
        Int
inner_r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
          let outer_r :: Int
outer_r = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (PatElem dec -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
inner_r
          Maybe [Int] -> Maybe (Maybe [Int])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int
inner_r .. Int
inner_r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
outer_r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
inner_r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
      | Bool
otherwise = Maybe (Maybe [Int])
forall a. Maybe a
Nothing

transformStm :: ExpMap -> Stm GPU -> BabysitM ExpMap
transformStm :: ExpMap -> Stm GPU -> BuilderT GPU (State VNameSource) ExpMap
transformStm ExpMap
expmap (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Op (SegOp op)))
  -- FIXME: We only make coalescing optimisations for SegThread
  -- SegOps, because that's what the analysis assumes.  For SegGroup
  -- we should probably look at the component SegThreads, but it
  -- apparently hasn't come up in practice yet.
  | SegThread {} <- SegOp SegLevel GPU -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op = do
      let mapper :: SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
mapper =
            SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
              { mapOnSegOpBody :: KernelBody GPU -> BuilderT GPU (State VNameSource) (KernelBody GPU)
mapOnSegOpBody =
                  ExpMap
-> SegLevel
-> SegSpace
-> KernelBody GPU
-> BuilderT GPU (State VNameSource) (KernelBody GPU)
transformKernelBody ExpMap
expmap (SegOp SegLevel GPU -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op) (SegOp SegLevel GPU -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPU
op)
              }
      SegOp SegLevel GPU
op' <- SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
-> SegOp SegLevel GPU
-> BuilderT GPU (State VNameSource) (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPU GPU (BuilderT GPU (State VNameSource))
mapper SegOp SegLevel GPU
op
      let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel GPU
op'
      Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
stm'
      ExpMap -> BuilderT GPU (State VNameSource) ExpMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpMap -> BuilderT GPU (State VNameSource) ExpMap)
-> ExpMap -> BuilderT GPU (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm GPU)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm GPU
stm') | VName
name <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap
transformStm ExpMap
expmap (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
  Exp GPU
e' <- Mapper GPU GPU (BuilderT GPU (State VNameSource))
-> Exp GPU -> BabysitM (Exp GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (ExpMap -> Mapper GPU GPU (BuilderT GPU (State VNameSource))
transform ExpMap
expmap) Exp GPU
e
  let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e'
  Stm (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep (BuilderT GPU (State VNameSource)))
Stm GPU
stm'
  ExpMap -> BuilderT GPU (State VNameSource) ExpMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpMap -> BuilderT GPU (State VNameSource) ExpMap)
-> ExpMap -> BuilderT GPU (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm GPU)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm GPU
stm') | VName
name <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
Pat (LetDec GPU)
pat] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap

transform :: ExpMap -> Mapper GPU GPU BabysitM
transform :: ExpMap -> Mapper GPU GPU (BuilderT GPU (State VNameSource))
transform ExpMap
expmap =
  Mapper GPU GPU (BuilderT GPU (State VNameSource))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper {mapOnBody :: Scope GPU -> Body GPU -> BabysitM (Body GPU)
mapOnBody = \Scope GPU
scope -> Scope GPU -> BabysitM (Body GPU) -> BabysitM (Body GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
scope (BabysitM (Body GPU) -> BabysitM (Body GPU))
-> (Body GPU -> BabysitM (Body GPU))
-> Body GPU
-> BabysitM (Body GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpMap -> Body GPU -> BabysitM (Body GPU)
transformBody ExpMap
expmap}

transformKernelBody ::
  ExpMap ->
  SegLevel ->
  SegSpace ->
  KernelBody GPU ->
  BabysitM (KernelBody GPU)
transformKernelBody :: ExpMap
-> SegLevel
-> SegSpace
-> KernelBody GPU
-> BuilderT GPU (State VNameSource) (KernelBody GPU)
transformKernelBody ExpMap
expmap SegLevel
lvl SegSpace
space KernelBody GPU
kbody = do
  -- Go spelunking for accesses to arrays that are defined outside the
  -- kernel body and where the indices are kernel thread indices.
  Scope GPU
scope <- BuilderT GPU (State VNameSource) (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let thread_gids :: [VName]
thread_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      thread_local :: Names
thread_local = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
thread_gids
      free_ker_vars :: Names
free_ker_vars = KernelBody GPU -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody GPU
kbody Names -> Names -> Names
`namesSubtract` SegSpace -> Names
getKerVariantIds SegSpace
space
  SubExp
num_threads <-
    String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (Exp (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) SubExp)
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp
          (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
          (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
          (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
  StateT
  Replacements (BuilderT GPU (State VNameSource)) (KernelBody GPU)
-> Replacements
-> BuilderT GPU (State VNameSource) (KernelBody GPU)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    ( Names
-> Names
-> Scope GPU
-> ArrayIndexTransform
     (StateT Replacements (BuilderT GPU (State VNameSource)))
-> KernelBody GPU
-> StateT
     Replacements (BuilderT GPU (State VNameSource)) (KernelBody GPU)
forall (f :: * -> *).
(Applicative f, Monad f) =>
Names
-> Names
-> Scope GPU
-> ArrayIndexTransform f
-> KernelBody GPU
-> f (KernelBody GPU)
traverseKernelBodyArrayIndexes
        Names
free_ker_vars
        Names
thread_local
        (Scope GPU
scope Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
        (ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform
     (StateT Replacements (BuilderT GPU (State VNameSource)))
forall (m :: * -> *).
MonadBuilder m =>
ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess ExpMap
expmap (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) SubExp
num_threads)
        KernelBody GPU
kbody
    )
    Replacements
forall a. Monoid a => a
mempty
  where
    getKerVariantIds :: SegSpace -> Names
getKerVariantIds = [VName] -> Names
namesFromList ([VName] -> Names) -> (SegSpace -> [VName]) -> SegSpace -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> (SegSpace -> Map VName (NameInfo Any)) -> SegSpace -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace

type ArrayIndexTransform m =
  Names ->
  (VName -> Bool) -> -- thread local?
  (VName -> SubExp -> Bool) -> -- variant to a certain gid (given as first param)?
  (SubExp -> Maybe SubExp) -> -- split substitution?
  Scope GPU -> -- type environment
  VName ->
  Slice SubExp ->
  m (Maybe (VName, Slice SubExp))

traverseKernelBodyArrayIndexes ::
  (Applicative f, Monad f) =>
  Names ->
  Names ->
  Scope GPU ->
  ArrayIndexTransform f ->
  KernelBody GPU ->
  f (KernelBody GPU)
traverseKernelBodyArrayIndexes :: Names
-> Names
-> Scope GPU
-> ArrayIndexTransform f
-> KernelBody GPU
-> f (KernelBody GPU)
traverseKernelBodyArrayIndexes Names
free_ker_vars Names
thread_variant Scope GPU
outer_scope ArrayIndexTransform f
f (KernelBody () Stms GPU
kstms [KernelResult]
kres) =
  BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () (Stms GPU -> [KernelResult] -> KernelBody GPU)
-> ([Stm GPU] -> Stms GPU)
-> [Stm GPU]
-> [KernelResult]
-> KernelBody GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList
    ([Stm GPU] -> [KernelResult] -> KernelBody GPU)
-> f [Stm GPU] -> f ([KernelResult] -> KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> f (Stm GPU)) -> [Stm GPU] -> f [Stm GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( (VarianceTable, Map VName SubExp, Scope GPU)
-> Stm GPU -> f (Stm GPU)
onStm
          ( VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
forall a. Monoid a => a
mempty Stms GPU
kstms,
            Stms GPU -> Map VName SubExp
mkSizeSubsts Stms GPU
kstms,
            Scope GPU
outer_scope
          )
      )
      (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
kstms)
    f ([KernelResult] -> KernelBody GPU)
-> f [KernelResult] -> f (KernelBody GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> f [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
kres
  where
    onLambda :: (VarianceTable, Map VName SubExp, Scope GPU)
-> Lambda GPU -> f (Lambda GPU)
onLambda (VarianceTable
variance, Map VName SubExp
szsubst, Scope GPU
scope) Lambda GPU
lam =
      (\Body GPU
body' -> Lambda GPU
lam {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'})
        (Body GPU -> Lambda GPU) -> f (Body GPU) -> f (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VarianceTable, Map VName SubExp, Scope GPU)
-> Body GPU -> f (Body GPU)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope GPU
scope') (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam)
      where
        scope' :: Scope GPU
scope' = Scope GPU
scope Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Lambda GPU -> [LParam GPU]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam)

    onBody :: (VarianceTable, Map VName SubExp, Scope GPU)
-> Body GPU -> f (Body GPU)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope GPU
scope) (Body BodyDec GPU
bdec Stms GPU
stms Result
bres) = do
      Stms GPU
stms' <- [Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU) -> f [Stm GPU] -> f (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> f (Stm GPU)) -> [Stm GPU] -> f [Stm GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Map VName SubExp, Scope GPU)
-> Stm GPU -> f (Stm GPU)
onStm (VarianceTable
variance', Map VName SubExp
szsubst', Scope GPU
scope')) (Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPU
stms)
      Body GPU -> f (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU -> f (Body GPU)) -> Body GPU -> f (Body GPU)
forall a b. (a -> b) -> a -> b
$ BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPU
bdec Stms GPU
stms' Result
bres
      where
        variance' :: VarianceTable
variance' = VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
variance Stms GPU
stms
        szsubst' :: Map VName SubExp
szsubst' = Stms GPU -> Map VName SubExp
mkSizeSubsts Stms GPU
stms Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
szsubst
        scope' :: Scope GPU
scope' = Scope GPU
scope Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU -> Scope GPU
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
stms

    onStm :: (VarianceTable, Map VName SubExp, Scope GPU)
-> Stm GPU -> f (Stm GPU)
onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope GPU
_) (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec (BasicOp (Index VName
arr Slice SubExp
is))) =
      Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec (Exp GPU -> Stm GPU)
-> (Maybe (VName, Slice SubExp) -> Exp GPU)
-> Maybe (VName, Slice SubExp)
-> Stm GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (VName, Slice SubExp) -> Exp GPU
oldOrNew (Maybe (VName, Slice SubExp) -> Stm GPU)
-> f (Maybe (VName, Slice SubExp)) -> f (Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArrayIndexTransform f
f Names
free_ker_vars VName -> Bool
isThreadLocal VName -> SubExp -> Bool
isGidVariant SubExp -> Maybe SubExp
sizeSubst Scope GPU
outer_scope VName
arr Slice SubExp
is
      where
        oldOrNew :: Maybe (VName, Slice SubExp) -> Exp GPU
oldOrNew Maybe (VName, Slice SubExp)
Nothing =
          BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
is
        oldOrNew (Just (VName
arr', Slice SubExp
is')) =
          BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' Slice SubExp
is'

        isGidVariant :: VName -> SubExp -> Bool
isGidVariant VName
gid (Var VName
v) =
          VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance)
        isGidVariant VName
_ SubExp
_ = Bool
False

        isThreadLocal :: VName -> Bool
isThreadLocal VName
v =
          Names
thread_variant
            Names -> Names -> Bool
`namesIntersect` Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance

        sizeSubst :: SubExp -> Maybe SubExp
sizeSubst (Constant PrimValue
v) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
        sizeSubst (Var VName
v)
          | VName
v VName -> Scope GPU -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outer_scope = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
          | Just SubExp
v' <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
szsubst = SubExp -> Maybe SubExp
sizeSubst SubExp
v'
          | Bool
otherwise = Maybe SubExp
forall a. Maybe a
Nothing
    onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope GPU
scope) (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec Exp GPU
e) =
      Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
dec (Exp GPU -> Stm GPU) -> f (Exp GPU) -> f (Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper GPU GPU f -> Exp GPU -> f (Exp GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((VarianceTable, Map VName SubExp, Scope GPU) -> Mapper GPU GPU f
mapper (VarianceTable
variance, Map VName SubExp
szsubst, Scope GPU
scope)) Exp GPU
e

    onOp :: (VarianceTable, Map VName SubExp, Scope GPU)
-> HostOp GPU (SOAC GPU) -> f (HostOp GPU (SOAC GPU))
onOp (VarianceTable, Map VName SubExp, Scope GPU)
ctx (OtherOp SOAC GPU
soac) =
      SOAC GPU -> HostOp GPU (SOAC GPU)
forall rep op. op -> HostOp rep op
OtherOp (SOAC GPU -> HostOp GPU (SOAC GPU))
-> f (SOAC GPU) -> f (HostOp GPU (SOAC GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper GPU GPU f -> SOAC GPU -> f (SOAC GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper Any Any f
forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda GPU -> f (Lambda GPU)
mapOnSOACLambda = (VarianceTable, Map VName SubExp, Scope GPU)
-> Lambda GPU -> f (Lambda GPU)
onLambda (VarianceTable, Map VName SubExp, Scope GPU)
ctx} SOAC GPU
soac
    onOp (VarianceTable, Map VName SubExp, Scope GPU)
_ HostOp GPU (SOAC GPU)
op = HostOp GPU (SOAC GPU) -> f (HostOp GPU (SOAC GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp GPU (SOAC GPU)
op

    mapper :: (VarianceTable, Map VName SubExp, Scope GPU) -> Mapper GPU GPU f
mapper (VarianceTable, Map VName SubExp, Scope GPU)
ctx =
      Mapper GPU GPU f
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPU -> Body GPU -> f (Body GPU)
mapOnBody = (Body GPU -> f (Body GPU)) -> Scope GPU -> Body GPU -> f (Body GPU)
forall a b. a -> b -> a
const ((VarianceTable, Map VName SubExp, Scope GPU)
-> Body GPU -> f (Body GPU)
onBody (VarianceTable, Map VName SubExp, Scope GPU)
ctx),
          mapOnOp :: Op GPU -> f (Op GPU)
mapOnOp = (VarianceTable, Map VName SubExp, Scope GPU)
-> HostOp GPU (SOAC GPU) -> f (HostOp GPU (SOAC GPU))
onOp (VarianceTable, Map VName SubExp, Scope GPU)
ctx
        }

    mkSizeSubsts :: Stms GPU -> Map VName SubExp
mkSizeSubsts = (Stm GPU -> Map VName SubExp) -> Stms GPU -> Map VName SubExp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPU -> Map VName SubExp
forall rep rep op.
(Op rep ~ HostOp rep op) =>
Stm rep -> Map VName SubExp
mkStmSizeSubst
      where
        mkStmSizeSubst :: Stm rep -> Map VName SubExp
mkStmSizeSubst (Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (SizeOp (SplitSpace _ _ _ elems_per_i)))) =
          VName -> SubExp -> Map VName SubExp
forall k a. k -> a -> Map k a
M.singleton (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) SubExp
elems_per_i
        mkStmSizeSubst Stm rep
_ = Map VName SubExp
forall a. Monoid a => a
mempty

type Replacements = M.Map (VName, Slice SubExp) VName

ensureCoalescedAccess ::
  MonadBuilder m =>
  ExpMap ->
  [(VName, SubExp)] ->
  SubExp ->
  ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess :: ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess
  ExpMap
expmap
  [(VName, SubExp)]
thread_space
  SubExp
num_threads
  Names
free_ker_vars
  VName -> Bool
isThreadLocal
  VName -> SubExp -> Bool
isGidVariant
  SubExp -> Maybe SubExp
sizeSubst
  Scope GPU
outer_scope
  VName
arr
  Slice SubExp
slice = do
    Maybe VName
seen <- (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Replacements -> Maybe VName)
 -> StateT Replacements m (Maybe VName))
-> (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Replacements -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName
arr, Slice SubExp
slice)

    case (Maybe VName
seen, VName -> Bool
isThreadLocal VName
arr, NameInfo GPU -> Type
forall t. Typed t => t -> Type
typeOf (NameInfo GPU -> Type) -> Maybe (NameInfo GPU) -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Scope GPU -> Maybe (NameInfo GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Scope GPU
outer_scope) of
      -- Already took care of this case elsewhere.
      (Just VName
arr', Bool
_, Maybe Type
_) ->
        Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
      (Maybe VName
Nothing, Bool
False, Just Type
t)
        -- We are fully indexing the array with thread IDs, but the
        -- indices are in a permuted order.
        | Just [SubExp]
is <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
          [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t,
          Just [SubExp]
is' <- Names
-> (VName -> SubExp -> Bool)
-> [SubExp]
-> [SubExp]
-> Maybe [SubExp]
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) [SubExp]
is,
          Just [Int]
perm <- [SubExp]
is' [SubExp] -> [SubExp] -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` [SubExp]
is ->
            VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBuilder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)
        -- Check whether the access is already coalesced because of a
        -- previous rearrange being applied to the current array:
        -- 1. get the permutation of the source-array rearrange
        -- 2. apply it to the slice
        -- 3. check that the innermost index is actually the gid
        --    of the innermost kernel dimension.
        -- If so, the access is already coalesced, nothing to do!
        -- (Cosmin's Heuristic.)
        | Just (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (BasicOp (Rearrange [Int]
perm VName
_))) <- VName -> ExpMap -> Maybe (Stm GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
          VName
inner_gid <- [VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids,
          Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm,
          [DimIndex SubExp]
slice' <- (Int -> DimIndex SubExp) -> [Int] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice [DimIndex SubExp] -> Int -> DimIndex SubExp
forall a. [a] -> Int -> a
!!) [Int]
perm,
          DimFix SubExp
inner_ind <- [DimIndex SubExp] -> DimIndex SubExp
forall a. [a] -> a
last [DimIndex SubExp]
slice',
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
          VName -> SubExp -> Bool
isGidVariant VName
inner_gid SubExp
inner_ind ->
            Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
        -- We are not fully indexing an array, but the remaining slice
        -- is invariant to the innermost-kernel dimension. We assume
        -- the remaining slice will be sequentially streamed, hence
        -- tiling will be applied later and will solve coalescing.
        -- Hence nothing to do at this point. (Cosmin's Heuristic.)
        | ([SubExp]
is, Slice SubExp
rem_slice) <- Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice Slice SubExp
slice,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
          Slice SubExp -> Bool
allDimAreSlice Slice SubExp
rem_slice,
          Maybe (Stm GPU)
Nothing <- VName -> ExpMap -> Maybe (Stm GPU)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
          PrimType
pt <- Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
pt) Slice SubExp
rem_slice,
          [SubExp]
is [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
          Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids Bool -> Bool -> Bool
|| [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
is),
          Bool -> Bool
not ([VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids VName -> Names -> Bool
`nameIn` ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
is Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
rem_slice)) ->
            Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
        -- We are not fully indexing the array, and the indices are not
        -- a proper prefix of the thread indices, and some indices are
        -- thread local, so we assume (HEURISTIC!)  that the remaining
        -- dimensions will be traversed sequentially.
        | ([SubExp]
is, Slice SubExp
rem_slice) <- Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice Slice SubExp
slice,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
          PrimType
pt <- Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t,
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
pt) Slice SubExp
rem_slice,
          [SubExp]
is [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isThreadLocal (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
is) -> do
            let perm :: [Int]
perm = Int -> Int -> [Int]
coalescingPermutation ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t
            VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBuilder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)

        -- We are taking a slice of the array with a unit stride.  We
        -- assume that the slice will be traversed sequentially.
        --
        -- We will really want to treat the sliced dimension like two
        -- dimensions so we can transpose them.  This may require
        -- padding.
        | ([SubExp]
is, Slice SubExp
rem_slice) <- Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice Slice SubExp
slice,
          [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(==) [SubExp]
is ([SubExp] -> [Bool]) -> [SubExp] -> [Bool]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids,
          DimSlice SubExp
offset SubExp
len (Constant PrimValue
stride) : [DimIndex SubExp]
_ <- Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
rem_slice,
          SubExp -> Bool
isThreadLocalSubExp SubExp
offset,
          Just {} <- SubExp -> Maybe SubExp
sizeSubst SubExp
len,
          PrimValue -> Bool
oneIsh PrimValue
stride -> do
            let num_chunks :: PrimExp VName
num_chunks =
                  if [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
is
                    then TPrimExp Int32 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName -> PrimExp VName)
-> TPrimExp Int32 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int32 VName
pe32 SubExp
num_threads
                    else
                      TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                        [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
                          (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
                            Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) [SubExp]
thread_gdims
            VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Int -> SubExp -> PrimExp VName -> VName -> m VName
forall (m :: * -> *).
MonadBuilder m =>
Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is) Type
t) PrimExp VName
num_chunks VName
arr)

        -- Everything is fine... assuming that the array is in row-major
        -- order!  Make sure that is the case.
        | Just {} <- VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap ->
            case Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice of
              Just [SubExp]
is
                | Just [SubExp]
_ <- Names
-> (VName -> SubExp -> Bool)
-> [SubExp]
-> [SubExp]
-> Maybe [SubExp]
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) [SubExp]
is ->
                    VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr)
                | Bool
otherwise ->
                    Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
              Maybe [SubExp]
_ -> VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr)
      (Maybe VName, Bool, Maybe Type)
_ -> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
    where
      ([VName]
thread_gids, [SubExp]
thread_gdims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
thread_space

      replace :: VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace VName
arr' = do
        (Replacements -> Replacements) -> StateT Replacements m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Replacements -> Replacements) -> StateT Replacements m ())
-> (Replacements -> Replacements) -> StateT Replacements m ()
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> VName -> Replacements -> Replacements
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (VName
arr, Slice SubExp
slice) VName
arr'
        Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)

      isThreadLocalSubExp :: SubExp -> Bool
isThreadLocalSubExp (Var VName
v) = VName -> Bool
isThreadLocal VName
v
      isThreadLocalSubExp Constant {} = Bool
False

-- Heuristic for avoiding rearranging too small arrays.
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice Int32
bs = (Bool, Int32) -> Bool
forall a b. (a, b) -> a
fst ((Bool, Int32) -> Bool)
-> (Slice SubExp -> (Bool, Int32)) -> Slice SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, Int32) -> SubExp -> (Bool, Int32))
-> (Bool, Int32) -> [SubExp] -> (Bool, Int32)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
bs) ([SubExp] -> (Bool, Int32))
-> (Slice SubExp -> [SubExp]) -> Slice SubExp -> (Bool, Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims
  where
    comb :: (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
x) (Constant (IntValue (Int32Value Int32
d))) = (Int32
d Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
x Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
4, Int32
d Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
x)
    comb (Bool
_, Int32
x) SubExp
_ = (Bool
False, Int32
x)

splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice (Slice []) = ([], [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [])
splitSlice (Slice (DimFix SubExp
i : [DimIndex SubExp]
is)) = ([SubExp] -> [SubExp])
-> ([SubExp], Slice SubExp) -> ([SubExp], Slice SubExp)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (SubExp
i SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:) (([SubExp], Slice SubExp) -> ([SubExp], Slice SubExp))
-> ([SubExp], Slice SubExp) -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
is)
splitSlice Slice SubExp
is = ([], Slice SubExp
is)

allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice (Slice []) = Bool
True
allDimAreSlice (Slice (DimFix SubExp
_ : [DimIndex SubExp]
_)) = Bool
False
allDimAreSlice (Slice (DimIndex SubExp
_ : [DimIndex SubExp]
is)) = Slice SubExp -> Bool
allDimAreSlice ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
is)

-- Try to move thread indexes into their proper position.
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes :: Names
-> (VName -> SubExp -> Bool)
-> [SubExp]
-> [SubExp]
-> Maybe [SubExp]
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant [SubExp]
tgids [SubExp]
is
  -- Do Nothing if:
  -- 1. any of the indices is a constant or a kernel free variable
  --    (because it would transpose a bigger array then needed -- big overhead).
  -- 2. the innermost index is variant to the innermost-thread gid
  --    (because access is likely to be already coalesced)
  -- 3. the indexes are a prefix of the thread indexes, because that
  -- means multiple threads will be accessing the same element.
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
isCt [SubExp]
is =
      Maybe [SubExp]
forall a. Maybe a
Nothing
  | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
free_ker_vars) ([SubExp] -> [VName]
subExpVars [SubExp]
is) =
      Maybe [SubExp]
forall a. Maybe a
Nothing
  | [SubExp]
is [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` [SubExp]
tgids =
      Maybe [SubExp]
forall a. Maybe a
Nothing
  | Bool -> Bool
not ([SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
tgids),
    Bool -> Bool
not ([SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
is),
    Var VName
innergid <- [SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
tgids,
    Int
num_is Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& VName -> SubExp -> Bool
isGidVariant VName
innergid ([SubExp] -> SubExp
forall a. [a] -> a
last [SubExp]
is) =
      [SubExp] -> Maybe [SubExp]
forall a. a -> Maybe a
Just [SubExp]
is
  -- 3. Otherwise try fix coalescing
  | Bool
otherwise =
      [SubExp] -> Maybe [SubExp]
forall a. a -> Maybe a
Just ([SubExp] -> Maybe [SubExp]) -> [SubExp] -> Maybe [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ([SubExp] -> (Int, SubExp) -> [SubExp])
-> [SubExp] -> [(Int, SubExp)] -> [SubExp]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [SubExp] -> (Int, SubExp) -> [SubExp]
move ([SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse [SubExp]
is) ([(Int, SubExp)] -> [SubExp]) -> [(Int, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [Int] -> [SubExp] -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] ([SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse [SubExp]
tgids)
  where
    num_is :: Int
num_is = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
is

    move :: [SubExp] -> (Int, SubExp) -> [SubExp]
move [SubExp]
is_rev (Int
i, SubExp
tgid)
      -- If tgid is in is_rev anywhere but at position i, and
      -- position i exists, we move it to position i instead.
      | Just Int
j <- SubExp -> [SubExp] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex SubExp
tgid [SubExp]
is_rev,
        Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
j,
        Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_is =
          Int -> Int -> [SubExp] -> [SubExp]
forall a b t.
(Integral a, Integral b, Show a, Show b, Show t) =>
a -> b -> [t] -> [t]
swap Int
i Int
j [SubExp]
is_rev
      | Bool
otherwise =
          [SubExp]
is_rev

    swap :: a -> b -> [t] -> [t]
swap a
i b
j [t]
l
      | Just t
ix <- a -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [t]
l,
        Just t
jx <- b -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth b
j [t]
l =
          a -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update a
i t
jx ([t] -> [t]) -> [t] -> [t]
forall a b. (a -> b) -> a -> b
$ b -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update b
j t
ix [t]
l
      | Bool
otherwise =
          String -> [t]
forall a. HasCallStack => String -> a
error (String -> [t]) -> String -> [t]
forall a b. (a -> b) -> a -> b
$ String
"coalescedIndexes swap: invalid indices" String -> String -> String
forall a. [a] -> [a] -> [a]
++ (a, b, [t]) -> String
forall a. Show a => a -> String
show (a
i, b
j, [t]
l)

    update :: t -> t -> [t] -> [t]
update t
0 t
x (t
_ : [t]
ys) = t
x t -> [t] -> [t]
forall a. a -> [a] -> [a]
: [t]
ys
    update t
i t
x (t
y : [t]
ys) = t
y t -> [t] -> [t]
forall a. a -> [a] -> [a]
: t -> t -> [t] -> [t]
update (t
i t -> t -> t
forall a. Num a => a -> a -> a
- t
1) t
x [t]
ys
    update t
_ t
_ [] = String -> [t]
forall a. HasCallStack => String -> a
error String
"coalescedIndexes: update"

    isCt :: SubExp -> Bool
    isCt :: SubExp -> Bool
isCt (Constant PrimValue
_) = Bool
True
    isCt (Var VName
_) = Bool
False

coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation Int
num_is Int
rank =
  [Int
num_is .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_is Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

rearrangeInput ::
  MonadBuilder m =>
  Maybe (Maybe [Int]) ->
  [Int] ->
  VName ->
  m VName
rearrangeInput :: Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just [Int]
current_perm)) [Int]
perm VName
arr
  | [Int]
current_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr -- Already has desired representation.
rearrangeInput Maybe (Maybe [Int])
Nothing [Int]
perm VName
arr
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr -- We don't know the current
  -- representation, but the indexing
  -- is linear, so let's hope the
  -- array is too.
rearrangeInput (Just Just {}) [Int]
perm VName
arr
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr -- We just want a row-major array, no tricks.
rearrangeInput Maybe (Maybe [Int])
manifest [Int]
perm VName
arr = do
  -- We may first manifest the array to ensure that it is flat in
  -- memory.  This is sometimes unnecessary, in which case the copy
  -- will hopefully be removed by the simplifier.
  VName
manifested <- if Maybe (Maybe [Int]) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Maybe [Int])
manifest then VName -> m VName
forall (m :: * -> *). MonadBuilder m => VName -> m VName
rowMajorArray VName
arr else VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_coalesced") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
    BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
manifested

rowMajorArray ::
  MonadBuilder m =>
  VName ->
  m VName
rowMajorArray :: VName -> m VName
rowMajorArray VName
arr = do
  Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> m Type -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_rowmajor") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] VName
arr

rearrangeSlice ::
  MonadBuilder m =>
  Int ->
  SubExp ->
  PrimExp VName ->
  VName ->
  m VName
rearrangeSlice :: Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice Int
d SubExp
w PrimExp VName
num_chunks VName
arr = do
  SubExp
num_chunks' <- String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"num_chunks" PrimExp VName
num_chunks

  (SubExp
w_padded, SubExp
padding) <- SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
num_chunks'

  SubExp
per_chunk <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"per_chunk" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w_padded SubExp
num_chunks'
  Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  VName
arr_padded <- SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t
  SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk (VName -> String
baseString VName
arr) VName
arr_padded Type
arr_t
  where
    padArray :: SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t = do
      let arr_shape :: Shape
arr_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
          padding_shape :: Shape
padding_shape = Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
d Shape
arr_shape SubExp
padding
      VName
arr_padding <-
        String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padding") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_t) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
padding_shape)
      String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padded") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
d (VName
arr VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName
arr_padding]) SubExp
w_padded

    rearrange :: SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk String
arr_name VName
arr_padded Type
arr_t = do
      let arr_dims :: [SubExp]
arr_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
          pre_dims :: [SubExp]
pre_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
d [SubExp]
arr_dims
          post_dims :: [SubExp]
post_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [SubExp]
arr_dims
          extradim_shape :: Shape
extradim_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [SubExp]
pre_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp
num_chunks', SubExp
per_chunk] [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
post_dims
          tr_perm :: [Int]
tr_perm = [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
d) ([Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
extradim_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
d] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0])
      VName
arr_extradim <-
        String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
extradim_shape) VName
arr_padded
      VName
arr_extradim_tr <-
        String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim_tr") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
tr_perm VName
arr_extradim
      VName
arr_inv_tr <-
        String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
            ShapeChange SubExp -> VName -> BasicOp
Reshape
              ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion [SubExp]
pre_dims ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (SubExp
w_padded SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
post_dims))
              VName
arr_extradim_tr
      String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr_init")
        (Exp (Rep m) -> m VName) -> m (Exp (Rep m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int
-> VName -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Int
-> VName -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eSliceArray Int
d VName
arr_inv_tr (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)

paddedScanReduceInput ::
  MonadBuilder m =>
  SubExp ->
  SubExp ->
  m (SubExp, SubExp)
paddedScanReduceInput :: SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
stride = do
  SubExp
w_padded <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"padded_size"
      (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
IntType -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eRoundToMultipleOf IntType
Int64 (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
stride)
  SubExp
padding <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"padding" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) SubExp
w_padded SubExp
w
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w_padded, SubExp
padding)

--- Computing variance.

type VarianceTable = M.Map VName Names

varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable
varianceInStms VarianceTable
t = (VarianceTable -> Stm GPU -> VarianceTable)
-> VarianceTable -> [Stm GPU] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
t ([Stm GPU] -> VarianceTable)
-> (Stms GPU -> [Stm GPU]) -> Stms GPU -> VarianceTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPU -> [Stm GPU]
forall rep. Stms rep -> [Stm rep]
stmsToList

varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable
varianceInStm VarianceTable
variance Stm GPU
stm =
  (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat Type -> [VName]) -> Pat Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm
  where
    add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
    look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
    binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm GPU -> Names
forall a. FreeIn a => a -> Names
freeIn Stm GPU
stm)