{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations (expandAllocations) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Error
import Futhark.IR
import qualified Futhark.IR.Kernels.Simplify as Kernels
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Lore (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.Kernels (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToKernels (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

-- | The memory expansion pass definition.
expandAllocations :: Pass KernelsMem KernelsMem
expandAllocations :: Pass KernelsMem KernelsMem
expandAllocations =
  String
-> String
-> (Prog KernelsMem -> PassM (Prog KernelsMem))
-> Pass KernelsMem KernelsMem
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"expand allocations" String
"Expand allocations" ((Prog KernelsMem -> PassM (Prog KernelsMem))
 -> Pass KernelsMem KernelsMem)
-> (Prog KernelsMem -> PassM (Prog KernelsMem))
-> Pass KernelsMem KernelsMem
forall a b. (a -> b) -> a -> b
$
    \(Prog Stms KernelsMem
consts [FunDef KernelsMem]
funs) -> do
      Stms KernelsMem
consts' <-
        (VNameSource -> (Stms KernelsMem, VNameSource))
-> PassM (Stms KernelsMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms KernelsMem, VNameSource))
 -> PassM (Stms KernelsMem))
-> (VNameSource -> (Stms KernelsMem, VNameSource))
-> PassM (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Either String (Stms KernelsMem, VNameSource)
-> (Stms KernelsMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft (Either String (Stms KernelsMem, VNameSource)
 -> (Stms KernelsMem, VNameSource))
-> (VNameSource -> Either String (Stms KernelsMem, VNameSource))
-> VNameSource
-> (Stms KernelsMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Stms KernelsMem)
-> VNameSource -> Either String (Stms KernelsMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Stms KernelsMem)
-> Scope KernelsMem
-> StateT VNameSource (Either String) (Stms KernelsMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
transformStms Stms KernelsMem
consts) Scope KernelsMem
forall a. Monoid a => a
mempty)
      Stms KernelsMem -> [FunDef KernelsMem] -> Prog KernelsMem
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms KernelsMem
consts' ([FunDef KernelsMem] -> Prog KernelsMem)
-> PassM [FunDef KernelsMem] -> PassM (Prog KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunDef KernelsMem -> PassM (FunDef KernelsMem))
-> [FunDef KernelsMem] -> PassM [FunDef KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope KernelsMem -> FunDef KernelsMem -> PassM (FunDef KernelsMem)
transformFunDef (Scope KernelsMem
 -> FunDef KernelsMem -> PassM (FunDef KernelsMem))
-> Scope KernelsMem
-> FunDef KernelsMem
-> PassM (FunDef KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms KernelsMem
consts') [FunDef KernelsMem]
funs

-- Cannot use intraproceduralTransformation because it might create
-- duplicate size keys (which are not fixed by renamer, and size
-- keys must currently be globally unique).

type ExpandM = ReaderT (Scope KernelsMem) (StateT VNameSource (Either String))

limitationOnLeft :: Either String a -> a
limitationOnLeft :: Either String a -> a
limitationOnLeft = (String -> a) -> (a -> a) -> Either String a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> a
forall a. String -> a
compilerLimitationS a -> a
forall a. a -> a
id

transformFunDef ::
  Scope KernelsMem ->
  FunDef KernelsMem ->
  PassM (FunDef KernelsMem)
transformFunDef :: Scope KernelsMem -> FunDef KernelsMem -> PassM (FunDef KernelsMem)
transformFunDef Scope KernelsMem
scope FunDef KernelsMem
fundec = do
  Body KernelsMem
body' <- (VNameSource -> (Body KernelsMem, VNameSource))
-> PassM (Body KernelsMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body KernelsMem, VNameSource))
 -> PassM (Body KernelsMem))
-> (VNameSource -> (Body KernelsMem, VNameSource))
-> PassM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Either String (Body KernelsMem, VNameSource)
-> (Body KernelsMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft (Either String (Body KernelsMem, VNameSource)
 -> (Body KernelsMem, VNameSource))
-> (VNameSource -> Either String (Body KernelsMem, VNameSource))
-> VNameSource
-> (Body KernelsMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Body KernelsMem)
-> VNameSource -> Either String (Body KernelsMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Body KernelsMem)
-> Scope KernelsMem
-> StateT VNameSource (Either String) (Body KernelsMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Body KernelsMem)
m Scope KernelsMem
forall a. Monoid a => a
mempty)
  SimpleOps KernelsMem
-> SymbolTable (Wise KernelsMem)
-> FunDef KernelsMem
-> PassM (FunDef KernelsMem)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> SymbolTable (Wise lore) -> FunDef lore -> m (FunDef lore)
copyPropagateInFun
    SimpleOps KernelsMem
simpleKernelsMem
    (Scope (Wise KernelsMem) -> SymbolTable (Wise KernelsMem)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope KernelsMem -> Scope (Wise KernelsMem)
forall lore. Scope lore -> Scope (Wise lore)
addScopeWisdom Scope KernelsMem
scope))
    FunDef KernelsMem
fundec {funDefBody :: Body KernelsMem
funDefBody = Body KernelsMem
body'}
  where
    m :: ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Body KernelsMem)
m =
      Scope KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either String))
   (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$
        FunDef KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf FunDef KernelsMem
fundec (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either String))
   (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$
          Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
transformBody (Body KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Body KernelsMem))
-> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ FunDef KernelsMem -> Body KernelsMem
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef KernelsMem
fundec

transformBody :: Body KernelsMem -> ExpandM (Body KernelsMem)
transformBody :: Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
transformBody (Body () Stms KernelsMem
stms Result
res) = BodyDec KernelsMem -> Stms KernelsMem -> Result -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () (Stms KernelsMem -> Result -> Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Result -> Body KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
transformStms Stms KernelsMem
stms ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Result -> Body KernelsMem)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) Result
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformStms :: Stms KernelsMem -> ExpandM (Stms KernelsMem)
transformStms :: Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
transformStms Stms KernelsMem
stms =
  Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms KernelsMem
stms (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either String))
   (Stms KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ [Stms KernelsMem] -> Stms KernelsMem
forall a. Monoid a => [a] -> a
mconcat ([Stms KernelsMem] -> Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     [Stms KernelsMem]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem))
-> [Stm KernelsMem]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     [Stms KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
transformStm (Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms KernelsMem
stms)

transformStm :: Stm KernelsMem -> ExpandM (Stms KernelsMem)
-- It is possible that we are unable to expand allocations in some
-- code versions.  If so, we can remove the offending branch.  Only if
-- both versions fail do we propagate the error.
transformStm :: Stm KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
transformStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux (If SubExp
cond Body KernelsMem
tbranch Body KernelsMem
fbranch (IfDec [BranchType KernelsMem]
ts IfSort
IfEquiv))) = do
  Either String (Body KernelsMem)
tbranch' <- (Body KernelsMem -> Either String (Body KernelsMem)
forall a b. b -> Either a b
Right (Body KernelsMem -> Either String (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
transformBody Body KernelsMem
tbranch) ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Either String (Body KernelsMem))
-> (String
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either String))
         (Either String (Body KernelsMem)))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Either String (Body KernelsMem)))
-> (String -> Either String (Body KernelsMem))
-> String
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Body KernelsMem)
forall a b. a -> Either a b
Left)
  Either String (Body KernelsMem)
fbranch' <- (Body KernelsMem -> Either String (Body KernelsMem)
forall a b. b -> Either a b
Right (Body KernelsMem -> Either String (Body KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
transformBody Body KernelsMem
fbranch) ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Either String (Body KernelsMem))
-> (String
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either String))
         (Either String (Body KernelsMem)))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Either String (Body KernelsMem)))
-> (String -> Either String (Body KernelsMem))
-> String
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Either String (Body KernelsMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Body KernelsMem)
forall a b. a -> Either a b
Left)
  case (Either String (Body KernelsMem)
tbranch', Either String (Body KernelsMem)
fbranch') of
    (Left String
_, Right Body KernelsMem
fbranch'') ->
      Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Body KernelsMem -> Stms KernelsMem
useBranch Body KernelsMem
fbranch''
    (Right Body KernelsMem
tbranch'', Left String
_) ->
      Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Body KernelsMem -> Stms KernelsMem
useBranch Body KernelsMem
tbranch''
    (Right Body KernelsMem
tbranch'', Right Body KernelsMem
fbranch'') ->
      Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem -> Stms KernelsMem
forall lore. Stm lore -> Stms lore
oneStm (Stm KernelsMem -> Stms KernelsMem)
-> Stm KernelsMem -> Stms KernelsMem
forall a b. (a -> b) -> a -> b
$ Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body KernelsMem
-> Body KernelsMem
-> IfDec (BranchType KernelsMem)
-> ExpT KernelsMem
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond Body KernelsMem
tbranch'' Body KernelsMem
fbranch'' ([BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType KernelsMem]
[BranchTypeMem]
ts IfSort
IfEquiv)
    (Left String
e, Either String (Body KernelsMem)
_) ->
      String
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
e
  where
    bindRes :: PatElemT (LetDec lore) -> SubExp -> Stm lore
bindRes PatElemT (LetDec lore)
pe SubExp
se = Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

    useBranch :: Body KernelsMem -> Stms KernelsMem
useBranch Body KernelsMem
b =
      Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms Body KernelsMem
b
        Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList ((PatElemT LetDecMem -> SubExp -> Stm KernelsMem)
-> [PatElemT LetDecMem] -> Result -> [Stm KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT LetDecMem -> SubExp -> Stm KernelsMem
forall lore.
(ExpDec lore ~ ()) =>
PatElemT (LetDec lore) -> SubExp -> Stm lore
bindRes (PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LetDecMem
pat) (Body KernelsMem -> Result
forall lore. BodyT lore -> Result
bodyResult Body KernelsMem
b))
transformStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux ExpT KernelsMem
e) = do
  (Stms KernelsMem
bnds, ExpT KernelsMem
e') <- ExpT KernelsMem -> ExpandM (Stms KernelsMem, ExpT KernelsMem)
transformExp (ExpT KernelsMem -> ExpandM (Stms KernelsMem, ExpT KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
-> ExpT KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (ExpT KernelsMem)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
transform ExpT KernelsMem
e
  Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem))
-> Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem
bnds Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> Stm KernelsMem -> Stms KernelsMem
forall lore. Stm lore -> Stms lore
oneStm (Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
aux ExpT KernelsMem
e')
  where
    transform :: Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
transform =
      Mapper
  KernelsMem
  KernelsMem
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope KernelsMem
-> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
mapOnBody = \Scope KernelsMem
scope -> Scope KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either String))
   (Body KernelsMem)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Body KernelsMem))
-> (Body KernelsMem
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either String))
         (Body KernelsMem))
-> Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Body KernelsMem)
transformBody
        }

nameInfoConv :: NameInfo KernelsMem -> NameInfo KernelsMem
nameInfoConv :: NameInfo KernelsMem -> NameInfo KernelsMem
nameInfoConv (LetName LetDec KernelsMem
mem_info) = LetDec KernelsMem -> NameInfo KernelsMem
forall lore. LetDec lore -> NameInfo lore
LetName LetDec KernelsMem
mem_info
nameInfoConv (FParamName FParamInfo KernelsMem
mem_info) = FParamInfo KernelsMem -> NameInfo KernelsMem
forall lore. FParamInfo lore -> NameInfo lore
FParamName FParamInfo KernelsMem
mem_info
nameInfoConv (LParamName LParamInfo KernelsMem
mem_info) = LParamInfo KernelsMem -> NameInfo KernelsMem
forall lore. LParamInfo lore -> NameInfo lore
LParamName LParamInfo KernelsMem
mem_info
nameInfoConv (IndexName IntType
it) = IntType -> NameInfo KernelsMem
forall lore. IntType -> NameInfo lore
IndexName IntType
it

transformExp :: Exp KernelsMem -> ExpandM (Stms KernelsMem, Exp KernelsMem)
transformExp :: ExpT KernelsMem -> ExpandM (Stms KernelsMem, ExpT KernelsMem)
transformExp (Op (Inner (SegOp (SegMap lvl space ts kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
_, KernelBody KernelsMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody KernelsMem
kbody
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody KernelsMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegRed lvl space reds ts kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
lams, KernelBody KernelsMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp KernelsMem -> Lambda KernelsMem)
-> [SegBinOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp KernelsMem]
reds) KernelBody KernelsMem
kbody
  let reds' :: [SegBinOp KernelsMem]
reds' = (SegBinOp KernelsMem -> Lambda KernelsMem -> SegBinOp KernelsMem)
-> [SegBinOp KernelsMem]
-> [Lambda KernelsMem]
-> [SegBinOp KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp KernelsMem
red Lambda KernelsMem
lam -> SegBinOp KernelsMem
red {segBinOpLambda :: Lambda KernelsMem
segBinOpLambda = Lambda KernelsMem
lam}) [SegBinOp KernelsMem]
reds [Lambda KernelsMem]
lams
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> [Type]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds' [Type]
ts KernelBody KernelsMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegScan lvl space scans ts kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
lams, KernelBody KernelsMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp KernelsMem -> Lambda KernelsMem)
-> [SegBinOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda [SegBinOp KernelsMem]
scans) KernelBody KernelsMem
kbody
  let scans' :: [SegBinOp KernelsMem]
scans' = (SegBinOp KernelsMem -> Lambda KernelsMem -> SegBinOp KernelsMem)
-> [SegBinOp KernelsMem]
-> [Lambda KernelsMem]
-> [SegBinOp KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp KernelsMem
red Lambda KernelsMem
lam -> SegBinOp KernelsMem
red {segBinOpLambda :: Lambda KernelsMem
segBinOpLambda = Lambda KernelsMem
lam}) [SegBinOp KernelsMem]
scans [Lambda KernelsMem]
lams
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> [Type]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans' [Type]
ts KernelBody KernelsMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegHist lvl space ops ts kbody)))) = do
  (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
lams', KernelBody KernelsMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda KernelsMem]
lams KernelBody KernelsMem
kbody
  let ops' :: [HistOp KernelsMem]
ops' = (HistOp KernelsMem -> Lambda KernelsMem -> HistOp KernelsMem)
-> [HistOp KernelsMem]
-> [Lambda KernelsMem]
-> [HistOp KernelsMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp KernelsMem -> Lambda KernelsMem -> HistOp KernelsMem
forall lore lore. HistOp lore -> Lambda lore -> HistOp lore
onOp [HistOp KernelsMem]
ops [Lambda KernelsMem]
lams'
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Stms KernelsMem
alloc_stms,
      Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp KernelsMem]
-> [Type]
-> KernelBody KernelsMem
-> SegOp SegLevel KernelsMem
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist SegLevel
lvl SegSpace
space [HistOp KernelsMem]
ops' [Type]
ts KernelBody KernelsMem
kbody'
    )
  where
    lams :: [Lambda KernelsMem]
lams = (HistOp KernelsMem -> Lambda KernelsMem)
-> [HistOp KernelsMem] -> [Lambda KernelsMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp KernelsMem -> Lambda KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp [HistOp KernelsMem]
ops
    onOp :: HistOp lore -> Lambda lore -> HistOp lore
onOp HistOp lore
op Lambda lore
lam = HistOp lore
op {histOp :: Lambda lore
histOp = Lambda lore
lam}
transformExp ExpT KernelsMem
e =
  (Stms KernelsMem, ExpT KernelsMem)
-> ExpandM (Stms KernelsMem, ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
forall a. Monoid a => a
mempty, ExpT KernelsMem
e)

transformScanRed ::
  SegLevel ->
  SegSpace ->
  [Lambda KernelsMem] ->
  KernelBody KernelsMem ->
  ExpandM (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda KernelsMem]
-> KernelBody KernelsMem
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda KernelsMem]
ops KernelBody KernelsMem
kbody = do
  Names
bound_outside <- (Scope KernelsMem -> Names)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope KernelsMem -> Names)
 -> ReaderT
      (Scope KernelsMem) (StateT VNameSource (Either String)) Names)
-> (Scope KernelsMem -> Names)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope KernelsMem -> [VName]) -> Scope KernelsMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope KernelsMem -> [VName]
forall k a. Map k a -> [k]
M.keys
  let (KernelBody KernelsMem
kbody', Extraction
kbody_allocs) =
        SegLevel
-> Names
-> Names
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
extractKernelBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_in_kernel KernelBody KernelsMem
kbody
      ([Lambda KernelsMem]
ops', [Extraction]
ops_allocs) = [(Lambda KernelsMem, Extraction)]
-> ([Lambda KernelsMem], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda KernelsMem, Extraction)]
 -> ([Lambda KernelsMem], [Extraction]))
-> [(Lambda KernelsMem, Extraction)]
-> ([Lambda KernelsMem], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda KernelsMem -> (Lambda KernelsMem, Extraction))
-> [Lambda KernelsMem] -> [(Lambda KernelsMem, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map (SegLevel
-> Names
-> Names
-> Lambda KernelsMem
-> (Lambda KernelsMem, Extraction)
extractLambdaAllocations SegLevel
lvl Names
bound_outside Names
forall a. Monoid a => a
mempty) [Lambda KernelsMem]
ops
      variantAlloc :: (SegLevel, SubExp, Space) -> Bool
variantAlloc (SegLevel
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_kernel
      variantAlloc (SegLevel, SubExp, Space)
_ = Bool
False
      allocs :: Extraction
allocs = Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      (Extraction
variant_allocs, Extraction
invariant_allocs) = ((SegLevel, SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition (SegLevel, SubExp, Space) -> Bool
variantAlloc Extraction
allocs

  case SegLevel
lvl of
    SegGroup {}
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
        String
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"Cannot handle invariant allocations in SegGroup."
    SegLevel
_ ->
      ()
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> (Stms KernelsMem
    -> KernelBody KernelsMem
    -> OffsetM
         (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> (Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody' ((Stms KernelsMem
  -> KernelBody KernelsMem
  -> OffsetM
       (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
 -> ExpandM
      (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
-> (Stms KernelsMem
    -> KernelBody KernelsMem
    -> OffsetM
         (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem)))
-> ExpandM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
forall a b. (a -> b) -> a -> b
$ \Stms KernelsMem
alloc_stms KernelBody KernelsMem
kbody'' -> do
    [Lambda KernelsMem]
ops'' <- [Lambda KernelsMem]
-> (Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
-> OffsetM [Lambda KernelsMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda KernelsMem]
ops' ((Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
 -> OffsetM [Lambda KernelsMem])
-> (Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
-> OffsetM [Lambda KernelsMem]
forall a b. (a -> b) -> a -> b
$ \Lambda KernelsMem
op' ->
      Scope KernelsMem
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Lambda KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda KernelsMem
op') (OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem))
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda Lambda KernelsMem
op'
    (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
-> OffsetM
     (Stms KernelsMem, ([Lambda KernelsMem], KernelBody KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
alloc_stms, ([Lambda KernelsMem]
ops'', KernelBody KernelsMem
kbody''))
  where
    bound_in_kernel :: Names
bound_in_kernel =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        Scope KernelsMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope KernelsMem -> [VName]) -> Scope KernelsMem -> [VName]
forall a b. (a -> b) -> a -> b
$
          SegSpace -> Scope KernelsMem
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space
            Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody)

allocsForBody ::
  Extraction ->
  Extraction ->
  SegLevel ->
  SegSpace ->
  KernelBody KernelsMem ->
  (Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b) ->
  ExpandM b
allocsForBody :: Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> (Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody' Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms KernelsMem
alloc_stms) <-
    SegLevel
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms KernelsMem)
memoryRequirements
      SegLevel
lvl
      SegSpace
space
      (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody')
      Extraction
variant_allocs
      Extraction
invariant_allocs

  Scope KernelsMem
scope <- ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let scope' :: Scope KernelsMem
scope' = SegSpace -> Scope KernelsMem
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> (NameInfo KernelsMem -> NameInfo KernelsMem)
-> Scope KernelsMem -> Scope KernelsMem
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo KernelsMem -> NameInfo KernelsMem
nameInfoConv Scope KernelsMem
scope
  (String -> ExpandM b)
-> (b -> ExpandM b) -> Either String b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ExpandM b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b -> ExpandM b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String b -> ExpandM b) -> Either String b -> ExpandM b
forall a b. (a -> b) -> a -> b
$
    Scope KernelsMem -> RebaseMap -> OffsetM b -> Either String b
forall a.
Scope KernelsMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope KernelsMem
scope' RebaseMap
alloc_offsets (OffsetM b -> Either String b) -> OffsetM b -> Either String b
forall a b. (a -> b) -> a -> b
$ do
      KernelBody KernelsMem
kbody'' <- KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody KernelBody KernelsMem
kbody'
      Stms KernelsMem -> KernelBody KernelsMem -> OffsetM b
m Stms KernelsMem
alloc_stms KernelBody KernelsMem
kbody''

memoryRequirements ::
  SegLevel ->
  SegSpace ->
  Stms KernelsMem ->
  Extraction ->
  Extraction ->
  ExpandM (RebaseMap, Stms KernelsMem)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms KernelsMem)
memoryRequirements SegLevel
lvl SegSpace
space Stms KernelsMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  ((SubExp
num_threads, SubExp
num_groups64, SubExp
num_threads64), Stms KernelsMem
num_threads_stms) <- Binder KernelsMem (SubExp, SubExp, SubExp)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     ((SubExp, SubExp, SubExp), Stms KernelsMem)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder KernelsMem (SubExp, SubExp, SubExp)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      ((SubExp, SubExp, SubExp), Stms KernelsMem))
-> Binder KernelsMem (SubExp, SubExp, SubExp)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     ((SubExp, SubExp, SubExp), Stms KernelsMem)
forall a b. (a -> b) -> a -> b
$ do
    SubExp
num_threads <-
      String
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> BinderT KernelsMem (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore (BinderT KernelsMem (State VNameSource)))
 -> BinderT KernelsMem (State VNameSource) SubExp)
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> BinderT KernelsMem (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$
          BinOp -> SubExp -> SubExp -> BasicOp
BinOp
            (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef)
            (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
            (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
    SubExp
num_groups64 <-
      String
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> BinderT KernelsMem (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_groups64" (Exp (Lore (BinderT KernelsMem (State VNameSource)))
 -> BinderT KernelsMem (State VNameSource) SubExp)
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> BinderT KernelsMem (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
    SubExp
num_threads64 <- String
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> BinderT KernelsMem (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads64" (Exp (Lore (BinderT KernelsMem (State VNameSource)))
 -> BinderT KernelsMem (State VNameSource) SubExp)
-> Exp (Lore (BinderT KernelsMem (State VNameSource)))
-> BinderT KernelsMem (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) SubExp
num_threads
    (SubExp, SubExp, SubExp)
-> Binder KernelsMem (SubExp, SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_threads, SubExp
num_groups64, SubExp
num_threads64)

  (Stms KernelsMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms KernelsMem
num_threads_stms (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either String))
   (Stms KernelsMem, RebaseMap)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem, RebaseMap))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      (SubExp, SubExp, Count NumGroups SubExp, Count GroupSize SubExp)
-> SegSpace
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
expandedInvariantAllocations
        (SubExp
num_threads64, SubExp
num_groups64, SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl, SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
        SegSpace
space
        Extraction
invariant_allocs

  (Stms KernelsMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms KernelsMem
num_threads_stms (ReaderT
   (Scope KernelsMem)
   (StateT VNameSource (Either String))
   (Stms KernelsMem, RebaseMap)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem, RebaseMap))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
expandedVariantAllocations
        SubExp
num_threads
        SegSpace
space
        Stms KernelsMem
kstms
        Extraction
variant_allocs

  (RebaseMap, Stms KernelsMem)
-> ExpandM (RebaseMap, Stms KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
      Stms KernelsMem
num_threads_stms Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem
invariant_alloc_stms Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem
variant_alloc_stms
    )

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (SegLevel, SubExp, Space)

extractKernelBodyAllocations ::
  SegLevel ->
  Names ->
  Names ->
  KernelBody KernelsMem ->
  ( KernelBody KernelsMem,
    Extraction
  )
extractKernelBodyAllocations :: SegLevel
-> Names
-> Names
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
extractKernelBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel =
  SegLevel
-> Names
-> Names
-> (KernelBody KernelsMem -> Stms KernelsMem)
-> (Stms KernelsMem
    -> KernelBody KernelsMem -> KernelBody KernelsMem)
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
forall body.
SegLevel
-> Names
-> Names
-> (body -> Stms KernelsMem)
-> (Stms KernelsMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms ((Stms KernelsMem
  -> KernelBody KernelsMem -> KernelBody KernelsMem)
 -> KernelBody KernelsMem -> (KernelBody KernelsMem, Extraction))
-> (Stms KernelsMem
    -> KernelBody KernelsMem -> KernelBody KernelsMem)
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms KernelsMem
stms KernelBody KernelsMem
kbody -> KernelBody KernelsMem
kbody {kernelBodyStms :: Stms KernelsMem
kernelBodyStms = Stms KernelsMem
stms}

extractBodyAllocations ::
  SegLevel ->
  Names ->
  Names ->
  Body KernelsMem ->
  (Body KernelsMem, Extraction)
extractBodyAllocations :: SegLevel
-> Names
-> Names
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
extractBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel =
  SegLevel
-> Names
-> Names
-> (Body KernelsMem -> Stms KernelsMem)
-> (Stms KernelsMem -> Body KernelsMem -> Body KernelsMem)
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
forall body.
SegLevel
-> Names
-> Names
-> (body -> Stms KernelsMem)
-> (Stms KernelsMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel Body KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms ((Stms KernelsMem -> Body KernelsMem -> Body KernelsMem)
 -> Body KernelsMem -> (Body KernelsMem, Extraction))
-> (Stms KernelsMem -> Body KernelsMem -> Body KernelsMem)
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms KernelsMem
stms Body KernelsMem
body -> Body KernelsMem
body {bodyStms :: Stms KernelsMem
bodyStms = Stms KernelsMem
stms}

extractLambdaAllocations ::
  SegLevel ->
  Names ->
  Names ->
  Lambda KernelsMem ->
  (Lambda KernelsMem, Extraction)
extractLambdaAllocations :: SegLevel
-> Names
-> Names
-> Lambda KernelsMem
-> (Lambda KernelsMem, Extraction)
extractLambdaAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel Lambda KernelsMem
lam = (Lambda KernelsMem
lam {lambdaBody :: Body KernelsMem
lambdaBody = Body KernelsMem
body'}, Extraction
allocs)
  where
    (Body KernelsMem
body', Extraction
allocs) = SegLevel
-> Names
-> Names
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
extractBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel (Body KernelsMem -> (Body KernelsMem, Extraction))
-> Body KernelsMem -> (Body KernelsMem, Extraction)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam

extractGenericBodyAllocations ::
  SegLevel ->
  Names ->
  Names ->
  (body -> Stms KernelsMem) ->
  (Stms KernelsMem -> body -> body) ->
  body ->
  ( body,
    Extraction
  )
extractGenericBodyAllocations :: SegLevel
-> Names
-> Names
-> (body -> Stms KernelsMem)
-> (Stms KernelsMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel body -> Stms KernelsMem
get_stms Stms KernelsMem -> body -> body
set_stms body
body =
  let ([Stm KernelsMem]
stms, Extraction
allocs) =
        Writer Extraction [Stm KernelsMem]
-> ([Stm KernelsMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm KernelsMem]
 -> ([Stm KernelsMem], Extraction))
-> Writer Extraction [Stm KernelsMem]
-> ([Stm KernelsMem], Extraction)
forall a b. (a -> b) -> a -> b
$
          ([Maybe (Stm KernelsMem)] -> [Stm KernelsMem])
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
-> Writer Extraction [Stm KernelsMem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm KernelsMem)] -> [Stm KernelsMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm KernelsMem)]
 -> Writer Extraction [Stm KernelsMem])
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
-> Writer Extraction [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$
            (Stm KernelsMem
 -> WriterT Extraction Identity (Maybe (Stm KernelsMem)))
-> [Stm KernelsMem]
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SegLevel
-> Names
-> Names
-> Stm KernelsMem
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
extractStmAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel) ([Stm KernelsMem]
 -> WriterT Extraction Identity [Maybe (Stm KernelsMem)])
-> [Stm KernelsMem]
-> WriterT Extraction Identity [Maybe (Stm KernelsMem)]
forall a b. (a -> b) -> a -> b
$
              Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem -> [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ body -> Stms KernelsMem
get_stms body
body
   in (Stms KernelsMem -> body -> body
set_stms ([Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm KernelsMem]
stms) body
body, Extraction
allocs)

expandable :: Space -> Bool
expandable :: Space -> Bool
expandable (Space String
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True

extractStmAllocations ::
  SegLevel ->
  Names ->
  Names ->
  Stm KernelsMem ->
  Writer Extraction (Maybe (Stm KernelsMem))
extractStmAllocations :: SegLevel
-> Names
-> Names
-> Stm KernelsMem
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
extractStmAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel (Let (Pattern [] [PatElemT (LetDec KernelsMem)
patElem]) StmAux (ExpDec KernelsMem)
_ (Op (Alloc size space)))
  | Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size Bool -> Bool -> Bool
|| SubExp -> Bool
boundInKernel SubExp
size = do
    Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName -> (SegLevel, SubExp, Space) -> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
patElem) (SegLevel
lvl, SubExp
size, Space
space)
    Maybe (Stm KernelsMem)
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm KernelsMem)
forall a. Maybe a
Nothing
  where
    expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    expandableSize Constant {} = Bool
True
    boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    boundInKernel Constant {} = Bool
False
extractStmAllocations SegLevel
lvl Names
bound_outside Names
bound_kernel Stm KernelsMem
stm = do
  ExpT KernelsMem
e <- Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
-> ExpT KernelsMem -> WriterT Extraction Identity (ExpT KernelsMem)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM (SegLevel
-> Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
expMapper SegLevel
lvl) (ExpT KernelsMem -> WriterT Extraction Identity (ExpT KernelsMem))
-> ExpT KernelsMem -> WriterT Extraction Identity (ExpT KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem -> ExpT KernelsMem
forall lore. Stm lore -> Exp lore
stmExp Stm KernelsMem
stm
  Maybe (Stm KernelsMem)
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stm KernelsMem)
 -> WriterT Extraction Identity (Maybe (Stm KernelsMem)))
-> Maybe (Stm KernelsMem)
-> WriterT Extraction Identity (Maybe (Stm KernelsMem))
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem -> Maybe (Stm KernelsMem)
forall a. a -> Maybe a
Just (Stm KernelsMem -> Maybe (Stm KernelsMem))
-> Stm KernelsMem -> Maybe (Stm KernelsMem)
forall a b. (a -> b) -> a -> b
$ Stm KernelsMem
stm {stmExp :: ExpT KernelsMem
stmExp = ExpT KernelsMem
e}
  where
    expMapper :: SegLevel
-> Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
expMapper SegLevel
lvl' =
      Mapper KernelsMem KernelsMem (WriterT Extraction Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope KernelsMem
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
mapOnBody = (Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem))
-> Scope KernelsMem
-> Body KernelsMem
-> WriterT Extraction Identity (Body KernelsMem)
forall a b. a -> b -> a
const ((Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem))
 -> Scope KernelsMem
 -> Body KernelsMem
 -> WriterT Extraction Identity (Body KernelsMem))
-> (Body KernelsMem
    -> WriterT Extraction Identity (Body KernelsMem))
-> Scope KernelsMem
-> Body KernelsMem
-> WriterT Extraction Identity (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
onBody SegLevel
lvl',
          mapOnOp :: Op KernelsMem -> WriterT Extraction Identity (Op KernelsMem)
mapOnOp = Op KernelsMem -> WriterT Extraction Identity (Op KernelsMem)
MemOp (HostOp KernelsMem ())
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
onOp
        }

    onBody :: SegLevel
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
onBody SegLevel
lvl' Body KernelsMem
body = do
      let (Body KernelsMem
body', Extraction
allocs) = SegLevel
-> Names
-> Names
-> Body KernelsMem
-> (Body KernelsMem, Extraction)
extractBodyAllocations SegLevel
lvl' Names
bound_outside Names
bound_kernel Body KernelsMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Body KernelsMem
body'

    onOp :: MemOp (HostOp KernelsMem ())
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
onOp (Inner (SegOp SegOp SegLevel KernelsMem
op)) =
      HostOp KernelsMem () -> MemOp (HostOp KernelsMem ())
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem () -> MemOp (HostOp KernelsMem ()))
-> (SegOp SegLevel KernelsMem -> HostOp KernelsMem ())
-> SegOp SegLevel KernelsMem
-> MemOp (HostOp KernelsMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel KernelsMem -> HostOp KernelsMem ()
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel KernelsMem -> MemOp (HostOp KernelsMem ()))
-> WriterT Extraction Identity (SegOp SegLevel KernelsMem)
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  SegLevel KernelsMem KernelsMem (WriterT Extraction Identity)
-> SegOp SegLevel KernelsMem
-> WriterT Extraction Identity (SegOp SegLevel KernelsMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM (SegLevel
-> SegOpMapper
     SegLevel KernelsMem KernelsMem (WriterT Extraction Identity)
opMapper (SegOp SegLevel KernelsMem -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel KernelsMem
op)) SegOp SegLevel KernelsMem
op
    onOp MemOp (HostOp KernelsMem ())
op = MemOp (HostOp KernelsMem ())
-> WriterT Extraction Identity (MemOp (HostOp KernelsMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp KernelsMem ())
op

    opMapper :: SegLevel
-> SegOpMapper
     SegLevel KernelsMem KernelsMem (WriterT Extraction Identity)
opMapper SegLevel
lvl' =
      SegOpMapper SegLevel Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
mapOnSegOpLambda = SegLevel
-> Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
onLambda SegLevel
lvl',
          mapOnSegOpBody :: KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
mapOnSegOpBody = SegLevel
-> KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
onKernelBody SegLevel
lvl'
        }

    onKernelBody :: SegLevel
-> KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
onKernelBody SegLevel
lvl' KernelBody KernelsMem
body = do
      let (KernelBody KernelsMem
body', Extraction
allocs) = SegLevel
-> Names
-> Names
-> KernelBody KernelsMem
-> (KernelBody KernelsMem, Extraction)
extractKernelBodyAllocations SegLevel
lvl' Names
bound_outside Names
bound_kernel KernelBody KernelsMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      KernelBody KernelsMem
-> WriterT Extraction Identity (KernelBody KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody KernelsMem
body'

    onLambda :: SegLevel
-> Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
onLambda SegLevel
lvl' Lambda KernelsMem
lam = do
      Body KernelsMem
body <- SegLevel
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
onBody SegLevel
lvl' (Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem))
-> Body KernelsMem -> WriterT Extraction Identity (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
      Lambda KernelsMem
-> WriterT Extraction Identity (Lambda KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda KernelsMem
lam {lambdaBody :: Body KernelsMem
lambdaBody = Body KernelsMem
body}

expandedInvariantAllocations ::
  ( SubExp,
    SubExp,
    Count NumGroups SubExp,
    Count GroupSize SubExp
  ) ->
  SegSpace ->
  Extraction ->
  ExpandM (Stms KernelsMem, RebaseMap)
expandedInvariantAllocations :: (SubExp, SubExp, Count NumGroups SubExp, Count GroupSize SubExp)
-> SegSpace
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
expandedInvariantAllocations
  ( SubExp
num_threads64,
    SubExp
num_groups64,
    Count SubExp
num_groups,
    Count SubExp
group_size
    )
  SegSpace
segspace
  Extraction
invariant_allocs = do
    -- We expand the invariant allocations by adding an inner dimension
    -- equal to the number of kernel threads.
    ([Stms KernelsMem]
alloc_bnds, [RebaseMap]
rebases) <- [(Stms KernelsMem, RebaseMap)] -> ([Stms KernelsMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms KernelsMem, RebaseMap)]
 -> ([Stms KernelsMem], [RebaseMap]))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     [(Stms KernelsMem, RebaseMap)]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     ([Stms KernelsMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SegLevel, SubExp, Space))
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms KernelsMem, RebaseMap))
-> [(VName, (SegLevel, SubExp, Space))]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     [(Stms KernelsMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SegLevel, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
expand (Extraction -> [(VName, (SegLevel, SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs)

    (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stms KernelsMem] -> Stms KernelsMem
forall a. Monoid a => [a] -> a
mconcat [Stms KernelsMem]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
    where
      expand :: (VName, (SegLevel, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
expand (VName
mem, (SegLevel
lvl, SubExp
per_thread_size, Space
space)) = do
        VName
total_size <- String
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"total_size"
        let sizepat :: PatternT LetDecMem
sizepat = [PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
total_size (LetDecMem -> PatElemT LetDecMem)
-> LetDecMem -> PatElemT LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64]
            allocpat :: PatternT LetDecMem
allocpat = [PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (LetDecMem -> PatElemT LetDecMem)
-> LetDecMem -> PatElemT LetDecMem
forall a b. (a -> b) -> a -> b
$ Space -> LetDecMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
            num_users :: SubExp
num_users = case SegLevel
lvl of
              SegThread {} -> SubExp
num_threads64
              SegGroup {} -> SubExp
num_groups64
        (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList
              [ Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT LetDecMem
sizepat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> ExpT KernelsMem
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT KernelsMem) -> BasicOp -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_users SubExp
per_thread_size,
                Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT LetDecMem
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$
                  Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp KernelsMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
              ],
            VName
-> (([TPrimExp Int32 VName], PrimType)
    -> IxFun (TPrimExp Int32 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int32 VName], PrimType)
  -> IxFun (TPrimExp Int32 VName))
 -> RebaseMap)
-> (([TPrimExp Int32 VName], PrimType)
    -> IxFun (TPrimExp Int32 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SegLevel
-> ([TPrimExp Int32 VName], PrimType)
-> IxFun (TPrimExp Int32 VName)
newBase SegLevel
lvl
          )

      untouched :: d -> DimIndex d
untouched d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1

      newBase :: SegLevel
-> ([TPrimExp Int32 VName], PrimType)
-> IxFun (TPrimExp Int32 VName)
newBase SegThread {} ([TPrimExp Int32 VName]
old_shape, PrimType
_) =
        let num_dims :: Int
num_dims = [TPrimExp Int32 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int32 VName]
old_shape
            perm :: [Int]
perm = Int
num_dims Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
0 .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
            root_ixfun :: IxFun (TPrimExp Int32 VName)
root_ixfun =
              [TPrimExp Int32 VName] -> IxFun (TPrimExp Int32 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota
                ( [TPrimExp Int32 VName]
old_shape
                    [TPrimExp Int32 VName]
-> [TPrimExp Int32 VName] -> [TPrimExp Int32 VName]
forall a. [a] -> [a] -> [a]
++ [ SubExp -> TPrimExp Int32 VName
pe32 SubExp
num_groups
                           TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Int32 VName
forall a. Num a => a -> a -> a
* SubExp -> TPrimExp Int32 VName
pe32 SubExp
group_size
                       ]
                )
            permuted_ixfun :: IxFun (TPrimExp Int32 VName)
permuted_ixfun = IxFun (TPrimExp Int32 VName)
-> [Int] -> IxFun (TPrimExp Int32 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int32 VName)
root_ixfun [Int]
perm
            offset_ixfun :: IxFun (TPrimExp Int32 VName)
offset_ixfun =
              IxFun (TPrimExp Int32 VName)
-> Slice (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int32 VName)
permuted_ixfun (Slice (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName))
-> Slice (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName)
forall a b. (a -> b) -> a -> b
$
                TPrimExp Int32 VName -> DimIndex (TPrimExp Int32 VName)
forall d. d -> DimIndex d
DimFix (VName -> TPrimExp Int32 VName
forall a. a -> TPrimExp Int32 a
le32 (SegSpace -> VName
segFlat SegSpace
segspace)) DimIndex (TPrimExp Int32 VName)
-> Slice (TPrimExp Int32 VName) -> Slice (TPrimExp Int32 VName)
forall a. a -> [a] -> [a]
:
                (TPrimExp Int32 VName -> DimIndex (TPrimExp Int32 VName))
-> [TPrimExp Int32 VName] -> Slice (TPrimExp Int32 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int32 VName -> DimIndex (TPrimExp Int32 VName)
forall d. Num d => d -> DimIndex d
untouched [TPrimExp Int32 VName]
old_shape
         in IxFun (TPrimExp Int32 VName)
offset_ixfun
      newBase SegGroup {} ([TPrimExp Int32 VName]
old_shape, PrimType
_) =
        let root_ixfun :: IxFun (TPrimExp Int32 VName)
root_ixfun = [TPrimExp Int32 VName] -> IxFun (TPrimExp Int32 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (SubExp -> TPrimExp Int32 VName
pe32 SubExp
num_groups TPrimExp Int32 VName
-> [TPrimExp Int32 VName] -> [TPrimExp Int32 VName]
forall a. a -> [a] -> [a]
: [TPrimExp Int32 VName]
old_shape)
            offset_ixfun :: IxFun (TPrimExp Int32 VName)
offset_ixfun =
              IxFun (TPrimExp Int32 VName)
-> Slice (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int32 VName)
root_ixfun (Slice (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName))
-> Slice (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName)
forall a b. (a -> b) -> a -> b
$
                TPrimExp Int32 VName -> DimIndex (TPrimExp Int32 VName)
forall d. d -> DimIndex d
DimFix (VName -> TPrimExp Int32 VName
forall a. a -> TPrimExp Int32 a
le32 (SegSpace -> VName
segFlat SegSpace
segspace)) DimIndex (TPrimExp Int32 VName)
-> Slice (TPrimExp Int32 VName) -> Slice (TPrimExp Int32 VName)
forall a. a -> [a] -> [a]
:
                (TPrimExp Int32 VName -> DimIndex (TPrimExp Int32 VName))
-> [TPrimExp Int32 VName] -> Slice (TPrimExp Int32 VName)
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int32 VName -> DimIndex (TPrimExp Int32 VName)
forall d. Num d => d -> DimIndex d
untouched [TPrimExp Int32 VName]
old_shape
         in IxFun (TPrimExp Int32 VName)
offset_ixfun

expandedVariantAllocations ::
  SubExp ->
  SegSpace ->
  Stms KernelsMem ->
  Extraction ->
  ExpandM (Stms KernelsMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms KernelsMem
-> Extraction
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms KernelsMem
_ Extraction
variant_allocs
  | Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms KernelsMem
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: Result
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms Kernels
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> Result
-> SegSpace
-> Stms KernelsMem
-> ExpandM (Stms Kernels, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
variant_sizes SegSpace
kspace Stms KernelsMem
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  (SymbolTable (Wise KernelsMem)
_, Stms KernelsMem
slice_stms_tmp) <-
    Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall (m :: * -> *).
(HasScope KernelsMem m, MonadFreshNames m) =>
Stms KernelsMem
-> m (SymbolTable (Wise KernelsMem), Stms KernelsMem)
simplifyStms (Stms KernelsMem
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (SymbolTable (Wise KernelsMem), Stms KernelsMem))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (SymbolTable (Wise KernelsMem), Stms KernelsMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
forall (m :: * -> *).
(MonadFreshNames m, HasScope KernelsMem m) =>
Stms Kernels -> m (Stms KernelsMem)
explicitAllocationsInStms Stms Kernels
slice_stms
  Stms KernelsMem
slice_stms' <- Stms KernelsMem
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem)
transformStms Stms KernelsMem
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
        [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
 -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
          ([(VName, Space)]
 -> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall a c.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
            (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
            ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm KernelsMem]
alloc_bnds, [RebaseMap]
rebases) <- [(Stm KernelsMem, RebaseMap)] -> ([Stm KernelsMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stm KernelsMem, RebaseMap)] -> ([Stm KernelsMem], [RebaseMap]))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     [(Stm KernelsMem, RebaseMap)]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     ([Stm KernelsMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, SubExp, Space))
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stm KernelsMem, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     [(Stm KernelsMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stm KernelsMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  (Stms KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms KernelsMem
slice_stms' Stms KernelsMem -> Stms KernelsMem -> Stms KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm KernelsMem]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stm KernelsMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
      let allocpat :: PatternT LetDecMem
allocpat = [PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
mem (LetDecMem -> PatElemT LetDecMem)
-> LetDecMem -> PatElemT LetDecMem
forall a b. (a -> b) -> a -> b
$ Space -> LetDecMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      (Stm KernelsMem, RebaseMap)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stm KernelsMem, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT LetDecMem
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT KernelsMem -> Stm KernelsMem)
-> ExpT KernelsMem -> Stm KernelsMem
forall a b. (a -> b) -> a -> b
$ Op KernelsMem -> ExpT KernelsMem
forall lore. Op lore -> ExpT lore
Op (Op KernelsMem -> ExpT KernelsMem)
-> Op KernelsMem -> ExpT KernelsMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp KernelsMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
          VName
-> (([TPrimExp Int32 VName], PrimType)
    -> IxFun (TPrimExp Int32 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int32 VName], PrimType)
  -> IxFun (TPrimExp Int32 VName))
 -> RebaseMap)
-> (([TPrimExp Int32 VName], PrimType)
    -> IxFun (TPrimExp Int32 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int32 VName], PrimType)
-> IxFun (TPrimExp Int32 VName)
newBase SubExp
offset
        )

    num_threads' :: TPrimExp Int32 VName
num_threads' = SubExp -> TPrimExp Int32 VName
pe32 SubExp
num_threads
    gtid :: TPrimExp Int32 VName
gtid = PrimExp VName -> TPrimExp Int32 VName
forall v. PrimExp v -> TPrimExp Int32 v
isInt32 (PrimExp VName -> TPrimExp Int32 VName)
-> PrimExp VName -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (SegSpace -> VName
segFlat SegSpace
kspace) PrimType
int32

    -- For the variant allocations, we add an inner dimension,
    -- which is then offset by a thread-specific amount.
    newBase :: SubExp
-> ([TPrimExp Int32 VName], PrimType)
-> IxFun (TPrimExp Int32 VName)
newBase SubExp
size_per_thread ([TPrimExp Int32 VName]
old_shape, PrimType
pt) =
      let elems_per_thread :: TPrimExp Int32 VName
elems_per_thread =
            PrimExp VName -> TPrimExp Int32 VName
forall v. PrimExp v -> TPrimExp Int32 v
isInt32 (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32 (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
size_per_thread))
              TPrimExp Int32 VName
-> TPrimExp Int32 VName -> TPrimExp Int32 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int32 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
          root_ixfun :: IxFun (TPrimExp Int32 VName)
root_ixfun = [TPrimExp Int32 VName] -> IxFun (TPrimExp Int32 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int32 VName
elems_per_thread, TPrimExp Int32 VName
num_threads']
          offset_ixfun :: IxFun (TPrimExp Int32 VName)
offset_ixfun =
            IxFun (TPrimExp Int32 VName)
-> Slice (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice
              IxFun (TPrimExp Int32 VName)
root_ixfun
              [ TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> TPrimExp Int32 VName
-> DimIndex (TPrimExp Int32 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int32 VName
0 TPrimExp Int32 VName
num_threads' TPrimExp Int32 VName
1,
                TPrimExp Int32 VName -> DimIndex (TPrimExp Int32 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int32 VName
gtid
              ]
          shapechange :: [DimChange (TPrimExp Int32 VName)]
shapechange =
            if [TPrimExp Int32 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int32 VName]
old_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
              then (TPrimExp Int32 VName -> DimChange (TPrimExp Int32 VName))
-> [TPrimExp Int32 VName] -> [DimChange (TPrimExp Int32 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int32 VName -> DimChange (TPrimExp Int32 VName)
forall d. d -> DimChange d
DimCoercion [TPrimExp Int32 VName]
old_shape
              else (TPrimExp Int32 VName -> DimChange (TPrimExp Int32 VName))
-> [TPrimExp Int32 VName] -> [DimChange (TPrimExp Int32 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int32 VName -> DimChange (TPrimExp Int32 VName)
forall d. d -> DimChange d
DimNew [TPrimExp Int32 VName]
old_shape
       in IxFun (TPrimExp Int32 VName)
-> [DimChange (TPrimExp Int32 VName)]
-> IxFun (TPrimExp Int32 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int32 VName)
offset_ixfun [DimChange (TPrimExp Int32 VName)]
shapechange

-- | A map from memory block names to new index function bases.
type RebaseMap = M.Map VName (([TPrimExp Int32 VName], PrimType) -> IxFun)

newtype OffsetM a
  = OffsetM
      ( ReaderT
          (Scope KernelsMem)
          (ReaderT RebaseMap (Either String))
          a
      )
  deriving
    ( Functor OffsetM
a -> OffsetM a
Functor OffsetM
-> (forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
    (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
OffsetM a -> OffsetM b -> OffsetM b
OffsetM a -> OffsetM b -> OffsetM a
OffsetM (a -> b) -> OffsetM a -> OffsetM b
(a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
$cp1Applicative :: Functor OffsetM
Applicative,
      a -> OffsetM b -> OffsetM a
(a -> b) -> OffsetM a -> OffsetM b
(forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
      Applicative OffsetM
a -> OffsetM a
Applicative OffsetM
-> (forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
OffsetM a -> (a -> OffsetM b) -> OffsetM b
OffsetM a -> OffsetM b -> OffsetM b
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$cp1Monad :: Applicative OffsetM
Monad,
      HasScope KernelsMem,
      LocalScope KernelsMem,
      MonadError String
    )

runOffsetM :: Scope KernelsMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: Scope KernelsMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope KernelsMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either String)) a
m) =
  ReaderT RebaseMap (Either String) a -> RebaseMap -> Either String a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either String)) a
-> Scope KernelsMem -> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either String)) a
m Scope KernelsMem
scope) RebaseMap
offsets

askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = ReaderT
  (Scope KernelsMem) (ReaderT RebaseMap (Either String)) RebaseMap
-> OffsetM RebaseMap
forall a.
ReaderT (Scope KernelsMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM (ReaderT
   (Scope KernelsMem) (ReaderT RebaseMap (Either String)) RebaseMap
 -> OffsetM RebaseMap)
-> ReaderT
     (Scope KernelsMem) (ReaderT RebaseMap (Either String)) RebaseMap
-> OffsetM RebaseMap
forall a b. (a -> b) -> a -> b
$ ReaderT RebaseMap (Either String) RebaseMap
-> ReaderT
     (Scope KernelsMem) (ReaderT RebaseMap (Either String)) RebaseMap
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT RebaseMap (Either String) RebaseMap
forall r (m :: * -> *). MonadReader r m => m r
ask

lookupNewBase :: VName -> ([TPrimExp Int32 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int32 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int32 VName)))
lookupNewBase VName
name ([TPrimExp Int32 VName], PrimType)
x = do
  RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
  Maybe (IxFun (TPrimExp Int32 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int32 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp Int32 VName))
 -> OffsetM (Maybe (IxFun (TPrimExp Int32 VName))))
-> Maybe (IxFun (TPrimExp Int32 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int32 VName)))
forall a b. (a -> b) -> a -> b
$ ((([TPrimExp Int32 VName], PrimType)
 -> IxFun (TPrimExp Int32 VName))
-> ([TPrimExp Int32 VName], PrimType)
-> IxFun (TPrimExp Int32 VName)
forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int32 VName], PrimType)
x) ((([TPrimExp Int32 VName], PrimType)
  -> IxFun (TPrimExp Int32 VName))
 -> IxFun (TPrimExp Int32 VName))
-> Maybe
     (([TPrimExp Int32 VName], PrimType)
      -> IxFun (TPrimExp Int32 VName))
-> Maybe (IxFun (TPrimExp Int32 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
     (([TPrimExp Int32 VName], PrimType)
      -> IxFun (TPrimExp Int32 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody :: KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody KernelBody KernelsMem
kbody = do
  Scope KernelsMem
scope <- OffsetM (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Stms KernelsMem
stms' <-
    [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm KernelsMem] -> Stms KernelsMem)
-> ((Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem])
-> (Scope KernelsMem, [Stm KernelsMem])
-> Stms KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem]
forall a b. (a, b) -> b
snd
      ((Scope KernelsMem, [Stm KernelsMem]) -> Stms KernelsMem)
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
-> OffsetM (Stms KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope KernelsMem
 -> Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Scope KernelsMem
-> [Stm KernelsMem]
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope KernelsMem
scope' -> Scope KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope' (OffsetM (Scope KernelsMem, Stm KernelsMem)
 -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> (Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Stm KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm)
        Scope KernelsMem
scope
        (Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem -> [Stm KernelsMem]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody)
  KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody KernelsMem
kbody {kernelBodyStms :: Stms KernelsMem
kernelBodyStms = Stms KernelsMem
stms'}

offsetMemoryInBody :: Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody :: Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody (Body BodyDec KernelsMem
dec Stms KernelsMem
stms Result
res) = do
  Scope KernelsMem
scope <- OffsetM (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Stms KernelsMem
stms' <-
    [Stm KernelsMem] -> Stms KernelsMem
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm KernelsMem] -> Stms KernelsMem)
-> ((Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem])
-> (Scope KernelsMem, [Stm KernelsMem])
-> Stms KernelsMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope KernelsMem, [Stm KernelsMem]) -> [Stm KernelsMem]
forall a b. (a, b) -> b
snd
      ((Scope KernelsMem, [Stm KernelsMem]) -> Stms KernelsMem)
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
-> OffsetM (Stms KernelsMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope KernelsMem
 -> Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Scope KernelsMem
-> [Stm KernelsMem]
-> OffsetM (Scope KernelsMem, [Stm KernelsMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope KernelsMem
scope' -> Scope KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
scope' (OffsetM (Scope KernelsMem, Stm KernelsMem)
 -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> (Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem))
-> Stm KernelsMem
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm)
        Scope KernelsMem
scope
        (Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms KernelsMem
stms)
  Body KernelsMem -> OffsetM (Body KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body KernelsMem -> OffsetM (Body KernelsMem))
-> Body KernelsMem -> OffsetM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ BodyDec KernelsMem -> Stms KernelsMem -> Result -> Body KernelsMem
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec KernelsMem
dec Stms KernelsMem
stms' Result
res

offsetMemoryInStm :: Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm :: Stm KernelsMem -> OffsetM (Scope KernelsMem, Stm KernelsMem)
offsetMemoryInStm (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
dec ExpT KernelsMem
e) = do
  PatternT LetDecMem
pat' <- Pattern KernelsMem -> OffsetM (Pattern KernelsMem)
offsetMemoryInPattern Pattern KernelsMem
pat
  ExpT KernelsMem
e' <- Scope KernelsMem
-> OffsetM (ExpT KernelsMem) -> OffsetM (ExpT KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (PatternT LetDecMem -> Scope KernelsMem
forall lore dec. (LetDec lore ~ dec) => PatternT dec -> Scope lore
scopeOfPattern PatternT LetDecMem
pat') (OffsetM (ExpT KernelsMem) -> OffsetM (ExpT KernelsMem))
-> OffsetM (ExpT KernelsMem) -> OffsetM (ExpT KernelsMem)
forall a b. (a -> b) -> a -> b
$ ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
offsetMemoryInExp ExpT KernelsMem
e
  Scope KernelsMem
scope <- OffsetM (Scope KernelsMem)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <- ReaderT (Scope KernelsMem) OffsetM [ExpReturns]
-> Scope KernelsMem -> OffsetM [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ExpT KernelsMem -> ReaderT (Scope KernelsMem) OffsetM [ExpReturns]
forall (m :: * -> *) lore.
(Monad m, HasScope lore m, Mem lore) =>
Exp lore -> m [ExpReturns]
expReturns ExpT KernelsMem
e') Scope KernelsMem
scope
  let pat'' :: PatternT LetDecMem
pat'' =
        [PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern
          (PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT LetDecMem
pat')
          ((PatElemT LetDecMem -> ExpReturns -> PatElemT LetDecMem)
-> [PatElemT LetDecMem] -> [ExpReturns] -> [PatElemT LetDecMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT LetDecMem -> ExpReturns -> PatElemT LetDecMem
pick (PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT LetDecMem
pat') [ExpReturns]
rts)
      stm :: Stm KernelsMem
stm = Pattern KernelsMem
-> StmAux (ExpDec KernelsMem) -> ExpT KernelsMem -> Stm KernelsMem
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern KernelsMem
PatternT LetDecMem
pat'' StmAux (ExpDec KernelsMem)
dec ExpT KernelsMem
e'
  let scope' :: Scope KernelsMem
scope' = Stm KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm KernelsMem
stm Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> Scope KernelsMem
scope
  (Scope KernelsMem, Stm KernelsMem)
-> OffsetM (Scope KernelsMem, Stm KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scope KernelsMem
scope', Stm KernelsMem
stm)
  where
    pick ::
      PatElemT (MemInfo SubExp NoUniqueness MemBind) ->
      ExpReturns ->
      PatElemT (MemInfo SubExp NoUniqueness MemBind)
    pick :: PatElemT LetDecMem -> ExpReturns -> PatElemT LetDecMem
pick
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
        | Just IxFun (TPrimExp Int32 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int32 VName))
instantiateIxFun ExtIxFun
extixfun =
          VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (TPrimExp Int32 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int32 VName)
ixfun))
    pick PatElemT LetDecMem
p ExpReturns
_ = PatElemT LetDecMem
p

    instantiateIxFun :: ExtIxFun -> Maybe IxFun
    instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int32 VName))
instantiateIxFun = (TPrimExp Int32 (Ext VName) -> Maybe (TPrimExp Int32 VName))
-> ExtIxFun -> Maybe (IxFun (TPrimExp Int32 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int32 (Ext VName) -> Maybe (TPrimExp Int32 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall a. Ext a -> Maybe a
inst)
      where
        inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
        inst (Free a
x) = a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

offsetMemoryInPattern :: Pattern KernelsMem -> OffsetM (Pattern KernelsMem)
offsetMemoryInPattern :: Pattern KernelsMem -> OffsetM (Pattern KernelsMem)
offsetMemoryInPattern (Pattern [PatElemT (LetDec KernelsMem)]
ctx [PatElemT (LetDec KernelsMem)]
vals) = do
  (PatElemT LetDecMem -> OffsetM ())
-> [PatElemT LetDecMem] -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT LetDecMem -> OffsetM ()
forall dec (m :: * -> *).
(Typed dec, MonadError String m) =>
PatElemT dec -> m ()
inspectCtx [PatElemT (LetDec KernelsMem)]
[PatElemT LetDecMem]
ctx
  [PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec KernelsMem)]
[PatElemT LetDecMem]
ctx ([PatElemT LetDecMem] -> PatternT LetDecMem)
-> OffsetM [PatElemT LetDecMem] -> OffsetM (PatternT LetDecMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT LetDecMem -> OffsetM (PatElemT LetDecMem))
-> [PatElemT LetDecMem] -> OffsetM [PatElemT LetDecMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT LetDecMem -> OffsetM (PatElemT LetDecMem)
forall u. PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal [PatElemT (LetDec KernelsMem)]
[PatElemT LetDecMem]
vals
  where
    inspectVal :: PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal PatElemT (MemBound u)
patElem = do
      MemBound u
new_dec <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ PatElemT (MemBound u) -> MemBound u
forall dec. PatElemT dec -> dec
patElemDec PatElemT (MemBound u)
patElem
      PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT (MemBound u)
patElem {patElemDec :: MemBound u
patElemDec = MemBound u
new_dec}
    inspectCtx :: PatElemT dec -> m ()
inspectCtx PatElemT dec
patElem
      | Mem Space
space <- PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
patElem,
        Space -> Bool
expandable Space
space =
        String -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords
            [ String
"Cannot deal with existential memory block",
              VName -> String
forall a. Pretty a => a -> String
pretty (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem),
              String
"when expanding inside kernels."
            ]
      | Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
  MemBound u
fparam' <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ Param (MemBound u) -> MemBound u
forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
  Param (MemBound u) -> OffsetM (Param (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}

offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int32 VName)
ixfun)) = do
  Maybe (IxFun (TPrimExp Int32 VName))
new_base <- VName
-> ([TPrimExp Int32 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int32 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int32 VName) -> [TPrimExp Int32 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int32 VName)
ixfun, PrimType
pt)
  MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$
    MemBound u -> Maybe (MemBound u) -> MemBound u
forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary (Maybe (MemBound u) -> MemBound u)
-> Maybe (MemBound u) -> MemBound u
forall a b. (a -> b) -> a -> b
$ do
      IxFun (TPrimExp Int32 VName)
new_base' <- Maybe (IxFun (TPrimExp Int32 VName))
new_base
      MemBound u -> Maybe (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> Maybe (MemBound u))
-> MemBound u -> Maybe (MemBound u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int32 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int32 VName) -> MemBind)
-> IxFun (TPrimExp Int32 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int32 VName)
-> IxFun (TPrimExp Int32 VName) -> IxFun (TPrimExp Int32 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int32 VName)
new_base' IxFun (TPrimExp Int32 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return MemBound u
summary

offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
  | Just IxFun (TPrimExp Int32 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int32 VName))
isStaticIxFun ExtIxFun
ixfun = do
    Maybe (IxFun (TPrimExp Int32 VName))
new_base <- VName
-> ([TPrimExp Int32 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int32 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int32 VName) -> [TPrimExp Int32 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int32 VName)
ixfun', PrimType
pt)
    BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> OffsetM BranchTypeMem)
-> BranchTypeMem -> OffsetM BranchTypeMem
forall a b. (a -> b) -> a -> b
$
      BranchTypeMem -> Maybe BranchTypeMem -> BranchTypeMem
forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br (Maybe BranchTypeMem -> BranchTypeMem)
-> Maybe BranchTypeMem -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$ do
        IxFun (TPrimExp Int32 VName)
new_base' <- Maybe (IxFun (TPrimExp Int32 VName))
new_base
        BranchTypeMem -> Maybe BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (BranchTypeMem -> Maybe BranchTypeMem)
-> BranchTypeMem -> Maybe BranchTypeMem
forall a b. (a -> b) -> a -> b
$
          PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem) -> MemReturn -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
            VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
              ExtIxFun -> ExtIxFun -> ExtIxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase ((TPrimExp Int32 VName -> TPrimExp Int32 (Ext VName))
-> IxFun (TPrimExp Int32 VName) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int32 VName -> TPrimExp Int32 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun (TPrimExp Int32 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = BranchTypeMem -> OffsetM BranchTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return BranchTypeMem
br

offsetMemoryInLambda :: Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda :: Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda Lambda KernelsMem
lam = Lambda KernelsMem
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Lambda KernelsMem
lam (OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem))
-> OffsetM (Lambda KernelsMem) -> OffsetM (Lambda KernelsMem)
forall a b. (a -> b) -> a -> b
$ do
  Body KernelsMem
body <- Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody (Body KernelsMem -> OffsetM (Body KernelsMem))
-> Body KernelsMem -> OffsetM (Body KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> Body KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
  Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda KernelsMem -> OffsetM (Lambda KernelsMem))
-> Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem
lam {lambdaBody :: Body KernelsMem
lambdaBody = Body KernelsMem
body}

offsetMemoryInExp :: Exp KernelsMem -> OffsetM (Exp KernelsMem)
offsetMemoryInExp :: ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
offsetMemoryInExp (DoLoop [(FParam KernelsMem, SubExp)]
ctx [(FParam KernelsMem, SubExp)]
val LoopForm KernelsMem
form Body KernelsMem
body) = do
  let ([Param (MemBound Uniqueness)]
ctxparams, Result
ctxinit) = [(Param (MemBound Uniqueness), SubExp)]
-> ([Param (MemBound Uniqueness)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam KernelsMem, SubExp)]
[(Param (MemBound Uniqueness), SubExp)]
ctx
      ([Param (MemBound Uniqueness)]
valparams, Result
valinit) = [(Param (MemBound Uniqueness), SubExp)]
-> ([Param (MemBound Uniqueness)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam KernelsMem, SubExp)]
[(Param (MemBound Uniqueness), SubExp)]
val
  [Param (MemBound Uniqueness)]
ctxparams' <- (Param (MemBound Uniqueness)
 -> OffsetM (Param (MemBound Uniqueness)))
-> [Param (MemBound Uniqueness)]
-> OffsetM [Param (MemBound Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemBound Uniqueness)
-> OffsetM (Param (MemBound Uniqueness))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemBound Uniqueness)]
ctxparams
  [Param (MemBound Uniqueness)]
valparams' <- (Param (MemBound Uniqueness)
 -> OffsetM (Param (MemBound Uniqueness)))
-> [Param (MemBound Uniqueness)]
-> OffsetM [Param (MemBound Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemBound Uniqueness)
-> OffsetM (Param (MemBound Uniqueness))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemBound Uniqueness)]
valparams
  Body KernelsMem
body' <- Scope KernelsMem
-> OffsetM (Body KernelsMem) -> OffsetM (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (MemBound Uniqueness)] -> Scope KernelsMem
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (MemBound Uniqueness)]
ctxparams' Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> [Param (MemBound Uniqueness)] -> Scope KernelsMem
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param (MemBound Uniqueness)]
valparams' Scope KernelsMem -> Scope KernelsMem -> Scope KernelsMem
forall a. Semigroup a => a -> a -> a
<> LoopForm KernelsMem -> Scope KernelsMem
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm KernelsMem
form) (Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody Body KernelsMem
body)
  ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT KernelsMem -> OffsetM (ExpT KernelsMem))
-> ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
forall a b. (a -> b) -> a -> b
$ [(FParam KernelsMem, SubExp)]
-> [(FParam KernelsMem, SubExp)]
-> LoopForm KernelsMem
-> Body KernelsMem
-> ExpT KernelsMem
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop ([Param (MemBound Uniqueness)]
-> Result -> [(Param (MemBound Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemBound Uniqueness)]
ctxparams' Result
ctxinit) ([Param (MemBound Uniqueness)]
-> Result -> [(Param (MemBound Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemBound Uniqueness)]
valparams' Result
valinit) LoopForm KernelsMem
form Body KernelsMem
body'
offsetMemoryInExp ExpT KernelsMem
e = Mapper KernelsMem KernelsMem OffsetM
-> ExpT KernelsMem -> OffsetM (ExpT KernelsMem)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper KernelsMem KernelsMem OffsetM
recurse ExpT KernelsMem
e
  where
    recurse :: Mapper KernelsMem KernelsMem OffsetM
recurse =
      Mapper KernelsMem KernelsMem OffsetM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
        { mapOnBody :: Scope KernelsMem -> Body KernelsMem -> OffsetM (Body KernelsMem)
mapOnBody = \Scope KernelsMem
bscope -> Scope KernelsMem
-> OffsetM (Body KernelsMem) -> OffsetM (Body KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope KernelsMem
bscope (OffsetM (Body KernelsMem) -> OffsetM (Body KernelsMem))
-> (Body KernelsMem -> OffsetM (Body KernelsMem))
-> Body KernelsMem
-> OffsetM (Body KernelsMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body KernelsMem -> OffsetM (Body KernelsMem)
offsetMemoryInBody,
          mapOnBranchType :: BranchType KernelsMem -> OffsetM (BranchType KernelsMem)
mapOnBranchType = BranchType KernelsMem -> OffsetM (BranchType KernelsMem)
BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
          mapOnOp :: Op KernelsMem -> OffsetM (Op KernelsMem)
mapOnOp = Op KernelsMem -> OffsetM (Op KernelsMem)
forall op.
MemOp (HostOp KernelsMem op)
-> OffsetM (MemOp (HostOp KernelsMem op))
onOp
        }
    onOp :: MemOp (HostOp KernelsMem op)
-> OffsetM (MemOp (HostOp KernelsMem op))
onOp (Inner (SegOp SegOp SegLevel KernelsMem
op)) =
      HostOp KernelsMem op -> MemOp (HostOp KernelsMem op)
forall inner. inner -> MemOp inner
Inner (HostOp KernelsMem op -> MemOp (HostOp KernelsMem op))
-> (SegOp SegLevel KernelsMem -> HostOp KernelsMem op)
-> SegOp SegLevel KernelsMem
-> MemOp (HostOp KernelsMem op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel KernelsMem -> HostOp KernelsMem op
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp
        (SegOp SegLevel KernelsMem -> MemOp (HostOp KernelsMem op))
-> OffsetM (SegOp SegLevel KernelsMem)
-> OffsetM (MemOp (HostOp KernelsMem op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope KernelsMem
-> OffsetM (SegOp SegLevel KernelsMem)
-> OffsetM (SegOp SegLevel KernelsMem)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope KernelsMem
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegOp SegLevel KernelsMem -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel KernelsMem
op)) (SegOpMapper SegLevel KernelsMem KernelsMem OffsetM
-> SegOp SegLevel KernelsMem -> OffsetM (SegOp SegLevel KernelsMem)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper SegLevel KernelsMem KernelsMem OffsetM
forall lvl. SegOpMapper lvl KernelsMem KernelsMem OffsetM
segOpMapper SegOp SegLevel KernelsMem
op)
      where
        segOpMapper :: SegOpMapper lvl KernelsMem KernelsMem OffsetM
segOpMapper =
          SegOpMapper lvl Any Any OffsetM
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
            { mapOnSegOpBody :: KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
mapOnSegOpBody = KernelBody KernelsMem -> OffsetM (KernelBody KernelsMem)
offsetMemoryInKernelBody,
              mapOnSegOpLambda :: Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
mapOnSegOpLambda = Lambda KernelsMem -> OffsetM (Lambda KernelsMem)
offsetMemoryInLambda
            }
    onOp MemOp (HostOp KernelsMem op)
op = MemOp (HostOp KernelsMem op)
-> OffsetM (MemOp (HostOp KernelsMem op))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp KernelsMem op)
op

---- Slicing allocation sizes out of a kernel.

unAllocKernelsStms :: Stms KernelsMem -> Either String (Stms Kernels.Kernels)
unAllocKernelsStms :: Stms KernelsMem -> Either String (Stms Kernels)
unAllocKernelsStms = Bool -> Stms KernelsMem -> Either String (Stms Kernels)
unAllocStms Bool
False
  where
    unAllocBody :: Body KernelsMem -> Either String (BodyT Kernels)
unAllocBody (Body BodyDec KernelsMem
dec Stms KernelsMem
stms Result
res) =
      BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec Kernels
BodyDec KernelsMem
dec (Stms Kernels -> Result -> BodyT Kernels)
-> Either String (Stms Kernels)
-> Either String (Result -> BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms KernelsMem -> Either String (Stms Kernels)
unAllocStms Bool
True Stms KernelsMem
stms Either String (Result -> BodyT Kernels)
-> Either String Result -> Either String (BodyT Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either String Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody KernelsMem -> Either String (KernelBody Kernels)
unAllocKernelBody (KernelBody BodyDec KernelsMem
dec Stms KernelsMem
stms [KernelResult]
res) =
      BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec Kernels
BodyDec KernelsMem
dec (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> Either String (Stms Kernels)
-> Either String ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms KernelsMem -> Either String (Stms Kernels)
unAllocStms Bool
True Stms KernelsMem
stms Either String ([KernelResult] -> KernelBody Kernels)
-> Either String [KernelResult]
-> Either String (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either String [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms KernelsMem -> Either String (Stms Kernels)
unAllocStms Bool
nested =
      ([Maybe (Stm Kernels)] -> Stms Kernels)
-> Either String [Maybe (Stm Kernels)]
-> Either String (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([Maybe (Stm Kernels)] -> [Stm Kernels])
-> [Maybe (Stm Kernels)]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Stm Kernels)] -> [Stm Kernels]
forall a. [Maybe a] -> [a]
catMaybes) (Either String [Maybe (Stm Kernels)]
 -> Either String (Stms Kernels))
-> (Stms KernelsMem -> Either String [Maybe (Stm Kernels)])
-> Stms KernelsMem
-> Either String (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm KernelsMem -> Either String (Maybe (Stm Kernels)))
-> [Stm KernelsMem] -> Either String [Maybe (Stm Kernels)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm KernelsMem -> Either String (Maybe (Stm Kernels))
unAllocStm Bool
nested) ([Stm KernelsMem] -> Either String [Maybe (Stm Kernels)])
-> (Stms KernelsMem -> [Stm KernelsMem])
-> Stms KernelsMem
-> Either String [Maybe (Stm Kernels)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms KernelsMem -> [Stm KernelsMem]
forall lore. Stms lore -> [Stm lore]
stmsToList

    unAllocStm :: Bool -> Stm KernelsMem -> Either String (Maybe (Stm Kernels))
unAllocStm Bool
nested stm :: Stm KernelsMem
stm@(Let Pattern KernelsMem
_ StmAux (ExpDec KernelsMem)
_ (Op Alloc {}))
      | Bool
nested = String -> Either String (Maybe (Stm Kernels))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String (Maybe (Stm Kernels)))
-> String -> Either String (Maybe (Stm Kernels))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm KernelsMem -> String
forall a. Pretty a => a -> String
pretty Stm KernelsMem
stm
      | Bool
otherwise = Maybe (Stm Kernels) -> Either String (Maybe (Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm Kernels)
forall a. Maybe a
Nothing
    unAllocStm Bool
_ (Let Pattern KernelsMem
pat StmAux (ExpDec KernelsMem)
dec ExpT KernelsMem
e) =
      Stm Kernels -> Maybe (Stm Kernels)
forall a. a -> Maybe a
Just (Stm Kernels -> Maybe (Stm Kernels))
-> Either String (Stm Kernels)
-> Either String (Maybe (Stm Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatternT Type -> StmAux () -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let (PatternT Type -> StmAux () -> ExpT Kernels -> Stm Kernels)
-> Either String (PatternT Type)
-> Either String (StmAux () -> ExpT Kernels -> Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatternT LetDecMem -> Either String (PatternT Type)
forall d u ret.
Pretty (PatElemT (MemInfo d u ret)) =>
PatternT (MemInfo d u ret)
-> Either String (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern Pattern KernelsMem
PatternT LetDecMem
pat Either String (StmAux () -> ExpT Kernels -> Stm Kernels)
-> Either String (StmAux ())
-> Either String (ExpT Kernels -> Stm Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec KernelsMem)
dec Either String (ExpT Kernels -> Stm Kernels)
-> Either String (ExpT Kernels) -> Either String (Stm Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper KernelsMem Kernels (Either String)
-> ExpT KernelsMem -> Either String (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper KernelsMem Kernels (Either String)
unAlloc' ExpT KernelsMem
e)

    unAllocLambda :: Lambda KernelsMem -> Either String (Lambda Kernels)
unAllocLambda (Lambda [LParam KernelsMem]
params Body KernelsMem
body [Type]
ret) =
      [LParam Kernels] -> BodyT Kernels -> [Type] -> Lambda Kernels
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda ([Param LetDecMem] -> [Param Type]
forall d u ret.
[Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams [LParam KernelsMem]
[Param LetDecMem]
params) (BodyT Kernels -> [Type] -> Lambda Kernels)
-> Either String (BodyT Kernels)
-> Either String ([Type] -> Lambda Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body KernelsMem -> Either String (BodyT Kernels)
unAllocBody Body KernelsMem
body Either String ([Type] -> Lambda Kernels)
-> Either String [Type] -> Either String (Lambda Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> Either String [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

    unParams :: [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams = (Param (MemInfo d u ret)
 -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Param (MemInfo d u ret)
  -> Maybe (Param (TypeBase (ShapeBase d) u)))
 -> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)])
-> (Param (MemInfo d u ret)
    -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)]
-> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem

    unAllocPattern :: PatternT (MemInfo d u ret)
-> Either String (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern pat :: PatternT (MemInfo d u ret)
pat@(Pattern [PatElemT (MemInfo d u ret)]
ctx [PatElemT (MemInfo d u ret)]
val) =
      [PatElemT (TypeBase (ShapeBase d) u)]
-> [PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern ([PatElemT (TypeBase (ShapeBase d) u)]
 -> [PatElemT (TypeBase (ShapeBase d) u)]
 -> PatternT (TypeBase (ShapeBase d) u))
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> Either
     String
     ([PatElemT (TypeBase (ShapeBase d) u)]
      -> PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
ctx)
        Either
  String
  ([PatElemT (TypeBase (ShapeBase d) u)]
   -> PatternT (TypeBase (ShapeBase d) u))
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String (PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem) [PatElemT (MemInfo d u ret)]
val)
      where
        bad :: Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad = String -> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. a -> Either a b
Left (String -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> String -> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory in pattern " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo d u ret) -> String
forall a. Pretty a => a -> String
pretty PatternT (MemInfo d u ret)
pat

    unAllocOp :: MemOp (HostOp KernelsMem ())
-> Either String (HostOp Kernels (SOAC Kernels))
unAllocOp Alloc {} = String -> Either String (HostOp Kernels (SOAC Kernels))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp {}) = String -> Either String (HostOp Kernels (SOAC Kernels))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner (SizeOp SizeOp
op)) =
      HostOp Kernels (SOAC Kernels)
-> Either String (HostOp Kernels (SOAC Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (HostOp Kernels (SOAC Kernels)
 -> Either String (HostOp Kernels (SOAC Kernels)))
-> HostOp Kernels (SOAC Kernels)
-> Either String (HostOp Kernels (SOAC Kernels))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp SegLevel KernelsMem
op)) = SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> Either String (SegOp SegLevel Kernels)
-> Either String (HostOp Kernels (SOAC Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel KernelsMem Kernels (Either String)
-> SegOp SegLevel KernelsMem
-> Either String (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper SegLevel KernelsMem Kernels (Either String)
mapper SegOp SegLevel KernelsMem
op
      where
        mapper :: SegOpMapper SegLevel KernelsMem Kernels (Either String)
mapper =
          SegOpMapper SegLevel Any Any (Either String)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
            { mapOnSegOpLambda :: Lambda KernelsMem -> Either String (Lambda Kernels)
mapOnSegOpLambda = Lambda KernelsMem -> Either String (Lambda Kernels)
unAllocLambda,
              mapOnSegOpBody :: KernelBody KernelsMem -> Either String (KernelBody Kernels)
mapOnSegOpBody = KernelBody KernelsMem -> Either String (KernelBody Kernels)
unAllocKernelBody
            }

    unParam :: t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam t (MemInfo d u ret)
p = Either String (t (TypeBase (ShapeBase d) u))
-> (t (TypeBase (ShapeBase d) u)
    -> Either String (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either String (t (TypeBase (ShapeBase d) u))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String (t (TypeBase (ShapeBase d) u))
bad t (TypeBase (ShapeBase d) u)
-> Either String (t (TypeBase (ShapeBase d) u))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (t (TypeBase (ShapeBase d) u))
 -> Either String (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either String (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> t (MemInfo d u ret) -> Maybe (t (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem t (MemInfo d u ret)
p
      where
        bad :: Either String (t (TypeBase (ShapeBase d) u))
bad = String -> Either String (t (TypeBase (ShapeBase d) u))
forall a b. a -> Either a b
Left (String -> Either String (t (TypeBase (ShapeBase d) u)))
-> String -> Either String (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory-typed parameter '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ t (MemInfo d u ret) -> String
forall a. Pretty a => a -> String
pretty t (MemInfo d u ret)
p String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

    unT :: MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT MemInfo d u ret
t = Either String (TypeBase (ShapeBase d) u)
-> (TypeBase (ShapeBase d) u
    -> Either String (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either String (TypeBase (ShapeBase d) u)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String (TypeBase (ShapeBase d) u)
bad TypeBase (ShapeBase d) u
-> Either String (TypeBase (ShapeBase d) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeBase (ShapeBase d) u)
 -> Either String (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either String (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem MemInfo d u ret
t
      where
        bad :: Either String (TypeBase (ShapeBase d) u)
bad = String -> Either String (TypeBase (ShapeBase d) u)
forall a b. a -> Either a b
Left (String -> Either String (TypeBase (ShapeBase d) u))
-> String -> Either String (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory type '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ MemInfo d u ret -> String
forall a. Pretty a => a -> String
pretty MemInfo d u ret
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

    unAlloc' :: Mapper KernelsMem Kernels (Either String)
unAlloc' =
      Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper
        { mapOnBody :: Scope Kernels -> Body KernelsMem -> Either String (BodyT Kernels)
mapOnBody = (Body KernelsMem -> Either String (BodyT Kernels))
-> Scope Kernels
-> Body KernelsMem
-> Either String (BodyT Kernels)
forall a b. a -> b -> a
const Body KernelsMem -> Either String (BodyT Kernels)
unAllocBody,
          mapOnRetType :: RetType KernelsMem -> Either String (RetType Kernels)
mapOnRetType = RetType KernelsMem -> Either String (RetType Kernels)
forall d u ret.
(Pretty d, Pretty u, Pretty ret,
 Pretty (TypeBase (ShapeBase d) u)) =>
MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT,
          mapOnBranchType :: BranchType KernelsMem -> Either String (BranchType Kernels)
mapOnBranchType = BranchType KernelsMem -> Either String (BranchType Kernels)
forall d u ret.
(Pretty d, Pretty u, Pretty ret,
 Pretty (TypeBase (ShapeBase d) u)) =>
MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT,
          mapOnFParam :: FParam KernelsMem -> Either String (FParam Kernels)
mapOnFParam = FParam KernelsMem -> Either String (FParam Kernels)
forall (t :: * -> *) d u ret.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnLParam :: LParam KernelsMem -> Either String (LParam Kernels)
mapOnLParam = LParam KernelsMem -> Either String (LParam Kernels)
forall (t :: * -> *) d u ret.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam,
          mapOnOp :: Op KernelsMem -> Either String (Op Kernels)
mapOnOp = Op KernelsMem -> Either String (Op Kernels)
MemOp (HostOp KernelsMem ())
-> Either String (HostOp Kernels (SOAC Kernels))
unAllocOp,
          mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = SubExp -> Either String SubExp
forall a b. b -> Either a b
Right,
          mapOnVName :: VName -> Either String VName
mapOnVName = VName -> Either String VName
forall a b. b -> Either a b
Right
        }

unMem :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem (MemPrim PrimType
pt) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem MemMem {} = Maybe (TypeBase (ShapeBase d) u)
forall a. Maybe a
Nothing

unAllocScope :: Scope KernelsMem -> Scope Kernels.Kernels
unAllocScope :: Scope KernelsMem -> Scope Kernels
unAllocScope = (NameInfo KernelsMem -> Maybe (NameInfo Kernels))
-> Scope KernelsMem -> Scope Kernels
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe NameInfo KernelsMem -> Maybe (NameInfo Kernels)
forall lore d u ret lore d u d u ret ret.
(FParamInfo lore ~ MemInfo d u ret,
 LParamInfo lore ~ TypeBase (ShapeBase d) u,
 LetDec lore ~ TypeBase (ShapeBase d) u,
 FParamInfo lore ~ TypeBase (ShapeBase d) u,
 LetDec lore ~ MemInfo d u ret,
 LParamInfo lore ~ MemInfo d u ret) =>
NameInfo lore -> Maybe (NameInfo lore)
unInfo
  where
    unInfo :: NameInfo lore -> Maybe (NameInfo lore)
unInfo (LetName LetDec lore
dec) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. LetDec lore -> NameInfo lore
LetName (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LetDec lore
MemInfo d u ret
dec
    unInfo (FParamName FParamInfo lore
dec) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. FParamInfo lore -> NameInfo lore
FParamName (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem FParamInfo lore
MemInfo d u ret
dec
    unInfo (LParamName LParamInfo lore
dec) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. LParamInfo lore -> NameInfo lore
LParamName (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unMem LParamInfo lore
MemInfo d u ret
dec
    unInfo (IndexName IntType
it) = NameInfo lore -> Maybe (NameInfo lore)
forall a. a -> Maybe a
Just (NameInfo lore -> Maybe (NameInfo lore))
-> NameInfo lore -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName IntType
it

removeCommonSizes ::
  Extraction ->
  [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
 -> (VName, (SegLevel, SubExp, Space))
 -> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, (SegLevel, SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, (SegLevel, SubExp, Space))
-> Map SubExp [(VName, Space)]
forall k a b a.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, (SegLevel, SubExp, Space))]
 -> Map SubExp [(VName, Space)])
-> (Extraction -> [(VName, (SegLevel, SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction -> [(VName, (SegLevel, SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
  where
    comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

sliceKernelSizes ::
  SubExp ->
  [SubExp] ->
  SegSpace ->
  Stms KernelsMem ->
  ExpandM (Stms Kernels.Kernels, [VName], [VName])
sliceKernelSizes :: SubExp
-> Result
-> SegSpace
-> Stms KernelsMem
-> ExpandM (Stms Kernels, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
sizes SegSpace
space Stms KernelsMem
kstms = do
  Stms Kernels
kstms' <- (String
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms Kernels))
-> (Stms Kernels
    -> ReaderT
         (Scope KernelsMem)
         (StateT VNameSource (Either String))
         (Stms Kernels))
-> Either String (Stms Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms Kernels)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms Kernels)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (Stms Kernels)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Stms Kernels))
-> Either String (Stms Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms KernelsMem -> Either String (Stms Kernels)
unAllocKernelsStms Stms KernelsMem
kstms
  let num_sizes :: Int
num_sizes = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
sizes
      i64s :: [Type]
i64s = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
num_sizes (Type -> [Type]) -> Type -> [Type]
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope Kernels
kernels_scope <- (Scope KernelsMem -> Scope Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Scope Kernels)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope KernelsMem -> Scope Kernels
unAllocScope

  (Lambda Kernels
max_lam, Stms Kernels
_) <- (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Lambda Kernels)
 -> Scope Kernels
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Lambda Kernels, Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Lambda Kernels, Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
  (Lambda Kernels)
-> Scope Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Lambda Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Lambda Kernels)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Lambda Kernels, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Lambda Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
xs <- Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Param Type)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      [Param Type])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param Type]
ys <- Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Param Type)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      [Param Type])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"y" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms Kernels
stms) <- Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param Type]
xs [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
ys) (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Result, Stms Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Result, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
      BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
  Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   Result
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Result,
       Stms
         (Lore
            (BinderT
               Kernels
               (ReaderT
                  (Scope KernelsMem) (StateT VNameSource (Either String)))))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either String))))))
forall a b. (a -> b) -> a -> b
$
        [(Param Type, Param Type)]
-> ((Param Type, Param Type)
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
         SubExp)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
xs [Param Type]
ys) (((Param Type, Param Type)
  -> BinderT
       Kernels
       (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
       SubExp)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      Result)
-> ((Param Type, Param Type)
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
         SubExp)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Result
forall a b. (a -> b) -> a -> b
$ \(Param Type
x, Param Type
y) ->
          String
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"z" (Exp
   (Lore
      (BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      SubExp)
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
y)
    Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Lambda Kernels))
-> Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$ [LParam Kernels] -> BodyT Kernels -> [Type] -> Lambda Kernels
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda ([Param Type]
xs [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
ys) (Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
zs) [Type]
i64s

  Param Type
flat_gtid_lparam <- VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param (VName -> Type -> Param Type)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) VName
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid" ReaderT
  (Scope KernelsMem)
  (StateT VNameSource (Either String))
  (Type -> Param Type)
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) Type
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type
-> ReaderT
     (Scope KernelsMem) (StateT VNameSource (Either String)) Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int32))

  (Lambda Kernels
size_lam', Stms Kernels
_) <- (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Lambda Kernels)
 -> Scope Kernels
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Lambda Kernels, Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Lambda Kernels, Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
  (Lambda Kernels)
-> Scope Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Lambda Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Lambda Kernels)
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (Lambda Kernels, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (Lambda Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
params <- Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Param Type)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      [Param Type])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms Kernels
stms) <- Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
      ( [Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
params
          Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type
flat_gtid_lparam]
      )
      (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Result, Stms Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Result, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
  Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   Result
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Result,
       Stms
         (Lore
            (BinderT
               Kernels
               (ReaderT
                  (Scope KernelsMem) (StateT VNameSource (Either String)))))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either String))))))
forall a b. (a -> b) -> a -> b
$ do
        -- Even though this SegRed is one-dimensional, we need to
        -- provide indexes corresponding to the original potentially
        -- multi-dimensional construct.
        let ([VName]
kspace_gtids, Result
kspace_dims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
            new_inds :: [TPrimExp Int32 VName]
new_inds =
              [TPrimExp Int32 VName]
-> TPrimExp Int32 VName -> [TPrimExp Int32 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                ((SubExp -> TPrimExp Int32 VName)
-> Result -> [TPrimExp Int32 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int32 VName
pe32 Result
kspace_dims)
                (SubExp -> TPrimExp Int32 VName
pe32 (SubExp -> TPrimExp Int32 VName) -> SubExp -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
flat_gtid_lparam)
        ([VName]
 -> ExpT Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      ())
-> [[VName]]
-> [ExpT Kernels]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> ExpT Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([ExpT Kernels]
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      ())
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [ExpT Kernels]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int32 VName
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (ExpT Kernels))
-> [TPrimExp Int32 VName]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [ExpT Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp Int32 VName
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (ExpT Kernels)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp [TPrimExp Int32 VName]
new_inds

        (Stm Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      ())
-> Stms Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stms Kernels
kstms'
        Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
sizes

    Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   (Lambda Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Lambda Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$
      Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Lambda Kernels)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
Lambda Kernels -> m (Lambda Kernels)
Kernels.simplifyLambda ([LParam Kernels] -> BodyT Kernels -> [Type] -> Lambda Kernels
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type
LParam Kernels
flat_gtid_lparam] (BodyDec Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms Kernels
stms Result
zs) [Type]
i64s)

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms Kernels
slice_stms) <- (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> Scope Kernels
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
  ([VName], [VName])
-> Scope Kernels
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> ReaderT
      (Scope KernelsMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope KernelsMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    SubExp
num_threads_64 <-
      String
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp
   (Lore
      (BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      SubExp)
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     SubExp
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) SubExp
num_threads

    PatternT Type
pat <-
      [Ident] -> [Ident] -> PatternT Type
basicPattern []
        ([Ident] -> PatternT Type)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Ident]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (PatternT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Ident
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM
          Int
num_sizes
          (String
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"max_per_thread" (Type
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      Ident)
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <-
      String
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"size_slice_w"
        (ExpT Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      SubExp)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (ExpT Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Exp
        (Lore
           (BinderT
              Kernels
              (ReaderT
                 (Scope KernelsMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) (SegSpace -> Result
segSpaceDims SegSpace
space)

    VName
thread_space_iota <-
      String
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"thread_space_iota" (Exp
   (Lore
      (BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) IntType
Int32
    let red_op :: SegBinOp Kernels
red_op =
          Commutativity
-> Lambda Kernels -> Result -> ShapeBase SubExp -> SegBinOp Kernels
forall lore.
Commutativity
-> Lambda lore -> Result -> ShapeBase SubExp -> SegBinOp lore
SegBinOp
            Commutativity
Commutative
            Lambda Kernels
max_lam
            (Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> Result) -> SubExp -> Result
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
            ShapeBase SubExp
forall a. Monoid a => a
mempty
    SegLevel
lvl <- String
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     SegLevel
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> m SegLevel
segThread String
"segred"

    Stms Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      ())
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Stm Kernels))
-> Stms Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Stms Kernels)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm
      (Stms Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      (Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
nonSegRed SegOpLevel Kernels
SegLevel
lvl PatternT Type
Pattern Kernels
pat SubExp
w [SegBinOp Kernels
red_op] Lambda Kernels
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- [VName]
-> (VName
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
         VName)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
pat) ((VName
  -> BinderT
       Kernels
       (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
       VName)
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      [VName])
-> (VName
    -> BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
         VName)
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     [VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      String
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"size_sum" (Exp
   (Lore
      (BinderT
         Kernels
         (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
 -> BinderT
      Kernels
      (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Lore
        (BinderT
           Kernels
           (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))))
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads_64

    ([VName], [VName])
-> BinderT
     Kernels
     (ReaderT (Scope KernelsMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
pat, [VName]
size_sums)

  (Stms Kernels, [VName], [VName])
-> ExpandM (Stms Kernels, [VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)