{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Facilities for converting a 'Kernels' program to 'KernelsMem'.
module Futhark.Pass.ExplicitAllocations.Kernels
  ( explicitAllocations,
    explicitAllocationsInStms,
  )
where

import qualified Data.Map as M
import qualified Data.Set as S
import Futhark.IR.Kernels
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass.ExplicitAllocations
import Futhark.Pass.ExplicitAllocations.SegOp

instance SizeSubst (HostOp lore op) where
  opSizeSubst :: forall dec. PatternT dec -> HostOp lore op -> ChunkMap
opSizeSubst (Pattern [PatElemT dec]
_ [PatElemT dec
size]) (SizeOp (SplitSpace SplitOrdering
_ SubExp
_ SubExp
_ SubExp
elems_per_thread)) =
    VName -> SubExp -> ChunkMap
forall k a. k -> a -> Map k a
M.singleton (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
size) SubExp
elems_per_thread
  opSizeSubst PatternT dec
_ HostOp lore op
_ = ChunkMap
forall a. Monoid a => a
mempty

  opIsConst :: HostOp lore op -> Bool
opIsConst (SizeOp GetSize {}) = Bool
True
  opIsConst (SizeOp GetSizeMax {}) = Bool
True
  opIsConst HostOp lore op
_ = Bool
False

allocAtLevel :: SegLevel -> AllocM fromlore tlore a -> AllocM fromlore tlore a
allocAtLevel :: forall fromlore tlore a.
SegLevel -> AllocM fromlore tlore a -> AllocM fromlore tlore a
allocAtLevel SegLevel
lvl = (AllocEnv fromlore tlore -> AllocEnv fromlore tlore)
-> AllocM fromlore tlore a -> AllocM fromlore tlore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((AllocEnv fromlore tlore -> AllocEnv fromlore tlore)
 -> AllocM fromlore tlore a -> AllocM fromlore tlore a)
-> (AllocEnv fromlore tlore -> AllocEnv fromlore tlore)
-> AllocM fromlore tlore a
-> AllocM fromlore tlore a
forall a b. (a -> b) -> a -> b
$ \AllocEnv fromlore tlore
env ->
  AllocEnv fromlore tlore
env
    { allocSpace :: Space
allocSpace = Space
space,
      aggressiveReuse :: Bool
aggressiveReuse = Bool
True
    }
  where
    space :: Space
space = case SegLevel
lvl of
      SegThread {} -> Space
DefaultSpace
      SegGroup {} -> SpaceId -> Space
Space SpaceId
"local"

handleSegOp ::
  SegOp SegLevel Kernels ->
  AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
handleSegOp :: SegOp SegLevel Kernels
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
handleSegOp SegOp SegLevel Kernels
op = do
  SubExp
num_threads <-
    SpaceId
-> Exp (Lore (AllocM Kernels KernelsMem))
-> AllocM Kernels KernelsMem SubExp
forall (m :: * -> *).
MonadBinder m =>
SpaceId -> Exp (Lore m) -> m SubExp
letSubExp SpaceId
"num_threads" (Exp (Lore (AllocM Kernels KernelsMem))
 -> AllocM Kernels KernelsMem SubExp)
-> Exp (Lore (AllocM Kernels KernelsMem))
-> AllocM Kernels KernelsMem SubExp
forall a b. (a -> b) -> a -> b
$
      BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp
          (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
          (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl))
          (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl))
  SegLevel
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
forall fromlore tlore a.
SegLevel -> AllocM fromlore tlore a -> AllocM fromlore tlore a
allocAtLevel SegLevel
lvl (AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
 -> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem))
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
forall a b. (a -> b) -> a -> b
$ SegOpMapper SegLevel Kernels KernelsMem (AllocM Kernels KernelsMem)
-> SegOp SegLevel Kernels
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM (SubExp
-> SegOpMapper
     SegLevel Kernels KernelsMem (AllocM Kernels KernelsMem)
mapper SubExp
num_threads) SegOp SegLevel Kernels
op
  where
    scope :: Scope KernelsMem
scope = SegSpace -> Scope KernelsMem
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope KernelsMem) -> SegSpace -> Scope KernelsMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op
    lvl :: SegLevel
lvl = SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op
    mapper :: SubExp
-> SegOpMapper
     SegLevel Kernels KernelsMem (AllocM Kernels KernelsMem)
mapper SubExp
num_threads =
      SegOpMapper SegLevel Any Any (AllocM Kernels KernelsMem)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpBody :: KernelBody Kernels
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
mapOnSegOpBody =
            Scope KernelsMem
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope (AllocM Kernels KernelsMem (KernelBody KernelsMem)
 -> AllocM Kernels KernelsMem (KernelBody KernelsMem))
