{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations (expandAllocations) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (find, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Error
import Futhark.IR
import qualified Futhark.IR.GPU.Simplify as GPU
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

-- | The memory expansion pass definition.
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
  String
-> String
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"expand allocations" String
"Expand allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$
    \(Prog Stms GPUMem
consts [FunDef GPUMem]
funs) -> do
      Stms GPUMem
consts' <-
        (VNameSource -> (Stms GPUMem, VNameSource)) -> PassM (Stms GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPUMem, VNameSource))
 -> PassM (Stms GPUMem))
-> (VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Either String (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft (Either String (Stms GPUMem, VNameSource)
 -> (Stms GPUMem, VNameSource))
-> (VNameSource -> Either String (Stms GPUMem, VNameSource))
-> VNameSource
-> (Stms GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Stms GPUMem)
-> VNameSource -> Either String (Stms GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Stms GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
consts) Scope GPUMem
forall a. Monoid a => a
mempty)
      Stms GPUMem -> [FunDef GPUMem] -> Prog GPUMem
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms GPUMem
consts' ([FunDef GPUMem] -> Prog GPUMem)
-> PassM [FunDef GPUMem] -> PassM (Prog GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunDef GPUMem -> PassM (FunDef GPUMem))
-> [FunDef GPUMem] -> PassM [FunDef GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem))
-> Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
consts') [FunDef GPUMem]
funs

-- Cannot use intraproceduralTransformation because it might create
-- duplicate size keys (which are not fixed by renamer, and size
-- keys must currently be globally unique).

type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))

limitationOnLeft :: Either String a -> a
limitationOnLeft :: Either String a -> a
limitationOnLeft = (String -> a) -> (a -> a) -> Either String a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> a
forall a. String -> a
compilerLimitationS a -> a
forall a. a -> a
id

transformFunDef ::
  Scope GPUMem ->
  FunDef GPUMem ->
  PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
  Body GPUMem
body' <- (VNameSource -> (Body GPUMem, VNameSource)) -> PassM (Body GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body GPUMem, VNameSource))
 -> PassM (Body GPUMem))
-> (VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Either String (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft (Either String (Body GPUMem, VNameSource)
 -> (Body GPUMem, VNameSource))
-> (VNameSource -> Either String (Body GPUMem, VNameSource))
-> VNameSource
-> (Body GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Body GPUMem)
-> VNameSource -> Either String (Body GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Body GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m Scope GPUMem
forall a. Monoid a => a
mempty)
  SimpleOps GPUMem
-> SymbolTable (Wise GPUMem)
-> FunDef GPUMem
-> PassM (FunDef GPUMem)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep)
copyPropagateInFun
    SimpleOps GPUMem
simpleGPUMem
    (Scope (Wise GPUMem) -> SymbolTable (Wise GPUMem)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope GPUMem -> Scope (Wise GPUMem)
forall rep. Scope rep -> Scope (Wise rep)
addScopeWisdom Scope GPUMem
scope))
    FunDef GPUMem
fundec {funDefBody :: Body GPUMem
funDefBody = Body GPUMem
body'}
  where
    m :: ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m =
      Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
        FunDef GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
          Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ FunDef GPUMem -> Body GPUMem
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef GPUMem
fundec

transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () (Stms GPUMem -> Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Result -> Body GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [Type]
ret) =
  [LParam GPUMem] -> Body GPUMem -> [Type] -> Lambda GPUMem
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam GPUMem]
params
    (Body GPUMem -> [Type] -> Lambda GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([Type] -> Lambda GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param LParamMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
[Param LParamMem]
params) (Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body)
    ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  ([Type] -> Lambda GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Type]
-> ExpandM (Lambda GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms =
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ([Stms GPUMem] -> Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> [Stm GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)

transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
-- It is possible that we are unable to expand allocations in some
-- code versions.  If so, we can remove the offending branch.  Only if
-- both versions fail do we propagate the error.
transformStm :: Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Let Pat GPUMem
pat StmAux (ExpDec GPUMem)
aux (If SubExp
cond Body GPUMem
tbranch Body GPUMem
fbranch (IfDec [BranchType GPUMem]
ts IfSort
IfEquiv))) = do
  Either String (Body GPUMem)
tbranch' <- (Body GPUMem -> Either String (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either String (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
tbranch) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Either String (Body GPUMem))
-> (String
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either String (Body GPUMem)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Body GPUMem)))
-> (String -> Either String (Body GPUMem))
-> String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Body GPUMem)
forall a b. a -> Either a b
Left)
  Either String (Body GPUMem)
