{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.Base
  ( KernelConstants (..)
  , keyWithEntryPoint
  , CallKernelGen
  , InKernelGen
  , HostEnv (..)
  , KernelEnv (..)
  , computeThreadChunkSize
  , groupReduce
  , groupScan
  , isActive
  , sKernelThread
  , sKernelGroup
  , sReplicate
  , sIota
  , sCopy
  , compileThreadResult
  , compileGroupResult
  , virtualiseGroups
  , groupLoop
  , kernelLoop
  , groupCoverSpace
  , precomputeSegOpIDs

  , atomicUpdateLocking
  , AtomicBinOp
  , Locking(..)
  , AtomicUpdate(..)
  , DoAtomicUpdate
  )
  where

import Control.Monad.Except
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.List (elemIndex, find, nub, zip4)

import Prelude hiding (quot, rem)

import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Futhark.Util (chunks, maybeNth, mapAccumLM, takeLast, dropLast)

newtype HostEnv = HostEnv
  { HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp }

data KernelEnv = KernelEnv
  { KernelEnv -> AtomicBinOp
kernelAtomics :: AtomicBinOp
  , KernelEnv -> KernelConstants
kernelConstants :: KernelConstants
  }

type CallKernelGen = ImpM KernelsMem HostEnv Imp.HostOp
type InKernelGen = ImpM KernelsMem KernelEnv Imp.KernelOp

data KernelConstants =
  KernelConstants
  { KernelConstants -> Exp
kernelGlobalThreadId :: Imp.Exp
  , KernelConstants -> Exp
kernelLocalThreadId :: Imp.Exp
  , KernelConstants -> Exp
kernelGroupId :: Imp.Exp
  , KernelConstants -> VName
kernelGlobalThreadIdVar :: VName
  , KernelConstants -> VName
kernelLocalThreadIdVar :: VName
  , KernelConstants -> VName
kernelGroupIdVar :: VName
  , KernelConstants -> Exp
kernelNumGroups :: Imp.Exp
  , KernelConstants -> Exp
kernelGroupSize :: Imp.Exp
  , KernelConstants -> Exp
kernelNumThreads :: Imp.Exp
  , KernelConstants -> Exp
kernelWaveSize :: Imp.Exp
  , KernelConstants -> Exp
kernelThreadActive :: Imp.Exp
  , KernelConstants -> Map [SubExp] [Exp]
kernelLocalIdMap :: M.Map [SubExp] [Imp.Exp]
    -- ^ A mapping from dimensions of nested SegOps to already
    -- computed local thread IDs.
  }

segOpSizes :: Stms KernelsMem -> S.Set [SubExp]
segOpSizes :: Stms KernelsMem -> Set [SubExp]
segOpSizes = Stms KernelsMem -> Set [SubExp]
onStms
  where onStms :: Stms KernelsMem -> Set [SubExp]
onStms = (Stm KernelsMem -> Set [SubExp]) -> Stms KernelsMem -> Set [SubExp]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp KernelsMem -> Set [SubExp]
onExp (Exp KernelsMem -> Set [SubExp])
-> (Stm KernelsMem -> Exp KernelsMem)
-> Stm KernelsMem
-> Set [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm KernelsMem -> Exp KernelsMem
forall lore. Stm lore -> Exp lore
stmExp)
        onExp :: Exp KernelsMem -> Set [SubExp]
onExp (Op (Inner (SegOp op))) =
          [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton ([SubExp] -> Set [SubExp]) -> [SubExp] -> Set [SubExp]
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace (SegSpace -> [(VName, SubExp)]) -> SegSpace -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel KernelsMem
op
        onExp (If SubExp
_ BodyT KernelsMem
tbranch BodyT KernelsMem
fbranch IfDec (BranchType KernelsMem)
_) =
          Stms KernelsMem -> Set [SubExp]
onStms (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms BodyT KernelsMem
tbranch) Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem -> Set [SubExp]
onStms (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms BodyT KernelsMem
fbranch)
        onExp (DoLoop [(FParam KernelsMem, SubExp)]
_ [(FParam KernelsMem, SubExp)]
_ LoopForm KernelsMem
_ BodyT KernelsMem
body) =
          Stms KernelsMem -> Set [SubExp]
onStms (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms BodyT KernelsMem
body)
        onExp Exp KernelsMem
_ = Set [SubExp]
forall a. Monoid a => a
mempty

precomputeSegOpIDs :: Stms KernelsMem -> InKernelGen a -> InKernelGen a
precomputeSegOpIDs :: Stms KernelsMem -> InKernelGen a -> InKernelGen a
precomputeSegOpIDs Stms KernelsMem
stms InKernelGen a
m = do
  Exp
ltid <- KernelConstants -> Exp
kernelLocalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  Map [SubExp] [Exp]
new_ids <- [([SubExp], [Exp])] -> Map [SubExp] [Exp]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([([SubExp], [Exp])] -> Map [SubExp] [Exp])
-> ImpM KernelsMem KernelEnv KernelOp [([SubExp], [Exp])]
-> ImpM KernelsMem KernelEnv KernelOp (Map [SubExp] [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> ImpM KernelsMem KernelEnv KernelOp ([SubExp], [Exp]))
-> [[SubExp]]
-> ImpM KernelsMem KernelEnv KernelOp [([SubExp], [Exp])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Exp
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp ([SubExp], [Exp])
forall a lore r op.
ToExp a =>
Exp -> [a] -> ImpM lore r op ([a], [Exp])
mkMap Exp
ltid) (Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList (Stms KernelsMem -> Set [SubExp]
segOpSizes Stms KernelsMem
stms))
  let f :: KernelEnv -> KernelEnv
f KernelEnv
env = KernelEnv
env { kernelConstants :: KernelConstants
kernelConstants =
                      (KernelEnv -> KernelConstants
kernelConstants KernelEnv
env) { kernelLocalIdMap :: Map [SubExp] [Exp]
kernelLocalIdMap = Map [SubExp] [Exp]
new_ids }
                  }
  (KernelEnv -> KernelEnv) -> InKernelGen a -> InKernelGen a
forall r lore op a.
(r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv KernelEnv -> KernelEnv
f InKernelGen a
m
  where mkMap :: Exp -> [a] -> ImpM lore r op ([a], [Exp])
mkMap Exp
ltid [a]
dims = do
          [Exp]
dims' <- (a -> ImpM lore r op Exp) -> [a] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [a]
dims
          [Exp]
ids' <- (Exp -> ImpM lore r op Exp) -> [Exp] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp -> ImpM lore r op Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"ltid_pre") ([Exp] -> ImpM lore r op [Exp]) -> [Exp] -> ImpM lore r op [Exp]
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' Exp
ltid
          ([a], [Exp]) -> ImpM lore r op ([a], [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
dims, [Exp]
ids')

keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key =
  String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String -> (Name -> String) -> Maybe Name -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ((String -> String -> String
forall a. [a] -> [a] -> [a]
++String
".") (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString) Maybe Name
fname String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameToString Name
key

allocLocal :: AllocCompiler KernelsMem r Imp.KernelOp
allocLocal :: AllocCompiler KernelsMem r KernelOp
allocLocal VName
mem Count Bytes Exp
size =
  KernelOp -> ImpM KernelsMem r KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM KernelsMem r KernelOp ())
-> KernelOp -> ImpM KernelsMem r KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes Exp -> KernelOp
Imp.LocalAlloc VName
mem Count Bytes Exp
size

kernelAlloc :: Pattern KernelsMem
            -> SubExp -> Space
            -> InKernelGen ()
kernelAlloc :: Pattern KernelsMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
_]) SubExp
_ ScalarSpace{} =
  -- Handled by the declaration of the memory block, which is then
  -- translated to an actual scalar variable during C code generation.
  () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
kernelAlloc (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
mem]) SubExp
size (Space String
"local") = do
  Exp
size' <- SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
size
  AllocCompiler KernelsMem KernelEnv KernelOp
forall r. AllocCompiler KernelsMem r KernelOp
allocLocal (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
mem) (Count Bytes Exp -> InKernelGen ())
-> Count Bytes Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> Count Bytes Exp
Imp.bytes Exp
size'
kernelAlloc (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
mem]) SubExp
_ Space
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate memory block " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatElemT LetDecMem -> String
forall a. Pretty a => a -> String
pretty PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
mem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in kernel."
kernelAlloc Pattern KernelsMem
dest SubExp
_ Space
_ =
  String -> InKernelGen ()
forall a. HasCallStack => String -> a
error (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for in-kernel allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LetDecMem -> String
forall a. Show a => a -> String
show Pattern KernelsMem
PatternT LetDecMem
dest

splitSpace :: (ToExp w, ToExp i, ToExp elems_per_thread) =>
              Pattern KernelsMem -> SplitOrdering -> w -> i -> elems_per_thread
           -> ImpM lore r op ()
splitSpace :: Pattern KernelsMem
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace (Pattern [] [PatElemT (LetDec KernelsMem)
size]) SplitOrdering
o w
w i
i elems_per_thread
elems_per_thread = do
  Count Elements Exp
num_elements <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> w -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp w
w
  Exp
i' <- i -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp i
i
  Count Elements Exp
elems_per_thread' <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> elems_per_thread -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp elems_per_thread
elems_per_thread
  SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
forall lore r op.
SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
o Exp
i' Count Elements Exp
elems_per_thread' Count Elements Exp
num_elements (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
size)
splitSpace Pattern KernelsMem
pat SplitOrdering
_ w
_ i
_ elems_per_thread
_ =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for splitSpace: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LetDecMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LetDecMem
pat

compileThreadExp :: ExpCompiler KernelsMem KernelEnv Imp.KernelOp
compileThreadExp :: ExpCompiler KernelsMem KernelEnv KernelOp
compileThreadExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (ArrayLit [SubExp]
es Type
_)) =
  [(Int32, SubExp)]
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int32] -> [SubExp] -> [(Int32, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int32
0..] [SubExp]
es) (((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int32
i,SubExp
e) ->
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
dest) [Int32 -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
i::Int32)] SubExp
e []
compileThreadExp Pattern KernelsMem
dest Exp KernelsMem
e =
  ExpCompiler KernelsMem KernelEnv KernelOp
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern KernelsMem
dest Exp KernelsMem
e


-- | Assign iterations of a for-loop to all threads in the kernel.
-- The passed-in function is invoked with the (symbolic) iteration.
-- 'threadOperations' will be in effect in the body.  For
-- multidimensional loops, use 'groupCoverSpace'.
kernelLoop :: Imp.Exp -> Imp.Exp -> Imp.Exp
           -> (Imp.Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop :: Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
tid Exp
num_threads Exp
n Exp -> InKernelGen ()
f =
  Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
  if Exp
n Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
num_threads then
    Exp -> InKernelGen ()
f Exp
tid
  else do
    -- Compute how many elements this thread is responsible for.
    -- Formula: (n - tid) / num_threads (rounded up).
    let elems_for_this :: Exp
elems_for_this = (Exp
n Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
tid) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
num_threads

    String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
elems_for_this ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> Exp -> InKernelGen ()
f (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_threads Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
tid

-- | Assign iterations of a for-loop to threads in the workgroup.  The
-- passed-in function is invoked with the (symbolic) iteration.  For
-- multidimensional loops, use 'groupCoverSpace'.
groupLoop :: Imp.Exp
          -> (Imp.Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop :: Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop Exp
n Exp -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop (KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants) (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Exp
n Exp -> InKernelGen ()
f

-- | Iterate collectively though a multidimensional space, such that
-- all threads in the group participate.  The passed-in function is
-- invoked with a (symbolic) point in the index space.
groupCoverSpace :: [Imp.Exp]
                -> ([Imp.Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace :: [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds [Exp] -> InKernelGen ()
f =
  Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
ds) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> InKernelGen ()
f ([Exp] -> InKernelGen ())
-> (Exp -> [Exp]) -> Exp -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
ds

compileGroupExp :: ExpCompiler KernelsMem KernelEnv Imp.KernelOp
-- The static arrays stuff does not work inside kernels.
compileGroupExp :: ExpCompiler KernelsMem KernelEnv KernelOp
compileGroupExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (ArrayLit [SubExp]
es Type
_)) =
  [(Int32, SubExp)]
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int32] -> [SubExp] -> [(Int32, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int32
0..] [SubExp]
es) (((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int32
i,SubExp
e) ->
  VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
dest) [Int32 -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
i::Int32)] SubExp
e []
compileGroupExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (Replicate Shape
ds SubExp
se)) = do
  [Exp]
ds' <- (SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp])
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ds
  [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds' (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
dest) [Exp]
is SubExp
se (Int -> [Exp] -> [Exp]
forall a. Int -> [a] -> [a]
drop (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
ds) [Exp]
is)
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
_)) = do
  Exp
n' <- SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
n
  Exp
e' <- SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  Exp
s' <- SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s
  Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop Exp
n' ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i' -> do
    VName
x <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"x" (Exp -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ Exp
e' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s'
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
dest) [Exp
i'] (VName -> SubExp
Var VName
x) []
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp Pattern KernelsMem
dest Exp KernelsMem
e =
  ExpCompiler KernelsMem KernelEnv KernelOp
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern KernelsMem
dest Exp KernelsMem
e

sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel SegThread{} = () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sanityCheckLevel SegGroup{} =
  String -> InKernelGen ()
forall a. HasCallStack => String -> a
error String
"compileGroupOp: unexpected group-level SegOp."

localThreadIDs :: [SubExp] -> InKernelGen [Imp.Exp]
localThreadIDs :: [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
localThreadIDs [SubExp]
dims = do
  Exp
ltid <- KernelConstants -> Exp
kernelLocalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  [Exp]
dims' <- (SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
  [Exp] -> Maybe [Exp] -> [Exp]
forall a. a -> Maybe a -> a
fromMaybe ([Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' Exp
ltid) (Maybe [Exp] -> [Exp])
-> (KernelEnv -> Maybe [Exp]) -> KernelEnv -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
    [SubExp] -> Map [SubExp] [Exp] -> Maybe [Exp]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims (Map [SubExp] [Exp] -> Maybe [Exp])
-> (KernelEnv -> Map [SubExp] [Exp]) -> KernelEnv -> Maybe [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] [Exp]
kernelLocalIdMap (KernelConstants -> Map [SubExp] [Exp])
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Map [SubExp] [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> [Exp])
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space = do
  SegLevel -> InKernelGen ()
sanityCheckLevel SegLevel
lvl
  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
ltids ([Exp] -> InKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp [Exp] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
localThreadIDs [SubExp]
dims
  Exp
ltid <- KernelConstants -> Exp
kernelLocalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) Exp
ltid

-- Construct the necessary lock arrays for an intra-group histogram.
prepareIntraGroupSegHist :: Count GroupSize SubExp
                         -> [HistOp KernelsMem]
                         -> InKernelGen [[Imp.Exp] -> InKernelGen ()]
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp KernelsMem] -> InKernelGen [[Exp] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size =
  ((Maybe Locking, [[Exp] -> InKernelGen ()])
 -> [[Exp] -> InKernelGen ()])
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
-> InKernelGen [[Exp] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()]
forall a b. (a, b) -> b
snd (ImpM
   KernelsMem
   KernelEnv
   KernelOp
   (Maybe Locking, [[Exp] -> InKernelGen ()])
 -> InKernelGen [[Exp] -> InKernelGen ()])
-> ([HistOp KernelsMem]
    -> ImpM
         KernelsMem
         KernelEnv
         KernelOp
         (Maybe Locking, [[Exp] -> InKernelGen ()]))
-> [HistOp KernelsMem]
-> InKernelGen [[Exp] -> InKernelGen ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Locking
 -> HistOp KernelsMem
 -> ImpM
      KernelsMem
      KernelEnv
      KernelOp
      (Maybe Locking, [Exp] -> InKernelGen ()))
-> Maybe Locking
-> [HistOp KernelsMem]
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM Maybe Locking
-> HistOp KernelsMem
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
onOp Maybe Locking
forall a. Maybe a
Nothing
  where
    onOp :: Maybe Locking
-> HistOp KernelsMem
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
onOp Maybe Locking
l HistOp KernelsMem
op = do

      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
      AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

      let local_subhistos :: [VName]
local_subhistos = HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op

      case (Maybe Locking
l, AtomicBinOp
-> Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp (Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv)
-> Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Lambda KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op) of
        (Maybe Locking
_, AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Maybe Locking
_, AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> do
          VName
locks <- String -> ImpM KernelsMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"locks"
          Exp
num_locks <- SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size

          let dims :: [Exp]
dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$
                     Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp KernelsMem
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
                     [HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op]
              l' :: Locking
l' = VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 (Exp -> [Exp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> [Exp]) -> ([Exp] -> Exp) -> [Exp] -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
num_locks) (Exp -> Exp) -> ([Exp] -> Exp) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> [Exp] -> Exp
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [Exp]
dims)
              locks_t :: Type
locks_t = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness

          VName
locks_mem <- String
-> Count Bytes Exp
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> Count Bytes Exp -> Space -> ImpM lore r op VName
sAlloc String
"locks_mem" (Type -> Count Bytes Exp
typeSize Type
locks_t) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
          VName -> PrimType -> Shape -> MemBind -> InKernelGen ()
forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
locks PrimType
int32 (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
locks_t) (MemBind -> InKernelGen ()) -> MemBind -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName -> IxFun -> MemBind
ArrayIn VName
locks_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
            (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
locks_t

          String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> Exp
kernelGroupSize KernelConstants
constants] (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
locks [Exp]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

          (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
     KernelsMem
     KernelEnv
     KernelOp
     (Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)

whenActive :: SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive :: SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space InKernelGen ()
m
  | SegVirt
SegNoVirtFull <- SegLevel -> SegVirt
segVirt SegLevel
lvl = InKernelGen ()
m
  | Bool
otherwise = Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) InKernelGen ()
m

compileGroupOp :: OpCompiler KernelsMem KernelEnv Imp.KernelOp

compileGroupOp :: OpCompiler KernelsMem KernelEnv KernelOp
compileGroupOp Pattern KernelsMem
pat (Alloc size space) =
  Pattern KernelsMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern KernelsMem
pat SubExp
size Space
space

compileGroupOp Pattern KernelsMem
pat (Inner (SizeOp (SplitSpace o w i elems_per_thread))) =
  Pattern KernelsMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern KernelsMem
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern KernelsMem
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread

compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegMap lvl space _ body))) = do
  InKernelGen () -> InKernelGen ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space

  SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegScan lvl space scans _ body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims

  SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT LetDecMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern KernelsMem
PatternT LetDecMem
pat) ([KernelResult] -> [(VName, KernelResult)])
-> [KernelResult] -> [(VName, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
    VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest
    ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
ltids)
    (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
      crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
segment_size)

  -- groupScan needs to treat the scan output as a one-dimensional
  -- array of scan elements, so we invent some new flattened arrays
  -- here.  XXX: this assumes that the original index function is just
  -- row-major, but does not actually verify it.
  VName
dims_flat <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"dims_flat" (Exp -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims'
  let flattened :: PatElemT LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName
flattened PatElemT LetDecMem
pe = do
        MemLocation VName
mem [SubExp]
_ IxFun Exp
_ <-
          ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem KernelEnv KernelOp ArrayEntry
-> ImpM KernelsMem KernelEnv KernelOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem KernelEnv KernelOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe)
        let pe_t :: Type
pe_t = PatElemT LetDecMem -> Type
forall t. Typed t => t -> Type
typeOf PatElemT LetDecMem
pe
            arr_dims :: [SubExp]
arr_dims = VName -> SubExp
Var VName
dims_flat SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop ([Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
dims') (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
pe_t)
        String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray (VName -> String
baseString (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_flat")
          (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
pe_t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
arr_dims) (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) [SubExp]
arr_dims

      num_scan_results :: Int
num_scan_results = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
scans

  [VName]
arrs_flat <- (PatElemT LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> [PatElemT LetDecMem]
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT LetDecMem -> ImpM KernelsMem KernelEnv KernelOp VName
flattened ([PatElemT LetDecMem]
 -> ImpM KernelsMem KernelEnv KernelOp [VName])
-> [PatElemT LetDecMem]
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a. Int -> [a] -> [a]
take Int
num_scan_results ([PatElemT LetDecMem] -> [PatElemT LetDecMem])
-> [PatElemT LetDecMem] -> [PatElemT LetDecMem]
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat

  [SegBinOp KernelsMem]
-> (SegBinOp KernelsMem -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOp KernelsMem]
scans ((SegBinOp KernelsMem -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOp KernelsMem -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOp KernelsMem
scan -> do
    let scan_op :: Lambda KernelsMem
scan_op = SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan
    Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') Lambda KernelsMem
scan_op [VName]
arrs_flat

compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegRed lvl space ops _ body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space

  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ([PatElemT LetDecMem]
red_pes, [PatElemT LetDecMem]
map_pes) =
        Int
-> [PatElemT LetDecMem]
-> ([PatElemT LetDecMem], [PatElemT LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
ops) ([PatElemT LetDecMem]
 -> ([PatElemT LetDecMem], [PatElemT LetDecMem]))
-> [PatElemT LetDecMem]
-> ([PatElemT LetDecMem], [PatElemT LetDecMem])
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat

  [Exp]
dims' <- (SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM KernelsMem KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims

  let mkTempArr :: Type -> ImpM KernelsMem KernelEnv KernelOp VName
mkTempArr Type
t =
        String
-> PrimType
-> Shape
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
  [VName]
tmp_arrs <- (Type -> ImpM KernelsMem KernelEnv KernelOp VName)
-> [Type] -> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ImpM KernelsMem KernelEnv KernelOp VName
mkTempArr ([Type] -> ImpM KernelsMem KernelEnv KernelOp [VName])
-> [Type] -> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (SegBinOp KernelsMem -> [Type]) -> [SegBinOp KernelsMem] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda KernelsMem -> [Type])
-> (SegBinOp KernelsMem -> Lambda KernelsMem)
-> SegBinOp KernelsMem
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp KernelsMem]
ops
  let tmps_for_ops :: [[VName]]
tmps_for_ops = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
ops) [VName]
tmp_arrs

  SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    let ([KernelResult]
red_res, [KernelResult]
map_res) =
          Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
    [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
    (PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LetDecMem]
map_pes [KernelResult]
map_res

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  case [Exp]
dims' of
    -- Nonsegmented case (or rather, a single segment) - this we can
    -- handle directly with a group-level reduction.
    [Exp
dim'] -> do
      [(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp KernelsMem
op, [VName]
tmps) ->
        Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce Exp
dim' (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
op) [VName]
tmps

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      [(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
red_pes [VName]
tmp_arrs) (((PatElemT LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
arr) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [Exp
0]

    [Exp]
_ -> do
      -- Segmented intra-group reductions are turned into (regular)
      -- segmented scans.  It is possible that this can be done
      -- better, but at least this approach is simple.

      -- groupScan operates on flattened arrays.  This does not
      -- involve copying anything; merely playing with the index
      -- function.
      VName
dims_flat <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"dims_flat" (Exp -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims'
      let flatten :: VName -> ImpM KernelsMem KernelEnv KernelOp VName
flatten VName
arr = do
            ArrayEntry MemLocation
arr_loc PrimType
pt <- VName -> ImpM KernelsMem KernelEnv KernelOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
            let flat_shape :: Shape
flat_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dims_flat SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
                             Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ltids) (MemLocation -> [SubExp]
memLocationShape MemLocation
arr_loc)
            String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"red_arr_flat" PrimType
pt Shape
flat_shape (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
              VName -> IxFun -> MemBind
ArrayIn (MemLocation -> VName
memLocationName MemLocation
arr_loc) (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
              Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
flat_shape

      let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
          crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`rem` Exp
segment_size)

      [(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
 -> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp KernelsMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- (VName -> ImpM KernelsMem KernelEnv KernelOp VName)
-> [VName] -> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ImpM KernelsMem KernelEnv KernelOp VName
flatten [VName]
tmps
        Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims')
          (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
op) [VName]
tmps_flat

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      [(PatElemT LetDecMem, VName)]
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LetDecMem] -> [VName] -> [(PatElemT LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LetDecMem]
red_pes [VName]
tmp_arrs) (((PatElemT LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LetDecMem
pe, VName
arr) ->
        VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LetDecMem
pe) [] (VName -> SubExp
Var VName
arr)
        ((Exp -> DimIndex Exp) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> DimIndex Exp
forall d. Num d => d -> d -> DimIndex d
unitSlice Exp
0) ([Exp] -> [Exp]
forall a. [a] -> [a]
init [Exp]
dims') [DimIndex Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall a. [a] -> [a] -> [a]
++ [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1])

      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
  let ltids :: [VName]
ltids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  -- We don't need the red_pes, because it is guaranteed by our type
  -- rules that they occupy the same memory as the destinations for
  -- the ops.
  let num_red_res :: Int
num_red_res = [HistOp KernelsMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp KernelsMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp KernelsMem]
ops)
      ([PatElemT LetDecMem]
_red_pes, [PatElemT LetDecMem]
map_pes) =
        Int
-> [PatElemT LetDecMem]
-> ([PatElemT LetDecMem], [PatElemT LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElemT LetDecMem]
 -> ([PatElemT LetDecMem], [PatElemT LetDecMem]))
-> [PatElemT LetDecMem]
-> ([PatElemT LetDecMem], [PatElemT LetDecMem])
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat

  [[Exp] -> InKernelGen ()]
ops' <- Count GroupSize SubExp
-> [HistOp KernelsMem] -> InKernelGen [[Exp] -> InKernelGen ()]
prepareIntraGroupSegHist (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) [HistOp KernelsMem]
ops

  -- Ensure that all locks have been initialised.
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
        ([SubExp]
red_is, [SubExp]
red_vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp KernelsMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp KernelsMem]
ops) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
    (PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LetDecMem]
map_pes [KernelResult]
map_res

    let vs_per_op :: [[SubExp]]
vs_per_op = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp KernelsMem -> [VName]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp KernelsMem]
ops) [SubExp]
red_vs

    [(SubExp, [SubExp], [Exp] -> InKernelGen (), HistOp KernelsMem)]
-> ((SubExp, [SubExp], [Exp] -> InKernelGen (), HistOp KernelsMem)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [[SubExp]]
-> [[Exp] -> InKernelGen ()]
-> [HistOp KernelsMem]
-> [(SubExp, [SubExp], [Exp] -> InKernelGen (), HistOp KernelsMem)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[Exp] -> InKernelGen ()]
ops' [HistOp KernelsMem]
ops) (((SubExp, [SubExp], [Exp] -> InKernelGen (), HistOp KernelsMem)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SubExp, [SubExp], [Exp] -> InKernelGen (), HistOp KernelsMem)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      \(SubExp
bin, [SubExp]
op_vs, [Exp] -> InKernelGen ()
do_op, HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda KernelsMem
lam) -> do
        let bin' :: Exp
bin' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
bin
            dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
            bin_in_bounds :: Exp
bin_in_bounds = Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bin' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
bin' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w'
            bin_is :: [Exp]
bin_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
ltids) [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp
bin']
            vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam

        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bin_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          [LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam
          Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
            [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
op_vs) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
              VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [Exp]
is
            [Exp] -> InKernelGen ()
do_op ([Exp]
bin_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)

  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

compileGroupOp Pattern KernelsMem
pat Op KernelsMem
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileGroupOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LetDecMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LetDecMem
pat

compileThreadOp :: OpCompiler KernelsMem KernelEnv Imp.KernelOp
compileThreadOp :: OpCompiler KernelsMem KernelEnv KernelOp
compileThreadOp Pattern KernelsMem
pat (Alloc size space) =
  Pattern KernelsMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern KernelsMem
pat SubExp
size Space
space
compileThreadOp Pattern KernelsMem
pat (Inner (SizeOp (SplitSpace o w i elems_per_thread))) =
  Pattern KernelsMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern KernelsMem
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern KernelsMem
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileThreadOp Pattern KernelsMem
pat Op KernelsMem
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileThreadOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LetDecMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LetDecMem
pat

-- | Locking strategy used for an atomic update.
data Locking =
  Locking { Locking -> VName
lockingArray :: VName
            -- ^ Array containing the lock.
          , Locking -> Exp
lockingIsUnlocked :: Imp.Exp
            -- ^ Value for us to consider the lock free.
          , Locking -> Exp
lockingToLock :: Imp.Exp
            -- ^ What to write when we lock it.
          , Locking -> Exp
lockingToUnlock :: Imp.Exp
            -- ^ What to write when we unlock it.
          , Locking -> [Exp] -> [Exp]
lockingMapping :: [Imp.Exp] -> [Imp.Exp]
            -- ^ A transformation from the logical lock index to the
            -- physical position in the array.  This can also be used
            -- to make the lock array smaller.
          }

-- | A function for generating code for an atomic update.  Assumes
-- that the bucket is in-bounds.
type DoAtomicUpdate lore r =
  Space -> [VName] -> [Imp.Exp] -> ImpM lore r Imp.KernelOp ()

-- | The mechanism that will be used for performing the atomic update.
-- Approximates how efficient it will be.  Ordered from most to least
-- efficient.
data AtomicUpdate lore r
  = AtomicPrim (DoAtomicUpdate lore r)
    -- ^ Supported directly by primitive.
  | AtomicCAS (DoAtomicUpdate lore r)
    -- ^ Can be done by efficient swaps.
  | AtomicLocking (Locking -> DoAtomicUpdate lore r)
    -- ^ Requires explicit locking.

-- | Is there an atomic t'BinOp' corresponding to this t'BinOp'?
type AtomicBinOp =
  BinOp ->
  Maybe (VName -> VName -> Count Imp.Elements Imp.Exp -> Imp.Exp -> Imp.AtomicOp)

-- | Do an atomic update corresponding to a binary operator lambda.
atomicUpdateLocking :: AtomicBinOp -> Lambda KernelsMem
                    -> AtomicUpdate KernelsMem KernelEnv

atomicUpdateLocking :: AtomicBinOp
-> Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda KernelsMem
lam
  | Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda KernelsMem -> Maybe [(BinOp, PrimType, VName, VName)]
forall lore.
ASTLore lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda KernelsMem
lam,
    ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
    [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate KernelsMem KernelEnv
 -> AtomicUpdate KernelsMem KernelEnv)
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName]
arrs [Exp]
bucket ->
  -- If the operator is a vectorised binary operator on 32-bit values,
  -- we can use a particularly efficient implementation. If the
  -- operator has an atomic implementation we use that, otherwise it
  -- is still a binary operator which can be implemented by atomic
  -- compare-and-swap if 32 bits.
  [(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do

  -- Common variables.
  VName
old <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
t

  (VName
arr', Space
_a_space, Count Elements Exp
bucket_offset) <- VName
-> [Exp]
-> ImpM
     KernelsMem KernelEnv KernelOp (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
a [Exp]
bucket

  case Space
-> VName
-> VName
-> Count Elements Exp
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements Exp
bucket_offset BinOp
op of
    Just Exp -> KernelOp
f -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> KernelOp
f (Exp -> KernelOp) -> Exp -> KernelOp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
y PrimType
t
    Maybe (Exp -> KernelOp)
Nothing -> Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
a VName
old [Exp]
bucket VName
x (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
x VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t) (VName -> PrimType -> Exp
Imp.var VName
y PrimType
t)

  where opHasAtomicSupport :: Space
-> VName
-> VName
-> Count Elements Exp
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements Exp
bucket' BinOp
bop = do
          let atomic :: (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Exp -> KernelOp
atomic VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
f = Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> (Exp -> AtomicOp) -> Exp -> KernelOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
f VName
old VName
arr' Count Elements Exp
bucket'
          (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Exp -> KernelOp
atomic ((VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
 -> Exp -> KernelOp)
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Maybe (Exp -> KernelOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop

        primOrCas :: [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops
          | ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, PrimType, VName, VName) -> Bool
isPrim [(BinOp, PrimType, VName, VName)]
ops = DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicPrim
          | Bool
otherwise      = DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS

        isPrim :: (BinOp, PrimType, VName, VName) -> Bool
isPrim (BinOp
op, PrimType
_, VName
_, VName
_) = Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
 -> Bool)
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op

-- If the operator functions purely on single 32-bit values, we can
-- use an implementation based on CAS, no matter what the operator
-- does.
atomicUpdateLocking AtomicBinOp
_ Lambda KernelsMem
op
  | [Prim PrimType
t] <- Lambda KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda KernelsMem
op,
    [LParam KernelsMem
xp, LParam KernelsMem
_] <- Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
op,
    PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS (DoAtomicUpdate KernelsMem KernelEnv
 -> AtomicUpdate KernelsMem KernelEnv)
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName
arr] [Exp]
bucket -> do
      VName
old <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
t
      Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [Exp]
bucket (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName LParam KernelsMem
Param LetDecMem
xp) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [Param LetDecMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [LParam KernelsMem
Param LetDecMem
xp] (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
op

atomicUpdateLocking AtomicBinOp
_ Lambda KernelsMem
op = (Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> AtomicUpdate KernelsMem KernelEnv
forall lore r.
(Locking -> DoAtomicUpdate lore r) -> AtomicUpdate lore r
AtomicLocking ((Locking -> DoAtomicUpdate KernelsMem KernelEnv)
 -> AtomicUpdate KernelsMem KernelEnv)
-> (Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Locking
locking Space
space [VName]
arrs [Exp]
bucket -> do
  VName
old <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
int32
  VName
continue <- String -> ImpM KernelsMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"continue"
  VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrimVol_ VName
continue PrimType
Bool
  VName
continue VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
forall v. PrimExp v
true

  -- Correctly index into locks.
  (VName
locks', Space
_locks_space, Count Elements Exp
locks_offset) <-
    VName
-> [Exp]
-> ImpM
     KernelsMem KernelEnv KernelOp (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([Exp]
 -> ImpM
      KernelsMem KernelEnv KernelOp (VName, Space, Count Elements Exp))
-> [Exp]
-> ImpM
     KernelsMem KernelEnv KernelOp (VName, Space, Count Elements Exp)
forall a b. (a -> b) -> a -> b
$ Locking -> [Exp] -> [Exp]
lockingMapping Locking
locking [Exp]
bucket

  -- Critical section
  let try_acquire_lock :: InKernelGen ()
try_acquire_lock =
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
        PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old VName
locks' Count Elements Exp
locks_offset
        (Locking -> Exp
lockingIsUnlocked Locking
locking) (Locking -> Exp
lockingToLock Locking
locking)
      lock_acquired :: Exp
lock_acquired = VName -> PrimType -> Exp
Imp.var VName
old PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Locking -> Exp
lockingIsUnlocked Locking
locking
      -- Even the releasing is done with an atomic rather than a
      -- simple write, for memory coherency reasons.
      release_lock :: InKernelGen ()
release_lock =
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
        PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old VName
locks' Count Elements Exp
locks_offset
        (Locking -> Exp
lockingToLock Locking
locking) (Locking -> Exp
lockingToUnlock Locking
locking)
      break_loop :: InKernelGen ()
break_loop = VName
continue VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
forall v. PrimExp v
false

  -- Preparing parameters. It is assumed that the caller has already
  -- filled the arr_params. We copy the current value to the
  -- accumulator parameters.
  --
  -- Note the use of 'everythingVolatile' when reading and writing the
  -- buckets.  This was necessary to ensure correct execution on a
  -- newer NVIDIA GPU (RTX 2080).  The 'volatile' modifiers likely
  -- make the writes pass through the (SM-local) L1 cache, which is
  -- necessary here, because we are really doing device-wide
  -- synchronisation without atomics (naughty!).
  let ([Param LetDecMem]
acc_params, [Param LetDecMem]
_arr_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
op
      bind_acc_params :: InKernelGen ()
bind_acc_params =
        InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"bind lhs" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
acc_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
acc_p, VName
arr) ->
        VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
acc_p) [] (VName -> SubExp
Var VName
arr) [Exp]
bucket

  let op_body :: InKernelGen ()
op_body = String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"execute operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                [Param LetDecMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
acc_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
op

      do_hist :: InKernelGen ()
do_hist =
        InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"update global result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (VName -> SubExp -> InKernelGen ())
-> [VName] -> [SubExp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([Exp] -> VName -> SubExp -> InKernelGen ()
forall lore r op. [Exp] -> VName -> SubExp -> ImpM lore r op ()
writeArray [Exp]
bucket) [VName]
arrs ([SubExp] -> InKernelGen ()) -> [SubExp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> SubExp) -> [Param LetDecMem] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param LetDecMem -> VName) -> Param LetDecMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName) [Param LetDecMem]
acc_params

      fence :: InKernelGen ()
fence = case Space
space of Space String
"local" -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
                            Space
_             -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal


  -- While-loop: Try to insert your value
  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
continue PrimType
Bool) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
try_acquire_lock
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
lock_acquired (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      [LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams [LParam KernelsMem]
[Param LetDecMem]
acc_params
      InKernelGen ()
bind_acc_params
      InKernelGen ()
op_body
      InKernelGen ()
do_hist
      InKernelGen ()
fence
      InKernelGen ()
release_lock
      InKernelGen ()
break_loop
    InKernelGen ()
fence
  where writeArray :: [Exp] -> VName -> SubExp -> ImpM lore r op ()
writeArray [Exp]
bucket VName
arr SubExp
val = VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp]
bucket SubExp
val []

atomicUpdateCAS :: Space -> PrimType
                -> VName -> VName
                -> [Imp.Exp] -> VName
                -> InKernelGen ()
                -> InKernelGen ()
atomicUpdateCAS :: Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [Exp]
bucket VName
x InKernelGen ()
do_op = do
  -- Code generation target:
  --
  -- old = d_his[idx];
  -- do {
  --   assumed = old;
  --   x = do_op(assumed, y);
  --   old = atomicCAS(&d_his[idx], assumed, tmp);
  -- } while(assumed != old);
  VName
assumed <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"assumed" PrimType
t
  VName
run_loop <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"run_loop" Exp
1

  -- XXX: CUDA may generate really bad code if this is not a volatile
  -- read.  Unclear why.  The later reads are volatile, so maybe
  -- that's it.
  InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [Exp]
bucket

  (VName
arr', Space
_a_space, Count Elements Exp
bucket_offset) <- VName
-> [Exp]
-> ImpM
     KernelsMem KernelEnv KernelOp (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
arr [Exp]
bucket

  -- While-loop: Try to insert your value
  let (Exp -> Exp
toBits, Exp -> Exp
fromBits) =
        case PrimType
t of FloatType FloatType
Float32 -> (\Exp
v -> String -> [Exp] -> PrimType -> Exp
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits32" [Exp
v] PrimType
int32,
                                        \Exp
v -> String -> [Exp] -> PrimType -> Exp
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits32" [Exp
v] PrimType
t)
                  PrimType
_                 -> (Exp -> Exp
forall a. a -> a
id, Exp -> Exp
forall a. a -> a
id)
  Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
run_loop PrimType
int32) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
assumed VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
old PrimType
t
    VName
x VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t
    InKernelGen ()
do_op
    VName
old_bits <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old_bits" PrimType
int32
    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
      PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old_bits VName
arr' Count Elements Exp
bucket_offset
      (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t)) (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t))
    VName
old VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp -> Exp
fromBits (VName -> PrimType -> Exp
Imp.var VName
old_bits PrimType
int32)
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. VName -> PrimType -> Exp
Imp.var VName
old_bits PrimType
int32)
      (VName
run_loop VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0)

-- | Horizontally fission a lambda that models a binary operator.
splitOp :: ASTLore lore => Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp :: Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda lore
lam = (SubExp -> Maybe (BinOp, PrimType, VName, VName))
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm ([SubExp] -> Maybe [(BinOp, PrimType, VName, VName)])
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  where n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
        splitStm :: SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm (Var VName
res) = do
          Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
            (Stm lore -> Bool) -> [Stm lore] -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res][VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
==) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm lore] -> Maybe (Stm lore)) -> [Stm lore] -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$
            Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
          Int
i <- VName -> SubExp
Var VName
res SubExp -> [SubExp] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
          Param (LParamInfo lore)
xp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
          Param (LParamInfo lore)
yp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
i) ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
          Prim PrimType
t <- Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec lore)
pe
          (BinOp, PrimType, VName, VName)
-> Maybe (BinOp, PrimType, VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinOp
op, PrimType
t, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp)
        splitStm SubExp
_ = Maybe (BinOp, PrimType, VName, VName)
forall a. Maybe a
Nothing

computeKernelUses :: FreeIn a =>
                     a -> [VName]
                  -> CallKernelGen [Imp.KernelUse]
computeKernelUses :: a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses a
kernel_body [VName]
bound_in_kernel = do
  let actually_free :: Names
actually_free = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
kernel_body Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
bound_in_kernel
  -- Compute the variables that we need to pass to the kernel.
  [KernelUse] -> [KernelUse]
forall a. Eq a => [a] -> [a]
nub ([KernelUse] -> [KernelUse])
-> CallKernelGen [KernelUse] -> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> CallKernelGen [KernelUse]
readsFromSet Names
actually_free

readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet :: Names -> CallKernelGen [KernelUse]
readsFromSet Names
free =
  ([Maybe KernelUse] -> [KernelUse])
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe KernelUse] -> [KernelUse]
forall a. [Maybe a] -> [a]
catMaybes (ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
 -> CallKernelGen [KernelUse])
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall a b. (a -> b) -> a -> b
$
  [VName]
-> (VName -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
free) ((VName -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
 -> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse])
-> (VName -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
forall a b. (a -> b) -> a -> b
$ \VName
var -> do
    Type
t <- VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
var
    VTable KernelsMem
vtable <- ImpM KernelsMem HostEnv HostOp (VTable KernelsMem)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
    case Type
t of
      Array {} -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
      Mem (Space String
"local") -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
      Mem {} -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
 -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelUse
Imp.MemoryUse VName
var
      Prim PrimType
bt ->
        VTable KernelsMem
-> Exp -> ImpM KernelsMem HostEnv HostOp (Maybe KernelConstExp)
forall lore r op.
VTable KernelsMem -> Exp -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable KernelsMem
vtable (VName -> PrimType -> Exp
Imp.var VName
var PrimType
bt) ImpM KernelsMem HostEnv HostOp (Maybe KernelConstExp)
-> (Maybe KernelConstExp
    -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Just KernelConstExp
ce -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
 -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
var KernelConstExp
ce
          Maybe KernelConstExp
Nothing | PrimType
bt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
Cert -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
                  | Bool
otherwise  -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
 -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> KernelUse
Imp.ScalarUse VName
var PrimType
bt

isConstExp :: VTable KernelsMem -> Imp.Exp
           -> ImpM lore r op (Maybe Imp.KernelConstExp)
isConstExp :: VTable KernelsMem -> Exp -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable KernelsMem
vtable Exp
size = do
  Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let onLeaf :: ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf (Imp.ScalarVar VName
name) PrimType
_ = VName -> Maybe KernelConstExp
lookupConstExp VName
name
      onLeaf (Imp.SizeOf PrimType
pt) PrimType
_ = KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ PrimType -> KernelConstExp
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
      onLeaf Imp.Index{} PrimType
_ = Maybe KernelConstExp
forall a. Maybe a
Nothing
      lookupConstExp :: VName -> Maybe KernelConstExp
lookupConstExp VName
name =
        Exp KernelsMem -> Maybe KernelConstExp
constExp (Exp KernelsMem -> Maybe KernelConstExp)
-> Maybe (Exp KernelsMem) -> Maybe KernelConstExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarEntry KernelsMem -> Maybe (Exp KernelsMem)
forall lore. VarEntry lore -> Maybe (Exp lore)
hasExp (VarEntry KernelsMem -> Maybe (Exp KernelsMem))
-> Maybe (VarEntry KernelsMem) -> Maybe (Exp KernelsMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VTable KernelsMem -> Maybe (VarEntry KernelsMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name VTable KernelsMem
vtable
      constExp :: Exp KernelsMem -> Maybe KernelConstExp
constExp (Op (Inner (SizeOp (GetSize key _)))) =
        KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PrimType -> KernelConstExp
forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> KernelConst
Imp.SizeConst (Name -> KernelConst) -> Name -> KernelConst
forall a b. (a -> b) -> a -> b
$ Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) PrimType
int32
      constExp Exp KernelsMem
e = (VName -> Maybe KernelConstExp)
-> Exp KernelsMem -> Maybe KernelConstExp
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp VName -> Maybe KernelConstExp
lookupConstExp Exp KernelsMem
e
  Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp))
-> Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall a b. (a -> b) -> a -> b
$ (ExpLeaf -> PrimType -> Maybe KernelConstExp)
-> Exp -> Maybe KernelConstExp
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf Exp
size
  where hasExp :: VarEntry lore -> Maybe (Exp lore)
hasExp (ArrayVar Maybe (Exp lore)
e ArrayEntry
_) = Maybe (Exp lore)
e
        hasExp (ScalarVar Maybe (Exp lore)
e ScalarEntry
_) = Maybe (Exp lore)
e
        hasExp (MemVar Maybe (Exp lore)
e MemEntry
_) = Maybe (Exp lore)
e

computeThreadChunkSize :: SplitOrdering
                       -> Imp.Exp
                       -> Imp.Count Imp.Elements Imp.Exp
                       -> Imp.Count Imp.Elements Imp.Exp
                       -> VName
                       -> ImpM lore r op ()
computeThreadChunkSize :: SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize (SplitStrided SubExp
stride) Exp
thread_index Count Elements Exp
elements_per_thread Count Elements Exp
num_elements VName
chunk_var = do
  Exp
stride' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
stride
  VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
    BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32)
    (Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread)
    ((Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
thread_index) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
stride')

computeThreadChunkSize SplitOrdering
SplitContiguous Exp
thread_index Count Elements Exp
elements_per_thread Count Elements Exp
num_elements VName
chunk_var = do
  VName
starting_point <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"starting_point" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    Exp
thread_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread
  VName
remaining_elements <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"remaining_elements" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- VName -> PrimType -> Exp
Imp.var VName
starting_point PrimType
int32

  let no_remaining_elements :: Exp
no_remaining_elements = VName -> PrimType -> Exp
Imp.var VName
remaining_elements PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
0
      beyond_bounds :: Exp
beyond_bounds = Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> PrimType -> Exp
Imp.var VName
starting_point PrimType
int32

  Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
no_remaining_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
beyond_bounds)
    (VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0)
    (Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
is_last_thread
       (VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
last_thread_elements)
       (VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread))
  where last_thread_elements :: Count Elements Exp
last_thread_elements =
          Count Elements Exp
num_elements Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall a. Num a => a -> a -> a
- Exp -> Count Elements Exp
Imp.elements Exp
thread_index Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall a. Num a => a -> a -> a
* Count Elements Exp
elements_per_thread
        is_last_thread :: Exp
is_last_thread =
          Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<.
          (Exp
thread_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread

kernelInitialisationSimple :: Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
                           -> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple :: Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple (Count Exp
num_groups) (Count Exp
group_size) = do
  VName
global_tid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"global_tid"
  VName
local_tid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"local_tid"
  VName
group_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_tid"
  VName
wave_size <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"wave_size"
  VName
inner_group_size <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_size"
  let constants :: KernelConstants
constants =
        Exp
-> Exp
-> Exp
-> VName
-> VName
-> VName
-> Exp
-> Exp
-> Exp
-> Exp
-> Exp
-> Map [SubExp] [Exp]
-> KernelConstants
KernelConstants
        (VName -> PrimType -> Exp
Imp.var VName
global_tid PrimType
int32)
        (VName -> PrimType -> Exp
Imp.var VName
local_tid PrimType
int32)
        (VName -> PrimType -> Exp
Imp.var VName
group_id PrimType
int32)
        VName
global_tid VName
local_tid VName
group_id
        Exp
num_groups Exp
group_size (Exp
group_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
num_groups)
        (VName -> PrimType -> Exp
Imp.var VName
wave_size PrimType
int32)
        Exp
forall v. PrimExp v
true
        Map [SubExp] [Exp]
forall a. Monoid a => a
mempty

  let set_constants :: InKernelGen ()
set_constants = do
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
global_tid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
local_tid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
inner_group_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
wave_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32

        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
global_tid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
local_tid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> KernelOp
Imp.GetLockstepWidth VName
wave_size)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)

  (KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelConstants
constants, InKernelGen ()
set_constants)

isActive :: [(VName, SubExp)] -> Imp.Exp
isActive :: [(VName, SubExp)] -> Exp
isActive [(VName, SubExp)]
limit = case [Exp]
actives of
                    [] -> PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
                    Exp
x:[Exp]
xs -> (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) Exp
x [Exp]
xs
  where ([VName]
is, [SubExp]
ws) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
limit
        actives :: [Exp]
actives = (VName -> Exp -> Exp) -> [VName] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Exp
active [VName]
is ([Exp] -> [Exp]) -> [Exp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool) [SubExp]
ws
        active :: VName -> Exp -> Exp
active VName
i = (VName -> PrimType -> Exp
Imp.var VName
i PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<.)

-- | Change every memory block to be in the global address space,
-- except those who are in the local memory space.  This only affects
-- generated code - we still need to make sure that the memory is
-- actually present on the device (and dared as variables in the
-- kernel).
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
  Space -> CallKernelGen a -> CallKernelGen a
forall lore r op a. Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace (String -> Space
Imp.Space String
"global") (CallKernelGen a -> CallKernelGen a)
-> (CallKernelGen a -> CallKernelGen a)
-> CallKernelGen a
-> CallKernelGen a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VTable KernelsMem -> VTable KernelsMem)
-> CallKernelGen a -> CallKernelGen a
forall lore r op a.
(VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable ((VarEntry KernelsMem -> VarEntry KernelsMem)
-> VTable KernelsMem -> VTable KernelsMem
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry KernelsMem -> VarEntry KernelsMem
forall lore. VarEntry lore -> VarEntry lore
globalMemory)
  where globalMemory :: VarEntry lore -> VarEntry lore
globalMemory (MemVar Maybe (Exp lore)
_ MemEntry
entry)
          | MemEntry -> Space
entryMemSpace MemEntry
entry Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= String -> Space
Space String
"local" =
              Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing MemEntry
entry { entryMemSpace :: Space
entryMemSpace = String -> Space
Imp.Space String
"global" }
        globalMemory VarEntry lore
entry =
          VarEntry lore
entry

groupReduce :: Imp.Exp
            -> Lambda KernelsMem
            -> [VName]
            -> InKernelGen ()
groupReduce :: Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce Exp
w Lambda KernelsMem
lam [VName]
arrs = do
  VName
offset <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"offset" PrimType
int32
  VName -> Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduceWithOffset VName
offset Exp
w Lambda KernelsMem
lam [VName]
arrs

groupReduceWithOffset :: VName
                      -> Imp.Exp
                      -> Lambda KernelsMem
                      -> [VName]
                      -> InKernelGen ()
groupReduceWithOffset :: VName -> Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduceWithOffset VName
offset Exp
w Lambda KernelsMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv

  let local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      global_tid :: Exp
global_tid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants

      barrier :: InKernelGen ()
barrier
        | (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda KernelsMem
lam = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        | Bool
otherwise                           = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

      readReduceArgument :: Param LetDecMem -> VName -> InKernelGen ()
readReduceArgument Param LetDecMem
param VName
arr
        | Prim PrimType
_ <- Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
param = do
            let i :: Exp
i = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
offset
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
param) [] (VName -> SubExp
Var VName
arr) [Exp
i]
        | Bool
otherwise = do
            let i :: Exp
i = Exp
global_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
offset
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
param) [] (VName -> SubExp
Var VName
arr) [Exp
i]

      writeReduceOpResult :: Param LetDecMem -> VName -> InKernelGen ()
writeReduceOpResult Param LetDecMem
param VName
arr
        | Prim PrimType
_ <- Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
param =
            VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
param) []
        | Bool
otherwise =
            () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  let ([Param LetDecMem]
reduce_acc_params, [Param LetDecMem]
reduce_arr_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam

  VName
skip_waves <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"skip_waves" PrimType
int32
  [LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam

  VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"participating threads read initial accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readReduceArgument [Param LetDecMem]
reduce_acc_params [VName]
arrs

  let do_reduce :: InKernelGen ()
do_reduce = do String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read array element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                       (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readReduceArgument [Param LetDecMem]
reduce_arr_params [VName]
arrs
                     String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"apply reduction operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                       [Param LetDecMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
reduce_acc_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
                     String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"write result of operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                       (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
writeReduceOpResult [Param LetDecMem]
reduce_acc_params [VName]
arrs
      in_wave_reduce :: InKernelGen ()
in_wave_reduce = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile InKernelGen ()
do_reduce

      wave_size :: Exp
wave_size = KernelConstants -> Exp
kernelWaveSize KernelConstants
constants
      group_size :: Exp
group_size = KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
      wave_id :: Exp
wave_id = Exp
local_tid Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
wave_size
      in_wave_id :: Exp
in_wave_id = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
wave_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
wave_size
      num_waves :: Exp
num_waves = (Exp
group_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
wave_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
wave_size
      arg_in_bounds :: Exp
arg_in_bounds = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w

      doing_in_wave_reductions :: Exp
doing_in_wave_reductions =
        VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
wave_size
      apply_in_in_wave_iteration :: Exp
apply_in_in_wave_iteration =
        (Exp
in_wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&. (Exp
2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
      in_wave_reductions :: InKernelGen ()
in_wave_reductions = do
        VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
doing_in_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
arg_in_bounds Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
apply_in_in_wave_iteration)
            InKernelGen ()
in_wave_reduce
          VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2

      doing_cross_wave_reductions :: Exp
doing_cross_wave_reductions =
        VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
num_waves
      is_first_thread_in_wave :: Exp
is_first_thread_in_wave =
        Exp
in_wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
      wave_not_skipped :: Exp
wave_not_skipped =
        (Exp
wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&. (Exp
2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
      apply_in_cross_wave_iteration :: Exp
apply_in_cross_wave_iteration =
        Exp
arg_in_bounds Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
is_first_thread_in_wave Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
wave_not_skipped
      cross_wave_reductions :: InKernelGen ()
cross_wave_reductions = do
        VName
skip_waves VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
doing_cross_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          InKernelGen ()
barrier
          VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
wave_size
          Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
apply_in_cross_wave_iteration
            InKernelGen ()
do_reduce
          VName
skip_waves VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2

  InKernelGen ()
in_wave_reductions
  InKernelGen ()
cross_wave_reductions

groupScan :: Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
          -> Imp.Exp
          -> Imp.Exp
          -> Lambda KernelsMem
          -> [VName]
          -> InKernelGen ()
groupScan :: Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupScan Maybe (Exp -> Exp -> Exp)
seg_flag Exp
arrs_full_size Exp
w Lambda KernelsMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  Lambda KernelsMem
renamed_lam <- Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda KernelsMem
lam

  let ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
      ([Param LetDecMem]
x_params, [Param LetDecMem]
y_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam

  [LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam[Param LetDecMem] -> [Param LetDecMem] -> [Param LetDecMem]
forall a. [a] -> [a] -> [a]
++Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
renamed_lam)

  -- The scan works by splitting the group into blocks, which are
  -- scanned separately.  Typically, these blocks are smaller than
  -- the lockstep width, which enables barrier-free execution inside
  -- them.
  --
  -- We hardcode the block size here.  The only requirement is that
  -- it should not be less than the square root of the group size.
  -- With 32, we will work on groups of size 1024 or smaller, which
  -- fits every device Troels has seen.  Still, it would be nicer if
  -- it were a runtime parameter.  Some day.
  let block_size :: PrimExp v
block_size = PrimValue -> PrimExp v
forall v. PrimValue -> PrimExp v
Imp.ValueExp (PrimValue -> PrimExp v) -> PrimValue -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
32
      simd_width :: Exp
simd_width = KernelConstants -> Exp
kernelWaveSize KernelConstants
constants
      block_id :: Exp
block_id = Exp
ltid Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
forall v. PrimExp v
block_size
      in_block_id :: Exp
in_block_id = Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
forall v. PrimExp v
block_size
      doInBlockScan :: Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda KernelsMem -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
seg_flag' Exp
active =
        KernelConstants
-> Maybe (Exp -> Exp -> Exp)
-> Exp
-> Exp
-> Exp
-> Exp
-> [VName]
-> InKernelGen ()
-> Lambda KernelsMem
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (Exp -> Exp -> Exp)
seg_flag' Exp
arrs_full_size
        Exp
simd_width Exp
forall v. PrimExp v
block_size Exp
active [VName]
arrs InKernelGen ()
barrier
      ltid_in_bounds :: Exp
ltid_in_bounds = Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w
      array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda KernelsMem
lam
      barrier :: InKernelGen ()
barrier | Bool
array_scan =
                  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
              | Bool
otherwise =
                  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

      group_offset :: Exp
group_offset = KernelConstants -> Exp
kernelGroupId KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants

      writeBlockResult :: Param LetDecMem -> VName -> InKernelGen ()
writeBlockResult Param LetDecMem
p VName
arr
        | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
p =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []
        | Bool
otherwise =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []

      readPrevBlockResult :: Param LetDecMem -> VName -> InKernelGen ()
readPrevBlockResult Param LetDecMem
p VName
arr
        | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
p =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]
        | Bool
otherwise =
            VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]

  Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda KernelsMem -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
seg_flag Exp
ltid_in_bounds Lambda KernelsMem
lam
  InKernelGen ()
barrier

  let is_first_block :: Exp
is_first_block = Exp
block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, VName
arr) ->
      Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) []

    InKernelGen ()
barrier

  let last_in_block :: Exp
last_in_block = Exp
in_block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"last thread of block 'i' writes its result to offset 'i'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
last_in_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
writeBlockResult [Param LetDecMem]
x_params [VName]
arrs

  InKernelGen ()
barrier

  let first_block_seg_flag :: Maybe (Exp -> Exp -> Exp)
first_block_seg_flag = do
        Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag
        (Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp))
-> (Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
          Exp -> Exp -> Exp
flag_true (Exp
fromExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1) (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1)
  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment
    String
"scan the first block, after which offset 'i' contains carry-in for block 'i+1'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda KernelsMem -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
first_block_seg_flag (Exp
is_first_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
ltid_in_bounds) Lambda KernelsMem
renamed_lam

  InKernelGen ()
barrier

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"move correct values for first block back a block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, VName
arr) ->
      Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM
      VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]
      (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]

    InKernelGen ()
barrier

  let read_carry_in :: InKernelGen ()
read_carry_in = do
        [(Param LetDecMem, Param LetDecMem)]
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem] -> [(Param LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [Param LetDecMem]
y_params) (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x,Param LetDecMem
y) ->
          VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y) [] (VName -> SubExp
Var (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x)) []
        (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readPrevBlockResult [Param LetDecMem]
x_params [VName]
arrs

      y_to_x :: InKernelGen ()
y_to_x = [(Param LetDecMem, Param LetDecMem)]
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem] -> [(Param LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [Param LetDecMem]
y_params) (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x,Param LetDecMem
y) ->
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) [] (VName -> SubExp
Var (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y)) []

      op_to_x :: InKernelGen ()
op_to_x
        | Maybe (Exp -> Exp -> Exp)
Nothing <- Maybe (Exp -> Exp -> Exp)
seg_flag =
            [Param LetDecMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
        | Just Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag = do
            Exp
inactive <-
              String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"inactive" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
flag_true (Exp
block_idExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1) Exp
ltid
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
inactive InKernelGen ()
y_to_x
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam

      write_final_result :: InKernelGen ()
write_final_result =
        [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [VName]
arrs) (((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, VName
arr) ->
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) []

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"carry-in for every block except the first" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
is_first_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
Not Exp
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" InKernelGen ()
read_carry_in
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x
    String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write final result" InKernelGen ()
write_final_result

  InKernelGen ()
barrier

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"restore correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Param LetDecMem, Param LetDecMem, VName)]
-> ((Param LetDecMem, Param LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem]
-> [VName]
-> [(Param LetDecMem, Param LetDecMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param LetDecMem]
x_params [Param LetDecMem]
y_params [VName]
arrs) (((Param LetDecMem, Param LetDecMem, VName) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x, Param LetDecMem
y, VName
arr) ->
      if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
y)
      then VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y) []
      else VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]

  InKernelGen ()
barrier

inBlockScan :: KernelConstants
            -> Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
            -> Imp.Exp
            -> Imp.Exp
            -> Imp.Exp
            -> Imp.Exp
            -> [VName]
            -> InKernelGen ()
            -> Lambda KernelsMem
            -> InKernelGen ()
inBlockScan :: KernelConstants
-> Maybe (Exp -> Exp -> Exp)
-> Exp
-> Exp
-> Exp
-> Exp
-> [VName]
-> InKernelGen ()
-> Lambda KernelsMem
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (Exp -> Exp -> Exp)
seg_flag Exp
arrs_full_size Exp
lockstep_width Exp
block_size Exp
active [VName]
arrs InKernelGen ()
barrier Lambda KernelsMem
scan_lam = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  VName
skip_threads <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"skip_threads" PrimType
int32
  let in_block_thread_active :: Exp
in_block_thread_active =
        VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
in_block_id
      actual_params :: [LParam KernelsMem]
actual_params = Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_lam
      ([Param LetDecMem]
x_params, [Param LetDecMem]
y_params) =
        Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param LetDecMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam KernelsMem]
[Param LetDecMem]
actual_params Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [LParam KernelsMem]
[Param LetDecMem]
actual_params
      y_to_x :: InKernelGen ()
y_to_x =
        [(Param LetDecMem, Param LetDecMem)]
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [Param LetDecMem] -> [(Param LetDecMem, Param LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
x_params [Param LetDecMem]
y_params) (((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LetDecMem, Param LetDecMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
x,Param LetDecMem
y) ->
        Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) [] (VName -> SubExp
Var (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y)) []

  -- Set initial y values
  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read input for in-block scan" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
active (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LetDecMem -> VName -> InKernelGen ()
readInitial [Param LetDecMem]
y_params [VName]
arrs
    -- Since the final result is expected to be in x_params, we may
    -- need to copy it there for the first thread in the block.
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) InKernelGen ()
y_to_x

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier

  let op_to_x :: InKernelGen ()
op_to_x
        | Maybe (Exp -> Exp -> Exp)
Nothing <- Maybe (Exp -> Exp -> Exp)
seg_flag =
            [Param LetDecMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_lam
        | Just Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag = do
            Exp
inactive <- String -> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"inactive" (Exp -> ImpM KernelsMem KernelEnv KernelOp Exp)
-> Exp -> ImpM KernelsMem KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
                        Exp -> Exp -> Exp
flag_true (Exp
ltidExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32) Exp
ltid
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
inactive InKernelGen ()
y_to_x
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LetDecMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_lam

      maybeBarrier :: InKernelGen ()
maybeBarrier = Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
lockstep_width Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32)
                     InKernelGen ()
barrier

  String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"in-block scan (hopefully no barriers needed)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
skip_threads VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
block_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_thread_active Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Exp -> Param LetDecMem -> VName -> InKernelGen ()
readParam (VName -> Exp
Imp.vi32 VName
skip_threads)) [Param LetDecMem]
x_params [VName]
arrs
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x

      InKernelGen ()
maybeBarrier

      Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_thread_active Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [InKernelGen ()] -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([InKernelGen ()] -> InKernelGen ())
-> [InKernelGen ()] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> Param LetDecMem -> VName -> InKernelGen ())
-> [Param LetDecMem]
-> [Param LetDecMem]
-> [VName]
-> [InKernelGen ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param LetDecMem -> Param LetDecMem -> VName -> InKernelGen ()
writeResult [Param LetDecMem]
x_params [Param LetDecMem]
y_params [VName]
arrs

      InKernelGen ()
maybeBarrier

      VName
skip_threads VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2

  where block_id :: Exp
block_id = Exp
ltid Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`quot` Exp
block_size
        in_block_id :: Exp
in_block_id = Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_size
        ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
        gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
        array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda KernelsMem
scan_lam

        readInitial :: Param LetDecMem -> VName -> InKernelGen ()
readInitial Param LetDecMem
p VName
arr
          | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
p =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid]
          | Bool
otherwise =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
gtid]

        readParam :: Exp -> Param LetDecMem -> VName -> InKernelGen ()
readParam Exp
behind Param LetDecMem
p VName
arr
          | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
p =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
behind]
          | Bool
otherwise =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
gtid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
behind Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
arrs_full_size]

        writeResult :: Param LetDecMem -> Param LetDecMem -> VName -> InKernelGen ()
writeResult Param LetDecMem
x Param LetDecMem
y VName
arr
          | Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LetDecMem
x = do
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) []
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) []
          | Bool
