{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | Do various kernel optimisations - mostly related to coalescing.
module Futhark.Pass.KernelBabysitting
       ( babysitKernels
       , nonlinearInMemory
       )
       where

import Control.Arrow (first)
import Control.Monad.State.Strict
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.List (elemIndex, isPrefixOf, sort)
import Data.Maybe

import Futhark.MonadFreshNames
import Futhark.Representation.AST
import Futhark.Representation.Kernels
       hiding (Prog, Body, Stm, Pattern, PatElem,
               BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Tools
import Futhark.Pass
import Futhark.Util

babysitKernels :: Pass Kernels Kernels
babysitKernels :: Pass Kernels Kernels
babysitKernels = String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"babysit kernels"
                 String
"Transpose kernel input arrays for better performance." ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
                 (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall (f :: * -> *).
MonadFreshNames f =>
Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms
  where onStms :: Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms Scope Kernels
scope Stms Kernels
stms = do
          let m :: BinderT Kernels (State VNameSource) (Stms Kernels)
m = Scope Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BinderT Kernels (State VNameSource) (Stms Kernels)
 -> BinderT Kernels (State VNameSource) (Stms Kernels))
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
forall a. Monoid a => a
mempty Stms Kernels
stms
          ((Stms Kernels, Stms Kernels) -> Stms Kernels)
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Stms Kernels, Stms Kernels) -> Stms Kernels
forall a b. (a, b) -> a
fst (f (Stms Kernels, Stms Kernels) -> f (Stms Kernels))
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
 -> f (Stms Kernels, Stms Kernels))
-> (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms Kernels, Stms Kernels)
-> VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BinderT Kernels (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT Kernels (State VNameSource) (Stms Kernels)
m Scope Kernels
forall k a. Map k a
M.empty)

type BabysitM = Binder Kernels