-> (KernelBody Kernels
    -> AllocM Kernels KernelsMem (KernelBody KernelsMem))
-> KernelBody Kernels
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AllocEnv Kernels KernelsMem -> AllocEnv Kernels KernelsMem)
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv Kernels KernelsMem -> AllocEnv Kernels KernelsMem
f (AllocM Kernels KernelsMem (KernelBody KernelsMem)
 -> AllocM Kernels KernelsMem (KernelBody KernelsMem))
-> (KernelBody Kernels
    -> AllocM Kernels KernelsMem (KernelBody KernelsMem))
-> KernelBody Kernels
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody Kernels
-> AllocM Kernels KernelsMem (KernelBody KernelsMem)
forall fromlore tolore.
Allocable fromlore tolore =>
KernelBody fromlore -> AllocM fromlore tolore (KernelBody tolore)
allocInKernelBody,
          mapOnSegOpLambda :: Lambda Kernels -> AllocM Kernels KernelsMem (Lambda KernelsMem)
mapOnSegOpLambda =
            (AllocEnv Kernels KernelsMem -> AllocEnv Kernels KernelsMem)
-> AllocM Kernels KernelsMem (Lambda KernelsMem)
-> AllocM Kernels KernelsMem (Lambda KernelsMem)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv Kernels KernelsMem -> AllocEnv Kernels KernelsMem
forall {fromlore}.
(PrettyLore fromlore, BranchType fromlore ~ ExtType,
 LParamInfo fromlore ~ Type, RetType fromlore ~ DeclExtType,
 FParamInfo fromlore ~ DeclType, BodyDec fromlore ~ ()) =>
AllocEnv fromlore KernelsMem -> AllocEnv fromlore KernelsMem
inThread
              (AllocM Kernels KernelsMem (Lambda KernelsMem)
 -> AllocM Kernels KernelsMem (Lambda KernelsMem))
-> (Lambda Kernels
    -> AllocM Kernels KernelsMem (Lambda KernelsMem))
-> Lambda Kernels
-> AllocM Kernels KernelsMem (Lambda KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp
-> SegSpace
-> Lambda Kernels
-> AllocM Kernels KernelsMem (Lambda KernelsMem)
forall fromlore tolore.
Allocable fromlore tolore =>
SubExp
-> SegSpace
-> Lambda fromlore
-> AllocM fromlore tolore (Lambda tolore)
allocInBinOpLambda SubExp
num_threads (SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op)
        }
    f :: AllocEnv Kernels KernelsMem -> AllocEnv Kernels KernelsMem
f = case SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op of
      SegThread {} -> AllocEnv Kernels KernelsMem -> AllocEnv Kernels KernelsMem
forall {fromlore}.
(PrettyLore fromlore, BranchType fromlore ~ ExtType,
 LParamInfo fromlore ~ Type, RetType fromlore ~ DeclExtType,
 FParamInfo fromlore ~ DeclType, BodyDec fromlore ~ ()) =>
AllocEnv fromlore KernelsMem -> AllocEnv fromlore KernelsMem
inThread
      SegGroup {} -> AllocEnv Kernels KernelsMem -> AllocEnv Kernels KernelsMem
forall {fromlore}.
(PrettyLore fromlore, BranchType fromlore ~ ExtType,
 LParamInfo fromlore ~ Type, RetType fromlore ~ DeclExtType,
 FParamInfo fromlore ~ DeclType, BodyDec fromlore ~ ()) =>
AllocEnv fromlore KernelsMem -> AllocEnv fromlore KernelsMem
inGroup
    inThread :: AllocEnv fromlore KernelsMem -> AllocEnv fromlore KernelsMem
inThread AllocEnv fromlore KernelsMem
env = AllocEnv fromlore KernelsMem
env {envExpHints :: ExpT KernelsMem -> AllocM fromlore KernelsMem [ExpHint]
envExpHints = ExpT KernelsMem -> AllocM fromlore KernelsMem [ExpHint]
forall (m :: * -> *).
Allocator KernelsMem m =>
ExpT KernelsMem -> m [ExpHint]
inThreadExpHints}
    inGroup :: AllocEnv fromlore KernelsMem -> AllocEnv fromlore KernelsMem
inGroup AllocEnv fromlore KernelsMem
env = AllocEnv fromlore KernelsMem
env {envExpHints :: ExpT KernelsMem -> AllocM fromlore KernelsMem [ExpHint]
envExpHints = ExpT KernelsMem -> AllocM fromlore KernelsMem [ExpHint]
forall (m :: * -> *).
Allocator KernelsMem m =>
ExpT KernelsMem -> m [ExpHint]
inGroupExpHints}

handleHostOp ::
  HostOp Kernels (SOAC Kernels) ->
  AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
handleHostOp :: HostOp Kernels (SOAC Kernels)
-> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
handleHostOp (SizeOp SizeOp
op) =
  MemOp (HostOp KernelsMem ())
-> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (MemOp (HostOp KernelsMem ())
 -> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ())))
-> MemOp (HostOp KernelsMem ())
-> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp KernelsMem ()
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
handleHostOp (OtherOp SOAC Kernels
op) =
  SpaceId -> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
forall a. HasCallStack => SpaceId -> a
error (SpaceId
 -> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ())))
-> SpaceId
-> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot allocate memory in SOAC: " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SOAC Kernels -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty SOAC Kernels
op
handleHostOp (SegOp SegOp SegLevel Kernels
op) =
  HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem
-> MemOp (HostOp KernelsMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> MemOp (HostOp KernelsMem ()))
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
-> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel Kernels
-> AllocM Kernels KernelsMem (SegOp SegLevel KernelsMem)
handleSegOp SegOp SegLevel Kernels
op

kernelExpHints :: Allocator KernelsMem m => Exp KernelsMem -> m [ExpHint]
kernelExpHints :: forall (m :: * -> *).
Allocator KernelsMem m =>
ExpT KernelsMem -> m [ExpHint]
kernelExpHints (BasicOp (Manifest [Int]
perm VName
v)) = do
  [SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> m Type -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  let perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
      dims' :: [SubExp]
dims' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
      ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims') [Int]
perm_inv
  [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return [IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun Space
DefaultSpace]
kernelExpHints (Op (Inner (SegOp (SegMap lvl :: SegLevel
lvl@SegThread {} SegSpace
space [Type]
ts KernelBody KernelsMem
body)))) =
  (Type -> KernelResult -> m ExpHint)
-> [Type] -> [KernelResult] -> m [ExpHint]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
forall lore (m :: * -> *).
Allocator lore m =>
SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint SegLevel
lvl SegSpace
space) [Type]
ts ([KernelResult] -> m [ExpHint]) -> [KernelResult] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
kernelExpHints (Op (Inner (SegOp (SegRed lvl :: SegLevel
lvl@SegThread {} SegSpace
space [SegBinOp KernelsMem]
reds [Type]
ts KernelBody KernelsMem
body)))) =
  ((KernelResult -> ExpHint) -> [KernelResult] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> KernelResult -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [KernelResult]
red_res [ExpHint] -> [ExpHint] -> [ExpHint]
forall a. Semigroup a => a -> a -> a
<>) ([ExpHint] -> [ExpHint]) -> m [ExpHint] -> m [ExpHint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> KernelResult -> m ExpHint)
-> [Type] -> [KernelResult] -> m [ExpHint]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
forall lore (m :: * -> *).
Allocator lore m =>
SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint SegLevel
lvl SegSpace
space) (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_reds [Type]
ts) [KernelResult]
map_res
  where
    num_reds :: Int
