{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.Pass.ExplicitAllocations.GPU
( explicitAllocations,
explicitAllocationsInStms,
)
where
import Control.Monad
import Data.Set qualified as S
import Futhark.IR.GPU
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.Pass.ExplicitAllocations
import Futhark.Pass.ExplicitAllocations.SegOp
instance SizeSubst (HostOp rep op) where
opIsConst :: HostOp rep op -> Bool
opIsConst (SizeOp GetSize {}) = Bool
True
opIsConst (SizeOp GetSizeMax {}) = Bool
True
opIsConst HostOp rep op
_ = Bool
False
allocAtLevel :: SegLevel -> AllocM GPU GPUMem a -> AllocM GPU GPUMem a
allocAtLevel :: forall a. SegLevel -> AllocM GPU GPUMem a -> AllocM GPU GPUMem a
allocAtLevel SegLevel
lvl = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \AllocEnv GPU GPUMem
env ->
AllocEnv GPU GPUMem
env
{ allocSpace :: Space
allocSpace = Space
space,
aggressiveReuse :: Bool
aggressiveReuse = Bool
True,
allocInOp :: Op GPU -> AllocM GPU GPUMem (Op GPUMem)
allocInOp = Maybe SegLevel
-> HostOp SOAC GPU
-> AllocM GPU GPUMem (MemOp (HostOp NoOp) GPUMem)
handleHostOp (forall a. a -> Maybe a
Just SegLevel
lvl)
}
where
space :: Space
space = case SegLevel
lvl of
SegGroup {} -> [Char] -> Space
Space [Char]
"local"
SegThread {} -> [Char] -> Space
Space [Char]
"device"
SegThreadInGroup {} -> [Char] -> Space
Space [Char]
"device"
handleSegOp ::
Maybe SegLevel ->
SegOp SegLevel GPU ->
AllocM GPU GPUMem (SegOp SegLevel GPUMem)
handleSegOp :: Maybe SegLevel
-> SegOp SegLevel GPU -> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
handleSegOp Maybe SegLevel
outer_lvl SegOp SegLevel GPU
op = do
SubExp
num_threads <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_threads"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< case Maybe KernelGrid
maybe_grid of
Just KernelGrid
grid ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(forall {k} (u :: k) e. Count u e -> e
unCount (KernelGrid -> Count NumGroups SubExp
gridNumGroups KernelGrid
grid))
(forall {k} (u :: k) e. Count u e -> e
unCount (KernelGrid -> Count GroupSize SubExp
gridGroupSize KernelGrid
grid))
Maybe KernelGrid
Nothing ->
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
(SegSpace -> [SubExp]
segSpaceDims forall a b. (a -> b) -> a -> b
$ forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPU
op)
forall a. SegLevel -> AllocM GPU GPUMem a -> AllocM GPU GPUMem a
allocAtLevel (forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM (SubExp -> SegOpMapper SegLevel GPU GPUMem (AllocM GPU GPUMem)
mapper SubExp
num_threads) SegOp SegLevel GPU
op
where
maybe_grid :: Maybe KernelGrid
maybe_grid =
case (Maybe SegLevel
outer_lvl, forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op) of
(Just (SegThread SegVirt
_ (Just KernelGrid
grid)), SegLevel
_) -> forall a. a -> Maybe a
Just KernelGrid
grid
(Just (SegGroup SegVirt
_ (Just KernelGrid
grid)), SegLevel
_) -> forall a. a -> Maybe a
Just KernelGrid
grid
(Maybe SegLevel
_, SegThread SegVirt
_ (Just KernelGrid
grid)) -> forall a. a -> Maybe a
Just KernelGrid
grid
(Maybe SegLevel
_, SegGroup SegVirt
_ (Just KernelGrid
grid)) -> forall a. a -> Maybe a
Just KernelGrid
grid
(Maybe SegLevel, SegLevel)
_ -> forall a. Maybe a
Nothing
scope :: Scope GPUMem
scope = forall rep. SegSpace -> Scope rep
scopeOfSegSpace forall a b. (a -> b) -> a -> b
$ forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPU
op
mapper :: SubExp -> SegOpMapper SegLevel GPU GPUMem (AllocM GPU GPUMem)
mapper SubExp
num_threads =
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpBody :: KernelBody GPU -> AllocM GPU GPUMem (KernelBody GPUMem)
mapOnSegOpBody =
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
KernelBody fromrep -> AllocM fromrep torep (KernelBody torep)
allocInKernelBody,
mapOnSegOpLambda :: Lambda GPU -> AllocM GPU GPUMem (Lambda GPUMem)
mapOnSegOpLambda =
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inThread
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
SubExp
-> SegSpace
-> Lambda fromrep
-> AllocM fromrep torep (Lambda torep)
allocInBinOpLambda SubExp
num_threads (forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPU
op)
}
f :: AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
f = case forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPU
op of
SegThread {} -> AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inThread
SegThreadInGroup {} -> AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inThread
SegGroup {} -> AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inGroup
inThread :: AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inThread AllocEnv GPU GPUMem
env = AllocEnv GPU GPUMem
env {envExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
envExpHints = Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inThreadExpHints}
inGroup :: AllocEnv GPU GPUMem -> AllocEnv GPU GPUMem
inGroup AllocEnv GPU GPUMem
env = AllocEnv GPU GPUMem
env {envExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
envExpHints = Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inGroupExpHints}
handleHostOp ::
Maybe SegLevel ->
HostOp SOAC GPU ->
AllocM GPU GPUMem (MemOp (HostOp NoOp) GPUMem)
handleHostOp :: Maybe SegLevel
-> HostOp SOAC GPU
-> AllocM GPU GPUMem (MemOp (HostOp NoOp) GPUMem)
handleHostOp Maybe SegLevel
_ (SizeOp SizeOp
op) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
handleHostOp Maybe SegLevel
_ (OtherOp SOAC GPU
op) =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot allocate memory in SOAC: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString SOAC GPU
op
handleHostOp Maybe SegLevel
outer_lvl (SegOp SegOp SegLevel GPU
op) =
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe SegLevel
-> SegOp SegLevel GPU -> AllocM GPU GPUMem (SegOp SegLevel GPUMem)
handleSegOp Maybe SegLevel
outer_lvl SegOp SegLevel GPU
op
handleHostOp Maybe SegLevel
_ (GPUBody [Type]
ts (Body BodyDec GPU
_ Stms GPU
stms Result
res)) =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms GPU
stms forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
kernelExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints (BasicOp (Manifest [Int]
perm VName
v)) = do
[SubExp]
dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
let perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
dims' :: [SubExp]
dims' = forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims') [Int]
perm_inv
forall (f :: * -> *) a. Applicative f => a -> f a
pure [IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"]
kernelExpHints (Op (Inner (SegOp (SegMap lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [Type]
ts KernelBody GPUMem
body)))) =
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel
-> SegSpace -> Type -> KernelResult -> AllocM GPU GPUMem ExpHint
mapResultHint SegLevel
lvl SegSpace
space) [Type]
ts forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
kernelExpHints (Op (Inner (SegOp (SegRed lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [SegBinOp GPUMem]
reds [Type]
ts KernelBody GPUMem
body)))) =
(forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) [KernelResult]
red_res <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel
-> SegSpace -> Type -> KernelResult -> AllocM GPU GPUMem ExpHint
mapResultHint SegLevel
lvl SegSpace
space) (forall a. Int -> [a] -> [a]
drop Int
num_reds [Type]
ts) [KernelResult]
map_res
where
num_reds :: Int
num_reds = forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
reds
([KernelResult]
red_res, [KernelResult]
map_res) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_reds forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
kernelExpHints Exp GPUMem
e = forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp GPUMem
e
mapResultHint ::
SegLevel ->
SegSpace ->
Type ->
KernelResult ->
AllocM GPU GPUMem ExpHint
mapResultHint :: SegLevel
-> SegSpace -> Type -> KernelResult -> AllocM GPU GPUMem ExpHint
mapResultHint SegLevel
_lvl SegSpace
space = Type -> KernelResult -> AllocM GPU GPUMem ExpHint
hint
where
coalesceReturnOfShape :: Int64 -> [SubExp] -> Bool
coalesceReturnOfShape Int64
_ [] = Bool
False
coalesceReturnOfShape Int64
bs [Constant (IntValue (Int64Value Int64
d))] = Int64
bs forall a. Num a => a -> a -> a
* Int64
d forall a. Ord a => a -> a -> Bool
> Int64
4
coalesceReturnOfShape Int64
_ [SubExp]
_ = Bool
True
hint :: Type -> KernelResult -> AllocM GPU GPUMem ExpHint
hint Type
t Returns {}
| Int64 -> [SubExp] -> Bool
coalesceReturnOfShape (forall a. Num a => PrimType -> a
primByteSize (forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t = do
let space_dims :: [SubExp]
space_dims = SegSpace -> [SubExp]
segSpaceDims SegSpace
space
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint ([SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp]
space_dims (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"device"
hint Type
_ KernelResult
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpHint
NoHint
innermost :: [SubExp] -> [SubExp] -> IxFun
innermost :: [SubExp] -> [SubExp] -> IxFun (TPrimExp Int64 VName)
innermost [SubExp]
space_dims [SubExp]
t_dims =
let r :: Int
r = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
t_dims
dims :: [SubExp]
dims = [SubExp]
space_dims forall a. [a] -> [a] -> [a]
++ [SubExp]
t_dims
perm :: [Int]
perm =
[forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims forall a. Num a => a -> a -> a
+ Int
r forall a. Num a => a -> a -> a
- Int
1]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
space_dims forall a. Num a => a -> a -> a
- Int
1]
perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
dims_perm :: [SubExp]
dims_perm = forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
ixfun_base :: IxFun (TPrimExp Int64 VName)
ixfun_base = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims_perm
ixfun_rearranged :: IxFun (TPrimExp Int64 VName)
ixfun_rearranged = forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
ixfun_base [Int]
perm_inv
in IxFun (TPrimExp Int64 VName)
ixfun_rearranged
semiStatic :: S.Set VName -> SubExp -> Bool
semiStatic :: Set VName -> SubExp -> Bool
semiStatic Set VName
_ Constant {} = Bool
True
semiStatic Set VName
consts (Var VName
v) = VName
v forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
consts
inGroupExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inGroupExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inGroupExpHints (Op (Inner (SegOp (SegMap SegLevel
_ SegSpace
space [Type]
ts KernelBody GPUMem
body))))
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelResult -> Bool
private forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body = do
Set VName
consts <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
(Type
t, KernelResult
r) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ts forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
if KernelResult -> Bool
private KernelResult
r Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Set VName -> SubExp -> Bool
semiStatic Set VName
consts) (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
then
let seg_dims :: [TPrimExp Int64 VName]
seg_dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
dims :: [TPrimExp Int64 VName]
dims = [TPrimExp Int64 VName]
seg_dims forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
nilSlice :: d -> DimIndex d
nilSlice d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
0
in IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint
( forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice (forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
dims) forall a b. (a -> b) -> a -> b
$
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
dims forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
nilSlice [TPrimExp Int64 VName]
seg_dims
)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> PrimType -> Space
ScalarSpace (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
forall a b. (a -> b) -> a -> b
$ forall shape u. TypeBase shape u -> PrimType
elemType Type
t
else ExpHint
NoHint
where
private :: KernelResult -> Bool
private (Returns ResultManifest
ResultPrivate Certs
_ SubExp
_) = Bool
True
private KernelResult
_ = Bool
False
inGroupExpHints Exp GPUMem
e = forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp GPUMem
e
inThreadExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inThreadExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
inThreadExpHints Exp GPUMem
e = do
Set VName
consts <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {f :: * -> *} {u}.
Applicative f =>
Set VName -> TypeBase ExtShape u -> f ExpHint
maybePrivate Set VName
consts) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp GPUMem
e
where
maybePrivate :: Set VName -> TypeBase ExtShape u -> f ExpHint
maybePrivate Set VName
consts TypeBase ExtShape u
t
| Just (Array PrimType
pt Shape
shape u
_) <- forall u. TypeBase ExtShape u -> Maybe (TypeBase Shape u)
hasStaticShape TypeBase ExtShape u
t,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Set VName -> SubExp -> Bool
semiStatic Set VName
consts) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape = do
let ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> Space -> ExpHint
Hint IxFun (TPrimExp Int64 VName)
ixfun forall a b. (a -> b) -> a -> b
$ [SubExp] -> PrimType -> Space
ScalarSpace (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) PrimType
pt
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ExpHint
NoHint
explicitAllocations :: Pass GPU GPUMem
explicitAllocations :: Pass GPU GPUMem
explicitAllocations = forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric ([Char] -> Space
Space [Char]
"device") (Maybe SegLevel
-> HostOp SOAC GPU
-> AllocM GPU GPUMem (MemOp (HostOp NoOp) GPUMem)
handleHostOp forall a. Maybe a
Nothing) Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints
explicitAllocationsInStms ::
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU ->
m (Stms GPUMem)
explicitAllocationsInStms :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms = forall (m :: * -> *) torep fromrep (inner :: * -> *).
(MonadFreshNames m, HasScope torep m,
Allocable fromrep torep inner) =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric ([Char] -> Space
Space [Char]
"device") (Maybe SegLevel
-> HostOp SOAC GPU
-> AllocM GPU GPUMem (MemOp (HostOp NoOp) GPUMem)
handleHostOp forall a. Maybe a
Nothing) Exp GPUMem -> AllocM GPU GPUMem [ExpHint]
kernelExpHints