otherwise =
              VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
x) []

computeMapKernelGroups :: Imp.Exp -> CallKernelGen (Imp.Exp, Imp.Exp)
computeMapKernelGroups :: Exp -> CallKernelGen (Exp, Exp)
computeMapKernelGroups Exp
kernel_size = do
  VName
group_size <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"group_size" PrimType
int32
  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let group_size_var :: Exp
group_size_var = VName -> PrimType -> Exp
Imp.var VName
group_size PrimType
int32
      group_size_key :: Name
group_size_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
group_size
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
group_size Name
group_size_key SizeClass
Imp.SizeGroup
  VName
num_groups <- String -> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_groups" (Exp -> ImpM KernelsMem HostEnv HostOp VName)
-> Exp -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Exp
kernel_size Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp` Exp
group_size_var
  (Exp, Exp) -> CallKernelGen (Exp, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> PrimType -> Exp
Imp.var VName
num_groups PrimType
int32, VName -> PrimType -> Exp
Imp.var VName
group_size PrimType
int32)

simpleKernelConstants :: Imp.Exp -> String
                      -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants :: Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
kernel_size String
desc = do
  VName
thread_gtid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM KernelsMem HostEnv HostOp VName)
-> String -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gtid"
  VName
thread_ltid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM KernelsMem HostEnv HostOp VName)
-> String -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_ltid"
  VName
group_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM KernelsMem HostEnv HostOp VName)
-> String -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gid"
  (Exp
num_groups, Exp
group_size) <- Exp -> CallKernelGen (Exp, Exp)
computeMapKernelGroups Exp
kernel_size
  let set_constants :: InKernelGen ()
set_constants = do
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_gtid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_ltid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
thread_gtid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
thread_ltid Int
0)
        KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)


  (KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> Exp
-> Exp
-> VName
-> VName
-> VName
-> Exp
-> Exp
-> Exp
-> Exp
-> Exp
-> Map [SubExp] [Exp]
-> KernelConstants
KernelConstants
          (VName -> PrimType -> Exp
Imp.var VName
thread_gtid PrimType
int32) (VName -> PrimType -> Exp
Imp.var VName
thread_ltid PrimType
int32) (VName -> PrimType -> Exp
Imp.var VName
group_id PrimType
int32)
          VName
thread_gtid VName
thread_ltid VName
group_id
          Exp
num_groups Exp
group_size (Exp
group_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
num_groups) Exp
0
          (VName -> PrimType -> Exp
Imp.var VName
thread_gtid PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
kernel_size)
          Map [SubExp] [Exp]
forall a. Monoid a => a
mempty,

          InKernelGen ()
set_constants)

-- | For many kernels, we may not have enough physical groups to cover
-- the logical iteration space.  Some groups thus have to perform
-- double duty; we put an outer loop to accomplish this.  The
-- advantage over just launching a bazillion threads is that the cost
-- of memory expansion should be proportional to the number of
-- *physical* threads (hardware parallelism), not the amount of
-- application parallelism.
virtualiseGroups :: SegVirt
                 -> Imp.Exp
                 -> (VName -> InKernelGen ())
                 -> InKernelGen ()
virtualiseGroups :: SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt Exp
required_groups VName -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  VName
phys_group_id <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"phys_group_id" PrimType
int32
  KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetGroupId VName
phys_group_id Int
0
  let iterations :: Exp
iterations = (Exp
required_groups Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- VName -> Exp
Imp.vi32 VName
phys_group_id) Exp -> Exp -> Exp
forall e. IntegralExp e => e -> e -> e
`divUp`
                   KernelConstants -> Exp
kernelNumGroups KernelConstants
constants

  String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
iterations ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
    VName -> InKernelGen ()
m (VName -> InKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp VName -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp -> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"virt_group_id" (VName -> Exp
Imp.vi32 VName
phys_group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelNumGroups KernelConstants
constants)
    -- Make sure the virtual group is actually done before we let
    -- another virtual group have its way with it.
    KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
virtualiseGroups SegVirt
_ Exp
_ VName -> InKernelGen ()
m = do
  VName
gid <- KernelConstants -> VName
kernelGroupIdVar (KernelConstants -> VName)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> VName)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
  VName -> InKernelGen ()
m VName
gid

sKernelThread :: String
              -> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
              -> VName
              -> InKernelGen ()
              -> CallKernelGen ()
sKernelThread :: String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelThread = Operations KernelsMem KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernel Operations KernelsMem KernelEnv KernelOp
threadOperations KernelConstants -> Exp
kernelGlobalThreadId

sKernelGroup :: String
             -> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
             -> VName
             -> InKernelGen ()
             -> CallKernelGen ()
sKernelGroup :: String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelGroup = Operations KernelsMem KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernel Operations KernelsMem KernelEnv KernelOp
groupOperations KernelConstants -> Exp
kernelGroupId

sKernelFailureTolerant :: Bool
                       -> Operations KernelsMem KernelEnv Imp.KernelOp
                       -> KernelConstants
                       -> Name
                       -> InKernelGen ()
                       -> CallKernelGen ()
sKernelFailureTolerant :: Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
tol Operations KernelsMem KernelEnv KernelOp
ops KernelConstants
constants Name
name InKernelGen ()
m = do
  HostEnv AtomicBinOp
atomics <- ImpM KernelsMem HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
  Code KernelOp
body <- CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal (CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp))
-> CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a b. (a -> b) -> a -> b
$ KernelEnv
-> Operations KernelsMem KernelEnv KernelOp
-> InKernelGen ()
-> CallKernelGen (Code KernelOp)
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ (AtomicBinOp -> KernelConstants -> KernelEnv
KernelEnv AtomicBinOp
atomics KernelConstants
constants) Operations KernelsMem KernelEnv KernelOp
ops InKernelGen ()
m
  [KernelUse]
uses <- Code KernelOp -> [VName] -> CallKernelGen [KernelUse]
forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses Code KernelOp
body [VName]
forall a. Monoid a => a
mempty
  Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp) -> HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Kernel -> HostOp
Imp.CallKernel Kernel :: Code KernelOp
-> [KernelUse] -> [Exp] -> [Exp] -> Name -> Bool -> Kernel
Imp.Kernel
    { kernelBody :: Code KernelOp
Imp.kernelBody = Code KernelOp
body
    , kernelUses :: [KernelUse]
Imp.kernelUses = [KernelUse]
uses
    , kernelNumGroups :: [Exp]
Imp.kernelNumGroups = [KernelConstants -> Exp
kernelNumGroups KernelConstants
constants]
    , kernelGroupSize :: [Exp]
Imp.kernelGroupSize = [KernelConstants -> Exp
kernelGroupSize KernelConstants
constants]
    , kernelName :: Name
Imp.kernelName = Name
name
    , kernelFailureTolerant :: Bool
Imp.kernelFailureTolerant = Bool
tol
    }

sKernel :: Operations KernelsMem KernelEnv Imp.KernelOp
        -> (KernelConstants -> Imp.Exp)
        -> String
        -> Count NumGroups Imp.Exp
        -> Count GroupSize Imp.Exp
        -> VName
        -> InKernelGen ()
        -> CallKernelGen ()
sKernel :: Operations KernelsMem KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernel Operations KernelsMem KernelEnv KernelOp
ops KernelConstants -> Exp
flatf String
name Count NumGroups Exp
num_groups Count GroupSize Exp
group_size VName
v InKernelGen ()
f = do
  (KernelConstants
constants, InKernelGen ()
set_constants) <- Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple Count NumGroups Exp
num_groups Count GroupSize Exp
group_size
  Name
name' <- String -> ImpM KernelsMem HostEnv HostOp Name
forall lore r op. String -> ImpM lore r op Name
nameForFun (String -> ImpM KernelsMem HostEnv HostOp Name)
-> String -> ImpM KernelsMem HostEnv HostOp Name
forall a b. (a -> b) -> a -> b
$ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag VName
v)
  Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
False Operations KernelsMem KernelEnv KernelOp
ops KernelConstants
constants Name
name' (InKernelGen () -> ImpM KernelsMem HostEnv HostOp ())
-> InKernelGen () -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
v (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
flatf KernelConstants
constants
    InKernelGen ()
f

copyInGroup :: CopyCompiler KernelsMem KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler KernelsMem KernelEnv KernelOp
copyInGroup PrimType
pt MemLocation
destloc [DimIndex Exp]
destslice MemLocation
srcloc [DimIndex Exp]
srcslice = do
  Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem KernelEnv KernelOp MemEntry
-> ImpM KernelsMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
destloc)
  Space
src_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem KernelEnv KernelOp MemEntry
-> ImpM KernelsMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
srcloc)

  case (Space
dest_space, Space
src_space) of
    (ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
      let destslice' :: [DimIndex Exp]
destslice' =
            Int -> DimIndex Exp -> [DimIndex Exp]
forall a. Int -> a -> [a]
replicate ([DimIndex Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex Exp]
destslice Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
0) [DimIndex Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall a. [a] -> [a] -> [a]
++
            Int -> [DimIndex Exp] -> [DimIndex Exp]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) [DimIndex Exp]
destslice
          srcslice' :: [DimIndex Exp]
srcslice' =
            Int -> DimIndex Exp -> [DimIndex Exp]
forall a. Int -> a -> [a]
replicate ([DimIndex Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex Exp]
srcslice Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
0) [DimIndex Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall a. [a] -> [a] -> [a]
++
            Int -> [DimIndex Exp] -> [DimIndex Exp]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) [DimIndex Exp]
srcslice
      CopyCompiler KernelsMem KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
destloc [DimIndex Exp]
destslice' MemLocation
srcloc [DimIndex Exp]
srcslice'

    (Space, Space)
_ -> do
      [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace ([DimIndex Exp] -> [Exp]
forall d. Slice d -> [d]
sliceDims [DimIndex Exp]
destslice) (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
        CopyCompiler KernelsMem KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt
        MemLocation
destloc ((Exp -> DimIndex Exp) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix ([Exp] -> [DimIndex Exp]) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> [Exp] -> [Exp]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex Exp]
destslice [Exp]
is)
        MemLocation
srcloc ((Exp -> DimIndex Exp) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix ([Exp] -> [DimIndex Exp]) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> [Exp] -> [Exp]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex Exp]
srcslice [Exp]
is)
      KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

threadOperations, groupOperations :: Operations KernelsMem KernelEnv Imp.KernelOp
threadOperations :: Operations KernelsMem KernelEnv KernelOp
threadOperations =
  (OpCompiler KernelsMem KernelEnv KernelOp
-> Operations KernelsMem KernelEnv KernelOp
forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler KernelsMem KernelEnv KernelOp
compileThreadOp)
  { opsCopyCompiler :: CopyCompiler KernelsMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler KernelsMem KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise
  , opsExpCompiler :: ExpCompiler KernelsMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler KernelsMem KernelEnv KernelOp
compileThreadExp
  , opsStmsCompiler :: Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty
  , opsAllocCompilers :: Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
opsAllocCompilers =
      [(Space, AllocCompiler KernelsMem KernelEnv KernelOp)]
-> Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (String -> Space
Space String
"local", AllocCompiler KernelsMem KernelEnv KernelOp
forall r. AllocCompiler KernelsMem r KernelOp
allocLocal) ]
  }
groupOperations :: Operations KernelsMem KernelEnv KernelOp
groupOperations =
  (OpCompiler KernelsMem KernelEnv KernelOp
-> Operations KernelsMem KernelEnv KernelOp
forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler KernelsMem KernelEnv KernelOp
compileGroupOp)
  { opsCopyCompiler :: CopyCompiler KernelsMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler KernelsMem KernelEnv KernelOp
copyInGroup
  , opsExpCompiler :: ExpCompiler KernelsMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler KernelsMem KernelEnv KernelOp
compileGroupExp
  , opsStmsCompiler :: Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty
  , opsAllocCompilers :: Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
opsAllocCompilers =
      [(Space, AllocCompiler KernelsMem KernelEnv KernelOp)]
-> Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (String -> Space
Space String
"local", AllocCompiler KernelsMem KernelEnv KernelOp
forall r. AllocCompiler KernelsMem r KernelOp
allocLocal) ]
  }

-- | Perform a Replicate with a kernel.
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel :: VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se = do
  Type
t <- SubExp -> ImpM KernelsMem HostEnv HostOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
  [SubExp]
ds <- Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
dropLast (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) ([SubExp] -> [SubExp]) -> (Type -> [SubExp]) -> Type -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> ImpM KernelsMem HostEnv HostOp Type
-> ImpM KernelsMem HostEnv HostOp [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr

  [Exp]
dims <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp])
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ds [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
  (KernelConstants
constants, InKernelGen ()
set_constants) <-
    Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims) String
"replicate"

  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let name :: Name
name = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
             String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)
      is' :: [Exp]
is' = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants

  Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations KernelsMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM KernelsMem HostEnv HostOp ())
-> InKernelGen () -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp]
is' SubExp
se ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Int -> [Exp] -> [Exp]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [Exp]
is'

replicateName :: PrimType -> String
replicateName :: PrimType -> String
replicateName PrimType
bt = String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt

replicateForType :: PrimType -> CallKernelGen Name
replicateForType :: PrimType -> ImpM KernelsMem HostEnv HostOp Name
replicateForType PrimType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
replicateName PrimType
bt

  Bool
exists <- Name -> ImpM KernelsMem HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
    VName
num_elems <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"num_elems"
    VName
val <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"val"

    let params :: [Param]
params = [VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
                  VName -> PrimType -> Param
Imp.ScalarParam VName
num_elems PrimType
int32,
                  VName -> PrimType -> Param
Imp.ScalarParam VName
val PrimType
bt]
        shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_elems]
    Name
-> [Param]
-> [Param]
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [] [Param]
params (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
      VName
arr <- String
-> PrimType
-> Shape
-> MemBind
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"arr" PrimType
bt Shape
shape (MemBind -> ImpM KernelsMem HostEnv HostOp VName)
-> MemBind -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
             (SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicateKernel VName
arr (SubExp -> ImpM KernelsMem HostEnv HostOp ())
-> SubExp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val

  Name -> ImpM KernelsMem HostEnv HostOp Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill :: VName
-> SubExp
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
v = do
  ArrayEntry (MemLocation VName
arr_mem [SubExp]
arr_shape IxFun Exp
arr_ixfun) PrimType
_ <- VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
  Type
v_t <- SubExp -> ImpM KernelsMem HostEnv HostOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v
  case Type
v_t of
    Prim PrimType
v_t'
      | IxFun Exp -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun Exp
arr_ixfun -> Maybe (ImpM KernelsMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ImpM KernelsMem HostEnv HostOp ())
 -> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ())))
-> Maybe (ImpM KernelsMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
forall a b. (a -> b) -> a -> b
$ ImpM KernelsMem HostEnv HostOp ()
-> Maybe (ImpM KernelsMem HostEnv HostOp ())
forall a. a -> Maybe a
Just (ImpM KernelsMem HostEnv HostOp ()
 -> Maybe (ImpM KernelsMem HostEnv HostOp ()))
-> ImpM KernelsMem HostEnv HostOp ()
-> Maybe (ImpM KernelsMem HostEnv HostOp ())
forall a b. (a -> b) -> a -> b
$ do
          Name
fname <- PrimType -> ImpM KernelsMem HostEnv HostOp Name
replicateForType PrimType
v_t'
          Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [] Name
fname
            [VName -> Arg
Imp.MemArg VName
arr_mem,
             Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
arr_shape,
             Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
v_t' SubExp
v]
    Type
_ -> Maybe (ImpM KernelsMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ImpM KernelsMem HostEnv HostOp ())
forall a. Maybe a
Nothing

-- | Perform a Replicate with a kernel.
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate :: VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicate VName
arr SubExp
se = do
  -- If the replicate is of a particularly common and simple form
  -- (morally a memset()/fill), then we use a common function.
  Maybe (ImpM KernelsMem HostEnv HostOp ())
is_fill <- VName
-> SubExp
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
se

  case Maybe (ImpM KernelsMem HostEnv HostOp ())
is_fill of
    Just ImpM KernelsMem HostEnv HostOp ()
m -> ImpM KernelsMem HostEnv HostOp ()
m
    Maybe (ImpM KernelsMem HostEnv HostOp ())
Nothing -> VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se

-- | Perform an Iota with a kernel.
sIotaKernel :: VName -> Imp.Exp -> Imp.Exp -> Imp.Exp -> IntType
            -> CallKernelGen ()
sIotaKernel :: VName
-> Exp
-> Exp
-> Exp
-> IntType
-> ImpM KernelsMem HostEnv HostOp ()
sIotaKernel VName
arr Exp
n Exp
x Exp
s IntType
et = do
  MemLocation
destloc <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
  (KernelConstants
constants, InKernelGen ()
set_constants) <- Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
n String
"iota"

  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let name :: Name
name = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
             String
"iota_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
et String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++
             Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations KernelsMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM KernelsMem HostEnv HostOp ())
-> InKernelGen () -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    let gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
    Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      (VName
destmem, Space
destspace, Count Elements Exp
destidx) <- MemLocation
-> [Exp]
-> ImpM
     KernelsMem KernelEnv KernelOp (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
destloc [Exp
gtid]

      Code KernelOp -> InKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code KernelOp
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements Exp
destidx (IntType -> PrimType
IntType IntType
et) Space
destspace Volatility
Imp.Nonvolatile (Exp -> Code KernelOp) -> Exp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
        IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
Imp.sExt IntType
et Exp
gtid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
x

iotaName :: IntType -> String
iotaName :: IntType -> String
iotaName IntType
bt = String
"iota_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
bt

iotaForType :: IntType -> CallKernelGen Name
iotaForType :: IntType -> ImpM KernelsMem HostEnv HostOp Name
iotaForType IntType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> IntType -> String
iotaName IntType
bt

  Bool
exists <- Name -> ImpM KernelsMem HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
    VName
n <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"n"
    VName
x <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"x"
    VName
s <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"s"

    let params :: [Param]
params = [VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
                  VName -> PrimType -> Param
Imp.ScalarParam VName
n PrimType
int32,
                  VName -> PrimType -> Param
Imp.ScalarParam VName
x (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt,
                  VName -> PrimType -> Param
Imp.ScalarParam VName
s (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType