{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.Pass.ExplicitAllocations.GPU
( explicitAllocations,
explicitAllocationsInStms,
)
where
import qualified Data.Map as M
import qualified Data.Set as S
import Futhark.IR.GPU
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass.ExplicitAllocations
import Futhark.Pass.ExplicitAllocations.SegOp
instance SizeSubst (HostOp rep op) where
opSizeSubst :: Pat dec -> HostOp rep op -> ChunkMap
opSizeSubst (Pat [PatElem dec
size]) (SizeOp (SplitSpace SplitOrdering
_ SubExp
_ SubExp
_ SubExp
elems_per_thread)) =
VName -> SubExp -> ChunkMap
forall k a. k -> a -> Map k a
M.singleton (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
size) SubExp
elems_per_thread
opSizeSubst Pat dec
_ HostOp rep op
_ = ChunkMap
forall a. Monoid a => a
mempty
opIsConst :: HostOp rep op -> Bool
opIsConst (SizeOp GetSize {}) = Bool
True
opIsConst (SizeOp GetSizeMax {}) = Bool
True
opIsConst HostOp rep op
_ = Bool
False
allocAtLevel :: SegLevel -> AllocM fromrep trep a -> AllocM fromrep trep a
allocAtLevel :: SegLevel -> AllocM fromrep trep a -> AllocM fromrep trep a
allocAtLevel SegLevel
lvl = (AllocEnv fromrep trep -> AllocEnv fromrep trep)
-> AllocM fromrep trep a -> AllocM fromrep trep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((AllocEnv fromrep trep -> AllocEnv fromrep trep)
-> AllocM fromrep trep a -> AllocM fromrep trep a)
-> (AllocEnv fromrep trep -> AllocEnv fromrep trep)
-> AllocM fromrep trep a
-> AllocM fromrep trep a
forall a b. (a -> b) -> a -> b
$ \AllocEnv fromrep trep
env ->
AllocEnv fromrep trep
env
{ allocSpace :: Space
allocSpace = Space
space,
aggressiveReuse :: Bool
aggressiveReuse = Bool
True
}
where
space :: Space
space = case SegLevel
lvl of
SegThread {} -> Space
DefaultSpace
SegGroup {} -> SpaceId -> Space
Space SpaceId
"local"
handleSegOp ::
SegOp SegLevel GPU ->
AllocM GPU GPUMem (SegOp SegLevel GPUMem)
handleSegOp :: SegOp SegLevel GPU -> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
handleSegOp SegOp SegLevel GPU
op = do
SubExp
num_threads <-
SpaceId
-> Exp (Rep (AllocM GPU GPUMem)) -> AllocM GPU GPUMem SubExp
forall (m :: * -> *).
MonadBuilder m =>
SpaceId -> Exp (Rep m) -> m SubExp
letSubExp SpaceId
"num_threads" (Exp (Rep (AllocM GPU GPUMem)) -> AllocM GPU GPUMem SubExp)
-> Exp (Rep (AllocM GPU GPUMem)) -> AllocM GPU GPUMem SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp GPUMem
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPUMem) -> BasicOp -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl))
(Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl))
SegLevel
-> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
-> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
forall fromrep trep a.
SegLevel -> AllocM fromrep trep a -> AllocM fromrep trep a
allocAtLevel SegLevel
lvl (AllocM GPU GPUMem (SegOp SegLevel GPUMem)
-> AllocM GPU GPUMem (SegOp SegLevel GPUMem))
-> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
-> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
forall a b. (a -> b) -> a -> b
$ SegOpMapper SegLevel GPU GPUMem (AllocM GPU GPUMem)
-> SegOp SegLevel GPU -> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM (SubExp -> SegOpMapper SegLevel GPU GPUMem (AllocM GPU GPUMem)
mapper SubExp
num_threads) SegOp SegLevel GPU
op
where
scope :: Scope GPUMem
scope = SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace -> Scope GPUMem) -> SegSpace -> Scope GPUMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPU
op
lvl :: SegLevel
lvl = SegOp SegLevel GPU -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op
mapper :: SubExp -> SegOpMapper SegLevel GPU GPUMem (AllocM GPU GPUMem)
mapper SubExp
num_threads =
SegOpMapper SegLevel Any Any (AllocM GPU GPUMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpBody :: KernelBody GPU -> AllocM GPU GPUMem (KernelBody GPUMem)
mapOnSegOpBody =
Scope GPUMem
-> AllocM GPU GPUMem (KernelBody GPUMem)
-> AllocM GPU GPUMem (KernelBody GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (AllocM GPU GPUMem (KernelBody GPUMem)
-> AllocM GPU GPUMem (KernelBody GPUMem))
-> (KernelBody GPU -> AllocM GPU GPUMem (KernelBody GPUMem))
-> KernelBody GPU
-> AllocM GPU GPUMem (KernelBody GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem)
-> AllocM GPU GPUMem (KernelBody GPUMem)
-> AllocM GPU GPUMem (KernelBody GPUMem)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
f (AllocM GPU GPUMem (KernelBody GPUMem)
-> AllocM GPU GPUMem (KernelBody GPUMem))
-> (KernelBody GPU -> AllocM GPU GPUMem (KernelBody GPUMem))
-> KernelBody GPU
-> AllocM GPU GPUMem (KernelBody GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody GPU -> AllocM GPU GPUMem (KernelBody GPUMem)
forall fromrep torep inner.
Allocable fromrep torep inner =>
KernelBody fromrep -> AllocM fromrep torep (KernelBody torep)
allocInKernelBody,
mapOnSegOpLambda :: Lambda GPU -> AllocM GPU GPUMem (Lambda GPUMem)
mapOnSegOpLambda =
(AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem)
-> AllocM GPU GPUMem (Lambda GPUMem)
-> AllocM GPU GPUMem (Lambda GPUMem)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inThread
(AllocM GPU GPUMem (Lambda GPUMem)
-> AllocM GPU GPUMem (Lambda GPUMem))
-> (Lambda GPU -> AllocM GPU GPUMem (Lambda GPUMem))
-> Lambda GPU
-> AllocM GPU GPUMem (Lambda GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp
-> SegSpace -> Lambda GPU -> AllocM GPU GPUMem (Lambda GPUMem)
forall fromrep torep inner.
Allocable fromrep torep inner =>
SubExp
-> SegSpace
-> Lambda fromrep
-> AllocM fromrep torep (Lambda torep)
allocInBinOpLambda SubExp
num_threads (SegOp SegLevel GPU -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPU
op)
}
f :: AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
f = case SegOp SegLevel GPU -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op of
SegThread {} -> AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inThread
SegGroup {} -> AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inGroup
inThread :: AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inThread AllocEnv GPU GPUMem
env = AllocEnv GPU GPUMem
env {envExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
envExpHints = Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inThreadExpHints}
inGroup :: AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inGroup AllocEnv GPU GPUMem
env = AllocEnv GPU GPUMem
env {envExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
envExpHints = Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inGroupExpHints}
handleHostOp ::
HostOp GPU (SOAC GPU) ->
AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
handleHostOp :: HostOp GPU (SOAC GPU)
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
handleHostOp (SizeOp SizeOp
op) =
MemOp (HostOp GPUMem ())
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemOp (HostOp GPUMem ())
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ())))
-> MemOp (HostOp GPUMem ())
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPUMem ()
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
handleHostOp (OtherOp SOAC GPU
op) =
SpaceId -> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> AllocM GPU GPUMem (MemOp (HostOp GPUMem ())))
-> SpaceId -> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot allocate memory in SOAC: " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SOAC GPU -> SpaceId
forall a. Pretty a => a -> SpaceId
pretty SOAC GPU
op
handleHostOp (SegOp SegOp SegLevel GPU
op) =
HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem ()))
-> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel GPU -> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
handleSegOp SegOp SegLevel GPU
op
handleHostOp (GPUBody [Type]
ts (Body BodyDec GPU
_ Stms GPU
stms Result
res)) =
(Body GPUMem -> MemOp (HostOp GPUMem ()))
-> AllocM GPU GPUMem (Body GPUMem)
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (Body GPUMem -> HostOp GPUMem ())
-> Body GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> Body GPUMem -> HostOp GPUMem ()
forall rep op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
ts) (AllocM GPU GPUMem (Body GPUMem)
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ())))
-> (AllocM GPU GPUMem Result -> AllocM GPU GPUMem (Body GPUMem))
-> AllocM GPU GPUMem Result
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocM GPU GPUMem Result -> AllocM GPU GPUMem (Body GPUMem)
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM GPU GPUMem Result -> AllocM GPU GPUMem (Body GPUMem))
-> (AllocM GPU GPUMem Result -> AllocM GPU GPUMem Result)
-> AllocM GPU GPUMem Result
-> AllocM GPU GPUMem (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPU -> AllocM GPU GPUMem Result -> AllocM GPU GPUMem Result
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms GPU
stms (AllocM GPU GPUMem Result
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ())))
-> AllocM GPU GPUMem Result
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
forall a b. (a -> b) -> a -> b
$ Result -> AllocM GPU GPUMem Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
kernelExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints (BasicOp (Manifest [Int]
perm VName
v)) = do
[SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> AllocM GPU GPUMem Type -> AllocM GPU GPUMem [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM GPU GPUMem Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
let perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
dims' :: [SubExp]
dims' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims') [Int]
perm_inv
[ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun Space
DefaultSpace]
kernelExpHints (Op (Inner (SegOp (SegMap lvl@SegThread {} space ts body)))) =
(Type -> KernelResult -> AllocM GPU GPUMem ExpHint)
-> [Type] -> [KernelResult] -> AllocM GPU GPUMem [ExpHint]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel
-> SegSpace -> Type -> KernelResult -> AllocM GPU GPUMem ExpHint
mapResultHint SegLevel
lvl SegSpace
space) [Type]
ts ([KernelResult] -> AllocM GPU GPUMem [ExpHint])
-> [KernelResult] -> AllocM GPU GPUMem [ExpHint]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
kernelExpHints (Op (Inner (SegOp (SegRed lvl@SegThread {} space reds ts body)))) =
((KernelResult -> ExpHint) -> [KernelResult] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> KernelResult -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [KernelResult]
red_res [ExpHint] -> [ExpHint] -> [ExpHint]
forall a. Semigroup a => a -> a -> a
<>) ([ExpHint] -> [ExpHint])
-> AllocM GPU GPUMem [ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> KernelResult -> AllocM GPU GPUMem ExpHint)
-> [Type] -> [KernelResult] -> AllocM GPU GPUMem [ExpHint]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel
-> SegSpace -> Type -> KernelResult -> AllocM GPU GPUMem ExpHint
mapResultHint SegLevel
lvl SegSpace
space) (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_reds [Type]
ts) [KernelResult]
map_res
where
num_reds :: Int
num_reds = [SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds
([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_reds ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
kernelExpHints Exp GPUMem
e =
[ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpHint] -> AllocM GPU GPUMem [ExpHint])
-> [ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp GPUMem -> Int
forall rep. (RepTypes rep, TypedOp (Op rep)) => Exp rep -> Int
expExtTypeSize Exp GPUMem
e) ExpHint
NoHint
mapResultHint ::
SegLevel ->
SegSpace ->
Type ->
KernelResult ->
AllocM GPU GPUMem ExpHint
mapResultHint :: SegLevel
-> SegSpace -> Type -> KernelResult -> AllocM GPU GPUMem ExpHint
mapResultHint SegLevel
lvl SegSpace
space = Type -> KernelResult -> AllocM GPU GPUMem ExpHint
hint
where
num_threads :: TPrimExp Int64 VName
num_threads =
SubExp -> TPrimExp Int64 VName
pe64 (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int64 VName
pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
coalesceReturnOfShape :: Int64 -> [SubExp] -> Bool
coalesceReturnOfShape Int64
_ [] = Bool
False
coalesceReturnOfShape Int64
bs [Constant (IntValue (Int64Value Int64
d))] = Int64
bs Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int64
d Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
> Int64
4
coalesceReturnOfShape Int64
_ [SubExp]
_ = Bool
True
hint :: Type -> KernelResult -> AllocM GPU GPUMem ExpHint
hint Type
t Returns {}
| Int64 -> [SubExp] -> Bool
coalesceReturnOfShape (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t = do
ChunkMap
chunkmap <- (AllocEnv GPU GPUMem -> ChunkMap) -> AllocM GPU GPUMem ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv GPU GPUMem -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
let space_dims :: [SubExp]
space_dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
t_dims :: [SubExp]
t_dims = (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
ExpHint -> AllocM GPU GPUMem ExpHint
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpHint -> AllocM GPU GPUMem ExpHint)
-> ExpHint -> AllocM GPU GPUMem ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint ([SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp]
space_dims [SubExp]
t_dims) Space
DefaultSpace
hint Type
t (ConcatReturns Certs
_ SplitStrided {} SubExp
w SubExp
_ VName
_) = do
ChunkMap
chunkmap <- (AllocEnv GPU GPUMem -> ChunkMap) -> AllocM GPU GPUMem ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv GPU GPUMem -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
let t_dims :: [SubExp]
t_dims = (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
ExpHint -> AllocM GPU GPUMem ExpHint
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpHint -> AllocM GPU GPUMem ExpHint)
-> ExpHint -> AllocM GPU GPUMem ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint ([SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp
w] [SubExp]
t_dims) Space
DefaultSpace
hint Prim {} (ConcatReturns Certs
_ SplitOrdering
SplitContiguous SubExp
w SubExp
elems_per_thread VName
_) = do
let ixfun_base :: IxFun (TPrimExp Int64 VName)
ixfun_base = Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
num_threads, SubExp -> TPrimExp Int64 VName
pe64 SubExp
elems_per_thread]
ixfun_tr :: IxFun (TPrimExp Int64 VName)
ixfun_tr = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
ixfun_base [Int
1, Int
0]
ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = IxFun (TPrimExp Int64 VName)
-> ShapeChange (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
ixfun_tr (ShapeChange (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName))
-> ShapeChange (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimChange (TPrimExp Int64 VName))
-> [SubExp] -> ShapeChange (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimNew (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> DimChange (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) [SubExp
w]
ExpHint -> AllocM GPU GPUMem ExpHint
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpHint -> AllocM GPU GPUMem ExpHint)
-> ExpHint -> AllocM GPU GPUMem ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun Space
DefaultSpace
hint Type
_ KernelResult
_ = ExpHint -> AllocM GPU GPUMem ExpHint
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpHint
NoHint
innermost :: [SubExp] -> [SubExp] -> IxFun
innermost :: [SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp]
space_dims [SubExp]
t_dims =
let r :: Int
r = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
t_dims
dims :: [SubExp]
dims = [SubExp]
space_dims [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
t_dims
perm :: [Int]
perm =
[[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims .. [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
[Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
dims_perm :: [SubExp]
dims_perm = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
ixfun_base :: IxFun (TPrimExp Int64 VName)
ixfun_base = Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims_perm
ixfun_rearranged :: IxFun (TPrimExp Int64 VName)
ixfun_rearranged = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
ixfun_base [Int]
perm_inv
in IxFun (TPrimExp Int64 VName)
ixfun_rearranged
semiStatic :: S.Set VName -> SubExp -> Bool
semiStatic :: Set VName -> SubExp -> Bool
semiStatic Set VName
_ Constant {} = Bool
True
semiStatic Set VName
consts (Var VName
v) = VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
consts
inGroupExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inGroupExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inGroupExpHints (Op (Inner (SegOp (SegMap _ space ts body))))
| (KernelResult -> Bool) -> [KernelResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelResult -> Bool
private ([KernelResult] -> Bool) -> [KernelResult] -> Bool
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body = do
Set VName
consts <- (AllocEnv GPU GPUMem -> Set VName) -> AllocM GPU GPUMem (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv GPU GPUMem -> Set VName
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts
[ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpHint] -> AllocM GPU GPUMem [ExpHint])
-> [ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall a b. (a -> b) -> a -> b
$ do
(Type
t, KernelResult
r) <- [Type] -> [KernelResult] -> [(Type, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ts ([KernelResult] -> [(Type, KernelResult)])
-> [KernelResult] -> [(Type, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
ExpHint -> [ExpHint]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpHint -> [ExpHint]) -> ExpHint -> [ExpHint]
forall a b. (a -> b) -> a -> b
$
if KernelResult -> Bool
private KernelResult
r Bool -> Bool -> Bool
&& (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Set VName -> SubExp -> Bool
semiStatic Set VName
consts) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
then
let seg_dims :: Shape (TPrimExp Int64 VName)
seg_dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
dims :: Shape (TPrimExp Int64 VName)
dims = Shape (TPrimExp Int64 VName)
seg_dims Shape (TPrimExp Int64 VName)
-> Shape (TPrimExp Int64 VName) -> Shape (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
nilSlice :: d -> DimIndex d
nilSlice d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
0
in IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint
( IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota Shape (TPrimExp Int64 VName)
dims) (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum Shape (TPrimExp Int64 VName)
dims ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> DimIndex d
nilSlice Shape (TPrimExp Int64 VName)
seg_dims
)
(Space -> ExpHint) -> Space -> ExpHint
forall a b. (a -> b) -> a -> b
$ [SubExp] -> PrimType -> Space
ScalarSpace (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
(PrimType -> Space) -> PrimType -> Space
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
else ExpHint
NoHint
where
private :: KernelResult -> Bool
private (Returns ResultManifest
ResultPrivate Certs
_ SubExp
_) = Bool
True
private KernelResult
_ = Bool
False
inGroupExpHints Exp GPUMem
e = [ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExpHint] -> AllocM GPU GPUMem [ExpHint])
-> [ExpHint] -> AllocM GPU GPUMem [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp GPUMem -> Int
forall rep. (RepTypes rep, TypedOp (Op rep)) => Exp rep -> Int
expExtTypeSize Exp GPUMem
e) ExpHint
NoHint
inThreadExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inThreadExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inThreadExpHints Exp GPUMem
e = do
Set VName
consts <- (AllocEnv GPU GPUMem -> Set VName) -> AllocM GPU GPUMem (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv GPU GPUMem -> Set VName
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts
(TypeBase ExtShape NoUniqueness -> AllocM GPU GPUMem ExpHint)
-> [TypeBase ExtShape NoUniqueness] -> AllocM GPU GPUMem [ExpHint]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set VName
-> TypeBase ExtShape NoUniqueness -> AllocM GPU GPUMem ExpHint
forall (f :: * -> *) u.
Applicative f =>
Set VName -> TypeBase ExtShape u -> f ExpHint
maybePrivate Set VName
consts) ([TypeBase ExtShape NoUniqueness] -> AllocM GPU GPUMem [ExpHint])
-> AllocM GPU GPUMem [TypeBase ExtShape NoUniqueness]
-> AllocM GPU GPUMem [ExpHint]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp GPUMem -> AllocM GPU GPUMem [TypeBase ExtShape NoUniqueness]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [TypeBase ExtShape NoUniqueness]
expExtType Exp GPUMem
e
where
maybePrivate :: Set VName -> TypeBase ExtShape u -> f ExpHint
maybePrivate Set VName
consts TypeBase ExtShape u
t
| Just (Array PrimType
pt Shape
shape u
_) <- TypeBase ExtShape u -> Maybe (TypeBase Shape u)
forall u. TypeBase ExtShape u -> Maybe (TypeBase Shape u)
hasStaticShape TypeBase ExtShape u
t,
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Set VName -> SubExp -> Bool
semiStatic Set VName
consts) ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape = do
let ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Shape (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
ExpHint -> f ExpHint
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpHint -> f ExpHint) -> ExpHint -> f ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun (Space -> ExpHint) -> Space -> ExpHint
forall a b. (a -> b) -> a -> b
$ [SubExp] -> PrimType -> Space
ScalarSpace (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) PrimType
pt
| Bool
otherwise =
ExpHint -> f ExpHint
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpHint
NoHint
explicitAllocations :: Pass GPU GPUMem
explicitAllocations :: Pass GPU GPUMem
explicitAllocations = (Op GPU -> AllocM GPU GPUMem (Op GPUMem))
-> (Exp GPUMem -> AllocM GPU GPUMem [ExpHint]) -> Pass GPU GPUMem
forall fromrep torep inner.
Allocable fromrep torep inner =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Op GPU -> AllocM GPU GPUMem (Op GPUMem)
HostOp GPU (SOAC GPU)
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
handleHostOp Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints
explicitAllocationsInStms ::
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU ->
m (Stms GPUMem)
explicitAllocationsInStms :: Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms = (Op GPU -> AllocM GPU GPUMem (Op GPUMem))
-> (Exp GPUMem -> AllocM GPU GPUMem [ExpHint])
-> Stms GPU
-> m (Stms GPUMem)
forall (m :: * -> *) torep fromrep inner.
(MonadFreshNames m, HasScope torep m,
Allocable fromrep torep inner) =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Op GPU -> AllocM GPU GPUMem (Op GPUMem)
HostOp GPU (SOAC GPU)
-> AllocM GPU GPUMem (MemOp (HostOp GPUMem ()))
handleHostOp Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints