{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations (expandAllocations) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (find, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Error
import Futhark.IR
import qualified Futhark.IR.GPU.Simplify as GPU
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

-- | The memory expansion pass definition.
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
  [Char]
-> [Char]
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"expand allocations" [Char]
"Expand allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$
    \(Prog Stms GPUMem
consts [FunDef GPUMem]
funs) -> do
      Stms GPUMem
consts' <-
        (VNameSource -> (Stms GPUMem, VNameSource)) -> PassM (Stms GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPUMem, VNameSource))
 -> PassM (Stms GPUMem))
-> (VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Either [Char] (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource)
forall a. Either [Char] a -> a
limitationOnLeft (Either [Char] (Stms GPUMem, VNameSource)
 -> (Stms GPUMem, VNameSource))
-> (VNameSource -> Either [Char] (Stms GPUMem, VNameSource))
-> VNameSource
-> (Stms GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either [Char]) (Stms GPUMem)
-> VNameSource -> Either [Char] (Stms GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either [Char]) (Stms GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
consts) Scope GPUMem
forall a. Monoid a => a
mempty)
      Stms GPUMem -> [FunDef GPUMem] -> Prog GPUMem
forall rep. Stms rep -> [FunDef rep] -> Prog rep
Prog Stms GPUMem
consts' ([FunDef GPUMem] -> Prog GPUMem)
-> PassM [FunDef GPUMem] -> PassM (Prog GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunDef GPUMem -> PassM (FunDef GPUMem))
-> [FunDef GPUMem] -> PassM [FunDef GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem))
-> Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
consts') [FunDef GPUMem]
funs

-- Cannot use intraproceduralTransformation because it might create
-- duplicate size keys (which are not fixed by renamer, and size
-- keys must currently be globally unique).

type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))

limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either [Char] a -> a
limitationOnLeft = ([Char] -> a) -> (a -> a) -> Either [Char] a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> a
forall a. [Char] -> a
compilerLimitationS a -> a
forall a. a -> a
id

transformFunDef ::
  Scope GPUMem ->
  FunDef GPUMem ->
  PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
  Body GPUMem
body' <- (VNameSource -> (Body GPUMem, VNameSource)) -> PassM (Body GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body GPUMem, VNameSource))
 -> PassM (Body GPUMem))
-> (VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Either [Char] (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource)
forall a. Either [Char] a -> a
limitationOnLeft (Either [Char] (Body GPUMem, VNameSource)
 -> (Body GPUMem, VNameSource))
-> (VNameSource -> Either [Char] (Body GPUMem, VNameSource))
-> VNameSource
-> (Body GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either [Char]) (Body GPUMem)
-> VNameSource -> Either [Char] (Body GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either [Char]) (Body GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
m Scope GPUMem
forall a. Monoid a => a
mempty)
  SimpleOps GPUMem
-> SymbolTable (Wise GPUMem)
-> FunDef GPUMem
-> PassM (FunDef GPUMem)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep)
copyPropagateInFun
    SimpleOps GPUMem
simpleGPUMem
    (Scope (Wise GPUMem) -> SymbolTable (Wise GPUMem)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope GPUMem -> Scope (Wise GPUMem)
forall rep. Scope rep -> Scope (Wise rep)
addScopeWisdom Scope GPUMem
scope))
    FunDef GPUMem
fundec {funDefBody :: Body GPUMem
funDefBody = Body GPUMem
body'}
  where
    m :: ReaderT
  (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
m =
      Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
        FunDef GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
          Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody (Body GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ FunDef GPUMem -> Body GPUMem
forall rep. FunDef rep -> BodyT rep
funDefBody FunDef GPUMem
fundec

transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () (Stms GPUMem -> Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Result -> Body GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
stms ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either [Char]))
  (Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
  [LParam GPUMem]
-> Body GPUMem
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPUMem
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda [LParam GPUMem]
params
    (Body GPUMem
 -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody Body GPUMem
body)
    ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either [Char]))
  ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ExpandM (Lambda GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret

transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
stms =
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ([Stms GPUMem] -> Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) [Stms GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> [Stm GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) [Stms GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStm (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)

transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
-- It is possible that we are unable to expand allocations in some
-- code versions.  If so, we can remove the offending branch.  Only if
-- both versions fail do we propagate the error.
transformStm :: Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStm (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux (If SubExp
cond Body GPUMem
tbranch Body GPUMem
fbranch (IfDec [BranchType GPUMem]
ts IfSort
IfEquiv))) = do
  Either [Char] (Body GPUMem)
tbranch' <- (Body GPUMem -> Either [Char] (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either [Char] (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody Body GPUMem
tbranch) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either [Char]))
  (Either [Char] (Body GPUMem))
-> ([Char]
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either [Char]))
         (Either [Char] (Body GPUMem)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either [Char] (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Either [Char] (Body GPUMem)))
-> ([Char] -> Either [Char] (Body GPUMem))
-> [Char]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Either [Char] (Body GPUMem)
forall a b. a -> Either a b
Left)
  Either [Char] (Body GPUMem)
fbranch' <- (Body GPUMem -> Either [Char] (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either [Char] (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody Body GPUMem
fbranch) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either [Char]))
  (Either [Char] (Body GPUMem))
-> ([Char]
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either [Char]))
         (Either [Char] (Body GPUMem)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either [Char] (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Either [Char] (Body GPUMem)))