num_reds = [SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
reds
    ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_reds ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
kernelExpHints ExpT KernelsMem
e =
  [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (ExpT KernelsMem -> Int
forall lore.
(Decorations lore, TypedOp (Op lore)) =>
Exp lore -> Int
expExtTypeSize ExpT KernelsMem
e) ExpHint
NoHint

mapResultHint ::
  Allocator lore m =>
  SegLevel ->
  SegSpace ->
  Type ->
  KernelResult ->
  m ExpHint
mapResultHint :: forall lore (m :: * -> *).
Allocator lore m =>
SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint SegLevel
lvl SegSpace
space = Type -> KernelResult -> m ExpHint
hint
  where
    num_threads :: TPrimExp Int64 VName
num_threads =
      SubExp -> TPrimExp Int64 VName
pe64 (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

    -- Heuristic: do not rearrange for returned arrays that are
    -- sufficiently small.
    coalesceReturnOfShape :: Int64 -> [SubExp] -> Bool
coalesceReturnOfShape Int64
_ [] = Bool
False
    coalesceReturnOfShape Int64
bs [Constant (IntValue (Int64Value Int64
d))] = Int64
bs Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
d Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
> Int64
4
    coalesceReturnOfShape Int64
_ [SubExp]
_ = Bool
True

    hint :: Type -> KernelResult -> m ExpHint
hint Type
t Returns {}
      | Int64 -> [SubExp] -> Bool
coalesceReturnOfShape (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t = do
        let space_dims :: [SubExp]
space_dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
        [SubExp]
t_dims <- (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize ([SubExp] -> m [SubExp]) -> [SubExp] -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
        ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint ([SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp]
space_dims [SubExp]
t_dims) Space
DefaultSpace
    hint Type
t (ConcatReturns SplitStrided {} SubExp
w SubExp
_ VName
_) = do
      [SubExp]
t_dims <- (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize ([SubExp] -> m [SubExp]) -> [SubExp] -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
      ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint ([SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp
w] [SubExp]
t_dims) Space
DefaultSpace
    hint Prim {} (ConcatReturns SplitOrdering
SplitContiguous SubExp
w SubExp
elems_per_thread VName
_) = do
      let ixfun_base :: IxFun (TPrimExp Int64 VName)
ixfun_base = Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
num_threads, SubExp -> TPrimExp Int64 VName
pe64 SubExp
elems_per_thread]
          ixfun_tr :: IxFun (TPrimExp Int64 VName)
ixfun_tr = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
ixfun_base [Int
1, Int
0]
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = IxFun (TPrimExp Int64 VName)
-> ShapeChange (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
ixfun_tr (ShapeChange (TPrimExp Int64 VName)
 -> IxFun (TPrimExp Int64 VName))
-> ShapeChange (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimChange (TPrimExp Int64 VName))
-> [SubExp] -> ShapeChange (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimNew (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> DimChange (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) [SubExp
w]
      ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun Space
DefaultSpace
    hint Type
_ KernelResult
_ = ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return ExpHint
NoHint

innermost :: [SubExp] -> [SubExp] -> IxFun
innermost :: [SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp]
space_dims [SubExp]
t_dims =
  let r :: Int
r = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
t_dims
      dims :: [SubExp]
dims = [SubExp]
space_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
t_dims
      perm :: [Int]
perm =
        [[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims .. [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
          [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
      dims_perm :: [SubExp]
dims_perm = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
      ixfun_base :: IxFun (TPrimExp Int64 VName)
ixfun_base = Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims_perm
      ixfun_rearranged :: IxFun (TPrimExp Int64 VName)
ixfun_rearranged = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
ixfun_base [Int]
perm_inv
   in IxFun (TPrimExp Int64 VName)
ixfun_rearranged

semiStatic :: S.Set VName -> SubExp -> Bool
semiStatic :: Set VName -> SubExp -> Bool
semiStatic Set VName
_ Constant {} = Bool
True
semiStatic Set VName
consts (Var VName
v) = VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
consts

inGroupExpHints :: Allocator KernelsMem m => Exp KernelsMem -> m [ExpHint]
inGroupExpHints :: forall (m :: * -> *).
Allocator KernelsMem m =>
ExpT KernelsMem -> m [ExpHint]
inGroupExpHints (Op (Inner (SegOp (SegMap SegLevel
_ SegSpace
space [Type]
ts KernelBody KernelsMem
body))))
  | (KernelResult -> Bool) -> [KernelResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelResult -> Bool
private ([KernelResult] -> Bool) -> [KernelResult] -> Bool
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body = do
    Set VName
consts <- m (Set VName)
forall lore (m :: * -> *). Allocator lore m => m (Set VName)
askConsts
    [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ do
      (Type
t, KernelResult
r) <- [Type] -> [KernelResult] -> [(Type, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ts ([KernelResult] -> [(Type, KernelResult)])
-> [KernelResult] -> [(Type, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
      ExpHint -> [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> [ExpHint]) -> ExpHint -> [ExpHint]
forall a b. (a -> b) -> a -> b
$
        if KernelResult -> Bool
private KernelResult
r Bool -> Bool -> Bool
&& (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Set VName -> SubExp -> Bool
semiStatic Set VName
consts) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
          then
            let seg_dims :: Shape (TPrimExp Int64 VName)
seg_dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
                dims :: Shape (TPrimExp Int64 VName)
dims = Shape (TPrimExp Int64 VName)
seg_dims Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
                nilSlice :: d -> DimIndex d
nilSlice d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
0
             in IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint
                  ( IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota Shape (TPrimExp Int64 VName)
dims) (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                      Shape (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum Shape (TPrimExp Int64 VName)
dims (Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
nilSlice Shape (TPrimExp Int64 VName)
seg_dims
                  )
                  (Space -> ExpHint) -> Space -> ExpHint
forall a b. (a -> b) -> a -> b
$ [SubExp] -> PrimType -> Space
ScalarSpace (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimType -> Space) -> PrimType -> Space
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
          else ExpHint
NoHint
  where
    private :: KernelResult -> Bool
private (Returns ResultManifest
ResultPrivate SubExp
_) = Bool
True
    private KernelResult
_ = Bool
False
inGroupExpHints ExpT KernelsMem
e = [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (ExpT KernelsMem -> Int
forall lore.
(Decorations lore, TypedOp (Op lore)) =>
Exp lore -> Int
expExtTypeSize ExpT KernelsMem
e) ExpHint
NoHint

inThreadExpHints :: Allocator KernelsMem m => Exp KernelsMem -> m [ExpHint]
inThreadExpHints :: forall (m :: * -> *).
Allocator KernelsMem m =>
ExpT KernelsMem -> m [ExpHint]
inThreadExpHints ExpT KernelsMem
e = do
  Set VName
consts <- m (Set VName)
forall lore (m :: * -> *). Allocator lore m => m (Set VName)
askConsts
  (ExtType -> m ExpHint) -> [ExtType] -> m [ExpHint]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set VName -> ExtType -> m ExpHint
forall {m :: * -> *} {u}.
Monad m =>
Set VName -> TypeBase ExtShape u -> m ExpHint
maybePrivate Set VName
consts) ([ExtType] -> m [ExpHint]) -> m [ExtType] -> m [ExpHint]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpT KernelsMem -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [ExtType]
expExtType ExpT KernelsMem
e
  where
    maybePrivate :: Set VName -> TypeBase ExtShape u -> m ExpHint
maybePrivate Set VName
consts TypeBase ExtShape u
t
      | Just (Array PrimType
pt Shape
shape u
_) <- TypeBase ExtShape u -> Maybe (TypeBase Shape u)
forall u. TypeBase ExtShape u -> Maybe (TypeBase Shape u)
hasStaticShape TypeBase ExtShape u
t,
        (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Set VName -> SubExp -> Bool
semiStatic Set VName
consts) ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape = do
        let ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
        ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun (Space -> ExpHint) -> Space -> ExpHint
forall a b. (a -> b) -> a -> b
$ [SubExp] -> PrimType -> Space
ScalarSpace (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) PrimType
pt
      | Bool
otherwise =
        ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return ExpHint
NoHint

-- | The pass from 'Kernels' to 'KernelsMem'.
explicitAllocations :: Pass Kernels KernelsMem
explicitAllocations :: Pass Kernels KernelsMem
explicitAllocations = (Op Kernels -> AllocM Kernels KernelsMem (Op KernelsMem))
-> (ExpT KernelsMem -> AllocM Kernels KernelsMem [ExpHint])
-> Pass Kernels KernelsMem
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> Pass fromlore tolore
explicitAllocationsGeneric Op Kernels -> AllocM Kernels KernelsMem (Op KernelsMem)
HostOp Kernels (SOAC Kernels)
-> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
handleHostOp ExpT KernelsMem -> AllocM Kernels KernelsMem [ExpHint]
forall (m :: * -> *).
Allocator KernelsMem m =>
ExpT KernelsMem -> m [ExpHint]
kernelExpHints

-- | Convert some 'Kernels' stms to 'KernelsMem'.
explicitAllocationsInStms ::
  (MonadFreshNames m, HasScope KernelsMem m) =>
  Stms Kernels ->
  m (Stms KernelsMem)
explicitAllocationsInStms :: forall (m :: * -> *).
(MonadFreshNames m, HasScope KernelsMem m) =>
Stms Kernels -> m (Stms KernelsMem)
explicitAllocationsInStms = (Op Kernels -> AllocM Kernels KernelsMem (Op KernelsMem))
-> (ExpT KernelsMem -> AllocM Kernels KernelsMem [ExpHint])
-> Stms Kernels
-> m (Stms KernelsMem)
forall (m :: * -> *) tolore fromlore.
(MonadFreshNames m, HasScope tolore m,
 Allocable fromlore tolore) =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> Stms fromlore
-> m (Stms tolore)
explicitAllocationsInStmsGeneric Op Kernels -> AllocM Kernels KernelsMem (Op KernelsMem)
HostOp Kernels (SOAC Kernels)
-> AllocM Kernels KernelsMem (MemOp (HostOp KernelsMem ()))
handleHostOp ExpT KernelsMem -> AllocM Kernels KernelsMem [ExpHint]
forall (m :: * -> *).
Allocator KernelsMem m =>
ExpT KernelsMem -> m [ExpHint]
kernelExpHints