fbranch' <- (Body GPUMem -> Either String (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either String (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
fbranch) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Either String (Body GPUMem))
-> (String
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either String (Body GPUMem)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Body GPUMem)))
-> (String -> Either String (Body GPUMem))
-> String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Body GPUMem)
forall a b. a -> Either a b
Left)
  case (Either String (Body GPUMem)
tbranch', Either String (Body GPUMem)
fbranch') of
    (Left String
_, Right Body GPUMem
fbranch'') ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
fbranch''
    (Right Body GPUMem
tbranch'', Left String
_) ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
tbranch''
    (Right Body GPUMem
tbranch'', Right Body GPUMem
fbranch'') ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pat GPUMem -> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat GPUMem
pat StmAux (ExpDec GPUMem)
aux (ExpT GPUMem -> Stm GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body GPUMem
-> Body GPUMem
-> IfDec (BranchType GPUMem)
-> ExpT GPUMem
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond Body GPUMem
tbranch'' Body GPUMem
fbranch'' ([BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType GPUMem]
[BranchTypeMem]
ts IfSort
IfEquiv)
    (Left String
e, Either String (Body GPUMem)
_) ->
      String
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
e
  where
    bindRes :: PatElemT (LetDec rep) -> SubExpRes -> Stm rep
bindRes PatElemT (LetDec rep)
pe (SubExpRes Certs
cs SubExp
se) =
      Certs -> Stm rep -> Stm rep
forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs (Stm rep -> Stm rep) -> Stm rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT (LetDec rep)] -> Pat rep
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT (LetDec rep)
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

    useBranch :: Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
b =
      Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms Body GPUMem
b
        Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ((PatElemT LParamMem -> SubExpRes -> Stm GPUMem)
-> [PatElemT LParamMem] -> Result -> [Stm GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT LParamMem -> SubExpRes -> Stm GPUMem
forall rep.
(ExpDec rep ~ ()) =>
PatElemT (LetDec rep) -> SubExpRes -> Stm rep
bindRes (PatT LParamMem -> [PatElemT LParamMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat GPUMem
PatT LParamMem
pat) (Body GPUMem -> Result
forall rep. BodyT rep -> Result
bodyResult Body GPUMem
b))
transformStm (Let Pat GPUMem
pat StmAux (ExpDec GPUMem)
aux ExpT GPUMem
e) = do
  (Stms GPUMem
stms, ExpT GPUMem
e') <- ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem)
transformExp (ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (ExpT GPUMem)
-> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
-> ExpT GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform ExpT GPUMem
e
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem
stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Pat GPUMem -> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat GPUMem
pat StmAux (ExpDec GPUMem)
aux ExpT GPUMem
e')
  where
    transform :: Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform =
      Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
mapOnBody = \Scope GPUMem
scope -> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> (Body GPUMem
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody
        }

transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem)
transformExp (Op (Inner (SegOp (SegMap lvl space ts kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
_, KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPUMem -> SegOp SegLevel GPUMem
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegRed lvl space reds ts kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
  let reds' :: [SegBinOp GPUMem]
reds' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds' [Type]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegScan lvl space scans ts kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
  let scans' :: [SegBinOp GPUMem]
scans' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans' [Type]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegHist lvl space ops ts kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams', KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
  let ops' :: [HistOp GPUMem]
ops' = (HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem] -> [HistOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem
forall rep rep. HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> [Type]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops' [Type]
ts KernelBody GPUMem
kbody'
    )
  where
    lams :: [Lambda GPUMem]
lams = (HistOp GPUMem -> Lambda GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
    onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp :: Lambda rep
histOp = Lambda rep
lam}
transformExp (WithAcc [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam) = do
  Lambda GPUMem
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
  ([Stms GPUMem]
input_alloc_stms, [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs') <- [(Stms GPUMem, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> ([Stms GPUMem],
    [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms GPUMem, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
 -> ([Stms GPUMem],
     [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stms GPUMem, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([Stms GPUMem],
      [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))))
-> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stms GPUMem, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))
forall b b.
(Shape, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
onInput [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat [Stms GPUMem]
input_alloc_stms,
      [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem -> ExpT GPUMem
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs' Lambda GPUMem
lam'
    )
  where
    onInput :: (Shape, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
onInput (Shape
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
      (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, (Shape
shape, b
arrs, Maybe (Lambda GPUMem, b)
forall a. Maybe a
Nothing))
    onInput (Shape
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
      Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
      let -- XXX: fake a SegLevel, which we don't have here.  We will not
          -- use it for anything, as we will not allow irregular
          -- allocations inside the update function.
          lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> SubExp -> Count NumGroups SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SegVirt
SegNoVirt
          (Lambda GPUMem
op_lam', Extraction
lam_allocs) =
            User
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel
lvl, [TPrimExp Int64 VName
0]) Names
bound_outside Names
forall a. Monoid a => a
mempty Lambda GPUMem
op_lam
          variantAlloc :: (User, SubExp, Space) -> Bool
variantAlloc (User
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
          variantAlloc (User, SubExp, Space)
_ = Bool
False
          (Extraction
variant_allocs, Extraction
invariant_allocs) = ((User, SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition (User, SubExp, Space) -> Bool
variantAlloc Extraction
lam_allocs

      case Extraction -> [(User, SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
        (User
_, SubExp
v, Space
_) : [(User, SubExp, Space)]
_ ->
          String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
            String
"Cannot handle un-sliceable allocation size: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
v
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside accumulator update operator."
        [] ->
          ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      let num_is :: Int
num_is = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
          is :: [VName]
is = (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
take Int
num_is ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
      (Stms GPUMem
alloc_stms, RebaseMap
alloc_offsets) <-
        (User -> (Shape, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations ((Shape, [TPrimExp Int64 VName])
-> User -> (Shape, [TPrimExp Int64 VName])
forall a b. a -> b -> a
const (Shape
shape, (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is)) Extraction
invariant_allocs

      Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
      let scope' :: Scope GPUMem
scope' = Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
      (String
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b))))
-> ((Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b))))
-> Either
     String (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b))))
-> Either
     String (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$
        Scope GPUMem
-> RebaseMap
-> OffsetM (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
-> Either
     String (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
 -> Either
      String (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b))))
-> OffsetM (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
-> Either
     String (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$ do
          Lambda GPUMem
op_lam'' <- Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op_lam'
          (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
-> OffsetM (Stms GPUMem, (Shape, b, Maybe (Lambda GPUMem, b)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, (Shape
shape, b
arrs, (Lambda GPUMem, b) -> Maybe (Lambda GPUMem, b)
forall a. a -> Maybe a
Just (Lambda GPUMem
op_lam'', b
nes)))
transformExp ExpT GPUMem
e =
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
forall a. Monoid a => a
mempty, ExpT GPUMem
e)

transformScanRed ::
  SegLevel ->
  SegSpace ->
  [Lambda GPUMem] ->
  KernelBody GPUMem ->
  ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
  Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
  let user :: User
user = (SegLevel
lvl, [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
      (KernelBody GPUMem
kbody', Extraction
kbody_allocs) =
        User
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations User
user Names
bound_outside Names
bound_in_kernel KernelBody GPUMem
kbody
      ([Lambda GPUMem]
ops', [Extraction]
ops_allocs) = [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction]))
-> [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda GPUMem -> (Lambda GPUMem, Extraction))
-> [Lambda GPUMem] -> [(Lambda GPUMem, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map (User
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations User
user Names
bound_outside Names
forall a. Monoid a => a
mempty) [Lambda GPUMem]
ops
      variantAlloc :: (User, SubExp, Space) -> Bool
variantAlloc (User
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
      variantAlloc (User, SubExp, Space)
_ = Bool
False
      (Extraction
variant_allocs, Extraction
invariant_allocs) =
        ((User, SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition (User, SubExp, Space) -> Bool
variantAlloc (Extraction -> (Extraction, Extraction))
-> Extraction -> (Extraction, Extraction)
forall a b. (a -> b) -> a -> b
$ Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      badVariant :: (User, SubExp, Space) -> Bool
badVariant (User
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_kernel
      badVariant (User, SubExp, Space)
_ = Bool
False

  case ((User, SubExp, Space) -> Bool)
-> [(User, SubExp, Space)] -> Maybe (User, SubExp, Space)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (User, SubExp, Space) -> Bool
badVariant ([(User, SubExp, Space)] -> Maybe (User, SubExp, Space))
-> [(User, SubExp, Space)] -> Maybe (User, SubExp, Space)
forall a b. (a -> b) -> a -> b
$ Extraction -> [(User, SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
    Just (User, SubExp, Space)
v ->
      String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
        String
"Cannot handle un-sliceable allocation size: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (User, SubExp, Space) -> String
forall a. Pretty a => a -> String
pretty (User, SubExp, Space)
v
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside parallel constructs."
    Maybe (User, SubExp, Space)
Nothing ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  case SegLevel
lvl of
    SegGroup {}
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
        String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"Cannot handle invariant allocations in SegGroup."
    SegLevel
_ ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' ((Stms GPUMem
  -> KernelBody GPUMem
  -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
 -> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> (Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall a b. (a -> b) -> a -> b
$ \Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
    [Lambda GPUMem]
ops'' <- [Lambda GPUMem]
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' ((Lambda GPUMem -> OffsetM (Lambda GPUMem))
 -> OffsetM [Lambda GPUMem])
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
      Scope GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op'
    (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
-> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
ops'', KernelBody GPUMem
kbody''))
  where
    bound_in_kernel :: Names
bound_in_kernel =
      [VName] -> Names
namesFromList (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
        Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody

boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList ([VName] -> Names)
-> (KernelBody GPUMem -> [VName]) -> KernelBody GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope GPUMem -> [VName])
-> (KernelBody GPUMem -> Scope GPUMem)
-> KernelBody GPUMem
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Stms GPUMem -> Scope GPUMem)
-> (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem
-> Scope GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms

allocsForBody ::
  Extraction ->
  Extraction ->
  SegLevel ->
  SegSpace ->
  KernelBody GPUMem ->
  (Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
  ExpandM b
allocsForBody :: Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms GPUMem
alloc_stms) <-
    SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
      SegLevel
lvl
      SegSpace
space
      (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody')
      Extraction
variant_allocs
      Extraction
invariant_allocs

  Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let scope' :: Scope GPUMem
scope' = SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  (String -> ExpandM b)
-> (b -> ExpandM b) -> Either String b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ExpandM b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b -> ExpandM b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String b -> ExpandM b) -> Either String b -> ExpandM b
forall a b. (a -> b) -> a -> b
$
    Scope GPUMem -> RebaseMap -> OffsetM b -> Either String b
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM b -> Either String b) -> OffsetM b -> Either String b
forall a b. (a -> b) -> a -> b
$ do
      KernelBody GPUMem
kbody'' <- KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody'
      Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m Stms GPUMem
alloc_stms KernelBody GPUMem
kbody''

memoryRequirements ::
  SegLevel ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  Extraction ->
  ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements SegLevel
lvl SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  (SubExp
num_threads, Stms GPUMem
num_threads_stms) <-
    Builder GPUMem SubExp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem SubExp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (SubExp, Stms GPUMem))
-> (BasicOp -> Builder GPUMem SubExp)
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> Builder GPUMem SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (ExpT GPUMem -> Builder GPUMem SubExp)
-> (BasicOp -> ExpT GPUMem) -> BasicOp -> Builder GPUMem SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT GPUMem
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (SubExp, Stms GPUMem))
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
        (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

  (Stms GPUMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations
        SubExp
num_threads
        (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
        (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
        Extraction
invariant_allocs

  (Stms GPUMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations
        SubExp
num_threads
        SegSpace
space
        Stms GPUMem
kstms
        Extraction
variant_allocs

  (RebaseMap, Stms GPUMem) -> ExpandM (RebaseMap, Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
      Stms GPUMem
num_threads_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
invariant_alloc_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
variant_alloc_stms
    )

-- | Identifying the spot where an allocation occurs in terms of its
-- level and unique thread ID.
type User = (SegLevel, [TPrimExp Int64 VName])

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (User, SubExp, Space)

extractKernelBodyAllocations ::
  User ->
  Names ->
  Names ->
  KernelBody GPUMem ->
  ( KernelBody GPUMem,
    Extraction
  )
extractKernelBodyAllocations :: User
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations User
lvl Names
bound_outside Names
bound_kernel =
  User
-> Names
-> Names
-> (KernelBody GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall body.
User
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations User
lvl Names
bound_outside Names
bound_kernel KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms ((Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
 -> KernelBody GPUMem -> (KernelBody GPUMem, Extraction))
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}

extractBodyAllocations ::
  User ->
  Names ->
  Names ->
  Body GPUMem ->
  (Body GPUMem, Extraction)
extractBodyAllocations :: User -> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations User
user Names
bound_outside Names
bound_kernel =
  User
-> Names
-> Names
-> (Body GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall body.
User
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations User
user Names
bound_outside Names
bound_kernel Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms ((Stms GPUMem -> Body GPUMem -> Body GPUMem)
 -> Body GPUMem -> (Body GPUMem, Extraction))
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Stms GPUMem
stms}

extractLambdaAllocations ::
  User ->
  Names ->
  Names ->
  Lambda GPUMem ->
  (Lambda GPUMem, Extraction)
extractLambdaAllocations :: User
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations User
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam = (Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body'}, Extraction
allocs)
  where
    (Body GPUMem
body', Extraction
allocs) = User -> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations User
user Names
bound_outside Names
bound_kernel (Body GPUMem -> (Body GPUMem, Extraction))
-> Body GPUMem -> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam

extractGenericBodyAllocations ::
  User ->
  Names ->
  Names ->
  (body -> Stms GPUMem) ->
  (Stms GPUMem -> body -> body) ->
  body ->
  ( body,
    Extraction
  )
extractGenericBodyAllocations :: User
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations User
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
  let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Names
forall rep. Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
      ([Stm GPUMem]
stms, Extraction
allocs) =
        Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction))
-> Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall a b. (a -> b) -> a -> b
$
          ([Maybe (Stm GPUMem)] -> [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm GPUMem)] -> [Stm GPUMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm GPUMem)]
 -> Writer Extraction [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$
            (Stm GPUMem -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (User
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations User
user Names
bound_outside Names
bound_kernel') ([Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)])
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall a b. (a -> b) -> a -> b
$
              Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ body -> Stms GPUMem
get_stms body
body
   in (Stms GPUMem -> body -> body
set_stms ([Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)

expandable, notScalar :: Space -> Bool
expandable :: Space -> Bool
expandable (Space String
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True

extractStmAllocations ::
  User ->
  Names ->
  Names ->
  Stm GPUMem ->
  Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations :: User
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations User
user Names
bound_outside Names
bound_kernel (Let (Pat [PatElemT (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc size space)))
  | Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
      -- FIXME: the '&& notScalar space' part is a hack because we
      -- don't otherwise hoist the sizes out far enough, and we
      -- promise to be super-duper-careful about not having variant
      -- scalar allocations.
      Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
    Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName -> (User, SubExp, Space) -> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT LParamMem
patElem) (User
user, SubExp
size, Space
space)
    Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm GPUMem)
forall a. Maybe a
Nothing
  where
    expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    expandableSize Constant {} = Bool
True
    boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    boundInKernel Constant {} = Bool
False
extractStmAllocations User
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
  ExpT GPUMem
e <- Mapper GPUMem GPUMem (WriterT Extraction Identity)
-> ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM (User -> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper User
user) (ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem))
-> ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> ExpT GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm
  Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stm GPUMem)
 -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Maybe (Stm GPUMem)
forall a. a -> Maybe a
Just (Stm GPUMem -> Maybe (Stm GPUMem))
-> Stm GPUMem -> Maybe (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = ExpT GPUMem
e}
  where
    expMapper :: User -> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper User
user' =
      Mapper GPUMem GPUMem (WriterT Extraction Identity)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
mapOnBody = (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. a -> b -> a
const ((Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
 -> Scope GPUMem
 -> Body GPUMem
 -> WriterT Extraction Identity (Body GPUMem))
-> (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ User -> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody User
user',
          mapOnOp :: Op GPUMem -> WriterT Extraction Identity (Op GPUMem)
mapOnOp = User
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp User
user'
        }

    onBody :: User -> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody User
user' Body GPUMem
body = do
      let (Body GPUMem
body', Extraction
allocs) = User -> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations User
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Body GPUMem
body'

    onOp :: User
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem ()))
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
-> SegOp SegLevel GPUMem
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM (User
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper User
user'') SegOp SegLevel GPUMem
op
      where
        user'' :: User
user'' =
          (SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
    onOp User
_ MemOp (HostOp GPUMem ())
op = MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp GPUMem ())
op

    opMapper :: User
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper User
user' =
      SegOpMapper SegLevel Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
mapOnSegOpLambda = User
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda User
user',
          mapOnSegOpBody :: KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
mapOnSegOpBody = User
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody User
user'
        }

    onKernelBody :: User
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody User
user' KernelBody GPUMem
body = do
      let (KernelBody GPUMem
body', Extraction
allocs) = User
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations User
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody GPUMem
body'

    onLambda :: User
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda User
user' Lambda GPUMem
lam = do
      Body GPUMem
body <- User -> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody User
user' (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam
      Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

genericExpandedInvariantAllocations ::
  (User -> (Shape, [TPrimExp Int64 VName])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: (User -> (Shape, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations User -> (Shape, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the number of kernel threads.
  ([RebaseMap]
rebases, Stms GPUMem
alloc_stms) <- Builder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([RebaseMap], Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem [RebaseMap]
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      ([RebaseMap], Stms GPUMem))
-> Builder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([RebaseMap], Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ ((VName, (User, SubExp, Space))
 -> BuilderT GPUMem (State VNameSource) RebaseMap)
-> [(VName, (User, SubExp, Space))] -> Builder GPUMem [RebaseMap]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (User, SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand ([(VName, (User, SubExp, Space))] -> Builder GPUMem [RebaseMap])
-> [(VName, (User, SubExp, Space))] -> Builder GPUMem [RebaseMap]
forall a b. (a -> b) -> a -> b
$ Extraction -> [(VName, (User, SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (User, SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, (User
user, SubExp
per_thread_size, Space
space)) = do
      let num_users :: Shape
num_users = (Shape, [TPrimExp Int64 VName]) -> Shape
forall a b. (a, b) -> a
fst ((Shape, [TPrimExp Int64 VName]) -> Shape)
-> (Shape, [TPrimExp Int64 VName]) -> Shape
forall a b. (a -> b) -> a -> b
$ User -> (Shape, [TPrimExp Int64 VName])
getNumUsers User
user
          allocpat :: PatT LParamMem
allocpat = [PatElemT LParamMem] -> PatT LParamMem
forall dec. [PatElemT dec] -> PatT dec
Pat [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      VName
total_size <-
        String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"total_size" (ExpT GPUMem -> BuilderT GPUMem (State VNameSource) VName)
-> ([TPrimExp Int64 VName]
    -> BuilderT GPUMem (State VNameSource) (ExpT GPUMem))
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName
-> BuilderT GPUMem (State VNameSource) (ExpT GPUMem)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName
 -> BuilderT GPUMem (State VNameSource) (ExpT GPUMem))
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) (ExpT GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName]
 -> BuilderT GPUMem (State VNameSource) VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
num_users)
      Pat (Rep (BuilderT GPUMem (State VNameSource)))
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep (BuilderT GPUMem (State VNameSource)))
PatT LParamMem
allocpat (Exp (Rep (BuilderT GPUMem (State VNameSource)))
 -> BuilderT GPUMem (State VNameSource) ())
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
      RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap)
-> RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap
forall a b. (a -> b) -> a -> b
$ VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ User
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase User
user

    untouched :: d -> DimIndex d
untouched d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1

    newBase :: User
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase user :: User
user@(SegThread {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
      let (Shape
users_shape, [TPrimExp Int64 VName]
user_ids) = User -> (Shape, [TPrimExp Int64 VName])
getNumUsers User
user
          num_dims :: Int
num_dims = [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape
          perm :: [Int]
perm = [Int
num_dims .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
users_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName]
old_shape [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
users_shape))
          permuted_ixfun :: IxFun (TPrimExp Int64 VName)
permuted_ixfun = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
root_ixfun [Int]
perm
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
permuted_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun
    newBase user :: User
user@(SegGroup {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
      let (Shape
users_shape, [TPrimExp Int64 VName]
user_ids) = User -> (Shape, [TPrimExp Int64 VName])
getNumUsers User
user
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
users_shape) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
old_shape
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> ([DimIndex (TPrimExp Int64 VName)]
    -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> IxFun (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun

expandedInvariantAllocations ::
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_groups) (Count SubExp
group_size) =
  (User -> (Shape, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations User -> (Shape, [TPrimExp Int64 VName])
getNumUsers
  where
    getNumUsers :: User -> (Shape, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) = ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
    getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) = ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups, SubExp
group_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
    getNumUsers (SegGroup {}, [TPrimExp Int64 VName
gid]) = ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups], [TPrimExp Int64 VName
gid])
    getNumUsers User
user = String -> (Shape, [TPrimExp Int64 VName])
forall a. HasCallStack => String -> a
error (String -> (Shape, [TPrimExp Int64 VName]))
-> String -> (Shape, [TPrimExp Int64 VName])
forall a b. (a -> b) -> a -> b
$ String
"getNumUsers: unhandled " String -> String -> String
forall a. [a] -> [a] -> [a]
++ User -> String
forall a. Show a => a -> String
show User
user

expandedVariantAllocations ::
  SubExp ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
  | Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: [SubExp]
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms GPU
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
variant_sizes SegSpace
kspace Stms GPUMem
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  (SymbolTable (Wise GPUMem)
_, Stms GPUMem
slice_stms_tmp) <-
    Stms GPUMem
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SymbolTable (Wise GPUMem), Stms GPUMem)
forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (SymbolTable (Wise GPUMem), Stms GPUMem)
simplifyStms (Stms GPUMem
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (SymbolTable (Wise GPUMem), Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SymbolTable (Wise GPUMem), Stms GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms Stms GPU
slice_stms
  Stms GPUMem
slice_stms' <- Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
        [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
 -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
          ([(VName, Space)]
 -> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall a c.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
            (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
            ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm GPUMem]
alloc_stms, [RebaseMap]
rebases) <- [(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap]))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stm GPUMem, RebaseMap)]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([Stm GPUMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, SubExp, Space))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stm GPUMem, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stm GPUMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
slice_stms' Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
      let allocpat :: PatT LParamMem
allocpat = [PatElemT LParamMem] -> PatT LParamMem
forall dec. [PatElemT dec] -> PatT dec
Pat [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      (Stm GPUMem, RebaseMap)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Pat GPUMem -> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat GPUMem
PatT LParamMem
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT GPUMem -> Stm GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
          VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
offset
        )

    num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
    gtid :: TPrimExp Int64 VName
gtid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace

    -- For the variant allocations, we add an inner dimension,
    -- which is then offset by a thread-specific amount.
    newBase :: SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
size_per_thread ([TPrimExp Int64 VName]
old_shape, PrimType
pt) =
      let elems_per_thread :: TPrimExp Int64 VName
elems_per_thread =
            SubExp -> TPrimExp Int64 VName
pe64 SubExp
size_per_thread TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName
elems_per_thread, TPrimExp Int64 VName
num_threads']
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> ([DimIndex (TPrimExp Int64 VName)]
    -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> IxFun (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              [TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
num_threads' TPrimExp Int64 VName
1, TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid]
          shapechange :: [DimChange (TPrimExp Int64 VName)]
shapechange =
            if [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
              then (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimCoercion [TPrimExp Int64 VName]
old_shape
              else (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimNew [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
-> [DimChange (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
offset_ixfun [DimChange (TPrimExp Int64 VName)]
shapechange

-- | A map from memory block names to new index function bases.
type RebaseMap = M.Map VName (([TPrimExp Int64 VName], PrimType) -> IxFun)

newtype OffsetM a
  = OffsetM
      ( ReaderT
          (Scope GPUMem)
          (ReaderT RebaseMap (Either String))
          a
      )
  deriving
    ( Functor OffsetM
a -> OffsetM a
Functor OffsetM
-> (forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
    (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
OffsetM a -> OffsetM b -> OffsetM b
OffsetM a -> OffsetM b -> OffsetM a
OffsetM (a -> b) -> OffsetM a -> OffsetM b
(a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
$cp1Applicative :: Functor OffsetM
Applicative,
      a -> OffsetM b -> OffsetM a
(a -> b) -> OffsetM a -> OffsetM b
(forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
      Applicative OffsetM
a -> OffsetM a
Applicative OffsetM
-> (forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
OffsetM a -> (a -> OffsetM b) -> OffsetM b
OffsetM a -> OffsetM b -> OffsetM b
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$cp1Monad :: Applicative OffsetM
Monad,
      HasScope GPUMem,
      LocalScope GPUMem,
      MonadError String
    )

runOffsetM :: Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) =
  ReaderT RebaseMap (Either String) a -> RebaseMap -> Either String a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> Scope GPUMem -> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope) RebaseMap
offsets

askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = ReaderT
  (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
-> OffsetM RebaseMap
forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM (ReaderT
   (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
 -> OffsetM RebaseMap)
-> ReaderT
     (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
-> OffsetM RebaseMap
forall a b. (a -> b) -> a -> b
$ ReaderT RebaseMap (Either String) RebaseMap
-> ReaderT
     (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT RebaseMap (Either String) RebaseMap
forall r (m :: * -> *). MonadReader r m => m r
ask

localRebaseMap :: (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap :: (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
f (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) = ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM (ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
 -> OffsetM a)
-> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
forall a b. (a -> b) -> a -> b
$ do
  Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (ReaderT RebaseMap (Either String)) (Scope GPUMem)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ReaderT RebaseMap (Either String) a
-> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT RebaseMap (Either String) a
 -> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a)
-> ReaderT RebaseMap (Either String) a
-> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
forall a b. (a -> b) -> a -> b
$ (RebaseMap -> RebaseMap)
-> ReaderT RebaseMap (Either String) a
-> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local RebaseMap -> RebaseMap
f (ReaderT RebaseMap (Either String) a
 -> ReaderT RebaseMap (Either String) a)
-> ReaderT RebaseMap (Either String) a
-> ReaderT RebaseMap (Either String) a
forall a b. (a -> b) -> a -> b
$ ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> Scope GPUMem -> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope

lookupNewBase :: VName -> ([TPrimExp Int64 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
name ([TPrimExp Int64 VName], PrimType)
x = do
  RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
  Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp Int64 VName))
 -> OffsetM (Maybe (IxFun (TPrimExp Int64 VName))))
-> Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall a b. (a -> b) -> a -> b
$ ((([TPrimExp Int64 VName], PrimType)
 -> IxFun (TPrimExp Int64 VName))
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int64 VName], PrimType)
x) ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> IxFun (TPrimExp Int64 VName))
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
-> Maybe (IxFun (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody = do
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
      ((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
 -> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
  KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms'}

offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) = do
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
      ((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
 -> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
  Body GPUMem -> OffsetM (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec GPUMem
dec Stms GPUMem
stms' Result
res

offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm (Let Pat GPUMem
pat StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e) = do
  ExpT GPUMem
e' <- ExpT GPUMem -> OffsetM (ExpT GPUMem)
offsetMemoryInExp ExpT GPUMem
e
  PatT LParamMem
pat' <- Pat GPUMem -> [ExpReturns] -> OffsetM (Pat GPUMem)
offsetMemoryInPat Pat GPUMem
pat ([ExpReturns] -> OffsetM (PatT LParamMem))
-> OffsetM [ExpReturns] -> OffsetM (PatT LParamMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExpT GPUMem -> OffsetM [ExpReturns]
forall (m :: * -> *) rep inner.
(Monad m, LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns ExpT GPUMem
e'
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <- ReaderT (Scope GPUMem) OffsetM [ExpReturns]
-> Scope GPUMem -> OffsetM [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ExpT GPUMem -> ReaderT (Scope GPUMem) OffsetM [ExpReturns]
forall (m :: * -> *) rep inner.
(Monad m, LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns ExpT GPUMem
e') Scope GPUMem
scope
  let pat'' :: PatT LParamMem
pat'' = [PatElemT LParamMem] -> PatT LParamMem
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT LParamMem] -> PatT LParamMem)
-> [PatElemT LParamMem] -> PatT LParamMem
forall a b. (a -> b) -> a -> b
$ (PatElemT LParamMem -> ExpReturns -> PatElemT LParamMem)
-> [PatElemT LParamMem] -> [ExpReturns] -> [PatElemT LParamMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT LParamMem -> ExpReturns -> PatElemT LParamMem
pick (PatT LParamMem -> [PatElemT LParamMem]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT LParamMem
pat') [ExpReturns]
rts
      stm :: Stm GPUMem
stm = Pat GPUMem -> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat GPUMem
PatT LParamMem
pat'' StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e'
  let scope' :: Scope GPUMem
scope' = Stm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stm GPUMem
stm Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  (Scope GPUMem, Stm GPUMem) -> OffsetM (Scope GPUMem, Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scope GPUMem
scope', Stm GPUMem
stm)
  where
    pick ::
      PatElemT (MemInfo SubExp NoUniqueness MemBind) ->
      ExpReturns ->
      PatElemT (MemInfo SubExp NoUniqueness MemBind)
    pick :: PatElemT LParamMem -> ExpReturns -> PatElemT LParamMem
pick
      (PatElem VName
name (MemArray PrimType
pt Shape
s NoUniqueness
u MemBind
_ret))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
        | Just IxFun (TPrimExp Int64 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun ExtIxFun
extixfun =
          VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
s NoUniqueness
u (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun))
    pick PatElemT LParamMem
p ExpReturns
_ = PatElemT LParamMem
p

    instantiateIxFun :: ExtIxFun -> Maybe IxFun
    instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall a. Ext a -> Maybe a
inst)
      where
        inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
        inst (Free a
x) = a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

offsetMemoryInPat :: Pat GPUMem -> [ExpReturns] -> OffsetM (Pat GPUMem)
offsetMemoryInPat :: Pat GPUMem -> [ExpReturns] -> OffsetM (Pat GPUMem)
offsetMemoryInPat (Pat [PatElemT (LetDec GPUMem)]
pes) [ExpReturns]
rets = do
  [PatElemT LParamMem] -> PatT LParamMem
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT LParamMem] -> PatT LParamMem)
-> OffsetM [PatElemT LParamMem] -> OffsetM (PatT LParamMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT LParamMem -> ExpReturns -> OffsetM (PatElemT LParamMem))
-> [PatElemT LParamMem]
-> [ExpReturns]
-> OffsetM [PatElemT LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElemT LParamMem -> ExpReturns -> OffsetM (PatElemT LParamMem)
onPE [PatElemT (LetDec GPUMem)]
[PatElemT LParamMem]
pes [ExpReturns]
rets
  where
    onPE :: PatElemT LParamMem -> ExpReturns -> OffsetM (PatElemT LParamMem)
onPE
      (PatElem VName
name (MemArray PrimType
pt Shape
shape NoUniqueness
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_)))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsNewBlock Space
_ Int
_ ExtIxFun
ixfun))) =
        PatElemT LParamMem -> OffsetM (PatElemT LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElemT LParamMem -> OffsetM (PatElemT LParamMem))
-> (IxFun (TPrimExp Int64 VName) -> PatElemT LParamMem)
-> IxFun (TPrimExp Int64 VName)
-> OffsetM (PatElemT LParamMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (LParamMem -> PatElemT LParamMem)
-> (IxFun (TPrimExp Int64 VName) -> LParamMem)
-> IxFun (TPrimExp Int64 VName)
-> PatElemT LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem)
-> (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName)
-> LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> OffsetM (PatElemT LParamMem))
-> IxFun (TPrimExp Int64 VName) -> OffsetM (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$
          (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName)
-> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Ext VName -> VName)
-> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
unExt) ExtIxFun
ixfun
    onPE PatElemT LParamMem
pe ExpReturns
_ = do
      LParamMem
new_dec <- LParamMem -> OffsetM LParamMem
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (LParamMem -> OffsetM LParamMem) -> LParamMem -> OffsetM LParamMem
forall a b. (a -> b) -> a -> b
$ PatElemT LParamMem -> LParamMem
forall dec. PatElemT dec -> dec
patElemDec PatElemT LParamMem
pe
      PatElemT LParamMem -> OffsetM (PatElemT LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElemT LParamMem
pe {patElemDec :: LParamMem
patElemDec = LParamMem
new_dec}
    unExt :: Ext VName -> VName
unExt (Ext Int
i) = PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName ([PatElemT (LetDec GPUMem)]
[PatElemT LParamMem]
pes [PatElemT LParamMem] -> Int -> PatElemT LParamMem
forall a. [a] -> Int -> a
!! Int
i)
    unExt (Free VName
v) = VName
v

offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
  MemBound u
fparam' <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ Param (MemBound u) -> MemBound u
forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
  Param (MemBound u) -> OffsetM (Param (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}

offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt Shape
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) = do
  Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun, PrimType
pt)
  MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> OffsetM (MemBound u))
-> (Maybe (MemBound u) -> MemBound u)
-> Maybe (MemBound u)
-> OffsetM (MemBound u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBound u -> Maybe (MemBound u) -> MemBound u
forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary (Maybe (MemBound u) -> OffsetM (MemBound u))
-> Maybe (MemBound u) -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ do
    IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
    MemBound u -> Maybe (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> Maybe (MemBound u))
-> MemBound u -> Maybe (MemBound u)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int64 VName)
new_base' IxFun (TPrimExp Int64 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return MemBound u
summary

offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
  | Just IxFun (TPrimExp Int64 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
isStaticIxFun ExtIxFun
ixfun = do
    Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun', PrimType
pt)
    BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> OffsetM BranchTypeMem)
-> (Maybe BranchTypeMem -> BranchTypeMem)
-> Maybe BranchTypeMem
-> OffsetM BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchTypeMem -> Maybe BranchTypeMem -> BranchTypeMem
forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br (Maybe BranchTypeMem -> OffsetM BranchTypeMem)
-> Maybe BranchTypeMem -> OffsetM BranchTypeMem
forall a b. (a -> b) -> a -> b
$ do
      IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
      BranchTypeMem -> Maybe BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> Maybe BranchTypeMem)
-> (ExtIxFun -> BranchTypeMem) -> ExtIxFun -> Maybe BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem)
-> (ExtIxFun -> MemReturn) -> ExtIxFun -> BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> Maybe BranchTypeMem)
-> ExtIxFun -> Maybe BranchTypeMem
forall a b. (a -> b) -> a -> b
$
        ExtIxFun -> ExtIxFun -> ExtIxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun (TPrimExp Int64 VName) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun (TPrimExp Int64 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return BranchTypeMem
br

offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
lam = Lambda GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ do
  Body GPUMem
body <- Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam
  Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

-- A loop may have memory parameters, and those memory blocks may
-- be expanded.  We assume (but do not check - FIXME) that if the
-- initial value of a loop parameter is an expanded memory block,
-- then so will the result be.
offsetMemoryInLoopParams ::
  [(FParam GPUMem, SubExp)] ->
  ([(FParam GPUMem, SubExp)] -> OffsetM a) ->
  OffsetM a
offsetMemoryInLoopParams :: [(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge [(FParam GPUMem, SubExp)] -> OffsetM a
f = do
  let ([Param (MemBound Uniqueness)]
params, [SubExp]
args) = [(Param (MemBound Uniqueness), SubExp)]
-> ([Param (MemBound Uniqueness)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemBound Uniqueness), SubExp)]
merge
  (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
forall a. (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
extend (OffsetM a -> OffsetM a) -> OffsetM a -> OffsetM a
forall a b. (a -> b) -> a -> b
$ do
    [Param (MemBound Uniqueness)]
params' <- (Param (MemBound Uniqueness)
 -> OffsetM (Param (MemBound Uniqueness)))
-> [Param (MemBound Uniqueness)]
-> OffsetM [Param (MemBound Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemBound Uniqueness)
-> OffsetM (Param (MemBound Uniqueness))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemBound Uniqueness)]
params
    [(FParam GPUMem, SubExp)] -> OffsetM a
f ([(FParam GPUMem, SubExp)] -> OffsetM a)
-> [(FParam GPUMem, SubExp)] -> OffsetM a
forall a b. (a -> b) -> a -> b
$ [Param (MemBound Uniqueness)]
-> [SubExp] -> [(Param (MemBound Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemBound Uniqueness)]
params' [SubExp]
args
  where
    extend :: RebaseMap -> RebaseMap
extend RebaseMap
rm = (RebaseMap -> (Param (MemBound Uniqueness), SubExp) -> RebaseMap)
-> RebaseMap
-> [(Param (MemBound Uniqueness), SubExp)]
-> RebaseMap
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' RebaseMap -> (Param (MemBound Uniqueness), SubExp) -> RebaseMap
forall a dec. Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg RebaseMap
rm [(FParam GPUMem, SubExp)]
[(Param (MemBound Uniqueness), SubExp)]
merge
    onParamArg :: Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg Map VName a
rm (Param dec
param, Var VName
arg)
      | Just a
x <- VName -> Map VName a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arg Map VName a
rm =
        VName -> a -> Map VName a -> Map VName a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) a
x Map VName a
rm
    onParamArg Map VName a
rm (Param dec, SubExp)
_ = Map VName a
rm

offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: ExpT GPUMem -> OffsetM (ExpT GPUMem)
offsetMemoryInExp (DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form Body GPUMem
body) = do
  [(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM (ExpT GPUMem))
-> OffsetM (ExpT GPUMem)
forall a.
[(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge (([(FParam GPUMem, SubExp)] -> OffsetM (ExpT GPUMem))
 -> OffsetM (ExpT GPUMem))
-> ([(FParam GPUMem, SubExp)] -> OffsetM (ExpT GPUMem))
-> OffsetM (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ \[(FParam GPUMem, SubExp)]
merge' -> do
    Body GPUMem
body' <-
      Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
        ([Param (MemBound Uniqueness)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (MemBound Uniqueness), SubExp)
 -> Param (MemBound Uniqueness))
-> [(Param (MemBound Uniqueness), SubExp)]
-> [Param (MemBound Uniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemBound Uniqueness), SubExp)
-> Param (MemBound Uniqueness)
forall a b. (a, b) -> a
fst [(FParam GPUMem, SubExp)]
[(Param (MemBound Uniqueness), SubExp)]
merge') Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> LoopForm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPUMem
form)
        (Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody Body GPUMem
body)
    ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT GPUMem -> OffsetM (ExpT GPUMem))
-> ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem -> Body GPUMem -> ExpT GPUMem
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam GPUMem, SubExp)]
merge' LoopForm GPUMem
form Body GPUMem
body'
offsetMemoryInExp ExpT GPUMem
e = Mapper GPUMem GPUMem OffsetM
-> ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse ExpT GPUMem
e
  where
    recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
      Mapper GPUMem GPUMem OffsetM
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem -> Body GPUMem -> OffsetM (Body GPUMem)
mapOnBody = \Scope GPUMem
bscope -> Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem
-> OffsetM (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody,
          mapOnBranchType :: BranchType GPUMem -> OffsetM (BranchType GPUMem)
mapOnBranchType = BranchType GPUMem -> OffsetM (BranchType GPUMem)
BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
          mapOnOp :: Op GPUMem -> OffsetM (Op GPUMem)
mapOnOp = Op GPUMem -> OffsetM (Op GPUMem)
forall op.
MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp
        }
    onOp :: MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp GPUMem op -> MemOp (HostOp GPUMem op)
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem op -> MemOp (HostOp GPUMem op))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem op)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp
        (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem op))
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (MemOp (HostOp GPUMem op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (SegOpMapper SegLevel GPUMem GPUMem OffsetM
-> SegOp SegLevel GPUMem -> OffsetM (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem OffsetM
forall lvl. SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
      where
        segOpMapper :: SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper =
          SegOpMapper lvl Any Any OffsetM
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody,
              mapOnSegOpLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda
            }
    onOp MemOp (HostOp GPUMem op)
op = MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp GPUMem op)
op

---- Slicing allocation sizes out of a kernel.

unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
False
  where
    unAllocBody :: Body GPUMem -> Either String (BodyT GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
      BodyDec GPU -> Stms GPU -> Result -> BodyT GPU
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> Result -> BodyT GPU)
-> Either String (Stms GPU) -> Either String (Result -> BodyT GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String (Result -> BodyT GPU)
-> Either String Result -> Either String (BodyT GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either String Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
      BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> [KernelResult] -> KernelBody GPU)
-> Either String (Stms GPU)
-> Either String ([KernelResult] -> KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String ([KernelResult] -> KernelBody GPU)
-> Either String [KernelResult] -> Either String (KernelBody GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either String [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
nested =
      ([Maybe (Stm GPU)] -> Stms GPU)
-> Either String [Maybe (Stm GPU)] -> Either String (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU)
-> ([Maybe (Stm GPU)] -> [Stm GPU])
-> [Maybe (Stm GPU)]
-> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Stm GPU)] -> [Stm GPU]
forall a. [Maybe a] -> [a]
catMaybes) (Either String [Maybe (Stm GPU)] -> Either String (Stms GPU))
-> (Stms GPUMem -> Either String [Maybe (Stm GPU)])
-> Stms GPUMem
-> Either String (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPUMem -> Either String (Maybe (Stm GPU)))
-> [Stm GPUMem] -> Either String [Maybe (Stm GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested) ([Stm GPUMem] -> Either String [Maybe (Stm GPU)])
-> (Stms GPUMem -> [Stm GPUMem])
-> Stms GPUMem
-> Either String [Maybe (Stm GPU)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList

    unAllocStm :: Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pat GPUMem
_ StmAux (ExpDec GPUMem)
_ (Op Alloc {}))
      | Bool
nested = String -> Either String (Maybe (Stm GPU))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String (Maybe (Stm GPU)))
-> String -> Either String (Maybe (Stm GPU))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm GPUMem -> String
forall a. Pretty a => a -> String
pretty Stm GPUMem
stm
      | Bool
otherwise = Maybe (Stm GPU) -> Either String (Maybe (Stm GPU))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm GPU)
forall a. Maybe a
Nothing
    unAllocStm Bool
_ (Let Pat GPUMem
pat StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e) =
      Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just (Stm GPU -> Maybe (Stm GPU))
-> Either String (Stm GPU) -> Either String (Maybe (Stm GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatT Type -> StmAux () -> ExpT GPU -> Stm GPU
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (PatT Type -> StmAux () -> ExpT GPU -> Stm GPU)
-> Either String (PatT Type)
-> Either String (StmAux () -> ExpT GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatT LParamMem -> Either String (PatT Type)
forall d u ret.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
 Pretty (ShapeBase d)) =>
PatT (MemInfo d u ret)
-> Either String (PatT (TypeBase (ShapeBase d) u))
unAllocPat Pat GPUMem
PatT LParamMem
pat Either String (StmAux () -> ExpT GPU -> Stm GPU)
-> Either String (StmAux ()) -> Either String (ExpT GPU -> Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec Either String (ExpT GPU -> Stm GPU)
-> Either String (ExpT GPU) -> Either String (Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper GPUMem GPU (Either String)
-> ExpT GPUMem -> Either String (ExpT GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either String)
unAlloc' ExpT GPUMem
e)

    unAllocLambda :: Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [Type]
ret) =
      [LParam GPU] -> BodyT GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda ([Param LParamMem] -> [Param Type]
forall d u ret.
[Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams [LParam GPUMem]
[Param LParamMem]
params) (BodyT GPU -> [Type] -> Lambda GPU)
-> Either String (BodyT GPU)
-> Either String ([Type] -> Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either String (BodyT GPU)
unAllocBody Body GPUMem
body Either String ([Type] -> Lambda GPU)
-> Either String [Type] -> Either String (Lambda GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> Either String [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

    unParams :: [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams = (Param (MemInfo d u ret)
 -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Param (MemInfo d u ret)
  -> Maybe (Param (TypeBase (ShapeBase d) u)))
 -> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)])
-> (Param (MemInfo d u ret)
    -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)]
-> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem

    unAllocPat :: PatT (MemInfo d u ret)
-> Either String (PatT (TypeBase (ShapeBase d) u))
unAllocPat pat :: PatT (MemInfo d u ret)
pat@(Pat [PatElemT (MemInfo d u ret)]
merge) =
      [PatElemT (TypeBase (ShapeBase d) u)]
-> PatT (TypeBase (ShapeBase d) u)
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT (TypeBase (ShapeBase d) u)]
 -> PatT (TypeBase (ShapeBase d) u))
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String (PatT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
merge)
      where
        bad :: Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad = String -> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. a -> Either a b
Left (String -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> String -> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory in pattern " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatT (MemInfo d u ret) -> String
forall a. Pretty a => a -> String
pretty PatT (MemInfo d u ret)
pat

    unAllocOp :: MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp Alloc {} = String -> Either String (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp {}) = String -> Either String (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner (SizeOp SizeOp
op)) =
      HostOp GPU (SOAC GPU) -> Either String (HostOp GPU (SOAC GPU))
forall (m :: * -> *) a. Monad m => a -> m a
return (HostOp GPU (SOAC GPU) -> Either String (HostOp GPU (SOAC GPU)))
-> HostOp GPU (SOAC GPU) -> Either String (HostOp GPU (SOAC GPU))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> Either String (SegOp SegLevel GPU)
-> Either String (HostOp GPU (SOAC GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPU (Either String)
-> SegOp SegLevel GPUMem -> Either String (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either String)
mapper SegOp SegLevel GPUMem
op
      where
        mapper :: SegOpMapper SegLevel GPUMem GPU (Either String)
mapper =
          SegOpMapper SegLevel Any Any (Either String)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpLambda :: Lambda GPUMem -> Either String (Lambda GPU)
mapOnSegOpLambda = Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda,
              mapOnSegOpBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
mapOnSegOpBody = KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody
            }

    unParam :: t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam t (MemInfo d u ret)
p = Either String (t (TypeBase (ShapeBase d) u))
-> (t (TypeBase (ShapeBase d) u)
    -> Either String (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either String (t (TypeBase (ShapeBase d) u))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String (t (TypeBase (ShapeBase d) u))
bad t (TypeBase (ShapeBase d) u)
-> Either String (t (TypeBase (ShapeBase d) u))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (t (TypeBase (ShapeBase d) u))
 -> Either String (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either String (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> t (MemInfo d u ret) -> Maybe (t (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem t (MemInfo d u ret)
p
      where
        bad :: Either String (t (TypeBase (ShapeBase d) u))
bad = String -> Either String (t (TypeBase (ShapeBase d) u))
forall a b. a -> Either a b
Left (String -> Either String (t (TypeBase (ShapeBase d) u)))
-> String -> Either String (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory-typed parameter '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ t (MemInfo d u ret) -> String
forall a. Pretty a => a -> String
pretty t (MemInfo d u ret)
p String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

    unT :: MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT MemInfo d u ret
t = Either String (TypeBase (ShapeBase d) u)
-> (TypeBase (ShapeBase d) u
    -> Either String (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either String (TypeBase (ShapeBase d) u)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String (TypeBase (ShapeBase d) u)
bad TypeBase (ShapeBase d) u
-> Either String (TypeBase (ShapeBase d) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeBase (ShapeBase d) u)
 -> Either String (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either String (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem MemInfo d u ret
t
      where
        bad :: Either String (TypeBase (ShapeBase d) u)
bad = String -> Either String (TypeBase (ShapeBase d) u)
forall a b. a -> Either a b
Left (String -> Either String (TypeBase (ShapeBase d) u))
-> String -> Either String (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory type '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ MemInfo d u ret -> String
forall a. Pretty a => a -> String
pretty MemInfo d u ret
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

    unAlloc' :: Mapper GPUMem GPU (Either String)
unAlloc' =
      Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
        { mapOnBody :: Scope GPU -> Body GPUMem -> Either String (BodyT GPU)
mapOnBody = (Body GPUMem -> Either String (BodyT GPU))
-> Scope GPU -> Body GPUMem -> Either String (BodyT GPU)
forall a b. a -> b -> a
const Body GPUMem -> Either String (BodyT GPU)
unAllocBody,
          mapOnRetType :: RetType GPUMem -> Either String (RetType GPU)
mapOnRetType = RetType GPUMem -> Either String (RetType GPU)
forall d u ret.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
 Pretty (ShapeBase d)) =>
MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT,
          mapOnBranchType :: BranchType GPUMem -> Either String (BranchType GPU)
mapOnBranchType = BranchType GPUMem -> Either String (BranchType GPU)
forall d u ret.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
 Pretty (ShapeBase d)) =>
MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT,
          mapOnFParam :: FParam GPUMem -> Either String (FParam GPU)
mapOnFParam = FParam GPUMem -> Either String (FParam GPU)
forall (t :: * -> *) d u ret.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnLParam :: LParam GPUMem -> Either String (LParam GPU)
mapOnLParam = LParam GPUMem -> Either String (LParam GPU)
forall (t :: * -> *) d u ret.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnOp :: Op GPUMem -> Either String (Op GPU)
mapOnOp = Op GPUMem -> Either String (Op GPU)
MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp,
          mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = SubExp -> Either String SubExp
forall a b. b -> Either a b
Right,
          mapOnVName :: VName -> Either String VName
mapOnVName = VName -> Either String VName
forall a b. b -> Either a b
Right
        }

unMem :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem (MemPrim PrimType
pt) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc Shape
ispace [Type]
ts u
u) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> u -> TypeBase (ShapeBase d) u
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
unMem MemMem {} = Maybe (TypeBase (ShapeBase d) u)
forall a. Maybe a
Nothing

unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = (NameInfo GPUMem -> Maybe (NameInfo GPU))
-> Scope GPUMem -> Scope GPU
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe NameInfo GPUMem -> Maybe (NameInfo GPU)
forall rep d u ret rep d u d u ret ret.
(FParamInfo rep ~ MemInfo d u ret,
 LParamInfo rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ TypeBase (ShapeBase d) u,
 FParamInfo rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ MemInfo d u ret, LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> Maybe (NameInfo rep)
unInfo
  where
    unInfo :: NameInfo rep -> Maybe (NameInfo rep)
unInfo (LetName LetDec rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. LetDec rep -> NameInfo rep
LetName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LetDec rep
MemInfo d u ret
dec
    unInfo (FParamName FParamInfo rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. FParamInfo rep -> NameInfo rep
FParamName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem FParamInfo rep
MemInfo d u ret
dec
    unInfo (LParamName LParamInfo rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. LParamInfo rep -> NameInfo rep
LParamName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LParamInfo rep
MemInfo d u ret
dec
    unInfo (IndexName IntType
it) = NameInfo rep -> Maybe (NameInfo rep)
forall a. a -> Maybe a
Just (NameInfo rep -> Maybe (NameInfo rep))
-> NameInfo rep -> Maybe (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
it

removeCommonSizes ::
  Extraction ->
  [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
 -> (VName, (User, SubExp, Space)) -> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, (User, SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, (User, SubExp, Space)) -> Map SubExp [(VName, Space)]
forall k a b a.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, (User, SubExp, Space))] -> Map SubExp [(VName, Space)])
-> (Extraction -> [(VName, (User, SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction -> [(VName, (User, SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
  where
    comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

sliceKernelSizes ::
  SubExp ->
  [SubExp] ->
  SegSpace ->
  Stms GPUMem ->
  ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
sizes SegSpace
space Stms GPUMem
kstms = do
  Stms GPU
kstms' <- (String
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> (Stms GPU
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (Stms GPU)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
  let num_sizes :: Int
num_sizes = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
sizes
      i64s :: [Type]
i64s = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
num_sizes (Type -> [Type]) -> Type -> [Type]
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope GPU
kernels_scope <- (Scope GPUMem -> Scope GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPU)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope GPUMem -> Scope GPU
unAllocScope

  (Lambda GPU
max_lam, Stms GPU
_) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
xs <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param Type)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param Type])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param Type]
ys <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param Type)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param Type])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"y" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type] -> Scope GPU) -> [Param Type] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ [Param Type]
xs [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
ys) (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Result, Stms GPU)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$
      BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result,
       Stms
         (Rep
            (BuilderT
               GPU
               (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall a b. (a -> b) -> a -> b
$
        [(Param Type, Param Type)]
-> ((Param Type, Param Type)
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
xs [Param Type]
ys) (((Param Type, Param Type)
  -> BuilderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
       SubExpRes)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Result)
-> ((Param Type, Param Type)
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall a b. (a -> b) -> a -> b
$ \(Param Type
x, Param Type
y) ->
          (SubExp -> SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> SubExpRes
subExpRes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   SubExp
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExpRes)
-> (BasicOp
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExp)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"z" (ExpT GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExp)
-> (BasicOp -> ExpT GPU)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExpRes)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
y)
    Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU))
-> Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ [LParam GPU] -> BodyT GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda ([Param Type]
xs [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
ys) (Stms GPU -> Result -> BodyT GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
zs) [Type]
i64s

  Param Type
flat_gtid_lparam <- VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param (VName -> Type -> Param Type)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) VName
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid" ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Type -> Param Type)
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) Type
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))

  (Lambda GPU
size_lam', Stms GPU
_) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
params <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param Type)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param Type])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
      ([Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
params Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type
flat_gtid_lparam])
      (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Result, Stms GPU)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$ BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result,
       Stms
         (Rep
            (BuilderT
               GPU
               (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall a b. (a -> b) -> a -> b
$ do
        -- Even though this SegRed is one-dimensional, we need to
        -- provide indexes corresponding to the original potentially
        -- multi-dimensional construct.
        let ([VName]
kspace_gtids, [SubExp]
kspace_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
            new_inds :: [TPrimExp Int64 VName]
new_inds =
              [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
kspace_dims)
                (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
flat_gtid_lparam)
        ([VName]
 -> ExpT GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> [[VName]]
-> [ExpT GPU]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> ExpT GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([ExpT GPU]
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [ExpT GPU]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (ExpT GPU))
-> [TPrimExp Int64 VName]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [ExpT GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp Int64 VName
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (ExpT GPU)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds

        (Stm GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stms GPU
kstms'
        Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Result)
-> Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
sizes

    Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$
      Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
GPU.simplifyLambda ([LParam GPU] -> BodyT GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [Param Type
LParam GPU
flat_gtid_lparam] (BodyDec GPU -> Stms GPU -> Result -> BodyT GPU
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () Stms GPU
stms Result
zs) [Type]
i64s)

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms GPU
slice_stms) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  ([VName], [VName])
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    PatT Type
pat <-
      [Ident] -> PatT Type
basicPat ([Ident] -> PatT Type)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Ident]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (PatT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (String
-> Type
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"max_per_thread" (Type
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Ident)
-> Type
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <-
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"size_slice_w"
        (ExpT GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExp)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (ExpT GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

    VName
thread_space_iota <-
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"thread_space_iota" (Exp
   (Rep
      (BuilderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
    let red_op :: SegBinOp GPU
red_op =
          Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp
            Commutativity
Commutative
            Lambda GPU
max_lam
            (Int -> SubExp -> [SubExp]
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
            Shape
forall a. Monoid a => a
mempty
    SegLevel
lvl <- String
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SegLevel
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> m SegLevel
segThread String
"segred"

    Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Stm GPU))
-> Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stm GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
      (Stms GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat GPU
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat rep
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel GPU
SegLevel
lvl PatT Type
Pat GPU
pat SubExp
w [SegBinOp GPU
red_op] Lambda GPU
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- [VName]
-> (VName
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         VName)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
pat) ((VName
  -> BuilderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
       VName)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [VName])
-> (VName
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         VName)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"size_sum" (Exp
   (Rep
      (BuilderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads

    ([VName], [VName])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
pat, [VName]
size_sums)

  (Stms GPU, [VName], [VName])
-> ExpandM (Stms GPU, [VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPU
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)