-> ([Char] -> Either [Char] (Body GPUMem))
-> [Char]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Either [Char] (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Either [Char] (Body GPUMem)
forall a b. a -> Either a b
Left)
  case (Either [Char] (Body GPUMem)
tbranch', Either [Char] (Body GPUMem)
fbranch') of
    (Left [Char]
_, Right Body GPUMem
fbranch'') ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
fbranch''
    (Right Body GPUMem
tbranch'', Left [Char]
_) ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
tbranch''
    (Right Body GPUMem
tbranch'', Right Body GPUMem
fbranch'') ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux (ExpT GPUMem -> Stm GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body GPUMem
-> Body GPUMem
-> IfDec (BranchType GPUMem)
-> ExpT GPUMem
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond Body GPUMem
tbranch'' Body GPUMem
fbranch'' ([BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType GPUMem]
[BranchTypeMem]
ts IfSort
IfEquiv)
    (Left [Char]
e, Either [Char] (Body GPUMem)
_) ->
      [Char]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError [Char]
e
  where
    bindRes :: PatElemT (LetDec rep) -> SubExp -> Stm rep
bindRes PatElemT (LetDec rep)
pe SubExp
se = Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec rep)
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

    useBranch :: Body GPUMem -> Stms GPUMem
useBranch Body GPUMem
b =
      Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms Body GPUMem
b
        Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ((PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> SubExp -> Stm GPUMem)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> Result
-> [Stm GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> SubExp -> Stm GPUMem
forall {rep}.
(ExpDec rep ~ ()) =>
PatElemT (LetDec rep) -> SubExp -> Stm rep
bindRes (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat) (Body GPUMem -> Result
forall rep. BodyT rep -> Result
bodyResult Body GPUMem
b))
transformStm (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux ExpT GPUMem
e) = do
  (Stms GPUMem
bnds, ExpT GPUMem
e') <- ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem)
transformExp (ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (ExpT GPUMem)
-> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
-> ExpT GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
transform ExpT GPUMem
e
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem
bnds Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
aux ExpT GPUMem
e')
  where
    transform :: Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
transform =
      Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
mapOnBody = \Scope GPUMem
scope -> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> (Body GPUMem
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Body GPUMem)
transformBody
        }

transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: ExpT GPUMem -> ExpandM (Stms GPUMem, ExpT GPUMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
_, KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
  let reds' :: [SegBinOp GPUMem]
reds' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
  let scans' :: [SegBinOp GPUMem]
scans' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams', KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
  let ops' :: [HistOp GPUMem]
ops' = (HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem] -> [HistOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem
forall {rep} {rep}. HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
  where
    lams :: [Lambda GPUMem]
lams = (HistOp GPUMem -> Lambda GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
    onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp :: Lambda rep
histOp = Lambda rep
lam}
transformExp (WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs Lambda GPUMem
lam) = do
  Lambda GPUMem
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
  ([Stms GPUMem]
input_alloc_stms, [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs') <- [(Stms GPUMem,
  (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
-> ([Stms GPUMem],
    [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms GPUMem,
   (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
 -> ([Stms GPUMem],
     [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     [(Stms GPUMem,
       (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     ([Stms GPUMem],
      [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Stms GPUMem,
       (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))))
-> [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     [(Stms GPUMem,
       (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem,
      (ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result)))
forall {b} {b}.
(ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat [Stms GPUMem]
input_alloc_stms,
      [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
-> Lambda GPUMem -> ExpT GPUMem
forall rep.
[(ShapeBase SubExp, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
WithAcc [(ShapeBase SubExp, [VName], Maybe (Lambda GPUMem, Result))]
inputs' Lambda GPUMem
lam'
    )
  where
    onInput :: (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
forall a. Maybe a
Nothing))
    onInput (ShapeBase SubExp
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
      Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
      let -- XXX: fake a SegLevel, which we don't have here.  We will not
          -- use it for anything, as we will not allow irregular
          -- allocations inside the update function.
          lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> SubExp -> Count NumGroups SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SegVirt
SegNoVirt
          (Lambda GPUMem
op_lam', Extraction
lam_allocs) =
            (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel
lvl, [TPrimExp Int64 VName
0]) Names
bound_outside Names
forall a. Monoid a => a
mempty Lambda GPUMem
op_lam
          variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
          variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
          (Extraction
variant_allocs, Extraction
invariant_allocs) = (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc Extraction
lam_allocs

      case Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
        ((SegLevel, [TPrimExp Int64 VName])
_, SubExp
v, Space
_) : [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
_ ->
          [Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char]
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ())
-> [Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall a b. (a -> b) -> a -> b
$
            [Char]
"Cannot handle un-sliceable allocation size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SubExp -> [Char]
forall a. Pretty a => a -> [Char]
pretty SubExp
v
              [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nLikely cause: irregular nested operations inside accumulator update operator."
        [] ->
          ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      let num_is :: Int
num_is = ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
          is :: [VName]
is = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take Int
num_is ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
      (Stms GPUMem
alloc_stms, RebaseMap
alloc_offsets) <-
        ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations ((ShapeBase SubExp, [TPrimExp Int64 VName])
-> (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. a -> b -> a
const (ShapeBase SubExp
shape, (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is)) Extraction
invariant_allocs

      Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either [Char])) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
      let scope' :: Scope GPUMem
scope' = Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
      ([Char]
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> ((Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either [Char]))
         (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> Either
     [Char]
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   [Char]
   (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> Either
     [Char]
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$
        Scope GPUMem
-> RebaseMap
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> Either
     [Char]
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM
   (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
 -> Either
      [Char]
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> Either
     [Char]
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$ do
          Lambda GPUMem
op_lam'' <- Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op_lam'
          (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, (ShapeBase SubExp
shape, b
arrs, (Lambda GPUMem, b) -> Maybe (Lambda GPUMem, b)
forall a. a -> Maybe a
Just (Lambda GPUMem
op_lam'', b
nes)))
transformExp ExpT GPUMem
e =
  (Stms GPUMem, ExpT GPUMem) -> ExpandM (Stms GPUMem, ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
forall a. Monoid a => a
mempty, ExpT GPUMem
e)

transformScanRed ::
  SegLevel ->
  SegSpace ->
  [Lambda GPUMem] ->
  KernelBody GPUMem ->
  ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
  Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
  let user :: (SegLevel, [TPrimExp Int64 VName])
user = (SegLevel
lvl, [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
      (KernelBody GPUMem
kbody', Extraction
kbody_allocs) =
        (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_in_kernel KernelBody GPUMem
kbody
      ([Lambda GPUMem]
ops', [Extraction]
ops_allocs) = [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction]))
-> [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda GPUMem -> (Lambda GPUMem, Extraction))
-> [Lambda GPUMem] -> [(Lambda GPUMem, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map ((SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
forall a. Monoid a => a
mempty) [Lambda GPUMem]
ops
      variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
      variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
      (Extraction
variant_allocs, Extraction
invariant_allocs) =
        (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc (Extraction -> (Extraction, Extraction))
-> Extraction -> (Extraction, Extraction)
forall a b. (a -> b) -> a -> b
$ Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      badVariant :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_kernel
      badVariant ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False

  case (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ([((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
 -> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall a b. (a -> b) -> a -> b
$ Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
    Just ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v ->
      [Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char]
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ())
-> [Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall a b. (a -> b) -> a -> b
$
        [Char]
"Cannot handle un-sliceable allocation size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> [Char]
forall a. Pretty a => a -> [Char]
pretty ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v
          [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nLikely cause: irregular nested operations inside parallel constructs."
    Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
Nothing ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  case SegLevel
lvl of
    SegGroup {}
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
        [Char]
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError [Char]
"Cannot handle invariant allocations in SegGroup."
    SegLevel
_ ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' ((Stms GPUMem
  -> KernelBody GPUMem
  -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
 -> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> (Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall a b. (a -> b) -> a -> b
$ \Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
    [Lambda GPUMem]
ops'' <- [Lambda GPUMem]
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' ((Lambda GPUMem -> OffsetM (Lambda GPUMem))
 -> OffsetM [Lambda GPUMem])
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
      Scope GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op'
    (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
-> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
ops'', KernelBody GPUMem
kbody''))
  where
    bound_in_kernel :: Names
bound_in_kernel =
      [VName] -> Names
namesFromList (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
        Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody

boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList ([VName] -> Names)
-> (KernelBody GPUMem -> [VName]) -> KernelBody GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope GPUMem -> [VName])
-> (KernelBody GPUMem -> Scope GPUMem)
-> KernelBody GPUMem
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Stms GPUMem -> Scope GPUMem)
-> (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem
-> Scope GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms

allocsForBody ::
  Extraction ->
  Extraction ->
  SegLevel ->
  SegSpace ->
  KernelBody GPUMem ->
  (Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
  ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms GPUMem
alloc_stms) <-
    SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
      SegLevel
lvl
      SegSpace
space
      (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody')
      Extraction
variant_allocs
      Extraction
invariant_allocs

  Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either [Char])) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let scope' :: Scope GPUMem
scope' = SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  ([Char] -> ExpandM b)
-> (b -> ExpandM b) -> Either [Char] b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char] -> ExpandM b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b -> ExpandM b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either [Char] b -> ExpandM b) -> Either [Char] b -> ExpandM b
forall a b. (a -> b) -> a -> b
$
    Scope GPUMem -> RebaseMap -> OffsetM b -> Either [Char] b
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM b -> Either [Char] b) -> OffsetM b -> Either [Char] b
forall a b. (a -> b) -> a -> b
$ do
      KernelBody GPUMem
kbody'' <- KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody'
      Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m Stms GPUMem
alloc_stms KernelBody GPUMem
kbody''

memoryRequirements ::
  SegLevel ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  Extraction ->
  ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements SegLevel
lvl SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  (SubExp
num_threads, Stms GPUMem
num_threads_stms) <-
    Binder GPUMem SubExp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (SubExp, Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPUMem SubExp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (SubExp, Stms GPUMem))
-> (BasicOp -> Binder GPUMem SubExp)
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (SubExp, Stms GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char]
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> Binder GPUMem SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"num_threads" (ExpT GPUMem -> Binder GPUMem SubExp)
-> (BasicOp -> ExpT GPUMem) -> BasicOp -> Binder GPUMem SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT GPUMem
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (SubExp, Stms GPUMem))
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (SubExp, Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
        (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

  (Stms GPUMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations
        SubExp
num_threads
        (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
        (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
        Extraction
invariant_allocs

  (Stms GPUMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations
        SubExp
num_threads
        SegSpace
space
        Stms GPUMem
kstms
        Extraction
variant_allocs

  (RebaseMap, Stms GPUMem) -> ExpandM (RebaseMap, Stms GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
      Stms GPUMem
num_threads_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
invariant_alloc_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
variant_alloc_stms
    )

-- | Identifying the spot where an allocation occurs in terms of its
-- level and unique thread ID.
type User = (SegLevel, [TPrimExp Int64 VName])

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (User, SubExp, Space)

extractKernelBodyAllocations ::
  User ->
  Names ->
  Names ->
  KernelBody GPUMem ->
  ( KernelBody GPUMem,
    Extraction
  )
extractKernelBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel =
  (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (KernelBody GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms ((Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
 -> KernelBody GPUMem -> (KernelBody GPUMem, Extraction))
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}

extractBodyAllocations ::
  User ->
  Names ->
  Names ->
  Body GPUMem ->
  (Body GPUMem, Extraction)
extractBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel =
  (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (Body GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Body GPUMem -> Stms GPUMem
forall rep. BodyT rep -> Stms rep
bodyStms ((Stms GPUMem -> Body GPUMem -> Body GPUMem)
 -> Body GPUMem -> (Body GPUMem, Extraction))
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Stms GPUMem
stms}

extractLambdaAllocations ::
  User ->
  Names ->
  Names ->
  Lambda GPUMem ->
  (Lambda GPUMem, Extraction)
extractLambdaAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam = (Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body'}, Extraction
allocs)
  where
    (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Body GPUMem -> (Body GPUMem, Extraction))
-> Body GPUMem -> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam

extractGenericBodyAllocations ::
  User ->
  Names ->
  Names ->
  (body -> Stms GPUMem) ->
  (Stms GPUMem -> body -> body) ->
  body ->
  ( body,
    Extraction
  )
extractGenericBodyAllocations :: forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
  let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Names
forall rep. Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
      ([Stm GPUMem]
stms, Extraction
allocs) =
        Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction))
-> Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall a b. (a -> b) -> a -> b
$
          ([Maybe (Stm GPUMem)] -> [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm GPUMem)] -> [Stm GPUMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm GPUMem)]
 -> Writer Extraction [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$
            (Stm GPUMem -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel') ([Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)])
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall a b. (a -> b) -> a -> b
$
              Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ body -> Stms GPUMem
get_stms body
body
   in (Stms GPUMem -> body -> body
set_stms ([Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)

expandable, notScalar :: Space -> Bool
expandable :: Space -> Bool
expandable (Space [Char]
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True

extractStmAllocations ::
  User ->
  Names ->
  Names ->
  Stm GPUMem ->
  Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Let (Pattern [] [PatElemT (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
size Space
space)))
  | Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
      -- FIXME: the '&& notScalar space' part is a hack because we
      -- don't otherwise hoist the sizes out far enough, and we
      -- promise to be super-duper-careful about not having variant
      -- scalar allocations.
      Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
    Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName
-> ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
-> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec GPUMem)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
patElem) ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
size, Space
space)
    Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm GPUMem)
forall a. Maybe a
Nothing
  where
    expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    expandableSize Constant {} = Bool
True
    boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    boundInKernel Constant {} = Bool
False
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
  ExpT GPUMem
e <- Mapper GPUMem GPUMem (WriterT Extraction Identity)
-> ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user) (ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem))
-> ExpT GPUMem -> WriterT Extraction Identity (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> ExpT GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm
  Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stm GPUMem)
 -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Maybe (Stm GPUMem)
forall a. a -> Maybe a
Just (Stm GPUMem -> Maybe (Stm GPUMem))
-> Stm GPUMem -> Maybe (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: ExpT GPUMem
stmExp = ExpT GPUMem
e}
  where
    expMapper :: (SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      Mapper GPUMem GPUMem (WriterT Extraction Identity)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
mapOnBody = (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. a -> b -> a
const ((Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
 -> Scope GPUMem
 -> Body GPUMem
 -> WriterT Extraction Identity (Body GPUMem))
-> (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user',
          mapOnOp :: Op GPUMem -> WriterT Extraction Identity (Op GPUMem)
mapOnOp = (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel, [TPrimExp Int64 VName])
user'
        }

    onBody :: (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' Body GPUMem
body = do
      let (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Body GPUMem
body'

    onOp :: (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem ()))
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
-> SegOp SegLevel GPUMem
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM ((SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user'') SegOp SegLevel GPUMem
op
      where
        user'' :: (SegLevel, [TPrimExp Int64 VName])
user'' =
          (SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
    onOp (SegLevel, [TPrimExp Int64 VName])
_ MemOp (HostOp GPUMem ())
op = MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp GPUMem ())
op

    opMapper :: (SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      SegOpMapper SegLevel Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
mapOnSegOpLambda = (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user',
          mapOnSegOpBody :: KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
mapOnSegOpBody = (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user'
        }

    onKernelBody :: (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user' KernelBody GPUMem
body = do
      let (KernelBody GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody GPUMem
body'

    onLambda :: (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user' Lambda GPUMem
lam = do
      Body GPUMem
body <- (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam
      Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

genericExpandedInvariantAllocations ::
  (User -> (Shape, [TPrimExp Int64 VName])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the number of kernel threads.
  ([RebaseMap]
rebases, Stms GPUMem
alloc_stms) <- Binder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     ([RebaseMap], Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (a, Stms rep)
runBinder (Binder GPUMem [RebaseMap]
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      ([RebaseMap], Stms GPUMem))
-> Binder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     ([RebaseMap], Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ ((VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
 -> BinderT GPUMem (State VNameSource) RebaseMap)
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Binder GPUMem [RebaseMap]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BinderT GPUMem (State VNameSource) RebaseMap
expand ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
 -> Binder GPUMem [RebaseMap])
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Binder GPUMem [RebaseMap]
forall a b. (a -> b) -> a -> b
$ Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BinderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
per_thread_size, Space
space)) = do
      let num_users :: ShapeBase SubExp
num_users = (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a, b) -> a
fst ((ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp)
-> (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      VName
total_size <-
        [Char]
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> BinderT GPUMem (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"total_size" (ExpT GPUMem -> BinderT GPUMem (State VNameSource) VName)
-> ([TPrimExp Int64 VName]
    -> BinderT GPUMem (State VNameSource) (ExpT GPUMem))
-> [TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName
-> BinderT GPUMem (State VNameSource) (ExpT GPUMem)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName
 -> BinderT GPUMem (State VNameSource) (ExpT GPUMem))
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) (ExpT GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName]
 -> BinderT GPUMem (State VNameSource) VName)
-> [TPrimExp Int64 VName]
-> BinderT GPUMem (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> Result
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
num_users)
      Pattern (Rep (BinderT GPUMem (State VNameSource)))
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> BinderT GPUMem (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern (Rep (BinderT GPUMem (State VNameSource)))
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (Exp (Rep (BinderT GPUMem (State VNameSource)))
 -> BinderT GPUMem (State VNameSource) ())
-> Exp (Rep (BinderT GPUMem (State VNameSource)))
-> BinderT GPUMem (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
      RebaseMap -> BinderT GPUMem (State VNameSource) RebaseMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RebaseMap -> BinderT GPUMem (State VNameSource) RebaseMap)
-> RebaseMap -> BinderT GPUMem (State VNameSource) RebaseMap
forall a b. (a -> b) -> a -> b
$ VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase (SegLevel, [TPrimExp Int64 VName])
user

    untouched :: d -> DimIndex d
untouched d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1

    newBase :: (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThread {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          num_dims :: Int
num_dims = [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape
          perm :: [Int]
perm = [Int
num_dims .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
users_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName]
old_shape [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> Result
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape))
          permuted_ixfun :: IxFun (TPrimExp Int64 VName)
permuted_ixfun = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
root_ixfun [Int]
perm
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
permuted_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun
    newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegGroup {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> Result
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
old_shape
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun

expandedInvariantAllocations ::
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_groups) (Count SubExp
group_size) =
  ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers
  where
    getNumUsers :: (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) = (Result -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
    getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) = (Result -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups, SubExp
group_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
    getNumUsers (SegGroup {}, [TPrimExp Int64 VName
gid]) = (Result -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups], [TPrimExp Int64 VName
gid])
    getNumUsers (SegLevel, [TPrimExp Int64 VName])
user = [Char] -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a. HasCallStack => [Char] -> a
error ([Char] -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> [Char] -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. (a -> b) -> a -> b
$ [Char]
"getNumUsers: unhandled " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (SegLevel, [TPrimExp Int64 VName]) -> [Char]
forall a. Show a => a -> [Char]
show (SegLevel, [TPrimExp Int64 VName])
user

expandedVariantAllocations ::
  SubExp ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
  | Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: Result
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms GPU
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> Result
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
variant_sizes SegSpace
kspace Stms GPUMem
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  (SymbolTable (Wise GPUMem)
_, Stms GPUMem
slice_stms_tmp) <-
    Stms GPUMem
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (SymbolTable (Wise GPUMem), Stms GPUMem)
forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (SymbolTable (Wise GPUMem), Stms GPUMem)
simplifyStms (Stms GPUMem
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (SymbolTable (Wise GPUMem), Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (SymbolTable (Wise GPUMem), Stms GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms Stms GPU
slice_stms
  Stms GPUMem
slice_stms' <- Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPUMem)
transformStms Stms GPUMem
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
        [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
 -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
          ([(VName, Space)]
 -> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
            (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
            ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm GPUMem]
alloc_bnds, [RebaseMap]
rebases) <- [(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap]))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     [(Stm GPUMem, RebaseMap)]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     ([Stm GPUMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, SubExp, Space))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Stm GPUMem, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     [(Stm GPUMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stm GPUMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPUMem
slice_stms' Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
      let allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      (Stm GPUMem, RebaseMap)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Stm GPUMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT GPUMem -> Stm GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> ExpT GPUMem
forall rep. Op rep -> ExpT rep
Op (Op GPUMem -> ExpT GPUMem) -> Op GPUMem -> ExpT GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
          VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
offset
        )

    num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
    gtid :: TPrimExp Int64 VName
gtid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace

    -- For the variant allocations, we add an inner dimension,
    -- which is then offset by a thread-specific amount.
    newBase :: SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
size_per_thread ([TPrimExp Int64 VName]
old_shape, PrimType
pt) =
      let elems_per_thread :: TPrimExp Int64 VName
elems_per_thread =
            SubExp -> TPrimExp Int64 VName
pe64 SubExp
size_per_thread TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName
elems_per_thread, TPrimExp Int64 VName
num_threads']
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice
              IxFun (TPrimExp Int64 VName)
root_ixfun
              [ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
num_threads' TPrimExp Int64 VName
1,
                TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid
              ]
          shapechange :: [DimChange (TPrimExp Int64 VName)]
shapechange =
            if [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
              then (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimCoercion [TPrimExp Int64 VName]
old_shape
              else (TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimChange (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimChange (TPrimExp Int64 VName)
forall d. d -> DimChange d
DimNew [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
-> [DimChange (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
offset_ixfun [DimChange (TPrimExp Int64 VName)]
shapechange

-- | A map from memory block names to new index function bases.
type RebaseMap = M.Map VName (([TPrimExp Int64 VName], PrimType) -> IxFun)

newtype OffsetM a
  = OffsetM
      ( ReaderT
          (Scope GPUMem)
          (ReaderT RebaseMap (Either String))
          a
      )
  deriving
    ( Functor OffsetM
Functor OffsetM
-> (forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
    (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: forall a. a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
Applicative,
      (forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
      Applicative OffsetM
Applicative OffsetM
-> (forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
Monad,
      HasScope GPUMem,
      LocalScope GPUMem,
      MonadError String
    )

runOffsetM :: Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either [Char] a
runOffsetM Scope GPUMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
m) =
  ReaderT RebaseMap (Either [Char]) a -> RebaseMap -> Either [Char] a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
-> Scope GPUMem -> ReaderT RebaseMap (Either [Char]) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
m Scope GPUMem
scope) RebaseMap
offsets

askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = ReaderT
  (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
-> OffsetM RebaseMap
forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) a
-> OffsetM a
OffsetM (ReaderT
   (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
 -> OffsetM RebaseMap)
-> ReaderT
     (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
-> OffsetM RebaseMap
forall a b. (a -> b) -> a -> b
$ ReaderT RebaseMap (Either [Char]) RebaseMap
-> ReaderT
     (Scope GPUMem) (ReaderT RebaseMap (Either [Char])) RebaseMap
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT RebaseMap (Either [Char]) RebaseMap
forall r (m :: * -> *). MonadReader r m => m r
ask

lookupNewBase :: VName -> ([TPrimExp Int64 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
name ([TPrimExp Int64 VName], PrimType)
x = do
  RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
  Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp Int64 VName))
 -> OffsetM (Maybe (IxFun (TPrimExp Int64 VName))))
-> Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall a b. (a -> b) -> a -> b
$ ((([TPrimExp Int64 VName], PrimType)
 -> IxFun (TPrimExp Int64 VName))
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int64 VName], PrimType)
x) ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> IxFun (TPrimExp Int64 VName))
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
-> Maybe (IxFun (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody = do
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
      ((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
 -> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
  KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms'}

offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) = do
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
      ((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
 -> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
  Body GPUMem -> OffsetM (Body GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec GPUMem
dec Stms GPUMem
stms' Result
res

offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e) = do
  PatternT (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pattern GPUMem -> OffsetM (Pattern GPUMem)
offsetMemoryInPattern Pattern GPUMem
pat
  ExpT GPUMem
e' <- Scope GPUMem -> OffsetM (ExpT GPUMem) -> OffsetM (ExpT GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (PatternT (MemInfo SubExp NoUniqueness MemBind) -> Scope GPUMem
forall rep dec. (LetDec rep ~ dec) => PatternT dec -> Scope rep
scopeOfPattern PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') (OffsetM (ExpT GPUMem) -> OffsetM (ExpT GPUMem))
-> OffsetM (ExpT GPUMem) -> OffsetM (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ ExpT GPUMem -> OffsetM (ExpT GPUMem)
offsetMemoryInExp ExpT GPUMem
e
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <- ReaderT (Scope GPUMem) OffsetM [ExpReturns]
-> Scope GPUMem -> OffsetM [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ExpT GPUMem -> ReaderT (Scope GPUMem) OffsetM [ExpReturns]
forall (m :: * -> *) rep.
(Monad m, LocalScope rep m, Mem rep) =>
Exp rep -> m [ExpReturns]
expReturns ExpT GPUMem
e') Scope GPUMem
scope
  let pat'' :: PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' =
        [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern
          (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat')
          ((PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts)
      stm :: Stm GPUMem
stm = Pattern GPUMem
-> StmAux (ExpDec GPUMem) -> ExpT GPUMem -> Stm GPUMem
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e'
  let scope' :: Scope GPUMem
scope' = Stm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stm GPUMem
stm Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  (Scope GPUMem, Stm GPUMem) -> OffsetM (Scope GPUMem, Stm GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scope GPUMem
scope', Stm GPUMem
stm)
  where
    pick ::
      PatElemT (MemInfo SubExp NoUniqueness MemBind) ->
      ExpReturns ->
      PatElemT (MemInfo SubExp NoUniqueness MemBind)
    pick :: PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
        | Just IxFun (TPrimExp Int64 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun ExtIxFun
extixfun =
          VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun))
    pick PatElemT (MemInfo SubExp NoUniqueness MemBind)
p ExpReturns
_ = PatElemT (MemInfo SubExp NoUniqueness MemBind)
p

    instantiateIxFun :: ExtIxFun -> Maybe IxFun
    instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst)
      where
        inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
        inst (Free a
x) = a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

offsetMemoryInPattern :: Pattern GPUMem -> OffsetM (Pattern GPUMem)
offsetMemoryInPattern :: Pattern GPUMem -> OffsetM (Pattern GPUMem)
offsetMemoryInPattern (Pattern [PatElemT (LetDec GPUMem)]
ctx [PatElemT (LetDec GPUMem)]
vals) = do
  (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)] -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ()
forall {dec} {m :: * -> *}.
(Typed dec, MonadError [Char] m) =>
PatElemT dec -> m ()
inspectCtx [PatElemT (LetDec GPUMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx
  [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec GPUMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> PatternT (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall {u}.
PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal [PatElemT (LetDec GPUMem)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
vals
  where
    inspectVal :: PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal PatElemT (MemBound u)
patElem = do
      MemBound u
new_dec <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ PatElemT (MemBound u) -> MemBound u
forall dec. PatElemT dec -> dec
patElemDec PatElemT (MemBound u)
patElem
      PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT (MemBound u)
patElem {patElemDec :: MemBound u
patElemDec = MemBound u
new_dec}
    inspectCtx :: PatElemT dec -> m ()
inspectCtx PatElemT dec
patElem
      | Mem Space
space <- PatElemT dec -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElemT dec
patElem,
        Space -> Bool
expandable Space
space =
        [Char] -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char] -> m ()) -> [Char] -> m ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"Cannot deal with existential memory block",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem),
              [Char]
"when expanding inside kernels."
            ]
      | Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
  MemBound u
fparam' <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ Param (MemBound u) -> MemBound u
forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
  Param (MemBound u) -> OffsetM (Param (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}

offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) = do
  Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun, PrimType
pt)
  MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$
    MemBound u -> Maybe (MemBound u) -> MemBound u
forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary (Maybe (MemBound u) -> MemBound u)
-> Maybe (MemBound u) -> MemBound u
forall a b. (a -> b) -> a -> b
$ do
      IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
      MemBound u -> Maybe (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> Maybe (MemBound u))
-> MemBound u -> Maybe (MemBound u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int64 VName)
new_base' IxFun (TPrimExp Int64 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return MemBound u
summary

offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
  | Just IxFun (TPrimExp Int64 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
isStaticIxFun ExtIxFun
ixfun = do
    Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun', PrimType
pt)
    BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> OffsetM BranchTypeMem)
-> BranchTypeMem -> OffsetM BranchTypeMem
forall a b. (a -> b) -> a -> b
$
      BranchTypeMem -> Maybe BranchTypeMem -> BranchTypeMem
forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br (Maybe BranchTypeMem -> BranchTypeMem)
-> Maybe BranchTypeMem -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$ do
        IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
        BranchTypeMem -> Maybe BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> Maybe BranchTypeMem)
-> BranchTypeMem -> Maybe BranchTypeMem
forall a b. (a -> b) -> a -> b
$
          PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem) -> MemReturn -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
            VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
              ExtIxFun -> ExtIxFun -> ExtIxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun (TPrimExp Int64 VName) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun (TPrimExp Int64 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return BranchTypeMem
br

offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
lam = Lambda GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ do
  Body GPUMem
body <- Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda GPUMem
lam
  Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: ExpT GPUMem -> OffsetM (ExpT GPUMem)
offsetMemoryInExp (DoLoop [(FParam GPUMem, SubExp)]
ctx [(FParam GPUMem, SubExp)]
val LoopForm GPUMem
form Body GPUMem
body) = do
  let ([Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams, Result
ctxinit) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
ctx
      ([Param (MemInfo SubExp Uniqueness MemBind)]
valparams, Result
valinit) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
val
  [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' <- (Param (MemInfo SubExp Uniqueness MemBind)
 -> OffsetM (Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams
  [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' <- (Param (MemInfo SubExp Uniqueness MemBind)
 -> OffsetM (Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
valparams
  Body GPUMem
body' <- Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (MemInfo SubExp Uniqueness MemBind)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> LoopForm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPUMem
form) (Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody Body GPUMem
body)
  ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT GPUMem -> OffsetM (ExpT GPUMem))
-> ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall a b. (a -> b) -> a -> b
$ [(FParam GPUMem, SubExp)]
-> [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem
-> Body GPUMem
-> ExpT GPUMem
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams' Result
ctxinit) ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
valparams' Result
valinit) LoopForm GPUMem
form Body GPUMem
body'
offsetMemoryInExp ExpT GPUMem
e = Mapper GPUMem GPUMem OffsetM
-> ExpT GPUMem -> OffsetM (ExpT GPUMem)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse ExpT GPUMem
e
  where
    recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
      Mapper GPUMem GPUMem OffsetM
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem -> Body GPUMem -> OffsetM (Body GPUMem)
mapOnBody = \Scope GPUMem
bscope -> Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem
-> OffsetM (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody,
          mapOnBranchType :: BranchType GPUMem -> OffsetM (BranchType GPUMem)
mapOnBranchType = BranchType GPUMem -> OffsetM (BranchType GPUMem)
BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
          mapOnOp :: Op GPUMem -> OffsetM (Op GPUMem)
mapOnOp = Op GPUMem -> OffsetM (Op GPUMem)
forall {op}.
MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp
        }
    onOp :: MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp GPUMem op -> MemOp (HostOp GPUMem op)
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem op -> MemOp (HostOp GPUMem op))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem op)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp
        (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem op))
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (MemOp (HostOp GPUMem op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (SegOpMapper SegLevel GPUMem GPUMem OffsetM
-> SegOp SegLevel GPUMem -> OffsetM (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem OffsetM
forall {lvl}. SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
      where
        segOpMapper :: SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper =
          SegOpMapper lvl Any Any OffsetM
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody,
              mapOnSegOpLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda
            }
    onOp MemOp (HostOp GPUMem op)
op = MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp GPUMem op)
op

---- Slicing allocation sizes out of a kernel.

unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either [Char] (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
False
  where
    unAllocBody :: Body GPUMem -> Either [Char] (BodyT GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
      BodyDec GPU -> Stms GPU -> Result -> BodyT GPU
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> Result -> BodyT GPU)
-> Either [Char] (Stms GPU) -> Either [Char] (Result -> BodyT GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either [Char] (Result -> BodyT GPU)
-> Either [Char] Result -> Either [Char] (BodyT GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either [Char] Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody GPUMem -> Either [Char] (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
      BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> [KernelResult] -> KernelBody GPU)
-> Either [Char] (Stms GPU)
-> Either [Char] ([KernelResult] -> KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either [Char] ([KernelResult] -> KernelBody GPU)
-> Either [Char] [KernelResult] -> Either [Char] (KernelBody GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either [Char] [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms GPUMem -> Either [Char] (Stms GPU)
unAllocStms Bool
nested =
      ([Maybe (Stm GPU)] -> Stms GPU)
-> Either [Char] [Maybe (Stm GPU)] -> Either [Char] (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU)
-> ([Maybe (Stm GPU)] -> [Stm GPU])
-> [Maybe (Stm GPU)]
-> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Stm GPU)] -> [Stm GPU]
forall a. [Maybe a] -> [a]
catMaybes) (Either [Char] [Maybe (Stm GPU)] -> Either [Char] (Stms GPU))
-> (Stms GPUMem -> Either [Char] [Maybe (Stm GPU)])
-> Stms GPUMem
-> Either [Char] (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPUMem -> Either [Char] (Maybe (Stm GPU)))
-> [Stm GPUMem] -> Either [Char] [Maybe (Stm GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm GPUMem -> Either [Char] (Maybe (Stm GPU))
unAllocStm Bool
nested) ([Stm GPUMem] -> Either [Char] [Maybe (Stm GPU)])
-> (Stms GPUMem -> [Stm GPUMem])
-> Stms GPUMem
-> Either [Char] [Maybe (Stm GPU)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList

    unAllocStm :: Bool -> Stm GPUMem -> Either [Char] (Maybe (Stm GPU))
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pattern GPUMem
_ StmAux (ExpDec GPUMem)
_ (Op Alloc {}))
      | Bool
nested = [Char] -> Either [Char] (Maybe (Stm GPU))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Char] -> Either [Char] (Maybe (Stm GPU)))
-> [Char] -> Either [Char] (Maybe (Stm GPU))
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle nested allocation: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Stm GPUMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty Stm GPUMem
stm
      | Bool
otherwise = Maybe (Stm GPU) -> Either [Char] (Maybe (Stm GPU))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm GPU)
forall a. Maybe a
Nothing
    unAllocStm Bool
_ (Let Pattern GPUMem
pat StmAux (ExpDec GPUMem)
dec ExpT GPUMem
e) =
      Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just (Stm GPU -> Maybe (Stm GPU))
-> Either [Char] (Stm GPU) -> Either [Char] (Maybe (Stm GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> ExpT GPU -> Stm GPU
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> StmAux () -> ExpT GPU -> Stm GPU)
-> Either
     [Char] (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either [Char] (StmAux () -> ExpT GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatternT (MemInfo SubExp NoUniqueness MemBind)
-> Either
     [Char] (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
 Pretty (ShapeBase d)) =>
PatternT (MemInfo d u ret)
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern Pattern GPUMem
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat Either [Char] (StmAux () -> ExpT GPU -> Stm GPU)
-> Either [Char] (StmAux ()) -> Either [Char] (ExpT GPU -> Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either [Char] (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec Either [Char] (ExpT GPU -> Stm GPU)
-> Either [Char] (ExpT GPU) -> Either [Char] (Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper GPUMem GPU (Either [Char])
-> ExpT GPUMem -> Either [Char] (ExpT GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either [Char])
unAlloc' ExpT GPUMem
e)

    unAllocLambda :: Lambda GPUMem -> Either [Char] (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
      [LParam GPU]
-> BodyT GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall {d} {u} {ret}.
[Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (BodyT GPU
 -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
-> Either [Char] (BodyT GPU)
-> Either
     [Char] ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either [Char] (BodyT GPU)
unAllocBody Body GPUMem
body Either
  [Char] ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
-> Either [Char] [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either [Char] (Lambda GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either [Char] [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret

    unParams :: [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams = (Param (MemInfo d u ret)
 -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Param (MemInfo d u ret)
  -> Maybe (Param (TypeBase (ShapeBase d) u)))
 -> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)])
-> (Param (MemInfo d u ret)
    -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)]
-> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem

    unAllocPattern :: PatternT (MemInfo d u ret)
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern pat :: PatternT (MemInfo d u ret)
pat@(Pattern [PatElemT (MemInfo d u ret)]
ctx [PatElemT (MemInfo d u ret)]
val) =
      [PatElemT (TypeBase (ShapeBase d) u)]
-> [PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern ([PatElemT (TypeBase (ShapeBase d) u)]
 -> [PatElemT (TypeBase (ShapeBase d) u)]
 -> PatternT (TypeBase (ShapeBase d) u))
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> Either
     [Char]
     ([PatElemT (TypeBase (ShapeBase d) u)]
      -> PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
ctx)
        Either
  [Char]
  ([PatElemT (TypeBase (ShapeBase d) u)]
   -> PatternT (TypeBase (ShapeBase d) u))
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] (PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
val)
      where
        bad :: Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
bad = [Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. a -> Either a b
Left ([Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)])
-> [Char] -> Either [Char] [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory in pattern " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo d u ret) -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (MemInfo d u ret)
pat

    unAllocOp :: MemOp (HostOp GPUMem ()) -> Either [Char] (HostOp GPU (SOAC GPU))
unAllocOp Alloc {} = [Char] -> Either [Char] (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left [Char]
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp {}) = [Char] -> Either [Char] (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left [Char]
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner (SizeOp SizeOp
op)) =
      HostOp GPU (SOAC GPU) -> Either [Char] (HostOp GPU (SOAC GPU))
forall (m :: * -> *) a. Monad m => a -> m a
return (HostOp GPU (SOAC GPU) -> Either [Char] (HostOp GPU (SOAC GPU)))
-> HostOp GPU (SOAC GPU) -> Either [Char] (HostOp GPU (SOAC GPU))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> Either [Char] (SegOp SegLevel GPU)
-> Either [Char] (HostOp GPU (SOAC GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPU (Either [Char])
-> SegOp SegLevel GPUMem -> Either [Char] (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either [Char])
mapper SegOp SegLevel GPUMem
op
      where
        mapper :: SegOpMapper SegLevel GPUMem GPU (Either [Char])
mapper =
          SegOpMapper SegLevel Any Any (Either [Char])
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpLambda :: Lambda GPUMem -> Either [Char] (Lambda GPU)
mapOnSegOpLambda = Lambda GPUMem -> Either [Char] (Lambda GPU)
unAllocLambda,
              mapOnSegOpBody :: KernelBody GPUMem -> Either [Char] (KernelBody GPU)
mapOnSegOpBody = KernelBody GPUMem -> Either [Char] (KernelBody GPU)
unAllocKernelBody
            }

    unParam :: t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam t (MemInfo d u ret)
p = Either [Char] (t (TypeBase (ShapeBase d) u))
-> (t (TypeBase (ShapeBase d) u)
    -> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] (t (TypeBase (ShapeBase d) u))
bad t (TypeBase (ShapeBase d) u)
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (t (TypeBase (ShapeBase d) u))
 -> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> t (MemInfo d u ret) -> Maybe (t (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem t (MemInfo d u ret)
p
      where
        bad :: Either [Char] (t (TypeBase (ShapeBase d) u))
bad = [Char] -> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (t (TypeBase (ShapeBase d) u)))
-> [Char] -> Either [Char] (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory-typed parameter '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ t (MemInfo d u ret) -> [Char]
forall a. Pretty a => a -> [Char]
pretty t (MemInfo d u ret)
p [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"'"

    unT :: MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT MemInfo d u ret
t = Either [Char] (TypeBase (ShapeBase d) u)
-> (TypeBase (ShapeBase d) u
    -> Either [Char] (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either [Char] (TypeBase (ShapeBase d) u)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either [Char] (TypeBase (ShapeBase d) u)
bad TypeBase (ShapeBase d) u
-> Either [Char] (TypeBase (ShapeBase d) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeBase (ShapeBase d) u)
 -> Either [Char] (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem MemInfo d u ret
t
      where
        bad :: Either [Char] (TypeBase (ShapeBase d) u)
bad = [Char] -> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. a -> Either a b
Left ([Char] -> Either [Char] (TypeBase (ShapeBase d) u))
-> [Char] -> Either [Char] (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot handle memory type '" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ MemInfo d u ret -> [Char]
forall a. Pretty a => a -> [Char]
pretty MemInfo d u ret
t [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"'"

    unAlloc' :: Mapper GPUMem GPU (Either [Char])
unAlloc' =
      Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
        { mapOnBody :: Scope GPU -> Body GPUMem -> Either [Char] (BodyT GPU)
mapOnBody = (Body GPUMem -> Either [Char] (BodyT GPU))
-> Scope GPU -> Body GPUMem -> Either [Char] (BodyT GPU)
forall a b. a -> b -> a
const Body GPUMem -> Either [Char] (BodyT GPU)
unAllocBody,
          mapOnRetType :: RetType GPUMem -> Either [Char] (RetType GPU)
mapOnRetType = RetType GPUMem -> Either [Char] (RetType GPU)
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
 Pretty (ShapeBase d)) =>
MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT,
          mapOnBranchType :: BranchType GPUMem -> Either [Char] (BranchType GPU)
mapOnBranchType = BranchType GPUMem -> Either [Char] (BranchType GPU)
forall {d} {u} {ret}.
(Pretty d, Pretty u, Pretty ret, Pretty (TypeBase (ShapeBase d) u),
 Pretty (ShapeBase d)) =>
MemInfo d u ret -> Either [Char] (TypeBase (ShapeBase d) u)
unT,
          mapOnFParam :: FParam GPUMem -> Either [Char] (FParam GPU)
mapOnFParam = FParam GPUMem -> Either [Char] (FParam GPU)
forall {t :: * -> *} {d} {u} {ret}.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnLParam :: LParam GPUMem -> Either [Char] (LParam GPU)
mapOnLParam = LParam GPUMem -> Either [Char] (LParam GPU)
forall {t :: * -> *} {d} {u} {ret}.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either [Char] (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnOp :: Op GPUMem -> Either [Char] (Op GPU)
mapOnOp = Op GPUMem -> Either [Char] (Op GPU)
MemOp (HostOp GPUMem ()) -> Either [Char] (HostOp GPU (SOAC GPU))
unAllocOp,
          mapOnSubExp :: SubExp -> Either [Char] SubExp
mapOnSubExp = SubExp -> Either [Char] SubExp
forall a b. b -> Either a b
Right,
          mapOnVName :: VName -> Either [Char] VName
mapOnVName = VName -> Either [Char] VName
forall a b. b -> Either a b
Right
        }

unMem :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem :: forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem (MemPrim PrimType
pt) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase (ShapeBase d) u
forall shape u.
VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase shape u
Acc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u
unMem MemMem {} = Maybe (TypeBase (ShapeBase d) u)
forall a. Maybe a
Nothing

unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = (NameInfo GPUMem -> Maybe (NameInfo GPU))
-> Scope GPUMem -> Scope GPU
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe NameInfo GPUMem -> Maybe (NameInfo GPU)
forall {rep} {d} {u} {d} {u} {d} {u} {rep} {ret} {ret} {ret}.
(LParamInfo rep ~ TypeBase (ShapeBase d) u,
 FParamInfo rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ MemInfo d u ret, FParamInfo rep ~ MemInfo d u ret,
 LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> Maybe (NameInfo rep)
unInfo
  where
    unInfo :: NameInfo rep -> Maybe (NameInfo rep)
unInfo (LetName LetDec rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. LetDec rep -> NameInfo rep
LetName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LetDec rep
MemInfo d u ret
dec
    unInfo (FParamName FParamInfo rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. FParamInfo rep -> NameInfo rep
FParamName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem FParamInfo rep
MemInfo d u ret
dec
    unInfo (LParamName LParamInfo rep
dec) = TypeBase (ShapeBase d) u -> NameInfo rep
forall rep. LParamInfo rep -> NameInfo rep
LParamName (TypeBase (ShapeBase d) u -> NameInfo rep)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LParamInfo rep
MemInfo d u ret
dec
    unInfo (IndexName IntType
it) = NameInfo rep -> Maybe (NameInfo rep)
forall a. a -> Maybe a
Just (NameInfo rep -> Maybe (NameInfo rep))
-> NameInfo rep -> Maybe (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
it

removeCommonSizes ::
  Extraction ->
  [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
 -> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
 -> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> Map SubExp [(VName, Space)]
forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
 -> Map SubExp [(VName, Space)])
-> (Extraction
    -> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
  where
    comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

sliceKernelSizes ::
  SubExp ->
  [SubExp] ->
  SegSpace ->
  Stms GPUMem ->
  ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> Result
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
sizes SegSpace
space Stms GPUMem
kstms = do
  Stms GPU
kstms' <- ([Char]
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU))
-> (Stms GPU
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU))
-> Either [Char] (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either [Char]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [Char] (Stms GPU)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU))
-> Either [Char] (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either [Char] (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
  let num_sizes :: Int
num_sizes = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
sizes
      i64s :: [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s = Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate Int
num_sizes (TypeBase (ShapeBase SubExp) NoUniqueness
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope GPU
kernels_scope <- (Scope GPUMem -> Scope GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) (Scope GPU)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope GPUMem -> Scope GPU
unAllocScope

  (Lambda GPU
max_lam, Stms GPU
_) <- (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT Scope GPU
kernels_scope (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Lambda GPU, Stms GPU))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs <- Int
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys <- Int
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"y" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- Scope GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Result, Stms GPU)
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Result, Stms GPU))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$
      BinderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
  Result
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Rep
           (BinderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   Result
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Result,
       Stms
         (Rep
            (BinderT
               GPU
               (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Result
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Rep
           (BinderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall a b. (a -> b) -> a -> b
$
        [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
  Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BinderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
         SubExp)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
   Param (TypeBase (ShapeBase SubExp) NoUniqueness))
  -> BinderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
       SubExp)
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      Result)
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BinderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
         SubExp)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Result
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
          [Char]
-> Exp
     (Rep
        (BinderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"z" (Exp
   (Rep
      (BinderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      SubExp)
-> Exp
     (Rep
        (BinderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
    Lambda GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda GPU
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Lambda GPU))
-> Lambda GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ [LParam GPU]
-> BodyT GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (Stms GPU -> Result -> BodyT GPU
forall rep. Bindable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s

  Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam <- VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> Param dec
Param (VName
 -> TypeBase (ShapeBase SubExp) NoUniqueness
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) VName
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (TypeBase (ShapeBase SubExp) NoUniqueness
      -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either [Char])) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"flat_gtid" ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either [Char]))
  (TypeBase (ShapeBase SubExp) NoUniqueness
   -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (TypeBase (ShapeBase SubExp) NoUniqueness)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))

  (Lambda GPU
size_lam', Stms GPU
_) <- (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT Scope GPU
kernels_scope (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (Lambda GPU, Stms GPU))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params <- Int
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ [Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- Scope GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
      ( [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params
          Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam]
      )
      (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Result, Stms GPU)
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Result, Stms GPU))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$ BinderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
  Result
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Rep
           (BinderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   Result
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Result,
       Stms
         (Rep
            (BinderT
               GPU
               (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Result
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Result,
      Stms
        (Rep
           (BinderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall a b. (a -> b) -> a -> b
$ do
        -- Even though this SegRed is one-dimensional, we need to
        -- provide indexes corresponding to the original potentially
        -- multi-dimensional construct.
        let ([VName]
kspace_gtids, Result
kspace_dims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
            new_inds :: [TPrimExp Int64 VName]
new_inds =
              [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                ((SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 Result
kspace_dims)
                (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
        ([VName]
 -> ExpT GPU
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      ())
-> [[VName]]
-> [ExpT GPU]
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> ExpT GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([ExpT GPU]
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      ())
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [ExpT GPU]
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (ExpT GPU))
-> [TPrimExp Int64 VName]
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [ExpT GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp Int64 VName
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (ExpT GPU)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds

        (Stm GPU
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      ())
-> Stms GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *). MonadBinder m => Stm (Rep m) -> m ()
addStm Stms GPU
kstms'
        Result
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
sizes

    Scope GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   (Lambda GPU)
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Lambda GPU))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$
      Lambda GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Lambda GPU)
forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
GPU.simplifyLambda ([LParam GPU]
-> BodyT GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> BodyT rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT rep
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam GPU
flat_gtid_lparam] (BodyDec GPU -> Stms GPU -> Result -> BodyT GPU
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s)

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms GPU
slice_stms) <- (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   ([VName], [VName])
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (([VName], [VName]), Stms GPU))
-> Scope GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (([VName], [VName]), Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
  ([VName], [VName])
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (([VName], [VName]), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT Scope GPU
kernels_scope (BinderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
   ([VName], [VName])
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either [Char]))
      (([VName], [VName]), Stms GPU))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either [Char]))
     (([VName], [VName]), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <-
      [Ident]
-> [Ident] -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
basicPattern []
        ([Ident] -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Ident]
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Ident
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM
          Int
num_sizes
          ([Char]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent [Char]
"max_per_thread" (TypeBase (ShapeBase SubExp) NoUniqueness
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      Ident)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <-
      [Char]
-> Exp
     (Rep
        (BinderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"size_slice_w"
        (ExpT GPU
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      SubExp)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (ExpT GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Exp
        (Rep
           (BinderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char]))))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> Result
segSpaceDims SegSpace
space)

    VName
thread_space_iota <-
      [Char]
-> Exp
     (Rep
        (BinderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"thread_space_iota" (Exp
   (Rep
      (BinderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      VName)
-> Exp
     (Rep
        (BinderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
    let red_op :: SegBinOp GPU
red_op =
          Commutativity
-> Lambda GPU -> Result -> ShapeBase SubExp -> SegBinOp GPU
forall rep.
Commutativity
-> Lambda rep -> Result -> ShapeBase SubExp -> SegBinOp rep
SegBinOp
            Commutativity
Commutative
            Lambda GPU
max_lam
            (Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> Result) -> SubExp -> Result
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
            ShapeBase SubExp
forall a. Monoid a => a
mempty
    SegLevel
lvl <- [Char]
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     SegLevel
forall (m :: * -> *) inner.
(MonadBinder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
[Char] -> m SegLevel
segThread [Char]
"segred"

    Stms GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms GPU
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      ())
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Stms GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm GPU
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Stm GPU))
-> Stms GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Stm GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
      (Stms GPU
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      (Stms GPU))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Stms GPU)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pattern GPU
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pattern rep
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel GPU
SegLevel
lvl PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
Pattern GPU
pat SubExp
w [SegBinOp GPU
red_op] Lambda GPU
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- [VName]
-> (VName
    -> BinderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
         VName)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat) ((VName
  -> BinderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
       VName)
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      [VName])
-> (VName
    -> BinderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
         VName)
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     [VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      [Char]
-> Exp
     (Rep
        (BinderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     VName
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"size_sum" (Exp
   (Rep
      (BinderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
 -> BinderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
      VName)
-> Exp
     (Rep
        (BinderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))))
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT GPU
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT GPU) -> BasicOp -> ExpT GPU
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads

    ([VName], [VName])
-> BinderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either [Char])))
     ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
pat, [VName]
size_sums)

  (Stms GPU, [VName], [VName])
-> ExpandM (Stms GPU, [VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms GPU
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)