transformStms :: ExpMap -> Stms Kernels -> BabysitM (Stms Kernels)
transformStms :: ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
expmap Stms Kernels
stms = BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels
     (State VNameSource)
     (Stms (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (BinderT Kernels (State VNameSource) ()
 -> BinderT
      Kernels
      (State VNameSource)
      (Stms (Lore (BinderT Kernels (State VNameSource)))))
-> BinderT Kernels (State VNameSource) ()
-> BinderT
     Kernels
     (State VNameSource)
     (Stms (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (ExpMap
 -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> Stms Kernels -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ExpMap -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap
transformStm ExpMap
expmap Stms Kernels
stms

transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody ExpMap
expmap (Body () Stms Kernels
stms Result
res) = do
  Stms Kernels
stms' <- ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
expmap Stms Kernels
stms
  Body Kernels -> BabysitM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels -> BabysitM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ BodyAttr Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms Kernels
stms' Result
res

-- | Map from variable names to defining expression.  We use this to
-- hackily determine whether something is transposed or otherwise
-- funky in memory (and we'd prefer it not to be).  If we cannot find
-- it in the map, we just assume it's all good.  HACK and FIXME, I
-- suppose.  We really should do this at the memory level.
type ExpMap = M.Map VName (Stm Kernels)

nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
name ExpMap
m =
  case VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name ExpMap
m of
    Just (Let Pattern Kernels
_ StmAux (ExpAttr Kernels)
_ (BasicOp (Opaque (Var VName
arr)))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pattern Kernels
_ StmAux (ExpAttr Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
    Just (Let Pattern Kernels
_ StmAux (ExpAttr Kernels)
_ (BasicOp (Reshape ShapeChange SubExp
_ VName
arr))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
    Just (Let Pattern Kernels
_ StmAux (ExpAttr Kernels)
_ (BasicOp (Manifest [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
    Just (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
_ (Op (SegOp (SegMap _ _ ts _)))) ->
      (PatElemT Type, Type) -> Maybe (Maybe [Int])
forall shape attr u.
(ArrayShape shape, Typed attr) =>
(PatElemT attr, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear ((PatElemT Type, Type) -> Maybe (Maybe [Int]))
-> Maybe (PatElemT Type, Type) -> Maybe (Maybe [Int])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((PatElemT Type, Type) -> Bool)
-> [(PatElemT Type, Type)] -> Maybe (PatElemT Type, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
name) (VName -> Bool)
-> ((PatElemT Type, Type) -> VName)
-> (PatElemT Type, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName (PatElemT Type -> VName)
-> ((PatElemT Type, Type) -> PatElemT Type)
-> (PatElemT Type, Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT Type, Type) -> PatElemT Type
forall a b. (a, b) -> a
fst)
      ([PatElemT Type] -> [Type] -> [(PatElemT Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
Pattern Kernels
pat) [Type]
ts)
    Maybe (Stm Kernels)
_ -> Maybe (Maybe [Int])
forall a. Maybe a
Nothing
  where nonlinear :: (PatElemT attr, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear (PatElemT attr
pe, TypeBase shape u
t)
          | Int
inner_r <- TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t, Int
inner_r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
              let outer_r :: Int
outer_r = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
pe) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
inner_r
              Maybe [Int] -> Maybe (Maybe [Int])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int
inner_r..Int
inner_rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
outer_rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
inner_rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
          | Bool
otherwise = Maybe (Maybe [Int])
forall a. Maybe a
Nothing

transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap
transformStm :: ExpMap -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap
transformStm ExpMap
expmap (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (Op (SegOp op))) = do
  let mapper :: SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper = SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
               { mapOnSegOpBody :: KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
mapOnSegOpBody =
                   ExpMap
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody ExpMap
expmap (SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op) (SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op)
               }
  SegOp SegLevel Kernels
op' <- SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
-> SegOp SegLevel Kernels
-> BinderT Kernels (State VNameSource) (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
  SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper SegOp SegLevel Kernels
op
  let stm' :: Stm Kernels
stm' = Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel Kernels
op'
  Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
stm'
  ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpMap -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (VName
name, Stm Kernels
stm') | VName
name <- PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern Kernels
pat ] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap
transformStm ExpMap
expmap (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux ExpT Kernels
e) = do
  ExpT Kernels
e' <- Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
-> ExpT Kernels -> BabysitM (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM (ExpMap
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform ExpMap
expmap) ExpT Kernels
e
  let bnd' :: Stm Kernels
bnd' = Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
aux ExpT Kernels
e'
  Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
bnd'
  ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpMap -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (VName
name, Stm Kernels
bnd') | VName
name <- PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern Kernels
pat ] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap

transform :: ExpMap -> Mapper Kernels Kernels BabysitM
transform :: ExpMap
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform ExpMap
expmap =
  Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope Kernels -> Body Kernels -> BabysitM (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels -> BabysitM (Body Kernels) -> BabysitM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BabysitM (Body Kernels) -> BabysitM (Body Kernels))
-> (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels
-> BabysitM (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody ExpMap
expmap }

transformKernelBody :: ExpMap -> SegLevel -> SegSpace -> KernelBody Kernels
                    -> BabysitM (KernelBody Kernels)
transformKernelBody :: ExpMap
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody ExpMap
expmap SegLevel
lvl SegSpace
space KernelBody Kernels
kbody = do
  -- Go spelunking for accesses to arrays that are defined outside the
  -- kernel body and where the indices are kernel thread indices.
  Scope Kernels
scope <- BinderT Kernels (State VNameSource) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let thread_gids :: [VName]
thread_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      thread_local :: Names
thread_local = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
thread_gids
      free_ker_vars :: Names
free_ker_vars = KernelBody Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody Kernels
kbody Names -> Names -> Names
`namesSubtract` SegSpace -> Names
getKerVariantIds SegSpace
space
  SubExp
num_threads <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore (BinderT Kernels (State VNameSource)))
 -> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef)
                 (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
  StateT
  Replacements
  (BinderT Kernels (State VNameSource))
  (KernelBody Kernels)
-> Replacements
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform
     (StateT Replacements (BinderT Kernels (State VNameSource)))
-> KernelBody Kernels
-> StateT
     Replacements
     (BinderT Kernels (State VNameSource))
     (KernelBody Kernels)
forall (f :: * -> *).
(Applicative f, Monad f) =>
Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes
              Names
free_ker_vars
              Names
thread_local
              (Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space)
              (ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform
     (StateT Replacements (BinderT Kernels (State VNameSource)))
forall (m :: * -> *).
MonadBinder m =>
ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess ExpMap
expmap (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) SubExp
num_threads)
              KernelBody Kernels
kbody)
    Replacements
forall a. Monoid a => a
mempty
  where getKerVariantIds :: SegSpace -> Names
getKerVariantIds = [VName] -> Names
namesFromList ([VName] -> Names) -> (SegSpace -> [VName]) -> SegSpace -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> (SegSpace -> Map VName (NameInfo Any)) -> SegSpace -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace

type ArrayIndexTransform m =
  Names ->
  (VName -> Bool) ->           -- thread local?
  (VName -> SubExp -> Bool)->  -- variant to a certain gid (given as first param)?
  (SubExp -> Maybe SubExp) ->  -- split substitution?
  Scope Kernels ->            -- type environment
  VName -> Slice SubExp -> m (Maybe (VName, Slice SubExp))

traverseKernelBodyArrayIndexes :: (Applicative f, Monad f) =>
                                  Names
                               -> Names
                               -> Scope Kernels
                               -> ArrayIndexTransform f
                               -> KernelBody Kernels
                               -> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes :: Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes Names
free_ker_vars Names
thread_variant Scope Kernels
outer_scope ArrayIndexTransform f
f (KernelBody () Stms Kernels
kstms [KernelResult]
kres) =
  BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> ([Stm Kernels] -> Stms Kernels)
-> [Stm Kernels]
-> [KernelResult]
-> KernelBody Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> [KernelResult] -> KernelBody Kernels)
-> f [Stm Kernels] -> f ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  (Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
forall a. Monoid a => a
mempty Stms Kernels
kstms,
               Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
kstms,
               Scope Kernels
outer_scope)) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
kstms) f ([KernelResult] -> KernelBody Kernels)
-> f [KernelResult] -> f (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
  [KernelResult] -> f [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
kres
  where onLambda :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) LambdaT Kernels
lam =
          (\Body Kernels
body' -> LambdaT Kernels
lam { lambdaBody :: Body Kernels
lambdaBody = Body Kernels
body' }) (Body Kernels -> LambdaT Kernels)
-> f (Body Kernels) -> f (LambdaT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          (VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope') (LambdaT Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam)
          where scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope Kernels
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams (LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam)

        onBody :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Body BodyAttr Kernels
battr Stms Kernels
stms Result
bres) = do
          Stms Kernels
stms' <- [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> f [Stm Kernels] -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance', Map VName SubExp
szsubst', Scope Kernels
scope')) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)
          BodyAttr Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body BodyAttr Kernels
battr Stms Kernels
stms' (Result -> Body Kernels) -> f Result -> f (Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> f Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
bres
          where variance' :: VarianceTable
variance' = VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
variance Stms Kernels
stms
                szsubst' :: Map VName SubExp
szsubst' = Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
stms Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
szsubst
                scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms

        onStm :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
_) (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
attr (BasicOp (Index VName
arr Slice SubExp
is))) =
          Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
attr (ExpT Kernels -> Stm Kernels)
-> (Maybe (VName, Slice SubExp) -> ExpT Kernels)
-> Maybe (VName, Slice SubExp)
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew (Maybe (VName, Slice SubExp) -> Stm Kernels)
-> f (Maybe (VName, Slice SubExp)) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArrayIndexTransform f
f Names
free_ker_vars VName -> Bool
isThreadLocal VName -> SubExp -> Bool
isGidVariant SubExp -> Maybe SubExp
sizeSubst Scope Kernels
outer_scope VName
arr Slice SubExp
is
          where oldOrNew :: Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew Maybe (VName, Slice SubExp)
Nothing =
                  BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
is
                oldOrNew (Just (VName
arr', Slice SubExp
is')) =
                  BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' Slice SubExp
is'

                isGidVariant :: VName -> SubExp -> Bool
isGidVariant VName
gid (Var VName
v) =
                  VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance)
                isGidVariant VName
_ SubExp
_ = Bool
False

                isThreadLocal :: VName -> Bool
isThreadLocal VName
v =
                  Names
thread_variant Names -> Names -> Bool
`namesIntersect`
                  Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance

                sizeSubst :: SubExp -> Maybe SubExp
sizeSubst (Constant PrimValue
v) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
                sizeSubst (Var VName
v)
                  | VName
v VName -> Scope Kernels -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope Kernels
outer_scope      = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
                  | Just SubExp
v' <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
szsubst = SubExp -> Maybe SubExp
sizeSubst SubExp
v'
                  | Bool
otherwise                      = Maybe SubExp
forall a. Maybe a
Nothing

        onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
attr ExpT Kernels
e) =
          Pattern Kernels
-> StmAux (ExpAttr Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpAttr Kernels)
attr (ExpT Kernels -> Stm Kernels)
-> f (ExpT Kernels) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper Kernels Kernels f -> ExpT Kernels -> f (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope)) ExpT Kernels
e

        onOp :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx (OtherOp SOAC Kernels
soac) =
          SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> f (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper Kernels Kernels f -> SOAC Kernels -> f (SOAC Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any f
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper{ mapOnSOACLambda :: LambdaT Kernels -> f (LambdaT Kernels)
mapOnSOACLambda = (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable, Map VName SubExp, Scope Kernels)
ctx } SOAC Kernels
soac
        onOp (VarianceTable, Map VName SubExp, Scope Kernels)
_ HostOp Kernels (SOAC Kernels)
op = HostOp Kernels (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return HostOp Kernels (SOAC Kernels)
op

        mapper :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable, Map VName SubExp, Scope Kernels)
ctx = Mapper Kernels Kernels f
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope Kernels -> Body Kernels -> f (Body Kernels)
mapOnBody = (Body Kernels -> f (Body Kernels))
-> Scope Kernels -> Body Kernels -> f (Body Kernels)
forall a b. a -> b -> a
const ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable, Map VName SubExp, Scope Kernels)
ctx)
                                    , mapOnOp :: Op Kernels -> f (Op Kernels)
mapOnOp = (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx }

        mkSizeSubsts :: Stms Kernels -> Map VName SubExp
mkSizeSubsts = (Stm Kernels -> Map VName SubExp)
-> Stms Kernels -> Map VName SubExp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm Kernels -> Map VName SubExp
forall lore lore op.
(Op lore ~ HostOp lore op) =>
Stm lore -> Map VName SubExp
mkStmSizeSubst
          where mkStmSizeSubst :: Stm lore -> Map VName SubExp
mkStmSizeSubst (Let (Pattern [] [PatElemT (LetAttr lore)
pe]) StmAux (ExpAttr lore)
_ (Op (SizeOp (SplitSpace _ _ _ elems_per_i)))) =
                  VName -> SubExp -> Map VName SubExp
forall k a. k -> a -> Map k a
M.singleton (PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
pe) SubExp
elems_per_i
                mkStmSizeSubst Stm lore
_ = Map VName SubExp
forall a. Monoid a => a
mempty

type Replacements = M.Map (VName, Slice SubExp) VName

ensureCoalescedAccess :: MonadBinder m =>
                         ExpMap
                      -> [(VName,SubExp)]
                      -> SubExp
                      -> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess :: ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess ExpMap
expmap [(VName, SubExp)]
thread_space SubExp
num_threads Names
free_ker_vars VName -> Bool
isThreadLocal
                      VName -> SubExp -> Bool
isGidVariant SubExp -> Maybe SubExp
sizeSubst Scope Kernels
outer_scope VName
arr Slice SubExp
slice = do
  Maybe VName
seen <- (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Replacements -> Maybe VName)
 -> StateT Replacements m (Maybe VName))
-> (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Replacements -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName
arr, Slice SubExp
slice)

  case (Maybe VName
seen, VName -> Bool
isThreadLocal VName
arr, NameInfo Kernels -> Type
forall t. Typed t => t -> Type
typeOf (NameInfo Kernels -> Type)
-> Maybe (NameInfo Kernels) -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Scope Kernels -> Maybe (NameInfo Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Scope Kernels
outer_scope) of
    -- Already took care of this case elsewhere.
    (Just VName
arr', Bool
_, Maybe Type
_) ->
      Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)

    (Maybe VName
Nothing, Bool
False, Just Type
t)
      -- We are fully indexing the array with thread IDs, but the
      -- indices are in a permuted order.
      | Just Result
is <- Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
        Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t,
        Just Result
is' <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is,
        Just [Int]
perm <- Result
is' Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
is ->
          VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)

      -- Check whether the access is already coalesced because of a
      -- previous rearrange being applied to the current array:
      -- 1. get the permutation of the source-array rearrange
      -- 2. apply it to the slice
      -- 3. check that the innermost index is actually the gid
      --    of the innermost kernel dimension.
      -- If so, the access is already coalesced, nothing to do!
      -- (Cosmin's Heuristic.)
      | Just (Let Pattern Kernels
_ StmAux (ExpAttr Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) <- VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
        VName
inner_gid <- [VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids,
        Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm,
        Slice SubExp
slice' <- (Int -> DimIndex SubExp) -> [Int] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp
slice Slice SubExp -> Int -> DimIndex SubExp
forall a. [a] -> Int -> a
!!) [Int]
perm,
        DimFix SubExp
inner_ind <- Slice SubExp -> DimIndex SubExp
forall a. [a] -> a
last Slice SubExp
slice',
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
        VName -> SubExp -> Bool
isGidVariant VName
inner_gid SubExp
inner_ind ->
          Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing

      -- We are not fully indexing an array, but the remaining slice
      -- is invariant to the innermost-kernel dimension. We assume
      -- the remaining slice will be sequentially streamed, hence
      -- tiling will be applied later and will solve coalescing.
      -- Hence nothing to do at this point. (Cosmin's Heuristic.)
      | (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
        Slice SubExp -> Bool
allDimAreSlice Slice SubExp
rem_slice,
        Maybe (Stm Kernels)
Nothing <- VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
        Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
        Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids Bool -> Bool -> Bool
|| Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
        Bool -> Bool
not ([VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids VName -> Names -> Bool
`nameIn` (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
rem_slice)) ->
          Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing

      -- We are not fully indexing the array, and the indices are not
      -- a proper prefix of the thread indices, and some indices are
      -- thread local, so we assume (HEURISTIC!)  that the remaining
      -- dimensions will be traversed sequentially.
      | (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
        Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isThreadLocal (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is) -> do
          let perm :: [Int]
perm = Int -> Int -> [Int]
coalescingPermutation (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t
          VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)

      -- We are taking a slice of the array with a unit stride.  We
      -- assume that the slice will be traversed sequentially.
      --
      -- We will really want to treat the sliced dimension like two
      -- dimensions so we can transpose them.  This may require
      -- padding.
      | (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
        [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp -> Bool) -> Result -> Result -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(==) Result
is (Result -> [Bool]) -> Result -> [Bool]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids,
        DimSlice SubExp
offset SubExp
len (Constant PrimValue
stride):Slice SubExp
_ <- Slice SubExp
rem_slice,
        SubExp -> Bool
isThreadLocalSubExp SubExp
offset,
        Just {} <- SubExp -> Maybe SubExp
sizeSubst SubExp
len,
        PrimValue -> Bool
oneIsh PrimValue
stride -> do
          let num_chunks :: PrimExp VName
num_chunks = if Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is
                           then PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
num_threads
                           else IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int32 (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                                [PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Result -> [PrimExp VName]) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
                                Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Result
thread_gdims
          VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Int -> SubExp -> PrimExp VName -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Type
t) PrimExp VName
num_chunks VName
arr)

      -- Everything is fine... assuming that the array is in row-major
      -- order!  Make sure that is the case.
      | Just{} <- VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap ->
          case Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice of
            Just Result
is | Just Result
_ <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is ->
                        VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)
                    | Bool
otherwise ->
                        Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
            Maybe Result
_ -> VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)

    (Maybe VName, Bool, Maybe Type)
_ -> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing

  where ([VName]
thread_gids, Result
thread_gdims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
thread_space

        replace :: VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace VName
arr' = do
          (Replacements -> Replacements) -> StateT Replacements m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Replacements -> Replacements) -> StateT Replacements m ())
-> (Replacements -> Replacements) -> StateT Replacements m ()
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> VName -> Replacements -> Replacements
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (VName
arr, Slice SubExp
slice) VName
arr'
          Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Slice SubExp)
 -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)

        isThreadLocalSubExp :: SubExp -> Bool
isThreadLocalSubExp (Var VName
v) = VName -> Bool
isThreadLocal VName
v
        isThreadLocalSubExp Constant{} = Bool
False

-- Heuristic for avoiding rearranging too small arrays.
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice Int32
bs = (Bool, Int32) -> Bool
forall a b. (a, b) -> a
fst ((Bool, Int32) -> Bool)
-> (Slice SubExp -> (Bool, Int32)) -> Slice SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, Int32) -> SubExp -> (Bool, Int32))
-> (Bool, Int32) -> Result -> (Bool, Int32)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True,Int32
bs) (Result -> (Bool, Int32))
-> (Slice SubExp -> Result) -> Slice SubExp -> (Bool, Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims
  where comb :: (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
x) (Constant (IntValue (Int32Value Int32
d))) = (Int32
dInt32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
*Int32
x Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
4, Int32
dInt32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
*Int32
x)
        comb (Bool
_, Int32
x)     SubExp
_                                   = (Bool
False, Int32
x)

splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice :: Slice SubExp -> (Result, Slice SubExp)
splitSlice [] = ([], [])
splitSlice (DimFix SubExp
i:Slice SubExp
is) = (Result -> Result)
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (SubExp
iSubExp -> Result -> Result
forall a. a -> [a] -> [a]
:) ((Result, Slice SubExp) -> (Result, Slice SubExp))
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
is
splitSlice Slice SubExp
is = ([], Slice SubExp
is)

allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice [] = Bool
True
allDimAreSlice (DimFix SubExp
_:Slice SubExp
_) = Bool
False
allDimAreSlice (DimIndex SubExp
_:Slice SubExp
is) = Slice SubExp -> Bool
allDimAreSlice Slice SubExp
is

-- Try to move thread indexes into their proper position.
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes :: Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant Result
tgids Result
is
  -- Do Nothing if:
  -- 1. any of the indices is a constant or a kernel free variable
  --    (because it would transpose a bigger array then needed -- big overhead).
  -- 2. the innermost index is variant to the innermost-thread gid
  --    (because access is likely to be already coalesced)
  -- 3. the indexes are a prefix of the thread indexes, because that
  -- means multiple threads will be accessing the same element.
  | (SubExp -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
isCt Result
is =
      Maybe Result
forall a. Maybe a
Nothing
  | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
free_ker_vars) (Result -> [VName]
subExpVars Result
is) =
      Maybe Result
forall a. Maybe a
Nothing
  | Result
is Result -> Result -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Result
tgids =
      Maybe Result
forall a. Maybe a
Nothing
  | Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
tgids),
    Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
    Var VName
innergid <- Result -> SubExp
forall a. [a] -> a
last Result
tgids,
    Int
num_is Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& VName -> SubExp -> Bool
isGidVariant VName
innergid (Result -> SubExp
forall a. [a] -> a
last Result
is) =
      Result -> Maybe Result
forall a. a -> Maybe a
Just Result
is
  -- 3. Otherwise try fix coalescing
  | Bool
otherwise =
      Result -> Maybe Result
forall a. a -> Maybe a
Just (Result -> Maybe Result) -> Result -> Maybe Result
forall a b. (a -> b) -> a -> b
$ Result -> Result
forall a. [a] -> [a]
reverse (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ (Result -> (Int, SubExp) -> Result)
-> Result -> [(Int, SubExp)] -> Result
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Result -> (Int, SubExp) -> Result
move (Result -> Result
forall a. [a] -> [a]
reverse Result
is) ([(Int, SubExp)] -> Result) -> [(Int, SubExp)] -> Result
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] (Result -> Result
forall a. [a] -> [a]
reverse Result
tgids)
  where num_is :: Int
num_is = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is

        move :: Result -> (Int, SubExp) -> Result
move Result
is_rev (Int
i, SubExp
tgid)
          -- If tgid is in is_rev anywhere but at position i, and
          -- position i exists, we move it to position i instead.
          | Just Int
j <- SubExp -> Result -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex SubExp
tgid Result
is_rev, Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
j, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_is =
              Int -> Int -> Result -> Result
forall a b t.
(Integral a, Integral b, Show a, Show b, Show t) =>
a -> b -> [t] -> [t]
swap Int
i Int
j Result
is_rev
          | Bool
otherwise =
              Result
is_rev

        swap :: a -> b -> [t] -> [t]
swap a
i b
j [t]
l
          | Just t
ix <- a -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [t]
l,
            Just t
jx <- b -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth b
j [t]
l =
              a -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update a
i t
jx ([t] -> [t]) -> [t] -> [t]
forall a b. (a -> b) -> a -> b
$ b -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update b
j t
ix [t]
l
          | Bool
otherwise =
              String -> [t]
forall a. HasCallStack => String -> a
error (String -> [t]) -> String -> [t]
forall a b. (a -> b) -> a -> b
$ String
"coalescedIndexes swap: invalid indices" String -> String -> String
forall a. [a] -> [a] -> [a]
++ (a, b, [t]) -> String
forall a. Show a => a -> String
show (a
i, b
j, [t]
l)

        update :: t -> t -> [t] -> [t]
update t
0 t
x (t
_:[t]
ys) = t
x t -> [t] -> [t]
forall a. a -> [a] -> [a]
: [t]
ys
        update t
i t
x (t
y:[t]
ys) = t
y t -> [t] -> [t]
forall a. a -> [a] -> [a]
: t -> t -> [t] -> [t]
update (t
it -> t -> t
forall a. Num a => a -> a -> a
-t
1) t
x [t]
ys
        update t
_ t
_ []     = String -> [t]
forall a. HasCallStack => String -> a
error String
"coalescedIndexes: update"

        isCt :: SubExp -> Bool
        isCt :: SubExp -> Bool
isCt (Constant PrimValue
_) = Bool
True
        isCt (Var      VName
_) = Bool
False

coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation Int
num_is Int
rank =
  [Int
num_is..Int
rankInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
num_isInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

rearrangeInput :: MonadBinder m =>
                  Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput :: Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just [Int]
current_perm)) [Int]
perm VName
arr
  | [Int]
current_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr -- Already has desired representation.

rearrangeInput Maybe (Maybe [Int])
Nothing [Int]
perm VName
arr
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr -- We don't know the current
                                   -- representation, but the indexing
                                   -- is linear, so let's hope the
                                   -- array is too.
rearrangeInput (Just Just{}) [Int]
perm VName
arr
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr -- We just want a row-major array, no tricks.
rearrangeInput Maybe (Maybe [Int])
manifest [Int]
perm VName
arr = do
  -- We may first manifest the array to ensure that it is flat in
  -- memory.  This is sometimes unnecessary, in which case the copy
  -- will hopefully be removed by the simplifier.
  VName
manifested <- if Maybe (Maybe [Int]) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Maybe [Int])
manifest then VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr else VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_coalesced") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
    BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
manifested

rowMajorArray :: MonadBinder m =>
                 VName -> m VName
rowMajorArray :: VName -> m VName
rowMajorArray VName
arr = do
  Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> m Type -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_rowmajor") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int
0..Int
rankInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] VName
arr

rearrangeSlice :: MonadBinder m =>
                  Int -> SubExp -> PrimExp VName -> VName
               -> m VName
rearrangeSlice :: Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice Int
d SubExp
w PrimExp VName
num_chunks VName
arr = do
  SubExp
num_chunks' <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_chunks" (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp PrimExp VName
num_chunks

  (SubExp
w_padded, SubExp
padding) <- SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
num_chunks'

  SubExp
per_chunk <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_chunk" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SQuot IntType
Int32) SubExp
w_padded SubExp
num_chunks'
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  VName
arr_padded <- SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t
  SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk (VName -> String
baseString VName
arr) VName
arr_padded Type
arr_t

  where padArray :: SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t = do
          let arr_shape :: Shape
arr_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
              padding_shape :: Shape
padding_shape = Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
d Shape
arr_shape SubExp
padding
          VName
arr_padding <-
            String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padding") (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_t) (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
padding_shape)
          String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padded") (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d VName
arr [VName
arr_padding] SubExp
w_padded

        rearrange :: SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk String
arr_name VName
arr_padded Type
arr_t = do
          let arr_dims :: Result
arr_dims = Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
              pre_dims :: Result
pre_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
d Result
arr_dims
              post_dims :: Result
post_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Result
arr_dims
              extradim_shape :: Shape
extradim_shape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ Result
pre_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
num_chunks', SubExp
per_chunk] Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
post_dims
              tr_perm :: [Int]
tr_perm = [Int
0..Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
d) ([Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2..Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
extradim_shapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
d] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0])
          VName
arr_extradim <-
            String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim") (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (Result -> ShapeChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
extradim_shape) VName
arr_padded
          VName
arr_extradim_tr <-
            String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim_tr") (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
tr_perm VName
arr_extradim
          VName
arr_inv_tr <- String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr") (ExpT (Lore m) -> m VName) -> ExpT (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion Result
pre_dims ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (SubExp
w_padded SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: Result
post_dims))
            VName
arr_extradim_tr
          String -> ExpT (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr_init") (ExpT (Lore m) -> m VName) -> m (ExpT (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
            Int
-> VName
-> m (ExpT (Lore m))
-> m (ExpT (Lore m))
-> m (ExpT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
eSliceArray Int
d  VName
arr_inv_tr (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> m (ExpT (Lore m))) -> SubExp -> m (ExpT (Lore m))
forall a b. (a -> b) -> a -> b
$ Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w)

paddedScanReduceInput :: MonadBinder m =>
                         SubExp -> SubExp
                      -> m (SubExp, SubExp)
paddedScanReduceInput :: SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
stride = do
  SubExp
w_padded <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padded_size" (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
              IntType
-> m (ExpT (Lore m)) -> m (ExpT (Lore m)) -> m (ExpT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eRoundToMultipleOf IntType
Int32 (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w) (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
stride)
  SubExp
padding <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padding" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int32 Overflow
OverflowUndef) SubExp
w_padded SubExp
w
  (SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
w_padded, SubExp
padding)

--- Computing variance.

type VarianceTable = M.Map VName Names

varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
t = (VarianceTable -> Stm Kernels -> VarianceTable)
-> VarianceTable -> [Stm Kernels] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
t ([Stm Kernels] -> VarianceTable)
-> (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> VarianceTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList

varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
variance Stm Kernels
bnd =
  (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
bnd
  where add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
        look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
        binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
bnd)