{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExpandAllocations (expandAllocations) where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (find, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Error
import Futhark.IR
import qualified Futhark.IR.GPU.Simplify as GPU
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
[Char]
-> [Char]
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"expand allocations" [Char]
"Expand allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$
\(Prog Stms GPUMem
consts [FunDef GPUMem]
funs) -> do
Stms GPUMem
consts' <-
(VNameSource -> (Stms GPUMem, VNameSource)) -> PassM (Stms GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem))
-> (VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Either [Char] (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource)
forall a. Either [Char] a -> a
limitationOnLeft (Either [Char] (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource))
-> (VNameSource -> Either [Char] (Stms GPUMem, VNameSource))
-> VNameSource
-> (Stms GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either [Char]) (Stms GPUMem)
-> VNameSource -> Either [Char] (Stms GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either [Char]) (Stms GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
consts) Scope GPUMem
forall a. Monoid a => a
mempty)
Stms GPUMem -> [FunDef GPUMem] -> Prog GPUMem
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms GPUMem
consts' ([FunDef GPUMem] -> Prog GPUMem)
-> PassM [FunDef GPUMem] -> PassM (Prog GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunDef GPUMem -> PassM (FunDef GPUMem))
-> [FunDef GPUMem] -> PassM [FunDef GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem))
-> Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
consts') [FunDef GPUMem]
funs
type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))
limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either [Char] a -> a
limitationOnLeft = ([Char] -> a) -> (a -> a) -> Either [Char] a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> a
forall a. [Char] -> a
compilerLimitationS a -> a
forall a. a -> a
id
transformFunDef ::
Scope GPUMem ->
FunDef GPUMem ->
PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
Body GPUMem
body' <- (VNameSource -> (Body GPUMem, VNameSource)) -> PassM (Body GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem))
-> (VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Either [Char] (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource)
forall a. Either [Char] a -> a
limitationOnLeft (Either [Char] (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource))
-> (VNameSource -> Either [Char] (Body GPUMem, VNameSource))
-> VNameSource
-> (Body GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either [Char]) (Body GPUMem)
-> VNameSource -> Either [Char] (Body GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either [Char]) (Body GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
m Scope GPUMem
forall a. Monoid a => a
mempty)
SimpleOps GPUMem
-> SymbolTable (Wise GPUMem)
-> FunDef GPUMem
-> PassM (FunDef GPUMem)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep)
copyPropagateInFun
SimpleOps GPUMem
simpleGPUMem
(Scope (Wise GPUMem) -> SymbolTable (Wise GPUMem)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope GPUMem -> Scope (Wise GPUMem)
forall rep. Scope rep -> Scope (Wise rep)
addScopeWisdom Scope GPUMem
scope))
FunDef GPUMem
fundec {funDefBody :: Body GPUMem
funDefBody = Body GPUMem
body'}
where
m :: ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
m =
Scope GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
FunDef GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec (ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody (Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ FunDef GPUMem -> Body GPUMem
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef GPUMem
fundec
transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () (Stms GPUMem -> Result -> Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Result -> Body GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
stms ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Result -> Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Result
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
[LParam GPUMem]
-> Body GPUMem
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPUMem
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda [LParam GPUMem]
params
(Body GPUMem
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody Body GPUMem
body)
ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> ExpandM (Lambda GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
[TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret
transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
stms =
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ([Stms GPUMem] -> Stms GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) [Stms GPUMem]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> [Stm GPUMem]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) [Stms GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStm (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
transformStm :: Stm GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStm (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux (If SubExp
cond Body GPUMem
tbranch Body GPUMem
fbranch (IfDec [BranchType GPUMem]
ts IfSort
IfEquiv))) = do
Either [Char] (Body GPUMem)
tbranch' <- (Body GPUMem -> Either [Char] (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either [Char] (Body GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody Body GPUMem
tbranch) ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
-> ([Char]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either [Char] (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem)))
-> ([Char] -> Either [Char] (Body GPUMem))
-> [Char]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Either [Char] (Body GPUMem)
forall a b. a -> Either a b
Left)
Either [Char] (Body GPUMem)
fbranch' <- (Body GPUMem -> Either [Char] (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either [Char] (Body GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody Body GPUMem
fbranch) ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
-> ([Char]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either [Char] (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem)))
-> ([Char] -> Either [Char] (Body GPUMem))
-> [Char]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Either [Char] (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Either [Char] (Body GPUMem)
forall a b. a -> Either a b
Left)
case (Either [Char] (Body GPUMem)
tbranch', Either [Char] (Body GPUMem)
fbranch') of
(Left [Char]
_, Right Body GPUMem
fbranch'') ->
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
fbranch''
(Right Body GPUMem
tbranch'', Left [Char]
_) ->
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
tbranch''
(Right Body GPUMem
tbranch'', Right Body GPUMem
fbranch'') ->
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux (ExpT GPUMem -> Stm GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body GPUMem
-> Body GPUMem
-> IfDec (BranchType GPUMem)
-> ExpT GPUMem
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond Body GPUMem
tbranch'' Body GPUMem
fbranch'' ([BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType GPUMem]
[BranchTypeMem]
ts IfSort
IfEquiv)
(Left [Char]
e, Either [Char] (Body GPUMem)
_) ->
[Char]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError [Char]
e
where
bindRes :: PatElemT (LetDec rep) -> SubExp -> Stm rep
bindRes PatElemT (LetDec rep)
pe SubExp
se = Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec rep)
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
useBranch :: Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
b =
Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms Body GPUMem
b
Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ((PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> Stm GPUMem)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> Result
-> [Stm GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> Stm GPUMem
forall {rep}.
(ExpDec rep ~ ()) =>
PatElemT (LetDec rep) -> SubExp -> Stm rep
bindRes (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat) (Body GPUMem -> Result
forall rep. BodyT rep -> Result
bodyResult Body GPUMem
b))
transformStm (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux ExpT GPUMem
e) = do
(Stms GPUMem
bnds, ExpT GPUMem
e') <- ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem)
transformExp (ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (ExpT GPUMem)
-> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
-> ExpT GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
transform ExpT GPUMem
e
Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem
bnds Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux ExpT GPUMem
e')
where
transform :: Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
transform =
Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope GPUMem
-> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
mapOnBody = \Scope GPUMem
scope -> Scope GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> (Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody
}
transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
_, KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
(Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Stms GPUMem
alloc_stms,
Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
let reds' :: [SegBinOp GPUMem]
reds' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
(Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Stms GPUMem
alloc_stms,
Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
let scans' :: [SegBinOp GPUMem]
scans' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
(Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Stms GPUMem
alloc_stms,
Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams', KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
let ops' :: [HistOp GPUMem]
ops' = (HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem] -> [HistOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem
forall {rep} {rep}. HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
(Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Stms GPUMem
alloc_stms,
Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
where
lams :: [Lambda GPUMem]
lams = (HistOp GPUMem -> Lambda GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp :: Lambda rep
histOp = Lambda rep
lam}
transformExp (WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs Lambda GPUMem
lam) = do
Lambda GPUMem
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
([Stms GPUMem]
input_alloc_stms, [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs') <- [(Stms GPUMem,
(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
-> ([Stms GPUMem],
[(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms GPUMem,
(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
-> ([Stms GPUMem],
[(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
[(Stms GPUMem,
(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
([Stms GPUMem],
[(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem,
(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))))
-> [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
[(Stms GPUMem,
(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem,
(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))
forall {b} {b}.
(ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs
(Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat [Stms GPUMem]
input_alloc_stms,
[(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
-> Lambda GPUMem -> ExpT GPUMem
forall rep.
[(ShapeBase SubExp, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs' Lambda GPUMem
lam'
)
where
onInput :: (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
forall a. Maybe a
Nothing))
onInput (ShapeBase SubExp
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
let
lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> SubExp -> Count NumGroups SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SegVirt
SegNoVirt
(Lambda GPUMem
op_lam', Extraction
lam_allocs) =
(SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel
lvl, [TPrimExp Int64 VName
0]) Names
bound_outside Names
forall a. Monoid a => a
mempty Lambda GPUMem
op_lam
variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
(Extraction
variant_allocs, Extraction
invariant_allocs) = (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc Extraction
lam_allocs
case Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
((SegLevel, [TPrimExp Int64 VName])
_, SubExp
v, Space
_) : [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
_ ->
[Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ())
-> [Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall a b. (a -> b) -> a -> b
$
[Char]
"Cannot handle un-sliceable allocation size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SubExp -> [Char]
forall a. Pretty a => a -> [Char]
pretty SubExp
v
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nLikely cause: irregular nested operations inside accumulator update operator."
[] ->
()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
let num_is :: Int
num_is = ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
is :: [VName]
is = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take Int
num_is ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
(Stms GPUMem
alloc_stms, RebaseMap
alloc_offsets) <-
((SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations ((ShapeBase SubExp, [TPrimExp Int64 VName])
-> (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. a -> b -> a
const (ShapeBase SubExp
shape, (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is)) Extraction
invariant_allocs
Scope GPUMem
scope <- ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
let scope' :: Scope GPUMem
scope' = Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
([Char]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> ((Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> Either
[Char]
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
[Char]
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> Either
[Char]
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$
Scope GPUMem
-> RebaseMap
-> OffsetM
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> Either
[Char]
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> Either
[Char]
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> OffsetM
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> Either
[Char]
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
op_lam'' <- Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op_lam'
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> OffsetM
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, (ShapeBase SubExp
shape, b
arrs, (Lambda GPUMem, b) -> Maybe (Lambda GPUMem, b)
forall a. a -> Maybe a
Just (Lambda GPUMem
op_lam'', b
nes)))
transformExp ExpT GPUMem
e =
(Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
forall a. Monoid a => a
mempty, ExpT GPUMem
e)
transformScanRed ::
SegLevel ->
SegSpace ->
[Lambda GPUMem] ->
KernelBody GPUMem ->
ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
let user :: (SegLevel, [TPrimExp Int64 VName])
user = (SegLevel
lvl, [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
(KernelBody GPUMem
kbody', Extraction
kbody_allocs) =
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_in_kernel KernelBody GPUMem
kbody
([Lambda GPUMem]
ops', [Extraction]
ops_allocs) = [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction]))
-> [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda GPUMem -> (Lambda GPUMem, Extraction))
-> [Lambda GPUMem] -> [(Lambda GPUMem, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map ((SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
forall a. Monoid a => a
mempty) [Lambda GPUMem]
ops
variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
(Extraction
variant_allocs, Extraction
invariant_allocs) =
(((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc (Extraction -> (Extraction, Extraction))
-> Extraction -> (Extraction, Extraction)
forall a b. (a -> b) -> a -> b
$ Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
badVariant :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_kernel
badVariant ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
case (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ([((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall a b. (a -> b) -> a -> b
$ Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
Just ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v ->
[Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ())
-> [Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall a b. (a -> b) -> a -> b
$
[Char]
"Cannot handle un-sliceable allocation size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> [Char]
forall a. Pretty a => a -> [Char]
pretty ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v
[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nLikely cause: irregular nested operations inside parallel constructs."
Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
Nothing ->
()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
case SegLevel
lvl of
SegGroup {}
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
[Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError [Char]
"Cannot handle invariant allocations in SegGroup."
SegLevel
_ ->
()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem
-> KernelBody GPUMem
-> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' ((Stms GPUMem
-> KernelBody GPUMem
-> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> (Stms GPUMem
-> KernelBody GPUMem
-> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall a b. (a -> b) -> a -> b
$ \Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
[Lambda GPUMem]
ops'' <- [Lambda GPUMem]
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' ((Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem])
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
Scope GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op'
(Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
-> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
ops'', KernelBody GPUMem
kbody''))
where
bound_in_kernel :: Names
bound_in_kernel =
[VName] -> Names
namesFromList (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList ([VName] -> Names)
-> (KernelBody GPUMem -> [VName]) -> KernelBody GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope GPUMem -> [VName])
-> (KernelBody GPUMem -> Scope GPUMem)
-> KernelBody GPUMem
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Stms GPUMem -> Scope GPUMem)
-> (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem
-> Scope GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms
allocsForBody ::
Extraction ->
Extraction ->
SegLevel ->
SegSpace ->
KernelBody GPUMem ->
(Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
(RebaseMap
alloc_offsets, Stms GPUMem
alloc_stms) <-
SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
SegLevel
lvl
SegSpace
space
(KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody')
Extraction
variant_allocs
Extraction
invariant_allocs
Scope GPUMem
scope <- ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
let scope' :: Scope GPUMem
scope' = SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
([Char] -> ExpandM b)
-> (b -> ExpandM b) -> Either [Char] b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> ExpandM b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b -> ExpandM b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either [Char] b -> ExpandM b) -> Either [Char] b -> ExpandM b
forall a b. (a -> b) -> a -> b
$
Scope GPUMem -> RebaseMap -> OffsetM b -> Either [Char] b
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM b -> Either [Char] b) -> OffsetM b -> Either [Char] b
forall a b. (a -> b) -> a -> b
$ do
KernelBody GPUMem
kbody'' <- KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody'
Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m Stms GPUMem
alloc_stms KernelBody GPUMem
kbody''
memoryRequirements ::
SegLevel ->
SegSpace ->
Stms GPUMem ->
Extraction ->
Extraction ->
ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements SegLevel
lvl SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
(SubExp
num_threads, Stms GPUMem
num_threads_stms) <-
Binder GPUMem SubExp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SubExp, Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPUMem SubExp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SubExp, Stms GPUMem))
-> (BasicOp -> Binder GPUMem SubExp)
-> BasicOp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SubExp, Stms GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> Binder GPUMem SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_threads" (ExpT GPUMem -> Binder GPUMem SubExp)
-> (BasicOp -> ExpT GPUMem) -> BasicOp -> Binder GPUMem SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT GPUMem
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SubExp, Stms GPUMem))
-> BasicOp
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SubExp, Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
(Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
(Stms GPUMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations
SubExp
num_threads
(SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
(SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
Extraction
invariant_allocs
(Stms GPUMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations
SubExp
num_threads
SegSpace
space
Stms GPUMem
kstms
Extraction
variant_allocs
(RebaseMap, Stms GPUMem) -> ExpandM (RebaseMap, Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
( RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
Stms GPUMem
num_threads_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
invariant_alloc_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
variant_alloc_stms
)
type User = (SegLevel, [TPrimExp Int64 VName])
type = M.Map VName (User, SubExp, Space)
extractKernelBodyAllocations ::
User ->
Names ->
Names ->
KernelBody GPUMem ->
( KernelBody GPUMem,
Extraction
)
extractKernelBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel =
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (KernelBody GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms ((Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem -> (KernelBody GPUMem, Extraction))
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
\Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}
extractBodyAllocations ::
User ->
Names ->
Names ->
Body GPUMem ->
(Body GPUMem, Extraction)
extractBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel =
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (Body GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms ((Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem -> (Body GPUMem, Extraction))
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
\Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Stms GPUMem
stms}
extractLambdaAllocations ::
User ->
Names ->
Names ->
Lambda GPUMem ->
(Lambda GPUMem, Extraction)
(SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam = (Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body'}, Extraction
allocs)
where
(Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Body GPUMem -> (Body GPUMem, Extraction))
-> Body GPUMem -> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam
extractGenericBodyAllocations ::
User ->
Names ->
Names ->
(body -> Stms GPUMem) ->
(Stms GPUMem -> body -> body) ->
body ->
( body,
Extraction
)
extractGenericBodyAllocations :: forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Names
forall rep. Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
([Stm GPUMem]
stms, Extraction
allocs) =
Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction))
-> Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall a b. (a -> b) -> a -> b
$
([Maybe (Stm GPUMem)] -> [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm GPUMem)] -> [Stm GPUMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$
(Stm GPUMem -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel') ([Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)])
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall a b. (a -> b) -> a -> b
$
Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ body -> Stms GPUMem
get_stms body
body
in (Stms GPUMem -> body -> body
set_stms ([Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)
expandable, notScalar :: Space -> Bool
expandable :: Space -> Bool
expandable (Space [Char]
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True
extractStmAllocations ::
User ->
Names ->
Names ->
Stm GPUMem ->
Writer Extraction (Maybe (Stm GPUMem))
(SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Let (Pattern [] [PatElemT (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
size Space
space)))
| Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName
-> ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
-> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
patElem) ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
size, Space
space)
Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm GPUMem)
forall a. Maybe a
Nothing
where
expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
expandableSize Constant {} = Bool
True
boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
boundInKernel Constant {} = Bool
False
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
ExpT GPUMem
e <- Mapper GPUMem GPUMem (WriterT Extraction Identity)
-> ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user) (ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem))
-> ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> ExpT GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm
Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Maybe (Stm GPUMem)
forall a. a -> Maybe a
Just (Stm GPUMem -> Maybe (Stm GPUMem))
-> Stm GPUMem -> Maybe (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = ExpT GPUMem
e}
where
expMapper :: (SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user' =
Mapper GPUMem GPUMem (WriterT Extraction Identity)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope GPUMem
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
mapOnBody = (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. a -> b -> a
const ((Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem))
-> (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user',
mapOnOp :: Op GPUMem -> WriterT Extraction Identity (Op GPUMem)
mapOnOp = (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel, [TPrimExp Int64 VName])
user'
}
onBody :: (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' Body GPUMem
body = do
let (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Body GPUMem
body'
onOp :: (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem ()))
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
-> SegOp SegLevel GPUMem
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM ((SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user'') SegOp SegLevel GPUMem
op
where
user'' :: (SegLevel, [TPrimExp Int64 VName])
user'' =
(SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
onOp (SegLevel, [TPrimExp Int64 VName])
_ MemOp (HostOp GPUMem ())
op = MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp GPUMem ())
op
opMapper :: (SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user' =
SegOpMapper SegLevel Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
mapOnSegOpLambda = (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user',
mapOnSegOpBody :: KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
mapOnSegOpBody = (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user'
}
onKernelBody :: (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user' KernelBody GPUMem
body = do
let (KernelBody GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody GPUMem
body'
onLambda :: (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user' Lambda GPUMem
lam = do
Body GPUMem
body <- (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam
Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}
genericExpandedInvariantAllocations ::
(User -> (Shape, [TPrimExp Int64 VName])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: ((SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
([RebaseMap]
rebases, Stms GPUMem
alloc_stms) <- Binder GPUMem [RebaseMap]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
([RebaseMap], Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPUMem [RebaseMap]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
([RebaseMap], Stms GPUMem))
-> Binder GPUMem [RebaseMap]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
([RebaseMap], Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ ((VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BinderT GPUMem (State VNameSource) RebaseMap)
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Binder GPUMem [RebaseMap]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BinderT GPUMem (State VNameSource) RebaseMap
expand ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Binder GPUMem [RebaseMap])
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Binder GPUMem [RebaseMap]
forall a b. (a -> b) -> a -> b
$ Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs
(Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
where
expand :: (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BinderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
per_thread_size, Space
space)) = do
let num_users :: ShapeBase SubExp
num_users = (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a, b) -> a
fst ((ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp)
-> (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
VName
total_size <-
[Char]
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> BinderT GPUMem (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"total_size" (ExpT GPUMem -> BinderT GPUMem (State VNameSource) VName)
-> ([TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) (ExpT GPUMem))
-> [TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName
-> BinderT GPUMem (State VNameSource) (ExpT GPUMem)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName
-> BinderT GPUMem (State VNameSource) (ExpT GPUMem))
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) (ExpT GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) VName)
-> [TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> Result
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
num_users)
Pattern (Rep (BinderT GPUMem (State VNameSource)))
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> BinderT GPUMem (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern (Rep (BinderT GPUMem (State VNameSource)))
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (Exp (Rep (BinderT GPUMem (State VNameSource)))
-> BinderT GPUMem (State VNameSource) ())
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> BinderT GPUMem (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
RebaseMap -> BinderT GPUMem (State VNameSource) RebaseMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RebaseMap -> BinderT GPUMem (State VNameSource) RebaseMap)
-> RebaseMap -> BinderT GPUMem (State VNameSource) RebaseMap
forall a b. (a -> b) -> a -> b
$ VName
-> (([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase (SegLevel, [TPrimExp Int64 VName])
user
untouched :: d -> DimIndex d
untouched d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
newBase :: (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThread {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
num_dims :: Int
num_dims = [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape
perm :: [Int]
perm = [Int
num_dims .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
users_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName]
old_shape [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> Result
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape))
permuted_ixfun :: IxFun (TPrimExp Int64 VName)
permuted_ixfun = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
root_ixfun [Int]
perm
offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
permuted_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
in IxFun (TPrimExp Int64 VName)
offset_ixfun
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegGroup {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> Result
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
old_shape
offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
(TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
in IxFun (TPrimExp Int64 VName)
offset_ixfun
expandedInvariantAllocations ::
SubExp ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
Extraction ->
ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_groups) (Count SubExp
group_size) =
((SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers
where
getNumUsers :: (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) = (Result -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) = (Result -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups, SubExp
group_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
getNumUsers (SegGroup {}, [TPrimExp Int64 VName
gid]) = (Result -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups], [TPrimExp Int64 VName
gid])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user = [Char] -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a. HasCallStack => [Char] -> a
error ([Char] -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> [Char] -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. (a -> b) -> a -> b
$ [Char]
"getNumUsers: unhandled " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (SegLevel, [TPrimExp Int64 VName]) -> [Char]
forall a. Show a => a -> [Char]
show (SegLevel, [TPrimExp Int64 VName])
user
expandedVariantAllocations ::
SubExp ->
SegSpace ->
Stms GPUMem ->
Extraction ->
ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
| Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
variant_sizes :: Result
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks
(Stms GPU
slice_stms, [VName]
offsets, [VName]
size_sums) <-
SubExp
-> Result
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
variant_sizes SegSpace
kspace Stms GPUMem
kstms
(SymbolTable (Wise GPUMem)
_, Stms GPUMem
slice_stms_tmp) <-
Stms GPUMem
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SymbolTable (Wise GPUMem), Stms GPUMem)
forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (SymbolTable (Wise GPUMem), Stms GPUMem)
simplifyStms (Stms GPUMem
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SymbolTable (Wise GPUMem), Stms GPUMem))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(SymbolTable (Wise GPUMem), Stms GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms Stms GPU
slice_stms
Stms GPUMem
slice_stms' <- Stms GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
slice_stms_tmp
let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
[[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
([(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
[(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
(((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
[(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]
([Stm GPUMem]
alloc_bnds, [RebaseMap]
rebases) <- [(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap]))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
[(Stm GPUMem, RebaseMap)]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
([Stm GPUMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, SubExp, Space))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stm GPUMem, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
[(Stm GPUMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stm GPUMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'
(Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
slice_stms' Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
where
expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
let allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
(Stm GPUMem, RebaseMap)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Stm GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT GPUMem -> Stm GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
VName
-> (([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
offset
)
num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
gtid :: TPrimExp Int64 VName
gtid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace
newBase :: SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
size_per_thread ([TPrimExp Int64 VName]
old_shape, PrimType
pt) =
let elems_per_thread :: TPrimExp Int64 VName
elems_per_thread =
SubExp -> TPrimExp Int64 VName
pe64 SubExp
size_per_thread TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName
elems_per_thread, TPrimExp Int64 VName
num_threads']
offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice
IxFun (TPrimExp Int64 VName)
root_ixfun
[ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
num_threads' TPrimExp Int64 VName
1,
TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid
]
shapechange :: [DimChange (TPrimExp Int64 VName)]
shapechange =
if [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
then (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimCoercion [TPrimExp Int64 VName]
old_shape
else (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimNew [TPrimExp Int64 VName]
old_shape
in IxFun (TPrimExp Int64 VName)
-> [DimChange (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
offset_ixfun [DimChange (TPrimExp Int64 VName)]
shapechange
type RebaseMap = M.Map VName (([TPrimExp Int64 VName], PrimType) -> IxFun)
newtype OffsetM a
= OffsetM
( ReaderT
(Scope GPUMem)
(ReaderT RebaseMap (Either String))
a
)
deriving
( Functor OffsetM
Functor OffsetM
-> (forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
(a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: forall a. a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
Applicative,
(forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
Applicative OffsetM
Applicative OffsetM
-> (forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
Monad,
HasScope GPUMem,
LocalScope GPUMem,
MonadError String
)
runOffsetM :: Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope GPUMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
m) =
ReaderT RebaseMap (Either [Char]) a -> RebaseMap -> Either [Char] a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
-> Scope GPUMem -> ReaderT RebaseMap (Either [Char]) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
m Scope GPUMem
scope) RebaseMap
offsets
askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = ReaderT
(Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
-> OffsetM RebaseMap
forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
-> OffsetM a
OffsetM (ReaderT
(Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
-> OffsetM RebaseMap)
-> ReaderT
(Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
-> OffsetM RebaseMap
forall a b. (a -> b) -> a -> b
$ ReaderT RebaseMap (Either [Char]) RebaseMap
-> ReaderT
(Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT RebaseMap (Either [Char]) RebaseMap
forall r (m :: * -> *). MonadReader r m => m r
ask
lookupNewBase :: VName -> ([TPrimExp Int64 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
name ([TPrimExp Int64 VName], PrimType)
x = do
RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName))))
-> Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall a b. (a -> b) -> a -> b
$ ((([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int64 VName], PrimType)
x) ((([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> IxFun (TPrimExp Int64 VName))
-> Maybe
(([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
-> Maybe (IxFun (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
(([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody = do
Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
Stms GPUMem
stms' <-
[Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
(\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
Scope GPUMem
scope
(Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms'}
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) = do
Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
Stms GPUMem
stms' <-
[Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
(\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
Scope GPUMem
scope
(Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
Body GPUMem -> OffsetM (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec GPUMem
dec Stms GPUMem
stms' Result
res
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e) = do
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pattern GPUMem -> OffsetM (Pattern GPUMem)
offsetMemoryInPattern Pattern GPUMem
pat
ExpT GPUMem
e' <- Scope GPUMem -> OffsetM (ExpT GPUMem) -> OffsetM (ExpT GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (PatternT (MemInfo SubExp NoUniqueness MemBind) -> Scope GPUMem
forall rep dec. (LetDec rep ~ dec) => PatternT dec -> Scope rep
scopeOfPattern PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') (OffsetM (ExpT GPUMem) -> OffsetM (ExpT GPUMem))
-> OffsetM (ExpT GPUMem) -> OffsetM (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ ExpT GPUMem -> OffsetM (ExpT GPUMem)
offsetMemoryInExp ExpT GPUMem
e
Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
[ExpReturns]
rts <- ReaderT (Scope GPUMem) OffsetM [ExpReturns]
-> Scope GPUMem -> OffsetM [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ExpT GPUMem -> ReaderT (Scope GPUMem) OffsetM [ExpReturns]
forall (m :: * -> *) rep.
(Monad m, LocalScope rep m, Mem rep) =>
Exp rep -> m [ExpReturns]
expReturns ExpT GPUMem
e') Scope GPUMem
scope
let pat'' :: PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' =
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern
(PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat')
((PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts)
stm :: Stm GPUMem
stm = Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e'
let scope' :: Scope GPUMem
scope' = Stm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stm GPUMem
stm Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
(Scope GPUMem, Stm GPUMem) -> OffsetM (Scope GPUMem, Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scope GPUMem
scope', Stm GPUMem
stm)
where
pick ::
PatElemT (MemInfo SubExp NoUniqueness MemBind) ->
ExpReturns ->
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick :: PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick
(PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
(MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
| Just IxFun (TPrimExp Int64 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun ExtIxFun
extixfun =
VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun))
pick PatElemT (MemInfo SubExp NoUniqueness MemBind)
p ExpReturns
_ = PatElemT (MemInfo SubExp NoUniqueness MemBind)
p
instantiateIxFun :: ExtIxFun -> Maybe IxFun
instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst)
where
inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
inst (Free a
x) = a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
offsetMemoryInPattern :: Pattern GPUMem -> OffsetM (Pattern GPUMem)
offsetMemoryInPattern :: Pattern GPUMem -> OffsetM (Pattern GPUMem)
offsetMemoryInPattern (Pattern [PatElemT (LetDec GPUMem)]
ctx [PatElemT (LetDec GPUMem)]
vals) = do
(PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)] -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ()
forall {dec} {m :: * -> *}.
(Typed dec, MonadError [Char] m) =>
PatElemT dec -> m ()
inspectCtx [PatElemT (LetDec GPUMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec GPUMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall {u}.
PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal [PatElemT (LetDec GPUMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
vals
where
inspectVal :: PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal PatElemT (MemBound u)
patElem = do
MemBound u
new_dec <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ PatElemT (MemBound u) -> MemBound u
forall dec. PatElemT dec -> dec
patElemDec PatElemT (MemBound u)
patElem
PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT (MemBound u)
patElem {patElemDec :: MemBound u
patElemDec = MemBound u
new_dec}
inspectCtx :: PatElemT dec -> m ()
inspectCtx PatElemT dec
patElem
| Mem Space
space <- PatElemT dec -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElemT dec
patElem,
Space -> Bool
expandable Space
space =
[Char] -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char] -> m ()) -> [Char] -> m ()
forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords
[ [Char]
"Cannot deal with existential memory block",
VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem),
[Char]
"when expanding inside kernels."
]
| Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
MemBound u
fparam' <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ Param (MemBound u) -> MemBound u
forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
Param (MemBound u) -> OffsetM (Param (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}
offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) = do
Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun, PrimType
pt)
MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$
MemBound u -> Maybe (MemBound u) -> MemBound u
forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary (Maybe (MemBound u) -> MemBound u)
-> Maybe (MemBound u) -> MemBound u
forall a b. (a -> b) -> a -> b
$ do
IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
MemBound u -> Maybe (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> Maybe (MemBound u))
-> MemBound u -> Maybe (MemBound u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int64 VName)
new_base' IxFun (TPrimExp Int64 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return MemBound u
summary
offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
| Just IxFun (TPrimExp Int64 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
isStaticIxFun ExtIxFun
ixfun = do
Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun', PrimType
pt)
BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> OffsetM BranchTypeMem)
-> BranchTypeMem -> OffsetM BranchTypeMem
forall a b. (a -> b) -> a -> b
$
BranchTypeMem -> Maybe BranchTypeMem -> BranchTypeMem
forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br (Maybe BranchTypeMem -> BranchTypeMem)
-> Maybe BranchTypeMem -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$ do
IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
BranchTypeMem -> Maybe BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> Maybe BranchTypeMem)
-> BranchTypeMem -> Maybe BranchTypeMem
forall a b. (a -> b) -> a -> b
$
PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem) -> MemReturn -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
ExtIxFun -> ExtIxFun -> ExtIxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun (TPrimExp Int64 VName) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun (TPrimExp Int64 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return BranchTypeMem
br
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
lam = Lambda GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ do
Body GPUMem
body <- Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam
Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}
offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: ExpT GPUMem -> OffsetM (ExpT GPUMem)
offsetMemoryInExp (DoLoop [(FParam GPUMem, SubExp)]
ctx [(FParam GPUMem, SubExp)]
val LoopForm GPUMem
form Body GPUMem
body) = do
let ([Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams, Result
ctxinit) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx
([Param (MemInfo SubExp Uniqueness MemBind)]
valparams, Result
valinit) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val
[Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' <- (Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams' <- (Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
valparams
Body GPUMem
body' <- Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (MemInfo SubExp Uniqueness MemBind)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> LoopForm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPUMem
form) (Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody Body GPUMem
body)
ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT GPUMem -> OffsetM (ExpT GPUMem))
-> ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ [(FParam GPUMem, SubExp)]
-> [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem
-> Body GPUMem
-> ExpT GPUMem
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' Result
ctxinit) ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' Result
valinit) LoopForm GPUMem
form Body GPUMem
body'
offsetMemoryInExp ExpT GPUMem
e = Mapper GPUMem GPUMem OffsetM
-> ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse ExpT GPUMem
e
where
recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
Mapper GPUMem GPUMem OffsetM
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope GPUMem -> Body GPUMem -> OffsetM (Body GPUMem)
mapOnBody = \Scope GPUMem
bscope -> Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem
-> OffsetM (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody,
mapOnBranchType :: BranchType GPUMem -> OffsetM (BranchType GPUMem)
mapOnBranchType = BranchType GPUMem -> OffsetM (BranchType GPUMem)
BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
mapOnOp :: Op GPUMem -> OffsetM (Op GPUMem)
mapOnOp = Op GPUMem -> OffsetM (Op GPUMem)
forall {op}.
MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp
}
onOp :: MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
HostOp GPUMem op -> MemOp (HostOp GPUMem op)
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem op -> MemOp (HostOp GPUMem op))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem op)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp
(SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem op))
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (MemOp (HostOp GPUMem op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (SegOpMapper SegLevel GPUMem GPUMem OffsetM
-> SegOp SegLevel GPUMem -> OffsetM (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem OffsetM
forall {lvl}. SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
where
segOpMapper :: SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper =
SegOpMapper lvl Any Any OffsetM
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody,
mapOnSegOpLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda
}
onOp MemOp (HostOp GPUMem op)
op = MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp GPUMem op)
op
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either [Char] (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
False
where
unAllocBody :: Body GPUMem -> Either [Char] (BodyT GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
BodyDec GPU -> Stms GPU -> Result -> BodyT GPU
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> Result -> BodyT GPU)
-> Either [Char] (Stms GPU) -> Either [Char] (Result -> BodyT GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either [Char] (Result -> BodyT GPU)
-> Either [Char] Result -> Either [Char] (BodyT GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either [Char] Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
unAllocKernelBody :: KernelBody GPUMem -> Either [Char] (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> [KernelResult] -> KernelBody GPU)
-> Either [Char] (Stms GPU)
-> Either [Char] ([KernelResult] -> KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either [Char] ([KernelResult] -> KernelBody GPU)
-> Either [Char] [KernelResult] -> Either [Char] (KernelBody GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either [Char] [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
unAllocStms :: Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
nested =
([Maybe (Stm GPU)] -> Stms GPU)
-> Either [Char] [Maybe (Stm GPU)] -> Either [Char] (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU)
-> ([Maybe (Stm GPU)] -> [Stm GPU])
-> [Maybe (Stm GPU)]
-> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Stm GPU)] -> [Stm GPU]
forall a. [Maybe a] -> [a]
catMaybes) (Either [Char] [Maybe (Stm GPU)] -> Either [Char] (Stms GPU))
-> (Stms GPUMem -> Either [Char] [Maybe (Stm GPU)])
-> Stms GPUMem
-> Either [Char] (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPUMem -> Either [Char] (Maybe (Stm GPU)))
-> [Stm GPUMem] -> Either [Char] [Maybe (Stm GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm GPUMem -> Either [Char] (Maybe (Stm GPU))
unAllocStm Bool
nested) ([Stm GPUMem] -> Either [Char] [Maybe (Stm GPU)])
-> (Stms GPUMem -> [Stm GPUMem])
-> Stms GPUMem
-> Either [Char] [Maybe (Stm GPU)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList
unAllocStm :: Bool -> Stm GPUMem -> Either [Char] (Maybe (Stm GPU))
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pattern GPUMem
_ StmAux (ExpDec GPUMem)
_ (Op Alloc {}))
| Bool
nested = [Char] -> Either [Char] (Maybe (Stm GPU))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char] -> Either [Char] (Maybe (Stm GPU)))
-> [Char] -> Either [Char] (Maybe (Stm GPU))
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle nested allocation: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Stm GPUMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty Stm GPUMem
stm
| Bool
otherwise = Maybe (Stm GPU) -> Either [Char] (Maybe (Stm GPU))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm GPU)
forall a. Maybe a
Nothing
unAllocStm Bool
_ (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e) =
Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just (Stm GPU -> Maybe (Stm GPU))
-> Either [Char] (Stm GPU) -> Either [Char] (Maybe (Stm GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> ExpT GPU -> Stm GPU
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> ExpT GPU -> Stm GPU)
-> Either
[Char] (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either [Char] (StmAux () -> ExpT GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatternT (MemInfo SubExp NoUniqueness MemBind)
-> Either
[Char] (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
Pretty (ShapeBase d)) =>
PatternT (MemInfo d u ret)
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat Either [Char] (StmAux () -> ExpT GPU -> Stm GPU)
-> Either [Char] (StmAux ()) -> Either [Char] (ExpT GPU -> Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either [Char] (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec Either [Char] (ExpT GPU -> Stm GPU)
-> Either [Char] (ExpT GPU) -> Either [Char] (Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper GPUMem GPU (Either [Char])
-> ExpT GPUMem -> Either [Char] (ExpT GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either [Char])
unAlloc' ExpT GPUMem
e)
unAllocLambda :: Lambda GPUMem -> Either [Char] (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
[LParam GPU]
-> BodyT GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall {d} {u} {ret}.
[Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (BodyT GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
-> Either [Char] (BodyT GPU)
-> Either
[Char] ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either [Char] (BodyT GPU)
unAllocBody Body GPUMem
body Either
[Char] ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
-> Either [Char] [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either [Char] (Lambda GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either [Char] [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret
unParams :: [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams = (Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)])
-> (Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)]
-> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem
unAllocPattern :: PatternT (MemInfo d u ret)
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern pat :: PatternT (MemInfo d u ret)
pat@(Pattern [PatElemT (MemInfo d u ret)]
ctx [PatElemT (MemInfo d u ret)]
val) =
[PatElemT (TypeBase (ShapeBase d) u)]
-> [PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern ([PatElemT (TypeBase (ShapeBase d) u)]
-> [PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u))
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> Either
[Char]
([PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
ctx)
Either
[Char]
([PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u))
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
val)
where
bad :: Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad = [Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. a -> Either a b
Left ([Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> [Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory in pattern " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo d u ret) -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (MemInfo d u ret)
pat
unAllocOp :: MemOp (HostOp GPUMem ()) -> Either [Char] (HostOp GPU (SOAC GPU))
unAllocOp Alloc {} = [Char] -> Either [Char] (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left [Char]
"unAllocOp: unhandled Alloc"
unAllocOp (Inner OtherOp {}) = [Char] -> Either [Char] (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left [Char]
"unAllocOp: unhandled OtherOp"
unAllocOp (Inner (SizeOp SizeOp
op)) =
HostOp GPU (SOAC GPU) -> Either [Char] (HostOp GPU (SOAC GPU))
forall (m :: * -> *) a. Monad m => a -> m a
return (HostOp GPU (SOAC GPU) -> Either [Char] (HostOp GPU (SOAC GPU)))
-> HostOp GPU (SOAC GPU) -> Either [Char] (HostOp GPU (SOAC GPU))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> Either [Char] (SegOp SegLevel GPU)
-> Either [Char] (HostOp GPU (SOAC GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPU (Either [Char])
-> SegOp SegLevel GPUMem -> Either [Char] (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either [Char])
mapper SegOp SegLevel GPUMem
op
where
mapper :: SegOpMapper SegLevel GPUMem GPU (Either [Char])
mapper =
SegOpMapper SegLevel Any Any (Either [Char])
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda GPUMem -> Either [Char] (Lambda GPU)
mapOnSegOpLambda = Lambda GPUMem -> Either [Char] (Lambda GPU)
unAllocLambda,
mapOnSegOpBody :: KernelBody GPUMem -> Either [Char] (KernelBody GPU)
mapOnSegOpBody = KernelBody GPUMem -> Either [Char] (KernelBody GPU)
unAllocKernelBody
}
unParam :: t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam t (MemInfo d u ret)
p = Either [Char] (t (TypeBase (ShapeBase d) u))
-> (t (TypeBase (ShapeBase d) u)
-> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] (t (TypeBase (ShapeBase d) u))
bad t (TypeBase (ShapeBase d) u)
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (t (TypeBase (ShapeBase d) u))
-> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> t (MemInfo d u ret) -> Maybe (t (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem t (MemInfo d u ret)
p
where
bad :: Either [Char] (t (TypeBase (ShapeBase d) u))
bad = [Char] -> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> [Char] -> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory-typed parameter '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ t (MemInfo d u ret) -> [Char]
forall a. Pretty a => a -> [Char]
pretty t (MemInfo d u ret)
p [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"'"
unT :: MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT MemInfo d u ret
t = Either [Char] (TypeBase (ShapeBase d) u)
-> (TypeBase (ShapeBase d) u
-> Either [Char] (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either [Char] (TypeBase (ShapeBase d) u)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] (TypeBase (ShapeBase d) u)
bad TypeBase (ShapeBase d) u
-> Either [Char] (TypeBase (ShapeBase d) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeBase (ShapeBase d) u)
-> Either [Char] (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem MemInfo d u ret
t
where
bad :: Either [Char] (TypeBase (ShapeBase d) u)
bad = [Char] -> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (TypeBase (ShapeBase d) u))
-> [Char] -> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory type '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ MemInfo d u ret -> [Char]
forall a. Pretty a => a -> [Char]
pretty MemInfo d u ret
t [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"'"
unAlloc' :: Mapper GPUMem GPU (Either [Char])
unAlloc' =
Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
{ mapOnBody :: Scope GPU -> Body GPUMem -> Either [Char] (BodyT GPU)
mapOnBody = (Body GPUMem -> Either [Char] (BodyT GPU))
-> Scope GPU -> Body GPUMem -> Either [Char] (BodyT GPU)
forall a b. a -> b -> a
const Body GPUMem -> Either [Char] (BodyT GPU)
unAllocBody,
mapOnRetType :: RetType GPUMem -> Either [Char] (RetType GPU)
mapOnRetType = RetType GPUMem -> Either [Char] (RetType GPU)
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
Pretty (ShapeBase d)) =>
MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT,
mapOnBranchType :: BranchType GPUMem -> Either [Char] (BranchType GPU)
mapOnBranchType = BranchType GPUMem -> Either [Char] (BranchType GPU)
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
Pretty (ShapeBase d)) =>
MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT,
mapOnFParam :: FParam GPUMem -> Either [Char] (FParam GPU)
mapOnFParam = FParam GPUMem -> Either [Char] (FParam GPU)
forall {t :: * -> *} {d} {u} {ret}.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam,
mapOnLParam :: LParam GPUMem -> Either [Char] (LParam GPU)
mapOnLParam = LParam GPUMem -> Either [Char] (LParam GPU)
forall {t :: * -> *} {d} {u} {ret}.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam,
mapOnOp :: Op GPUMem -> Either [Char] (Op GPU)
mapOnOp = Op GPUMem -> Either [Char] (Op GPU)
MemOp (HostOp GPUMem ()) -> Either [Char] (HostOp GPU (SOAC GPU))
unAllocOp,
mapOnSubExp :: SubExp -> Either [Char] SubExp
mapOnSubExp = SubExp -> Either [Char] SubExp
forall a b. b -> Either a b
Right,
mapOnVName :: VName -> Either [Char] VName
mapOnVName = VName -> Either [Char] VName
forall a b. b -> Either a b
Right
}
unMem :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem :: forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem (MemPrim PrimType
pt) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase (ShapeBase d) u
forall shape u.
VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase shape u
Acc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u
unMem MemMem {} = Maybe (TypeBase (ShapeBase d) u)
forall a. Maybe a
Nothing
unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = (NameInfo GPUMem -> Maybe (NameInfo GPU))
-> Scope GPUMem -> Scope GPU
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe NameInfo GPUMem -> Maybe (NameInfo GPU)
forall {rep} {d} {u} {d} {u} {d} {u} {rep} {ret} {ret} {ret}.
(LParamInfo rep ~ TypeBase (ShapeBase d) u,
FParamInfo rep ~ TypeBase (ShapeBase d) u,
LetDec rep ~ TypeBase (ShapeBase d) u,
LetDec rep ~ MemInfo d u ret, FParamInfo rep ~ MemInfo d u ret,
LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> Maybe (NameInfo rep)
unInfo
where
unInfo :: NameInfo rep -> Maybe (NameInfo rep)
unInfo (LetName LetDec rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. LetDec rep -> NameInfo rep
LetName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LetDec rep
MemInfo d u ret
dec
unInfo (FParamName FParamInfo rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. FParamInfo rep -> NameInfo rep
FParamName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem FParamInfo rep
MemInfo d u ret
dec
unInfo (LParamName LParamInfo rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. LParamInfo rep -> NameInfo rep
LParamName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LParamInfo rep
MemInfo d u ret
dec
unInfo (IndexName IntType
it) = NameInfo rep -> Maybe (NameInfo rep)
forall a. a -> Maybe a
Just (NameInfo rep -> Maybe (NameInfo rep))
-> NameInfo rep -> Maybe (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
it
removeCommonSizes ::
Extraction ->
[(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
-> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> Map SubExp [(VName, Space)]
forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Map SubExp [(VName, Space)])
-> (Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
where
comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m
sliceKernelSizes ::
SubExp ->
[SubExp] ->
SegSpace ->
Stms GPUMem ->
ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> Result
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
sizes SegSpace
space Stms GPUMem
kstms = do
Stms GPU
kstms' <- ([Char]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU))
-> (Stms GPU
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU))
-> Either [Char] (Stms GPU)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms GPU
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Stms GPU)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU))
-> Either [Char] (Stms GPU)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either [Char] (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
let num_sizes :: Int
num_sizes = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
sizes
i64s :: [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s = Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate Int
num_sizes (TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Scope GPU
kernels_scope <- (Scope GPUMem -> Scope GPU)
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) (Scope GPU)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope GPUMem -> Scope GPU
unAllocScope
(Lambda GPU
max_lam, Stms GPU
_) <- (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> Scope GPU
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU))
-> Scope GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> Scope GPU
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT Scope GPU
kernels_scope (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs <- Int
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys <- Int
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"y" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
(Result
zs, Stms GPU
stms) <- Scope GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$
BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result,
Stms
(Rep
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result,
Stms
(Rep
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result,
Stms
(Rep
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall a b. (a -> b) -> a -> b
$
[(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result)
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
[Char]
-> Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"z" (Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp)
-> Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
Lambda GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU))
-> Lambda GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
forall a b. (a -> b) -> a -> b
$ [LParam GPU]
-> BodyT GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (Stms GPU -> Result -> BodyT GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s
Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam <- VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> Param dec
Param (VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) VName
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(TypeBase (ShapeBase SubExp) NoUniqueness
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either [Char])) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"flat_gtid" ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(TypeBase (ShapeBase SubExp) NoUniqueness
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))
(Lambda GPU
size_lam', Stms GPU
_) <- (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> Scope GPU
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU))
-> Scope GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> Scope GPU
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT Scope GPU
kernels_scope (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params <- Int
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
(Result
zs, Stms GPU
stms) <- Scope GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
( [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params
Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam]
)
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$ BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result,
Stms
(Rep
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result,
Stms
(Rep
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Result,
Stms
(Rep
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall a b. (a -> b) -> a -> b
$ do
let ([VName]
kspace_gtids, Result
kspace_dims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
new_inds :: [TPrimExp Int64 VName]
new_inds =
[TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
((SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 Result
kspace_dims)
(SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
([VName]
-> ExpT GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
())
-> [[VName]]
-> [ExpT GPU]
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> ExpT GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([ExpT GPU]
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
())
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[ExpT GPU]
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(ExpT GPU))
-> [TPrimExp Int64 VName]
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[ExpT GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp Int64 VName
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(ExpT GPU)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds
(Stm GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
())
-> Stms GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
()
forall (m :: * -> *). MonadBinder m => Stm (Rep m) -> m ()
addStm Stms GPU
kstms'
Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
sizes
Scope GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
forall a b. (a -> b) -> a -> b
$
Lambda GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Lambda GPU)
forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
GPU.simplifyLambda ([LParam GPU]
-> BodyT GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam GPU
flat_gtid_lparam] (BodyDec GPU -> Stms GPU -> Result -> BodyT GPU
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s)
(([VName]
maxes_per_thread, [VName]
size_sums), Stms GPU
slice_stms) <- (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
([VName], [VName])
-> Scope GPU
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(([VName], [VName]), Stms GPU))
-> Scope GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
([VName], [VName])
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(([VName], [VName]), Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
([VName], [VName])
-> Scope GPU
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(([VName], [VName]), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT Scope GPU
kernels_scope (BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
([VName], [VName])
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(([VName], [VName]), Stms GPU))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
([VName], [VName])
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either [Char]))
(([VName], [VName]), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <-
[Ident]
-> [Ident] -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
basicPattern []
([Ident] -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Ident]
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Ident
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM
Int
num_sizes
([Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent [Char]
"max_per_thread" (TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Ident)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
SubExp
w <-
[Char]
-> Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"size_slice_w"
(ExpT GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(ExpT GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Exp
(Rep
(BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> Result
segSpaceDims SegSpace
space)
VName
thread_space_iota <-
[Char]
-> Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"thread_space_iota" (Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName)
-> Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
let red_op :: SegBinOp GPU
red_op =
Commutativity
-> Lambda GPU -> Result -> ShapeBase SubExp -> SegBinOp GPU
forall rep.
Commutativity
-> Lambda rep -> Result -> ShapeBase SubExp -> SegBinOp rep
SegBinOp
Commutativity
Commutative
Lambda GPU
max_lam
(Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> Result) -> SubExp -> Result
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
ShapeBase SubExp
forall a. Monoid a => a
mempty
SegLevel
lvl <- [Char]
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
SegLevel
forall (m :: * -> *) inner.
(MonadBinder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
[Char] -> m SegLevel
segThread [Char]
"segred"
Stms GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
())
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stm GPU))
-> Stms GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stm GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
(Stms GPU
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stms GPU))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stms GPU)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pattern GPU
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
(Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pattern rep
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel GPU
SegLevel
lvl PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
Pattern GPU
pat SubExp
w [SegBinOp GPU
red_op] Lambda GPU
size_lam' [VName
thread_space_iota]
[VName]
size_sums <- [VName]
-> (VName
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat) ((VName
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[VName])
-> (VName
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName)
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
[VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
[Char]
-> Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"size_sum" (Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName)
-> Exp
(Rep
(BinderT
GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads
([VName], [VName])
-> BinderT
GPU
(ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat, [VName]
size_sums)
(Stms GPU, [VName], [VName])
-> ExpandM (Stms GPU, [VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPU
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)