{-# LANGUAGE TypeFamilies, FlexibleContexts, GeneralizedNewtypeDeriving #-}
-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations
       ( expandAllocations )
where

import Control.Monad.Identity
import Control.Monad.Except
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer
import qualified Data.Map.Strict as M
import Data.Maybe
import Data.List (foldl')

import Prelude hiding (quot)

import Futhark.Analysis.Rephrase
import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass
import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory
import qualified Futhark.Representation.ExplicitMemory.Simplify as ExplicitMemory
import qualified Futhark.Representation.Kernels as Kernels
import Futhark.Representation.Kernels.Simplify as Kernels
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Pass.ExtractKernels.BlockedKernel (segThread, nonSegRed)
import Futhark.Pass.ExplicitAllocations (explicitAllocationsInStms)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util.IntegralExp
import Futhark.Util (mapAccumLM)


expandAllocations :: Pass ExplicitMemory ExplicitMemory
expandAllocations :: Pass ExplicitMemory ExplicitMemory
expandAllocations =
  String
-> String
-> (Prog ExplicitMemory -> PassM (Prog ExplicitMemory))
-> Pass ExplicitMemory ExplicitMemory
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"expand allocations" String
"Expand allocations" ((Prog ExplicitMemory -> PassM (Prog ExplicitMemory))
 -> Pass ExplicitMemory ExplicitMemory)
-> (Prog ExplicitMemory -> PassM (Prog ExplicitMemory))
-> Pass ExplicitMemory ExplicitMemory
forall a b. (a -> b) -> a -> b
$
  \(Prog Stms ExplicitMemory
consts [FunDef ExplicitMemory]
funs) -> do
    Stms ExplicitMemory
consts' <-
      (VNameSource -> (Stms ExplicitMemory, VNameSource))
-> PassM (Stms ExplicitMemory)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms ExplicitMemory, VNameSource))
 -> PassM (Stms ExplicitMemory))
-> (VNameSource -> (Stms ExplicitMemory, VNameSource))
-> PassM (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms ExplicitMemory)
-> VNameSource -> (Stms ExplicitMemory, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms ExplicitMemory)
 -> VNameSource -> (Stms ExplicitMemory, VNameSource))
-> State VNameSource (Stms ExplicitMemory)
-> VNameSource
-> (Stms ExplicitMemory, VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
-> Scope ExplicitMemory -> State VNameSource (Stms ExplicitMemory)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
transformStms Stms ExplicitMemory
consts) Scope ExplicitMemory
forall a. Monoid a => a
mempty
    Stms ExplicitMemory
-> [FunDef ExplicitMemory] -> Prog ExplicitMemory
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms ExplicitMemory
consts' ([FunDef ExplicitMemory] -> Prog ExplicitMemory)
-> PassM [FunDef ExplicitMemory] -> PassM (Prog ExplicitMemory)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunDef ExplicitMemory -> PassM (FunDef ExplicitMemory))
-> [FunDef ExplicitMemory] -> PassM [FunDef ExplicitMemory]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope ExplicitMemory
-> FunDef ExplicitMemory -> PassM (FunDef ExplicitMemory)
transformFunDef (Scope ExplicitMemory
 -> FunDef ExplicitMemory -> PassM (FunDef ExplicitMemory))
-> Scope ExplicitMemory
-> FunDef ExplicitMemory
-> PassM (FunDef ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms ExplicitMemory
consts') [FunDef ExplicitMemory]
funs
  -- Cannot use intraproceduralTransformation because it might create
  -- duplicate size keys (which are not fixed by renamer, and size
  -- keys must currently be globally unique).

type ExpandM = ReaderT (Scope ExplicitMemory) (State VNameSource)

transformFunDef :: Scope ExplicitMemory -> FunDef ExplicitMemory
                -> PassM (FunDef ExplicitMemory)
transformFunDef :: Scope ExplicitMemory
-> FunDef ExplicitMemory -> PassM (FunDef ExplicitMemory)
transformFunDef Scope ExplicitMemory
scope FunDef ExplicitMemory
fundec = do
  Body ExplicitMemory
body' <- (VNameSource -> (Body ExplicitMemory, VNameSource))
-> PassM (Body ExplicitMemory)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body ExplicitMemory, VNameSource))
 -> PassM (Body ExplicitMemory))
-> (VNameSource -> (Body ExplicitMemory, VNameSource))
-> PassM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Body ExplicitMemory)
-> VNameSource -> (Body ExplicitMemory, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Body ExplicitMemory)
 -> VNameSource -> (Body ExplicitMemory, VNameSource))
-> State VNameSource (Body ExplicitMemory)
-> VNameSource
-> (Body ExplicitMemory, VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
-> Scope ExplicitMemory -> State VNameSource (Body ExplicitMemory)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
m Scope ExplicitMemory
forall a. Monoid a => a
mempty
  FunDef ExplicitMemory -> PassM (FunDef ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef ExplicitMemory
fundec { funDefBody :: Body ExplicitMemory
funDefBody = Body ExplicitMemory
body' }
  where m :: ReaderT
  (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
m = Scope ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope (ReaderT
   (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory))
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ FunDef ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf FunDef ExplicitMemory
fundec (ReaderT
   (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory))
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
            Body ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
transformBody (Body ExplicitMemory
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory))
-> Body ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ FunDef ExplicitMemory -> Body ExplicitMemory
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef ExplicitMemory
fundec

transformBody :: Body ExplicitMemory -> ExpandM (Body ExplicitMemory)
transformBody :: Body ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
transformBody (Body () Stms ExplicitMemory
stms Result
res) = BodyAttr ExplicitMemory
-> Stms ExplicitMemory -> Result -> Body ExplicitMemory
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () (Stms ExplicitMemory -> Result -> Body ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Result -> Body ExplicitMemory)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
transformStms Stms ExplicitMemory
stms ReaderT
  (Scope ExplicitMemory)
  (State VNameSource)
  (Result -> Body ExplicitMemory)
-> ReaderT (Scope ExplicitMemory) (State VNameSource) Result
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ReaderT (Scope ExplicitMemory) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformStms :: Stms ExplicitMemory -> ExpandM (Stms ExplicitMemory)
transformStms :: Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
transformStms Stms ExplicitMemory
stms =
  Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms ExplicitMemory
stms (ReaderT
   (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory))
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ [Stms ExplicitMemory] -> Stms ExplicitMemory
forall a. Monoid a => [a] -> a
mconcat ([Stms ExplicitMemory] -> Stms ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) [Stms ExplicitMemory]
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm ExplicitMemory
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory))
-> [Stm ExplicitMemory]
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) [Stms ExplicitMemory]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
transformStm (Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms ExplicitMemory
stms)

transformStm :: Stm ExplicitMemory -> ExpandM (Stms ExplicitMemory)

transformStm :: Stm ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
transformStm (Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux Exp ExplicitMemory
e) = do
  (Stms ExplicitMemory
bnds, Exp ExplicitMemory
e') <- Exp ExplicitMemory
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
transformExp (Exp ExplicitMemory
 -> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory))
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Exp ExplicitMemory)
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
  ExplicitMemory
  ExplicitMemory
  (ReaderT (Scope ExplicitMemory) (State VNameSource))
-> Exp ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Exp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper
  ExplicitMemory
  ExplicitMemory
  (ReaderT (Scope ExplicitMemory) (State VNameSource))
transform Exp ExplicitMemory
e
  Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory))
-> Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory
bnds Stms ExplicitMemory -> Stms ExplicitMemory -> Stms ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> Stm ExplicitMemory -> Stms ExplicitMemory
forall lore. Stm lore -> Stms lore
oneStm (Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> Exp ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
aux Exp ExplicitMemory
e')
  where transform :: Mapper
  ExplicitMemory
  ExplicitMemory
  (ReaderT (Scope ExplicitMemory) (State VNameSource))
transform = Mapper
  ExplicitMemory
  ExplicitMemory
  (ReaderT (Scope ExplicitMemory) (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope ExplicitMemory
-> Body ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
mapOnBody = \Scope ExplicitMemory
scope -> Scope ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope (ReaderT
   (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory))
-> (Body ExplicitMemory
    -> ReaderT
         (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory))
-> Body ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Body ExplicitMemory)
transformBody
                                   }

nameInfoConv :: NameInfo ExplicitMemory -> NameInfo ExplicitMemory
nameInfoConv :: NameInfo ExplicitMemory -> NameInfo ExplicitMemory
nameInfoConv (LetInfo LetAttr ExplicitMemory
mem_info) = LetAttr ExplicitMemory -> NameInfo ExplicitMemory
forall lore. LetAttr lore -> NameInfo lore
LetInfo LetAttr ExplicitMemory
mem_info
nameInfoConv (FParamInfo FParamAttr ExplicitMemory
mem_info) = FParamAttr ExplicitMemory -> NameInfo ExplicitMemory
forall lore. FParamAttr lore -> NameInfo lore
FParamInfo FParamAttr ExplicitMemory
mem_info
nameInfoConv (LParamInfo LParamAttr ExplicitMemory
mem_info) = LParamAttr ExplicitMemory -> NameInfo ExplicitMemory
forall lore. LParamAttr lore -> NameInfo lore
LParamInfo LParamAttr ExplicitMemory
mem_info
nameInfoConv (IndexInfo IntType
it) = IntType -> NameInfo ExplicitMemory
forall lore. IntType -> NameInfo lore
IndexInfo IntType
it

transformExp :: Exp ExplicitMemory -> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)

transformExp :: Exp ExplicitMemory
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
transformExp (Op (Inner (SegOp (SegMap lvl space ts kbody)))) = do
  (Stms ExplicitMemory
alloc_stms, ([Lambda ExplicitMemory]
_, KernelBody ExplicitMemory
kbody')) <- SegLevel
-> SegSpace
-> [Lambda ExplicitMemory]
-> KernelBody ExplicitMemory
-> ExpandM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody ExplicitMemory
kbody
  (Stms ExplicitMemory, Exp ExplicitMemory)
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
alloc_stms,
          Op ExplicitMemory -> Exp ExplicitMemory
forall lore. Op lore -> ExpT lore
Op (Op ExplicitMemory -> Exp ExplicitMemory)
-> Op ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall a b. (a -> b) -> a -> b
$ SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody ExplicitMemory
-> SegOp ExplicitMemory
forall lore.
SegLevel -> SegSpace -> [Type] -> KernelBody lore -> SegOp lore
SegMap SegLevel
lvl SegSpace
space [Type]
ts KernelBody ExplicitMemory
kbody')

transformExp (Op (Inner (SegOp (SegRed lvl space reds ts kbody)))) = do
  (Stms ExplicitMemory
alloc_stms, ([Lambda ExplicitMemory]
lams, KernelBody ExplicitMemory
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda ExplicitMemory]
-> KernelBody ExplicitMemory
-> ExpandM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
transformScanRed SegLevel
lvl SegSpace
space ((SegRedOp ExplicitMemory -> Lambda ExplicitMemory)
-> [SegRedOp ExplicitMemory] -> [Lambda ExplicitMemory]
forall a b. (a -> b) -> [a] -> [b]
map SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda [SegRedOp ExplicitMemory]
reds) KernelBody ExplicitMemory
kbody
  let reds' :: [SegRedOp ExplicitMemory]
reds' = (SegRedOp ExplicitMemory
 -> Lambda ExplicitMemory -> SegRedOp ExplicitMemory)
-> [SegRedOp ExplicitMemory]
-> [Lambda ExplicitMemory]
-> [SegRedOp ExplicitMemory]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegRedOp ExplicitMemory
red Lambda ExplicitMemory
lam -> SegRedOp ExplicitMemory
red { segRedLambda :: Lambda ExplicitMemory
segRedLambda = Lambda ExplicitMemory
lam }) [SegRedOp ExplicitMemory]
reds [Lambda ExplicitMemory]
lams
  (Stms ExplicitMemory, Exp ExplicitMemory)
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
alloc_stms,
          Op ExplicitMemory -> Exp ExplicitMemory
forall lore. Op lore -> ExpT lore
Op (Op ExplicitMemory -> Exp ExplicitMemory)
-> Op ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall a b. (a -> b) -> a -> b
$ SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegRedOp ExplicitMemory]
-> [Type]
-> KernelBody ExplicitMemory
-> SegOp ExplicitMemory
forall lore.
SegLevel
-> SegSpace
-> [SegRedOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegRed SegLevel
lvl SegSpace
space [SegRedOp ExplicitMemory]
reds' [Type]
ts KernelBody ExplicitMemory
kbody')

transformExp (Op (Inner (SegOp (SegScan lvl space scan_op nes ts kbody)))) = do
  (Stms ExplicitMemory
alloc_stms, ([Lambda ExplicitMemory]
scan_op', KernelBody ExplicitMemory
kbody')) <- SegLevel
-> SegSpace
-> [Lambda ExplicitMemory]
-> KernelBody ExplicitMemory
-> ExpandM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
transformScanRed SegLevel
lvl SegSpace
space [Lambda ExplicitMemory
scan_op] KernelBody ExplicitMemory
kbody
  (Stms ExplicitMemory, Exp ExplicitMemory)
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
alloc_stms,
          Op ExplicitMemory -> Exp ExplicitMemory
forall lore. Op lore -> ExpT lore
Op (Op ExplicitMemory -> Exp ExplicitMemory)
-> Op ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall a b. (a -> b) -> a -> b
$ SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> Lambda ExplicitMemory
-> Result
-> [Type]
-> KernelBody ExplicitMemory
-> SegOp ExplicitMemory
forall lore.
SegLevel
-> SegSpace
-> Lambda lore
-> Result
-> [Type]
-> KernelBody lore
-> SegOp lore
SegScan SegLevel
lvl SegSpace
space ([Lambda ExplicitMemory] -> Lambda ExplicitMemory
forall a. [a] -> a
head [Lambda ExplicitMemory]
scan_op') Result
nes [Type]
ts KernelBody ExplicitMemory
kbody')

transformExp (Op (Inner (SegOp (SegHist lvl space ops ts kbody)))) = do
  (Stms ExplicitMemory
alloc_stms, ([Lambda ExplicitMemory]
lams', KernelBody ExplicitMemory
kbody')) <- SegLevel
-> SegSpace
-> [Lambda ExplicitMemory]
-> KernelBody ExplicitMemory
-> ExpandM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
transformScanRed SegLevel
lvl SegSpace
space [Lambda ExplicitMemory]
lams KernelBody ExplicitMemory
kbody
  let ops' :: [HistOp ExplicitMemory]
ops' = (HistOp ExplicitMemory
 -> Lambda ExplicitMemory -> HistOp ExplicitMemory)
-> [HistOp ExplicitMemory]
-> [Lambda ExplicitMemory]
-> [HistOp ExplicitMemory]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp ExplicitMemory
-> Lambda ExplicitMemory -> HistOp ExplicitMemory
forall lore lore. HistOp lore -> Lambda lore -> HistOp lore
onOp [HistOp ExplicitMemory]
ops [Lambda ExplicitMemory]
lams'
  (Stms ExplicitMemory, Exp ExplicitMemory)
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
alloc_stms,
          Op ExplicitMemory -> Exp ExplicitMemory
forall lore. Op lore -> ExpT lore
Op (Op ExplicitMemory -> Exp ExplicitMemory)
-> Op ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall a b. (a -> b) -> a -> b
$ SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp ExplicitMemory]
-> [Type]
-> KernelBody ExplicitMemory
-> SegOp ExplicitMemory
forall lore.
SegLevel
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lore
SegHist SegLevel
lvl SegSpace
space [HistOp ExplicitMemory]
ops' [Type]
ts KernelBody ExplicitMemory
kbody')
  where lams :: [Lambda ExplicitMemory]
lams = (HistOp ExplicitMemory -> Lambda ExplicitMemory)
-> [HistOp ExplicitMemory] -> [Lambda ExplicitMemory]
forall a b. (a -> b) -> [a] -> [b]
map HistOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp [HistOp ExplicitMemory]
ops
        onOp :: HistOp lore -> Lambda lore -> HistOp lore
onOp HistOp lore
op Lambda lore
lam = HistOp lore
op { histOp :: Lambda lore
histOp = Lambda lore
lam }

transformExp Exp ExplicitMemory
e =
  (Stms ExplicitMemory, Exp ExplicitMemory)
-> ExpandM (Stms ExplicitMemory, Exp ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
forall a. Monoid a => a
mempty, Exp ExplicitMemory
e)

transformScanRed :: SegLevel -> SegSpace
                 -> [Lambda ExplicitMemory]
                 -> KernelBody ExplicitMemory
                 -> ExpandM (Stms ExplicitMemory, ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda ExplicitMemory]
-> KernelBody ExplicitMemory
-> ExpandM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
transformScanRed SegLevel
lvl SegSpace
space [Lambda ExplicitMemory]
ops KernelBody ExplicitMemory
kbody = do
  Names
bound_outside <- (Scope ExplicitMemory -> Names)
-> ReaderT (Scope ExplicitMemory) (State VNameSource) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope ExplicitMemory -> Names)
 -> ReaderT (Scope ExplicitMemory) (State VNameSource) Names)
-> (Scope ExplicitMemory -> Names)
-> ReaderT (Scope ExplicitMemory) (State VNameSource) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope ExplicitMemory -> [VName])
-> Scope ExplicitMemory
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope ExplicitMemory -> [VName]
forall k a. Map k a -> [k]
M.keys
  let (KernelBody ExplicitMemory
kbody', Extraction
kbody_allocs) =
        Names
-> KernelBody ExplicitMemory
-> (KernelBody ExplicitMemory, Extraction)
extractKernelBodyAllocations (Names
bound_outsideNames -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>Names
bound_in_kernel) KernelBody ExplicitMemory
kbody
      ([Lambda ExplicitMemory]
ops', [Extraction]
ops_allocs) = [(Lambda ExplicitMemory, Extraction)]
-> ([Lambda ExplicitMemory], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda ExplicitMemory, Extraction)]
 -> ([Lambda ExplicitMemory], [Extraction]))
-> [(Lambda ExplicitMemory, Extraction)]
-> ([Lambda ExplicitMemory], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda ExplicitMemory -> (Lambda ExplicitMemory, Extraction))
-> [Lambda ExplicitMemory] -> [(Lambda ExplicitMemory, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map (Names
-> Lambda ExplicitMemory -> (Lambda ExplicitMemory, Extraction)
extractLambdaAllocations Names
bound_outside) [Lambda ExplicitMemory]
ops
      variantAlloc :: SubExp -> Bool
variantAlloc (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_kernel
      variantAlloc SubExp
_ = Bool
False
      allocs :: Extraction
allocs = Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      (Extraction
variant_allocs, Extraction
invariant_allocs) = ((SubExp, Space) -> Bool) -> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition (SubExp -> Bool
variantAlloc (SubExp -> Bool)
-> ((SubExp, Space) -> SubExp) -> (SubExp, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Space) -> SubExp
forall a b. (a, b) -> a
fst) Extraction
allocs

  Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody ExplicitMemory
-> (Stms ExplicitMemory
    -> KernelBody ExplicitMemory
    -> OffsetM
         (Stms ExplicitMemory,
          ([Lambda ExplicitMemory], KernelBody ExplicitMemory)))
-> ExpandM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody ExplicitMemory
-> (Stms ExplicitMemory -> KernelBody ExplicitMemory -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody ExplicitMemory
kbody' ((Stms ExplicitMemory
  -> KernelBody ExplicitMemory
  -> OffsetM
       (Stms ExplicitMemory,
        ([Lambda ExplicitMemory], KernelBody ExplicitMemory)))
 -> ExpandM
      (Stms ExplicitMemory,
       ([Lambda ExplicitMemory], KernelBody ExplicitMemory)))
-> (Stms ExplicitMemory
    -> KernelBody ExplicitMemory
    -> OffsetM
         (Stms ExplicitMemory,
          ([Lambda ExplicitMemory], KernelBody ExplicitMemory)))
-> ExpandM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
forall a b. (a -> b) -> a -> b
$ \Stms ExplicitMemory
alloc_stms KernelBody ExplicitMemory
kbody'' -> do
    [Lambda ExplicitMemory]
ops'' <- [Lambda ExplicitMemory]
-> (Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory))
-> OffsetM [Lambda ExplicitMemory]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda ExplicitMemory]
ops' ((Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory))
 -> OffsetM [Lambda ExplicitMemory])
-> (Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory))
-> OffsetM [Lambda ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ \Lambda ExplicitMemory
op' ->
      Scope ExplicitMemory
-> OffsetM (Lambda ExplicitMemory)
-> OffsetM (Lambda ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Lambda ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda ExplicitMemory
op') (OffsetM (Lambda ExplicitMemory)
 -> OffsetM (Lambda ExplicitMemory))
-> OffsetM (Lambda ExplicitMemory)
-> OffsetM (Lambda ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory)
offsetMemoryInLambda Lambda ExplicitMemory
op'
    (Stms ExplicitMemory,
 ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
-> OffsetM
     (Stms ExplicitMemory,
      ([Lambda ExplicitMemory], KernelBody ExplicitMemory))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
alloc_stms, ([Lambda ExplicitMemory]
ops'', KernelBody ExplicitMemory
kbody''))

  where bound_in_kernel :: Names
bound_in_kernel = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Scope ExplicitMemory -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope ExplicitMemory -> [VName])
-> Scope ExplicitMemory -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Scope ExplicitMemory
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<>
                          Stms ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody)

allocsForBody :: M.Map VName (SubExp, Space)
              -> M.Map VName (SubExp, Space)
              -> SegLevel -> SegSpace
              -> KernelBody ExplicitMemory
              -> (Stms ExplicitMemory -> KernelBody ExplicitMemory -> OffsetM b)
              -> ExpandM b
allocsForBody :: Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody ExplicitMemory
-> (Stms ExplicitMemory -> KernelBody ExplicitMemory -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody ExplicitMemory
kbody' Stms ExplicitMemory -> KernelBody ExplicitMemory -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms ExplicitMemory
alloc_stms) <-
    SegLevel
-> SegSpace
-> Stms ExplicitMemory
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms ExplicitMemory)
memoryRequirements SegLevel
lvl SegSpace
space
    (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody') Extraction
variant_allocs Extraction
invariant_allocs

  Scope ExplicitMemory
scope <- ReaderT
  (Scope ExplicitMemory) (State VNameSource) (Scope ExplicitMemory)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let scope' :: Scope ExplicitMemory
scope' = SegSpace -> Scope ExplicitMemory
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> (NameInfo ExplicitMemory -> NameInfo ExplicitMemory)
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo ExplicitMemory -> NameInfo ExplicitMemory
nameInfoConv Scope ExplicitMemory
scope
  (String -> ExpandM b)
-> (b -> ExpandM b) -> Either String b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ExpandM b
forall a. String -> a
compilerLimitationS b -> ExpandM b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String b -> ExpandM b) -> Either String b -> ExpandM b
forall a b. (a -> b) -> a -> b
$ Scope ExplicitMemory -> RebaseMap -> OffsetM b -> Either String b
forall a.
Scope ExplicitMemory -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope ExplicitMemory
scope' RebaseMap
alloc_offsets (OffsetM b -> Either String b) -> OffsetM b -> Either String b
forall a b. (a -> b) -> a -> b
$ do
    KernelBody ExplicitMemory
kbody'' <- KernelBody ExplicitMemory -> OffsetM (KernelBody ExplicitMemory)
offsetMemoryInKernelBody KernelBody ExplicitMemory
kbody'
    Stms ExplicitMemory -> KernelBody ExplicitMemory -> OffsetM b
m Stms ExplicitMemory
alloc_stms KernelBody ExplicitMemory
kbody''

memoryRequirements :: SegLevel -> SegSpace
                   -> Stms ExplicitMemory
                   -> M.Map VName (SubExp, Space)
                   -> M.Map VName (SubExp, Space)
                   -> ExpandM (RebaseMap, Stms ExplicitMemory)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms ExplicitMemory
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms ExplicitMemory)
memoryRequirements SegLevel
lvl SegSpace
space Stms ExplicitMemory
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  ((SubExp
num_threads, SubExp
num_threads64), Stms ExplicitMemory
num_threads_stms) <- Binder ExplicitMemory (SubExp, SubExp)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     ((SubExp, SubExp), Stms ExplicitMemory)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder ExplicitMemory (SubExp, SubExp)
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      ((SubExp, SubExp), Stms ExplicitMemory))
-> Binder ExplicitMemory (SubExp, SubExp)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     ((SubExp, SubExp), Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ do
    SubExp
num_threads <- String
-> Exp (Lore (BinderT ExplicitMemory (State VNameSource)))
-> BinderT ExplicitMemory (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore (BinderT ExplicitMemory (State VNameSource)))
 -> BinderT ExplicitMemory (State VNameSource) SubExp)
-> Exp (Lore (BinderT ExplicitMemory (State VNameSource)))
-> BinderT ExplicitMemory (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp ExplicitMemory -> Exp ExplicitMemory
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp ExplicitMemory -> Exp ExplicitMemory)
-> BasicOp ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp ExplicitMemory
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32)
                   (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
    SubExp
num_threads64 <- String
-> Exp (Lore (BinderT ExplicitMemory (State VNameSource)))
-> BinderT ExplicitMemory (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads64" (Exp (Lore (BinderT ExplicitMemory (State VNameSource)))
 -> BinderT ExplicitMemory (State VNameSource) SubExp)
-> Exp (Lore (BinderT ExplicitMemory (State VNameSource)))
-> BinderT ExplicitMemory (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp ExplicitMemory -> Exp ExplicitMemory
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp ExplicitMemory -> Exp ExplicitMemory)
-> BasicOp ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp ExplicitMemory
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) SubExp
num_threads
    (SubExp, SubExp) -> Binder ExplicitMemory (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
num_threads, SubExp
num_threads64)

  (Stms ExplicitMemory
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms ExplicitMemory
num_threads_stms (ReaderT
   (Scope ExplicitMemory)
   (State VNameSource)
   (Stms ExplicitMemory, RebaseMap)
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Stms ExplicitMemory, RebaseMap))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall a b. (a -> b) -> a -> b
$ (SubExp, Count NumGroups SubExp, Count GroupSize SubExp)
-> SegSpace
-> Extraction
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
expandedInvariantAllocations
    (SubExp
num_threads64, SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl, SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
    SegSpace
space Extraction
invariant_allocs

  (Stms ExplicitMemory
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms ExplicitMemory
num_threads_stms (ReaderT
   (Scope ExplicitMemory)
   (State VNameSource)
   (Stms ExplicitMemory, RebaseMap)
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Stms ExplicitMemory, RebaseMap))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SegSpace
-> Stms ExplicitMemory
-> Extraction
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
expandedVariantAllocations SubExp
num_threads SegSpace
space Stms ExplicitMemory
kstms Extraction
variant_allocs

  (RebaseMap, Stms ExplicitMemory)
-> ExpandM (RebaseMap, Stms ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
          Stms ExplicitMemory
num_threads_stms Stms ExplicitMemory -> Stms ExplicitMemory -> Stms ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> Stms ExplicitMemory
invariant_alloc_stms Stms ExplicitMemory -> Stms ExplicitMemory -> Stms ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> Stms ExplicitMemory
variant_alloc_stms)

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (SubExp, Space)

-- | Extract allocations from 'Thread' statements with
-- 'extractThreadAllocations'.
extractKernelBodyAllocations :: Names -> KernelBody ExplicitMemory
                             -> (KernelBody ExplicitMemory,
                                 Extraction)
extractKernelBodyAllocations :: Names
-> KernelBody ExplicitMemory
-> (KernelBody ExplicitMemory, Extraction)
extractKernelBodyAllocations Names
bound_outside =
  Names
-> (KernelBody ExplicitMemory -> Stms ExplicitMemory)
-> (Stms ExplicitMemory
    -> KernelBody ExplicitMemory -> KernelBody ExplicitMemory)
-> KernelBody ExplicitMemory
-> (KernelBody ExplicitMemory, Extraction)
forall body.
Names
-> (body -> Stms ExplicitMemory)
-> (Stms ExplicitMemory -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations Names
bound_outside KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms ((Stms ExplicitMemory
  -> KernelBody ExplicitMemory -> KernelBody ExplicitMemory)
 -> KernelBody ExplicitMemory
 -> (KernelBody ExplicitMemory, Extraction))
-> (Stms ExplicitMemory
    -> KernelBody ExplicitMemory -> KernelBody ExplicitMemory)
-> KernelBody ExplicitMemory
-> (KernelBody ExplicitMemory, Extraction)
forall a b. (a -> b) -> a -> b
$
  \Stms ExplicitMemory
stms KernelBody ExplicitMemory
kbody -> KernelBody ExplicitMemory
kbody { kernelBodyStms :: Stms ExplicitMemory
kernelBodyStms = Stms ExplicitMemory
stms }

extractBodyAllocations :: Names -> Body ExplicitMemory
                       -> (Body ExplicitMemory, Extraction)
extractBodyAllocations :: Names -> Body ExplicitMemory -> (Body ExplicitMemory, Extraction)
extractBodyAllocations Names
bound_outside =
  Names
-> (Body ExplicitMemory -> Stms ExplicitMemory)
-> (Stms ExplicitMemory
    -> Body ExplicitMemory -> Body ExplicitMemory)
-> Body ExplicitMemory
-> (Body ExplicitMemory, Extraction)
forall body.
Names
-> (body -> Stms ExplicitMemory)
-> (Stms ExplicitMemory -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations Names
bound_outside Body ExplicitMemory -> Stms ExplicitMemory
forall lore. BodyT lore -> Stms lore
bodyStms ((Stms ExplicitMemory
  -> Body ExplicitMemory -> Body ExplicitMemory)
 -> Body ExplicitMemory -> (Body ExplicitMemory, Extraction))
-> (Stms ExplicitMemory
    -> Body ExplicitMemory -> Body ExplicitMemory)
-> Body ExplicitMemory
-> (Body ExplicitMemory, Extraction)
forall a b. (a -> b) -> a -> b
$
  \Stms ExplicitMemory
stms Body ExplicitMemory
body -> Body ExplicitMemory
body { bodyStms :: Stms ExplicitMemory
bodyStms = Stms ExplicitMemory
stms }

extractLambdaAllocations :: Names -> Lambda ExplicitMemory
                         -> (Lambda ExplicitMemory, Extraction)
extractLambdaAllocations :: Names
-> Lambda ExplicitMemory -> (Lambda ExplicitMemory, Extraction)
extractLambdaAllocations Names
bound_outside Lambda ExplicitMemory
lam = (Lambda ExplicitMemory
lam { lambdaBody :: Body ExplicitMemory
lambdaBody = Body ExplicitMemory
body' }, Extraction
allocs)
  where (Body ExplicitMemory
body', Extraction
allocs) = Names -> Body ExplicitMemory -> (Body ExplicitMemory, Extraction)
extractBodyAllocations Names
bound_outside (Body ExplicitMemory -> (Body ExplicitMemory, Extraction))
-> Body ExplicitMemory -> (Body ExplicitMemory, Extraction)
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam

extractGenericBodyAllocations :: Names
                              -> (body -> Stms ExplicitMemory)
                              -> (Stms ExplicitMemory -> body -> body)
                              -> body
                              -> (body,
                                  Extraction)
extractGenericBodyAllocations :: Names
-> (body -> Stms ExplicitMemory)
-> (Stms ExplicitMemory -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations Names
bound_outside body -> Stms ExplicitMemory
get_stms Stms ExplicitMemory -> body -> body
set_stms body
body =
  let ([Stm ExplicitMemory]
stms, Extraction
allocs) = Writer Extraction [Stm ExplicitMemory]
-> ([Stm ExplicitMemory], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm ExplicitMemory]
 -> ([Stm ExplicitMemory], Extraction))
-> Writer Extraction [Stm ExplicitMemory]
-> ([Stm ExplicitMemory], Extraction)
forall a b. (a -> b) -> a -> b
$ ([Maybe (Stm ExplicitMemory)] -> [Stm ExplicitMemory])
-> WriterT Extraction Identity [Maybe (Stm ExplicitMemory)]
-> Writer Extraction [Stm ExplicitMemory]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm ExplicitMemory)] -> [Stm ExplicitMemory]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm ExplicitMemory)]
 -> Writer Extraction [Stm ExplicitMemory])
-> WriterT Extraction Identity [Maybe (Stm ExplicitMemory)]
-> Writer Extraction [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$
                       (Stm ExplicitMemory
 -> WriterT Extraction Identity (Maybe (Stm ExplicitMemory)))
-> [Stm ExplicitMemory]
-> WriterT Extraction Identity [Maybe (Stm ExplicitMemory)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Names
-> Stm ExplicitMemory
-> WriterT Extraction Identity (Maybe (Stm ExplicitMemory))
extractStmAllocations Names
bound_outside) ([Stm ExplicitMemory]
 -> WriterT Extraction Identity [Maybe (Stm ExplicitMemory)])
-> [Stm ExplicitMemory]
-> WriterT Extraction Identity [Maybe (Stm ExplicitMemory)]
forall a b. (a -> b) -> a -> b
$
                       Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms ExplicitMemory -> [Stm ExplicitMemory])
-> Stms ExplicitMemory -> [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ body -> Stms ExplicitMemory
get_stms body
body
  in (Stms ExplicitMemory -> body -> body
set_stms ([Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm ExplicitMemory]
stms) body
body, Extraction
allocs)

expandable :: Space -> Bool
expandable :: Space -> Bool
expandable (Space String
"local") = Bool
False
expandable ScalarSpace{} = Bool
False
expandable Space
_ = Bool
True

extractStmAllocations :: Names -> Stm ExplicitMemory
                      -> Writer Extraction (Maybe (Stm ExplicitMemory))
extractStmAllocations :: Names
-> Stm ExplicitMemory
-> WriterT Extraction Identity (Maybe (Stm ExplicitMemory))
extractStmAllocations Names
bound_outside (Let (Pattern [] [PatElemT (LetAttr ExplicitMemory)
patElem]) StmAux (ExpAttr ExplicitMemory)
_ (Op (Alloc size space)))
  | Space -> Bool
expandable Space
space,
    SubExp -> Bool
visibleOutside SubExp
size = do
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName -> (SubExp, Space) -> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
patElem) (SubExp
size, Space
space)
      Maybe (Stm ExplicitMemory)
-> WriterT Extraction Identity (Maybe (Stm ExplicitMemory))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm ExplicitMemory)
forall a. Maybe a
Nothing

        where visibleOutside :: SubExp -> Bool
visibleOutside (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside
              visibleOutside Constant{} = Bool
True

extractStmAllocations Names
bound_outside Stm ExplicitMemory
stm = do
  Exp ExplicitMemory
e <- Mapper ExplicitMemory ExplicitMemory (WriterT Extraction Identity)
-> Exp ExplicitMemory
-> WriterT Extraction Identity (Exp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper ExplicitMemory ExplicitMemory (WriterT Extraction Identity)
expMapper (Exp ExplicitMemory
 -> WriterT Extraction Identity (Exp ExplicitMemory))
-> Exp ExplicitMemory
-> WriterT Extraction Identity (Exp ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stm ExplicitMemory -> Exp ExplicitMemory
forall lore. Stm lore -> Exp lore
stmExp Stm ExplicitMemory
stm
  Maybe (Stm ExplicitMemory)
-> WriterT Extraction Identity (Maybe (Stm ExplicitMemory))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Stm ExplicitMemory)
 -> WriterT Extraction Identity (Maybe (Stm ExplicitMemory)))
-> Maybe (Stm ExplicitMemory)
-> WriterT Extraction Identity (Maybe (Stm ExplicitMemory))
forall a b. (a -> b) -> a -> b
$ Stm ExplicitMemory -> Maybe (Stm ExplicitMemory)
forall a. a -> Maybe a
Just (Stm ExplicitMemory -> Maybe (Stm ExplicitMemory))
-> Stm ExplicitMemory -> Maybe (Stm ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stm ExplicitMemory
stm { stmExp :: Exp ExplicitMemory
stmExp = Exp ExplicitMemory
e }
  where expMapper :: Mapper ExplicitMemory ExplicitMemory (WriterT Extraction Identity)
expMapper = Mapper ExplicitMemory ExplicitMemory (WriterT Extraction Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope ExplicitMemory
-> Body ExplicitMemory
-> WriterT Extraction Identity (Body ExplicitMemory)
mapOnBody = (Body ExplicitMemory
 -> WriterT Extraction Identity (Body ExplicitMemory))
-> Scope ExplicitMemory
-> Body ExplicitMemory
-> WriterT Extraction Identity (Body ExplicitMemory)
forall a b. a -> b -> a
const Body ExplicitMemory
-> WriterT Extraction Identity (Body ExplicitMemory)
onBody
                                   , mapOnOp :: Op ExplicitMemory
-> WriterT Extraction Identity (Op ExplicitMemory)
mapOnOp = Op ExplicitMemory
-> WriterT Extraction Identity (Op ExplicitMemory)
MemOp (HostOp ExplicitMemory ())
-> WriterT Extraction Identity (MemOp (HostOp ExplicitMemory ()))
onOp }

        onBody :: Body ExplicitMemory
-> WriterT Extraction Identity (Body ExplicitMemory)
onBody Body ExplicitMemory
body = do
          let (Body ExplicitMemory
body', Extraction
allocs) = Names -> Body ExplicitMemory -> (Body ExplicitMemory, Extraction)
extractBodyAllocations Names
bound_outside Body ExplicitMemory
body
          Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
          Body ExplicitMemory
-> WriterT Extraction Identity (Body ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return Body ExplicitMemory
body'

        onOp :: MemOp (HostOp ExplicitMemory ())
-> WriterT Extraction Identity (MemOp (HostOp ExplicitMemory ()))
onOp (Inner (SegOp SegOp ExplicitMemory
op)) = HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory
-> MemOp (HostOp ExplicitMemory ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> MemOp (HostOp ExplicitMemory ()))
-> WriterT Extraction Identity (SegOp ExplicitMemory)
-> WriterT Extraction Identity (MemOp (HostOp ExplicitMemory ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper
  ExplicitMemory ExplicitMemory (WriterT Extraction Identity)
-> SegOp ExplicitMemory
-> WriterT Extraction Identity (SegOp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper
  ExplicitMemory ExplicitMemory (WriterT Extraction Identity)
opMapper SegOp ExplicitMemory
op
        onOp MemOp (HostOp ExplicitMemory ())
op = MemOp (HostOp ExplicitMemory ())
-> WriterT Extraction Identity (MemOp (HostOp ExplicitMemory ()))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp ExplicitMemory ())
op

        opMapper :: SegOpMapper
  ExplicitMemory ExplicitMemory (WriterT Extraction Identity)
opMapper = SegOpMapper Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper { mapOnSegOpLambda :: Lambda ExplicitMemory
-> WriterT Extraction Identity (Lambda ExplicitMemory)
mapOnSegOpLambda = Lambda ExplicitMemory
-> WriterT Extraction Identity (Lambda ExplicitMemory)
onLambda
                                       , mapOnSegOpBody :: KernelBody ExplicitMemory
-> WriterT Extraction Identity (KernelBody ExplicitMemory)
mapOnSegOpBody = KernelBody ExplicitMemory
-> WriterT Extraction Identity (KernelBody ExplicitMemory)
onKernelBody
                                       }

        onKernelBody :: KernelBody ExplicitMemory
-> WriterT Extraction Identity (KernelBody ExplicitMemory)
onKernelBody KernelBody ExplicitMemory
body = do
          let (KernelBody ExplicitMemory
body', Extraction
allocs) = Names
-> KernelBody ExplicitMemory
-> (KernelBody ExplicitMemory, Extraction)
extractKernelBodyAllocations Names
bound_outside KernelBody ExplicitMemory
body
          Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
          KernelBody ExplicitMemory
-> WriterT Extraction Identity (KernelBody ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody ExplicitMemory
body'

        onLambda :: Lambda ExplicitMemory
-> WriterT Extraction Identity (Lambda ExplicitMemory)
onLambda Lambda ExplicitMemory
lam = do
          Body ExplicitMemory
body <- Body ExplicitMemory
-> WriterT Extraction Identity (Body ExplicitMemory)
onBody (Body ExplicitMemory
 -> WriterT Extraction Identity (Body ExplicitMemory))
-> Body ExplicitMemory
-> WriterT Extraction Identity (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
          Lambda ExplicitMemory
-> WriterT Extraction Identity (Lambda ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda ExplicitMemory
lam { lambdaBody :: Body ExplicitMemory
lambdaBody = Body ExplicitMemory
body }

expandedInvariantAllocations :: (SubExp, Count NumGroups SubExp, Count GroupSize SubExp)
                             -> SegSpace
                             -> Extraction
                             -> ExpandM (Stms ExplicitMemory, RebaseMap)
expandedInvariantAllocations :: (SubExp, Count NumGroups SubExp, Count GroupSize SubExp)
-> SegSpace
-> Extraction
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
expandedInvariantAllocations (SubExp
num_threads64, Count SubExp
num_groups, Count SubExp
group_size)
                             SegSpace
segspace
                             Extraction
invariant_allocs = do
  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the number of kernel threads.
  ([Stms ExplicitMemory]
alloc_bnds, [RebaseMap]
rebases) <- [(Stms ExplicitMemory, RebaseMap)]
-> ([Stms ExplicitMemory], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms ExplicitMemory, RebaseMap)]
 -> ([Stms ExplicitMemory], [RebaseMap]))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     [(Stms ExplicitMemory, RebaseMap)]
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     ([Stms ExplicitMemory], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, Space))
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Stms ExplicitMemory, RebaseMap))
-> [(VName, (SubExp, Space))]
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     [(Stms ExplicitMemory, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, Space))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
expand (Extraction -> [(VName, (SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs)

  (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stms ExplicitMemory] -> Stms ExplicitMemory
forall a. Monoid a => [a] -> a
mconcat [Stms ExplicitMemory]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where expand :: (VName, (SubExp, Space))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
expand (VName
mem, (SubExp
per_thread_size, Space
space)) = do
          VName
total_size <- String -> ReaderT (Scope ExplicitMemory) (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"total_size"
          let sizepat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
sizepat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
total_size (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64]
              allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
          (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList
                  [Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> Exp ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
sizepat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp ExplicitMemory -> Stm ExplicitMemory)
-> Exp ExplicitMemory -> Stm ExplicitMemory
forall a b. (a -> b) -> a -> b
$
                    BasicOp ExplicitMemory -> Exp ExplicitMemory
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp ExplicitMemory -> Exp ExplicitMemory)
-> BasicOp ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp ExplicitMemory
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int64) SubExp
num_threads64 SubExp
per_thread_size,
                   Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> Exp ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp ExplicitMemory -> Stm ExplicitMemory)
-> Exp ExplicitMemory -> Stm ExplicitMemory
forall a b. (a -> b) -> a -> b
$
                    Op ExplicitMemory -> Exp ExplicitMemory
forall lore. Op lore -> ExpT lore
Op (Op ExplicitMemory -> Exp ExplicitMemory)
-> Op ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp ExplicitMemory ())
forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space],
                  VName
-> (([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ([PrimExp VName], PrimType) -> IxFun (PrimExp VName)
newBase)

        newBase :: ([PrimExp VName], PrimType) -> IxFun (PrimExp VName)
newBase ([PrimExp VName]
old_shape, PrimType
_) =
          let num_dims :: Int
num_dims = [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
old_shape
              perm :: [Int]
perm = Int
num_dims Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
0..Int
num_dimsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
              root_ixfun :: IxFun (PrimExp VName)
root_ixfun = [PrimExp VName] -> IxFun (PrimExp VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([PrimExp VName]
old_shape
                                       [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a. [a] -> [a] -> [a]
++ [PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
num_groups PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
                                           PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
group_size])
              permuted_ixfun :: IxFun (PrimExp VName)
permuted_ixfun = IxFun (PrimExp VName) -> [Int] -> IxFun (PrimExp VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (PrimExp VName)
root_ixfun [Int]
perm
              untouched :: d -> DimIndex d
untouched d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice (Int32 -> d
forall e. IntegralExp e => Int32 -> e
fromInt32 Int32
0) d
d (Int32 -> d
forall e. IntegralExp e => Int32 -> e
fromInt32 Int32
1)
              offset_ixfun :: IxFun (PrimExp VName)
offset_ixfun = IxFun (PrimExp VName)
-> Slice (PrimExp VName) -> IxFun (PrimExp VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (PrimExp VName)
permuted_ixfun (Slice (PrimExp VName) -> IxFun (PrimExp VName))
-> Slice (PrimExp VName) -> IxFun (PrimExp VName)
forall a b. (a -> b) -> a -> b
$
                             PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> DimIndex d
DimFix (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (SegSpace -> VName
segFlat SegSpace
segspace) PrimType
int32) DimIndex (PrimExp VName)
-> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall a. a -> [a] -> [a]
:
                             (PrimExp VName -> DimIndex (PrimExp VName))
-> [PrimExp VName] -> Slice (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map PrimExp VName -> DimIndex (PrimExp VName)
forall d. IntegralExp d => d -> DimIndex d
untouched [PrimExp VName]
old_shape
          in IxFun (PrimExp VName)
offset_ixfun

expandedVariantAllocations :: SubExp
                           -> SegSpace -> Stms ExplicitMemory
                           -> Extraction
                           -> ExpandM (Stms ExplicitMemory, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms ExplicitMemory
-> Extraction
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms ExplicitMemory
_ Extraction
variant_allocs
  | Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms ExplicitMemory
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: Result
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms Kernels
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> Result
-> SegSpace
-> Stms ExplicitMemory
-> ExpandM (Stms Kernels, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
variant_sizes SegSpace
kspace Stms ExplicitMemory
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  (SymbolTable (Wise ExplicitMemory)
_, Stms ExplicitMemory
slice_stms_tmp) <-
    Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (SymbolTable (Wise ExplicitMemory), Stms ExplicitMemory)
forall (m :: * -> *).
(HasScope ExplicitMemory m, MonadFreshNames m) =>
Stms ExplicitMemory
-> m (SymbolTable (Wise ExplicitMemory), Stms ExplicitMemory)
ExplicitMemory.simplifyStms (Stms ExplicitMemory
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (SymbolTable (Wise ExplicitMemory), Stms ExplicitMemory))
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (SymbolTable (Wise ExplicitMemory), Stms ExplicitMemory)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
forall (m :: * -> *).
(MonadFreshNames m, HasScope ExplicitMemory m) =>
Stms Kernels -> m (Stms ExplicitMemory)
explicitAllocationsInStms Stms Kernels
slice_stms
  Stms ExplicitMemory
slice_stms' <- Stms ExplicitMemory
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms ExplicitMemory)
transformStms Stms ExplicitMemory
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' = [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
 -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$ ([(VName, Space)]
 -> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall a c.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
                        ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [ (a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks ]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm ExplicitMemory]
alloc_bnds, [RebaseMap]
rebases) <- [(Stm ExplicitMemory, RebaseMap)]
-> ([Stm ExplicitMemory], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stm ExplicitMemory, RebaseMap)]
 -> ([Stm ExplicitMemory], [RebaseMap]))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     [(Stm ExplicitMemory, RebaseMap)]
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     ([Stm ExplicitMemory], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, SubExp, Space))
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Stm ExplicitMemory, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     [(Stm ExplicitMemory, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stm ExplicitMemory, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  (Stms ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stms ExplicitMemory, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms ExplicitMemory
slice_stms' Stms ExplicitMemory -> Stms ExplicitMemory -> Stms ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm ExplicitMemory]
alloc_bnds, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stm ExplicitMemory, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
          let allocpat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
          (Stm ExplicitMemory, RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Stm ExplicitMemory, RebaseMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> Exp ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp ExplicitMemory -> Stm ExplicitMemory)
-> Exp ExplicitMemory -> Stm ExplicitMemory
forall a b. (a -> b) -> a -> b
$ Op ExplicitMemory -> Exp ExplicitMemory
forall lore. Op lore -> ExpT lore
Op (Op ExplicitMemory -> Exp ExplicitMemory)
-> Op ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp ExplicitMemory ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
                  VName
-> (([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
 -> RebaseMap)
-> (([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SubExp -> ([PrimExp VName], PrimType) -> IxFun (PrimExp VName)
newBase SubExp
offset)

        num_threads' :: PrimExp VName
num_threads' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
num_threads
        gtid :: PrimExp VName
gtid = VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (SegSpace -> VName
segFlat SegSpace
kspace) PrimType
int32

        -- For the variant allocations, we add an inner dimension,
        -- which is then offset by a thread-specific amount.
        newBase :: SubExp -> ([PrimExp VName], PrimType) -> IxFun (PrimExp VName)
newBase SubExp
size_per_thread ([PrimExp VName]
old_shape, PrimType
pt) =
          let pt_size :: PrimExp VName
pt_size = Int32 -> PrimExp VName
forall e. IntegralExp e => Int32 -> e
fromInt32 (Int32 -> PrimExp VName) -> Int32 -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
              elems_per_thread :: PrimExp VName
elems_per_thread = ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int64 IntType
Int32)
                                 (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
size_per_thread)
                                 PrimExp VName -> PrimExp VName -> PrimExp VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimExp VName
pt_size
              root_ixfun :: IxFun (PrimExp VName)
root_ixfun = [PrimExp VName] -> IxFun (PrimExp VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [PrimExp VName
elems_per_thread, PrimExp VName
num_threads']
              offset_ixfun :: IxFun (PrimExp VName)
offset_ixfun = IxFun (PrimExp VName)
-> Slice (PrimExp VName) -> IxFun (PrimExp VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (PrimExp VName)
root_ixfun
                             [PrimExp VName
-> PrimExp VName -> PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> d -> d -> DimIndex d
DimSlice (Int32 -> PrimExp VName
forall e. IntegralExp e => Int32 -> e
fromInt32 Int32
0) PrimExp VName
num_threads' (Int32 -> PrimExp VName
forall e. IntegralExp e => Int32 -> e
fromInt32 Int32
1),
                              PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> DimIndex d
DimFix PrimExp VName
gtid]
              shapechange :: [DimChange (PrimExp VName)]
shapechange = if [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
old_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
                            then (PrimExp VName -> DimChange (PrimExp VName))
-> [PrimExp VName] -> [DimChange (PrimExp VName)]
forall a b. (a -> b) -> [a] -> [b]
map PrimExp VName -> DimChange (PrimExp VName)
forall d. d -> DimChange d
DimCoercion [PrimExp VName]
old_shape
                            else (PrimExp VName -> DimChange (PrimExp VName))
-> [PrimExp VName] -> [DimChange (PrimExp VName)]
forall a b. (a -> b) -> [a] -> [b]
map PrimExp VName -> DimChange (PrimExp VName)
forall d. d -> DimChange d
DimNew [PrimExp VName]
old_shape
          in IxFun (PrimExp VName)
-> [DimChange (PrimExp VName)] -> IxFun (PrimExp VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun (PrimExp VName)
offset_ixfun [DimChange (PrimExp VName)]
shapechange

-- | A map from memory block names to new index function bases.

type RebaseMap = M.Map VName (([PrimExp VName], PrimType) -> IxFun)

newtype OffsetM a = OffsetM (ReaderT (Scope ExplicitMemory)
                             (ReaderT RebaseMap (Either String)) a)
  deriving (Functor OffsetM
a -> OffsetM a
Functor OffsetM
-> (forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
    (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
OffsetM a -> OffsetM b -> OffsetM b
OffsetM a -> OffsetM b -> OffsetM a
OffsetM (a -> b) -> OffsetM a -> OffsetM b
(a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
$cp1Applicative :: Functor OffsetM
Applicative, a -> OffsetM b -> OffsetM a
(a -> b) -> OffsetM a -> OffsetM b
(forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor, Applicative OffsetM
a -> OffsetM a
Applicative OffsetM
-> (forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
OffsetM a -> (a -> OffsetM b) -> OffsetM b
OffsetM a -> OffsetM b -> OffsetM b
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$cp1Monad :: Applicative OffsetM
Monad,
            HasScope ExplicitMemory, LocalScope ExplicitMemory,
            MonadError String)

runOffsetM :: Scope ExplicitMemory -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: Scope ExplicitMemory -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope ExplicitMemory
scope RebaseMap
offsets (OffsetM ReaderT
  (Scope ExplicitMemory) (ReaderT RebaseMap (Either String)) a
m) =
  ReaderT RebaseMap (Either String) a -> RebaseMap -> Either String a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT
  (Scope ExplicitMemory) (ReaderT RebaseMap (Either String)) a
-> Scope ExplicitMemory -> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope ExplicitMemory) (ReaderT RebaseMap (Either String)) a
m Scope ExplicitMemory
scope) RebaseMap
offsets

askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = ReaderT
  (Scope ExplicitMemory)
  (ReaderT RebaseMap (Either String))
  RebaseMap
-> OffsetM RebaseMap
forall a.
ReaderT
  (Scope ExplicitMemory) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM (ReaderT
   (Scope ExplicitMemory)
   (ReaderT RebaseMap (Either String))
   RebaseMap
 -> OffsetM RebaseMap)
-> ReaderT
     (Scope ExplicitMemory)
     (ReaderT RebaseMap (Either String))
     RebaseMap
-> OffsetM RebaseMap
forall a b. (a -> b) -> a -> b
$ ReaderT RebaseMap (Either String) RebaseMap
-> ReaderT
     (Scope ExplicitMemory)
     (ReaderT RebaseMap (Either String))
     RebaseMap
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT RebaseMap (Either String) RebaseMap
forall r (m :: * -> *). MonadReader r m => m r
ask

lookupNewBase :: VName -> ([PrimExp VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([PrimExp VName], PrimType)
-> OffsetM (Maybe (IxFun (PrimExp VName)))
lookupNewBase VName
name ([PrimExp VName], PrimType)
x = do
  RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
  Maybe (IxFun (PrimExp VName))
-> OffsetM (Maybe (IxFun (PrimExp VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (PrimExp VName))
 -> OffsetM (Maybe (IxFun (PrimExp VName))))
-> Maybe (IxFun (PrimExp VName))
-> OffsetM (Maybe (IxFun (PrimExp VName)))
forall a b. (a -> b) -> a -> b
$ ((([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
-> ([PrimExp VName], PrimType) -> IxFun (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ ([PrimExp VName], PrimType)
x) ((([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
 -> IxFun (PrimExp VName))
-> Maybe (([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
-> Maybe (IxFun (PrimExp VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe (([PrimExp VName], PrimType) -> IxFun (PrimExp VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: KernelBody ExplicitMemory -> OffsetM (KernelBody ExplicitMemory)
offsetMemoryInKernelBody :: KernelBody ExplicitMemory -> OffsetM (KernelBody ExplicitMemory)
offsetMemoryInKernelBody KernelBody ExplicitMemory
kbody = do
  Scope ExplicitMemory
scope <- OffsetM (Scope ExplicitMemory)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Stms ExplicitMemory
stms' <- [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm ExplicitMemory] -> Stms ExplicitMemory)
-> ((Scope ExplicitMemory, [Stm ExplicitMemory])
    -> [Stm ExplicitMemory])
-> (Scope ExplicitMemory, [Stm ExplicitMemory])
-> Stms ExplicitMemory
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope ExplicitMemory, [Stm ExplicitMemory])
-> [Stm ExplicitMemory]
forall a b. (a, b) -> b
snd ((Scope ExplicitMemory, [Stm ExplicitMemory])
 -> Stms ExplicitMemory)
-> OffsetM (Scope ExplicitMemory, [Stm ExplicitMemory])
-> OffsetM (Stms ExplicitMemory)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
           (Scope ExplicitMemory
 -> Stm ExplicitMemory
 -> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory))
-> Scope ExplicitMemory
-> [Stm ExplicitMemory]
-> OffsetM (Scope ExplicitMemory, [Stm ExplicitMemory])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (\Scope ExplicitMemory
scope' -> Scope ExplicitMemory
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope' (OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
 -> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory))
-> (Stm ExplicitMemory
    -> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory))
-> Stm ExplicitMemory
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm ExplicitMemory
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
offsetMemoryInStm) Scope ExplicitMemory
scope
           (Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms ExplicitMemory -> [Stm ExplicitMemory])
-> Stms ExplicitMemory -> [Stm ExplicitMemory]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody)
  KernelBody ExplicitMemory -> OffsetM (KernelBody ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody ExplicitMemory
kbody { kernelBodyStms :: Stms ExplicitMemory
kernelBodyStms = Stms ExplicitMemory
stms' }

offsetMemoryInBody :: Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
offsetMemoryInBody :: Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
offsetMemoryInBody (Body BodyAttr ExplicitMemory
attr Stms ExplicitMemory
stms Result
res) = do
  Scope ExplicitMemory
scope <- OffsetM (Scope ExplicitMemory)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Stms ExplicitMemory
stms' <- [Stm ExplicitMemory] -> Stms ExplicitMemory
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm ExplicitMemory] -> Stms ExplicitMemory)
-> ((Scope ExplicitMemory, [Stm ExplicitMemory])
    -> [Stm ExplicitMemory])
-> (Scope ExplicitMemory, [Stm ExplicitMemory])
-> Stms ExplicitMemory
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope ExplicitMemory, [Stm ExplicitMemory])
-> [Stm ExplicitMemory]
forall a b. (a, b) -> b
snd ((Scope ExplicitMemory, [Stm ExplicitMemory])
 -> Stms ExplicitMemory)
-> OffsetM (Scope ExplicitMemory, [Stm ExplicitMemory])
-> OffsetM (Stms ExplicitMemory)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
           (Scope ExplicitMemory
 -> Stm ExplicitMemory
 -> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory))
-> Scope ExplicitMemory
-> [Stm ExplicitMemory]
-> OffsetM (Scope ExplicitMemory, [Stm ExplicitMemory])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (\Scope ExplicitMemory
scope' -> Scope ExplicitMemory
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope' (OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
 -> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory))
-> (Stm ExplicitMemory
    -> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory))
-> Stm ExplicitMemory
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm ExplicitMemory
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
offsetMemoryInStm) Scope ExplicitMemory
scope
           (Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms ExplicitMemory
stms)
  Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body ExplicitMemory -> OffsetM (Body ExplicitMemory))
-> Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ BodyAttr ExplicitMemory
-> Stms ExplicitMemory -> Result -> Body ExplicitMemory
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body BodyAttr ExplicitMemory
attr Stms ExplicitMemory
stms' Result
res

offsetMemoryInStm :: Stm ExplicitMemory -> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
offsetMemoryInStm :: Stm ExplicitMemory
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
offsetMemoryInStm (Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
attr Exp ExplicitMemory
e) = do
  PatternT (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pattern ExplicitMemory -> OffsetM (Pattern ExplicitMemory)
offsetMemoryInPattern Pattern ExplicitMemory
pat
  Exp ExplicitMemory
e' <- Scope ExplicitMemory
-> OffsetM (Exp ExplicitMemory) -> OffsetM (Exp ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> Scope ExplicitMemory
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') (OffsetM (Exp ExplicitMemory) -> OffsetM (Exp ExplicitMemory))
-> OffsetM (Exp ExplicitMemory) -> OffsetM (Exp ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Exp ExplicitMemory -> OffsetM (Exp ExplicitMemory)
offsetMemoryInExp Exp ExplicitMemory
e
  Scope ExplicitMemory
scope <- OffsetM (Scope ExplicitMemory)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <- ReaderT (Scope ExplicitMemory) OffsetM [ExpReturns]
-> Scope ExplicitMemory -> OffsetM [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Exp ExplicitMemory
-> ReaderT (Scope ExplicitMemory) OffsetM [ExpReturns]
forall (m :: * -> *) lore.
(Monad m, HasScope lore m, ExplicitMemorish lore) =>
Exp lore -> m [ExpReturns]
expReturns Exp ExplicitMemory
e') Scope ExplicitMemory
scope
  let pat'' :: PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat')
              ((PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts)
      stm :: Stm ExplicitMemory
stm = Pattern ExplicitMemory
-> StmAux (ExpAttr ExplicitMemory)
-> Exp ExplicitMemory
-> Stm ExplicitMemory
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpAttr ExplicitMemory)
attr Exp ExplicitMemory
e'
  let scope' :: Scope ExplicitMemory
scope' = Stm ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stm ExplicitMemory
stm Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> Scope ExplicitMemory
scope
  (Scope ExplicitMemory, Stm ExplicitMemory)
-> OffsetM (Scope ExplicitMemory, Stm ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scope ExplicitMemory
scope', Stm ExplicitMemory
stm)
  where pick :: PatElemT (MemInfo SubExp NoUniqueness MemBind) ->
                ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
        pick :: PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElemT (MemInfo SubExp NoUniqueness MemBind)
pick (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
             (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
          | Just IxFun (PrimExp VName)
ixfun <- ExtIxFun -> Maybe (IxFun (PrimExp VName))
instantiateIxFun ExtIxFun
extixfun =
              VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
name (PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (PrimExp VName) -> MemBind
ArrayIn VName
m IxFun (PrimExp VName)
ixfun))
        pick PatElemT (MemInfo SubExp NoUniqueness MemBind)
p ExpReturns
_ = PatElemT (MemInfo SubExp NoUniqueness MemBind)
p

        instantiateIxFun :: ExtIxFun -> Maybe IxFun
        instantiateIxFun :: ExtIxFun -> Maybe (IxFun (PrimExp VName))
instantiateIxFun = (PrimExp (Ext VName) -> Maybe (PrimExp VName))
-> ExtIxFun -> Maybe (IxFun (PrimExp VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> PrimExp (Ext VName) -> Maybe (PrimExp VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall a. Ext a -> Maybe a
inst)
          where inst :: Ext a -> Maybe a
inst Ext{} = Maybe a
forall a. Maybe a
Nothing
                inst (Free a
x) = a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

offsetMemoryInPattern :: Pattern ExplicitMemory -> OffsetM (Pattern ExplicitMemory)
offsetMemoryInPattern :: Pattern ExplicitMemory -> OffsetM (Pattern ExplicitMemory)
offsetMemoryInPattern (Pattern [PatElemT (LetAttr ExplicitMemory)]
ctx [PatElemT (LetAttr ExplicitMemory)]
vals) = do
  (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)] -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> OffsetM ()
forall attr (m :: * -> *).
(Typed attr, MonadError String m) =>
PatElemT attr -> m ()
inspectCtx [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx
  [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctx ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> PatternT (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall u. PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal [PatElemT (LetAttr ExplicitMemory)]
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]
vals
  where inspectVal :: PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
inspectVal PatElemT (MemBound u)
patElem = do
          MemBound u
new_attr <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ PatElemT (MemBound u) -> MemBound u
forall attr. PatElemT attr -> attr
patElemAttr PatElemT (MemBound u)
patElem
          PatElemT (MemBound u) -> OffsetM (PatElemT (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT (MemBound u)
patElem { patElemAttr :: MemBound u
patElemAttr = MemBound u
new_attr }
        inspectCtx :: PatElemT attr -> m ()
inspectCtx PatElemT attr
patElem
          | Mem Space
space <- PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
patElem,
            Space -> Bool
expandable Space
space =
              String -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords [String
"Cannot deal with existential memory block",
                                    VName -> String
forall a. Pretty a => a -> String
pretty (PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
patElem),
                                    String
"when expanding inside kernels."]
          | Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
  MemBound u
fparam' <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ Param (MemBound u) -> MemBound u
forall attr. Param attr -> attr
paramAttr Param (MemBound u)
fparam
  Param (MemBound u) -> OffsetM (Param (MemBound u))
forall (m :: * -> *) a. Monad m => a -> m a
return Param (MemBound u)
fparam { paramAttr :: MemBound u
paramAttr = MemBound u
fparam' }

offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (PrimExp VName)
ixfun)) = do
  Maybe (IxFun (PrimExp VName))
new_base <- VName
-> ([PrimExp VName], PrimType)
-> OffsetM (Maybe (IxFun (PrimExp VName)))
lookupNewBase VName
mem (IxFun (PrimExp VName) -> [PrimExp VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (PrimExp VName)
ixfun, PrimType
pt)
  MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ MemBound u -> Maybe (MemBound u) -> MemBound u
forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary (Maybe (MemBound u) -> MemBound u)
-> Maybe (MemBound u) -> MemBound u
forall a b. (a -> b) -> a -> b
$ do
    IxFun (PrimExp VName)
new_base' <- Maybe (IxFun (PrimExp VName))
new_base
    MemBound u -> Maybe (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemBound u -> Maybe (MemBound u))
-> MemBound u -> Maybe (MemBound u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (PrimExp VName) -> MemBind
ArrayIn VName
mem (IxFun (PrimExp VName) -> MemBind)
-> IxFun (PrimExp VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun (PrimExp VName)
-> IxFun (PrimExp VName) -> IxFun (PrimExp VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (PrimExp VName)
new_base' IxFun (PrimExp VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = MemBound u -> OffsetM (MemBound u)
forall (m :: * -> *) a. Monad m => a -> m a
return MemBound u
summary

offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns br :: BodyReturns
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
  | Just IxFun (PrimExp VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (PrimExp VName))
isStaticIxFun ExtIxFun
ixfun = do
      Maybe (IxFun (PrimExp VName))
new_base <- VName
-> ([PrimExp VName], PrimType)
-> OffsetM (Maybe (IxFun (PrimExp VName)))
lookupNewBase VName
mem (IxFun (PrimExp VName) -> [PrimExp VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (PrimExp VName)
ixfun', PrimType
pt)
      BodyReturns -> OffsetM BodyReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyReturns -> OffsetM BodyReturns)
-> BodyReturns -> OffsetM BodyReturns
forall a b. (a -> b) -> a -> b
$ BodyReturns -> Maybe BodyReturns -> BodyReturns
forall a. a -> Maybe a -> a
fromMaybe BodyReturns
br (Maybe BodyReturns -> BodyReturns)
-> Maybe BodyReturns -> BodyReturns
forall a b. (a -> b) -> a -> b
$ do
        IxFun (PrimExp VName)
new_base' <- Maybe (IxFun (PrimExp VName))
new_base
        BodyReturns -> Maybe BodyReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyReturns -> Maybe BodyReturns)
-> BodyReturns -> Maybe BodyReturns
forall a b. (a -> b) -> a -> b
$
          PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BodyReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BodyReturns) -> MemReturn -> BodyReturns
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
          ExtIxFun -> ExtIxFun -> ExtIxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase ((PrimExp VName -> PrimExp (Ext VName))
-> IxFun (PrimExp VName) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun (PrimExp VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BodyReturns
br = BodyReturns -> OffsetM BodyReturns
forall (m :: * -> *) a. Monad m => a -> m a
return BodyReturns
br

offsetMemoryInLambda :: Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory)
offsetMemoryInLambda :: Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory)
offsetMemoryInLambda Lambda ExplicitMemory
lam = Lambda ExplicitMemory
-> OffsetM (Lambda ExplicitMemory)
-> OffsetM (Lambda ExplicitMemory)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Lambda ExplicitMemory
lam (OffsetM (Lambda ExplicitMemory)
 -> OffsetM (Lambda ExplicitMemory))
-> OffsetM (Lambda ExplicitMemory)
-> OffsetM (Lambda ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ do
  Body ExplicitMemory
body <- Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
offsetMemoryInBody (Body ExplicitMemory -> OffsetM (Body ExplicitMemory))
-> Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
  Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory))
-> Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory
lam { lambdaBody :: Body ExplicitMemory
lambdaBody = Body ExplicitMemory
body }

offsetMemoryInExp :: Exp ExplicitMemory -> OffsetM (Exp ExplicitMemory)
offsetMemoryInExp :: Exp ExplicitMemory -> OffsetM (Exp ExplicitMemory)
offsetMemoryInExp (DoLoop [(FParam ExplicitMemory, SubExp)]
ctx [(FParam ExplicitMemory, SubExp)]
val LoopForm ExplicitMemory
form Body ExplicitMemory
body) = do
  let ([Param (MemBound Uniqueness)]
ctxparams, Result
ctxinit) = [(Param (MemBound Uniqueness), SubExp)]
-> ([Param (MemBound Uniqueness)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam ExplicitMemory, SubExp)]
[(Param (MemBound Uniqueness), SubExp)]
ctx
      ([Param (MemBound Uniqueness)]
valparams, Result
valinit) = [(Param (MemBound Uniqueness), SubExp)]
-> ([Param (MemBound Uniqueness)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam ExplicitMemory, SubExp)]
[(Param (MemBound Uniqueness), SubExp)]
val
  [Param (MemBound Uniqueness)]
ctxparams' <- (Param (MemBound Uniqueness)
 -> OffsetM (Param (MemBound Uniqueness)))
-> [Param (MemBound Uniqueness)]
-> OffsetM [Param (MemBound Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemBound Uniqueness)
-> OffsetM (Param (MemBound Uniqueness))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemBound Uniqueness)]
ctxparams
  [Param (MemBound Uniqueness)]
valparams' <- (Param (MemBound Uniqueness)
 -> OffsetM (Param (MemBound Uniqueness)))
-> [Param (MemBound Uniqueness)]
-> OffsetM [Param (MemBound Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemBound Uniqueness)
-> OffsetM (Param (MemBound Uniqueness))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemBound Uniqueness)]
valparams
  Body ExplicitMemory
body' <- Scope ExplicitMemory
-> OffsetM (Body ExplicitMemory) -> OffsetM (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (MemBound Uniqueness)] -> Scope ExplicitMemory
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param (MemBound Uniqueness)]
ctxparams' Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> [Param (MemBound Uniqueness)] -> Scope ExplicitMemory
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param (MemBound Uniqueness)]
valparams' Scope ExplicitMemory
-> Scope ExplicitMemory -> Scope ExplicitMemory
forall a. Semigroup a => a -> a -> a
<> LoopForm ExplicitMemory -> Scope ExplicitMemory
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm ExplicitMemory
form) (Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
offsetMemoryInBody Body ExplicitMemory
body)
  Exp ExplicitMemory -> OffsetM (Exp ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp ExplicitMemory -> OffsetM (Exp ExplicitMemory))
-> Exp ExplicitMemory -> OffsetM (Exp ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ [(FParam ExplicitMemory, SubExp)]
-> [(FParam ExplicitMemory, SubExp)]
-> LoopForm ExplicitMemory
-> Body ExplicitMemory
-> Exp ExplicitMemory
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop ([Param (MemBound Uniqueness)]
-> Result -> [(Param (MemBound Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemBound Uniqueness)]
ctxparams' Result
ctxinit) ([Param (MemBound Uniqueness)]
-> Result -> [(Param (MemBound Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemBound Uniqueness)]
valparams' Result
valinit) LoopForm ExplicitMemory
form Body ExplicitMemory
body'
offsetMemoryInExp Exp ExplicitMemory
e = Mapper ExplicitMemory ExplicitMemory OffsetM
-> Exp ExplicitMemory -> OffsetM (Exp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper ExplicitMemory ExplicitMemory OffsetM
recurse Exp ExplicitMemory
e
  where recurse :: Mapper ExplicitMemory ExplicitMemory OffsetM
recurse = Mapper ExplicitMemory ExplicitMemory OffsetM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
                  { mapOnBody :: Scope ExplicitMemory
-> Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
mapOnBody = \Scope ExplicitMemory
bscope -> Scope ExplicitMemory
-> OffsetM (Body ExplicitMemory) -> OffsetM (Body ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
bscope (OffsetM (Body ExplicitMemory) -> OffsetM (Body ExplicitMemory))
-> (Body ExplicitMemory -> OffsetM (Body ExplicitMemory))
-> Body ExplicitMemory
-> OffsetM (Body ExplicitMemory)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body ExplicitMemory -> OffsetM (Body ExplicitMemory)
offsetMemoryInBody
                  , mapOnBranchType :: BranchType ExplicitMemory -> OffsetM (BranchType ExplicitMemory)
mapOnBranchType = BranchType ExplicitMemory -> OffsetM (BranchType ExplicitMemory)
BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns
                  , mapOnOp :: Op ExplicitMemory -> OffsetM (Op ExplicitMemory)
mapOnOp = Op ExplicitMemory -> OffsetM (Op ExplicitMemory)
forall op.
MemOp (HostOp ExplicitMemory op)
-> OffsetM (MemOp (HostOp ExplicitMemory op))
onOp
                  }
        onOp :: MemOp (HostOp ExplicitMemory op)
-> OffsetM (MemOp (HostOp ExplicitMemory op))
onOp (Inner (SegOp SegOp ExplicitMemory
op)) =
          HostOp ExplicitMemory op -> MemOp (HostOp ExplicitMemory op)
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory op -> MemOp (HostOp ExplicitMemory op))
-> (SegOp ExplicitMemory -> HostOp ExplicitMemory op)
-> SegOp ExplicitMemory
-> MemOp (HostOp ExplicitMemory op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp ExplicitMemory -> HostOp ExplicitMemory op
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> MemOp (HostOp ExplicitMemory op))
-> OffsetM (SegOp ExplicitMemory)
-> OffsetM (MemOp (HostOp ExplicitMemory op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          Scope ExplicitMemory
-> OffsetM (SegOp ExplicitMemory) -> OffsetM (SegOp ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope ExplicitMemory
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegOp ExplicitMemory -> SegSpace
forall lore. SegOp lore -> SegSpace
segSpace SegOp ExplicitMemory
op)) (SegOpMapper ExplicitMemory ExplicitMemory OffsetM
-> SegOp ExplicitMemory -> OffsetM (SegOp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper ExplicitMemory ExplicitMemory OffsetM
segOpMapper SegOp ExplicitMemory
op)
          where segOpMapper :: SegOpMapper ExplicitMemory ExplicitMemory OffsetM
segOpMapper =
                  SegOpMapper Any Any OffsetM
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper { mapOnSegOpBody :: KernelBody ExplicitMemory -> OffsetM (KernelBody ExplicitMemory)
mapOnSegOpBody = KernelBody ExplicitMemory -> OffsetM (KernelBody ExplicitMemory)
offsetMemoryInKernelBody
                                      , mapOnSegOpLambda :: Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory)
mapOnSegOpLambda = Lambda ExplicitMemory -> OffsetM (Lambda ExplicitMemory)
offsetMemoryInLambda
                                      }
        onOp MemOp (HostOp ExplicitMemory op)
op = MemOp (HostOp ExplicitMemory op)
-> OffsetM (MemOp (HostOp ExplicitMemory op))
forall (m :: * -> *) a. Monad m => a -> m a
return MemOp (HostOp ExplicitMemory op)
op


---- Slicing allocation sizes out of a kernel.

unAllocKernelsStms :: Stms ExplicitMemory -> Either String (Stms Kernels.Kernels)
unAllocKernelsStms :: Stms ExplicitMemory -> Either String (Stms Kernels)
unAllocKernelsStms = Bool -> Stms ExplicitMemory -> Either String (Stms Kernels)
unAllocStms Bool
False
  where
    unAllocBody :: Body ExplicitMemory -> Either String (BodyT Kernels)
unAllocBody (Body BodyAttr ExplicitMemory
attr Stms ExplicitMemory
stms Result
res) =
      BodyAttr Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body BodyAttr Kernels
BodyAttr ExplicitMemory
attr (Stms Kernels -> Result -> BodyT Kernels)
-> Either String (Stms Kernels)
-> Either String (Result -> BodyT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms ExplicitMemory -> Either String (Stms Kernels)
unAllocStms Bool
True Stms ExplicitMemory
stms Either String (Result -> BodyT Kernels)
-> Either String Result -> Either String (BodyT Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either String Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody ExplicitMemory -> Either String (KernelBody Kernels)
unAllocKernelBody (KernelBody BodyAttr ExplicitMemory
attr Stms ExplicitMemory
stms [KernelResult]
res) =
      BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyAttr Kernels
BodyAttr ExplicitMemory
attr (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> Either String (Stms Kernels)
-> Either String ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms ExplicitMemory -> Either String (Stms Kernels)
unAllocStms Bool
True Stms ExplicitMemory
stms Either String ([KernelResult] -> KernelBody Kernels)
-> Either String [KernelResult]
-> Either String (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either String [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms ExplicitMemory -> Either String (Stms Kernels)
unAllocStms Bool
nested =
      ([Maybe (Stm Kernels)] -> Stms Kernels)
-> Either String [Maybe (Stm Kernels)]
-> Either String (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> ([Maybe (Stm Kernels)] -> [Stm Kernels])
-> [Maybe (Stm Kernels)]
-> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Stm Kernels)] -> [Stm Kernels]
forall a. [Maybe a] -> [a]
catMaybes) (Either String [Maybe (Stm Kernels)]
 -> Either String (Stms Kernels))
-> (Stms ExplicitMemory -> Either String [Maybe (Stm Kernels)])
-> Stms ExplicitMemory
-> Either String (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm ExplicitMemory -> Either String (Maybe (Stm Kernels)))
-> [Stm ExplicitMemory] -> Either String [Maybe (Stm Kernels)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm ExplicitMemory -> Either String (Maybe (Stm Kernels))
unAllocStm Bool
nested) ([Stm ExplicitMemory] -> Either String [Maybe (Stm Kernels)])
-> (Stms ExplicitMemory -> [Stm ExplicitMemory])
-> Stms ExplicitMemory
-> Either String [Maybe (Stm Kernels)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms ExplicitMemory -> [Stm ExplicitMemory]
forall lore. Stms lore -> [Stm lore]
stmsToList

    unAllocStm :: Bool -> Stm ExplicitMemory -> Either String (Maybe (Stm Kernels))
unAllocStm Bool
nested stm :: Stm ExplicitMemory
stm@(Let Pattern ExplicitMemory
_ StmAux (ExpAttr ExplicitMemory)
_ (Op Alloc{}))
      | Bool
nested = String -> Either String (Maybe (Stm Kernels))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String (Maybe (Stm Kernels)))
-> String -> Either String (Maybe (Stm Kernels))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm ExplicitMemory -> String
forall a. Pretty a => a -> String
pretty Stm ExplicitMemory
stm
      | Bool
otherwise = Maybe (Stm Kernels) -> Either String (Maybe (Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stm Kernels)
forall a. Maybe a
Nothing
    unAllocStm Bool
_ (Let Pattern ExplicitMemory
pat StmAux (ExpAttr ExplicitMemory)
attr Exp ExplicitMemory
e) =
      Stm Kernels -> Maybe (Stm Kernels)
forall a. a -> Maybe a
Just (Stm Kernels -> Maybe (Stm Kernels))
-> Either String (Stm Kernels)
-> Either String (Maybe (Stm Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatternT Type -> StmAux () -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let (PatternT Type -> StmAux () -> ExpT Kernels -> Stm Kernels)
-> Either String (PatternT Type)
-> Either String (StmAux () -> ExpT Kernels -> Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatternT (MemInfo SubExp NoUniqueness MemBind)
-> Either String (PatternT Type)
forall d u ret.
Pretty (PatElemT (MemInfo d u ret)) =>
PatternT (MemInfo d u ret)
-> Either String (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat Either String (StmAux () -> ExpT Kernels -> Stm Kernels)
-> Either String (StmAux ())
-> Either String (ExpT Kernels -> Stm Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpAttr ExplicitMemory)
attr Either String (ExpT Kernels -> Stm Kernels)
-> Either String (ExpT Kernels) -> Either String (Stm Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper ExplicitMemory Kernels (Either String)
-> Exp ExplicitMemory -> Either String (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper ExplicitMemory Kernels (Either String)
unAlloc' Exp ExplicitMemory
e)

    unAllocLambda :: Lambda ExplicitMemory -> Either String (Lambda Kernels)
unAllocLambda (Lambda [LParam ExplicitMemory]
params Body ExplicitMemory
body [Type]
ret) =
      [LParam Kernels] -> BodyT Kernels -> [Type] -> Lambda Kernels
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [Param Type]
forall d u ret.
[Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (BodyT Kernels -> [Type] -> Lambda Kernels)
-> Either String (BodyT Kernels)
-> Either String ([Type] -> Lambda Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body ExplicitMemory -> Either String (BodyT Kernels)
unAllocBody Body ExplicitMemory
body Either String ([Type] -> Lambda Kernels)
-> Either String [Type] -> Either String (Lambda Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> Either String [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

    unParams :: [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
unParams = (Param (MemInfo d u ret)
 -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Param (MemInfo d u ret)
  -> Maybe (Param (TypeBase (ShapeBase d) u)))
 -> [Param (MemInfo d u ret)] -> [Param (TypeBase (ShapeBase d) u)])
-> (Param (MemInfo d u ret)
    -> Maybe (Param (TypeBase (ShapeBase d) u)))
-> [Param (MemInfo d u ret)]
-> [Param (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> Param (MemInfo d u ret)
-> Maybe (Param (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr

    unAllocPattern :: PatternT (MemInfo d u ret)
-> Either String (PatternT (TypeBase (ShapeBase d) u))
unAllocPattern pat :: PatternT (MemInfo d u ret)
pat@(Pattern [PatElemT (MemInfo d u ret)]
ctx [PatElemT (MemInfo d u ret)]
val) =
      [PatElemT (TypeBase (ShapeBase d) u)]
-> [PatElemT (TypeBase (ShapeBase d) u)]
-> PatternT (TypeBase (ShapeBase d) u)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern ([PatElemT (TypeBase (ShapeBase d) u)]
 -> [PatElemT (TypeBase (ShapeBase d) u)]
 -> PatternT (TypeBase (ShapeBase d) u))
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> Either
     String
     ([PatElemT (TypeBase (ShapeBase d) u)]
      -> PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr) [PatElemT (MemInfo d u ret)]
ctx)
              Either
  String
  ([PatElemT (TypeBase (ShapeBase d) u)]
   -> PatternT (TypeBase (ShapeBase d) u))
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String (PatternT (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Either String [PatElemT (TypeBase (ShapeBase d) u)]
-> ([PatElemT (TypeBase (ShapeBase d) u)]
    -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad [PatElemT (TypeBase (ShapeBase d) u)]
-> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall (m :: * -> *) a. Monad m => a -> m a
return ((PatElemT (MemInfo d u ret)
 -> Maybe (PatElemT (TypeBase (ShapeBase d) u)))
-> [PatElemT (MemInfo d u ret)]
-> Maybe [PatElemT (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> PatElemT (MemInfo d u ret)
-> Maybe (PatElemT (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElemT from -> m (PatElemT to)
rephrasePatElem MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr) [PatElemT (MemInfo d u ret)]
val)
      where bad :: Either String [PatElemT (TypeBase (ShapeBase d) u)]
bad = String -> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. a -> Either a b
Left (String -> Either String [PatElemT (TypeBase (ShapeBase d) u)])
-> String -> Either String [PatElemT (TypeBase (ShapeBase d) u)]
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory in pattern " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo d u ret) -> String
forall a. Pretty a => a -> String
pretty PatternT (MemInfo d u ret)
pat

    unAllocOp :: MemOp (HostOp ExplicitMemory ())
-> Either String (HostOp Kernels (SOAC Kernels))
unAllocOp Alloc{} = String -> Either String (HostOp Kernels (SOAC Kernels))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp{}) = String -> Either String (HostOp Kernels (SOAC Kernels))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner (SizeOp SizeOp
op)) =
      HostOp Kernels (SOAC Kernels)
-> Either String (HostOp Kernels (SOAC Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (HostOp Kernels (SOAC Kernels)
 -> Either String (HostOp Kernels (SOAC Kernels)))
-> HostOp Kernels (SOAC Kernels)
-> Either String (HostOp Kernels (SOAC Kernels))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp ExplicitMemory
op)) = SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp Kernels -> HostOp Kernels (SOAC Kernels))
-> Either String (SegOp Kernels)
-> Either String (HostOp Kernels (SOAC Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper ExplicitMemory Kernels (Either String)
-> SegOp ExplicitMemory -> Either String (SegOp Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper ExplicitMemory Kernels (Either String)
mapper SegOp ExplicitMemory
op
      where mapper :: SegOpMapper ExplicitMemory Kernels (Either String)
mapper = SegOpMapper Any Any (Either String)
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper { mapOnSegOpLambda :: Lambda ExplicitMemory -> Either String (Lambda Kernels)
mapOnSegOpLambda = Lambda ExplicitMemory -> Either String (Lambda Kernels)
unAllocLambda
                                         , mapOnSegOpBody :: KernelBody ExplicitMemory -> Either String (KernelBody Kernels)
mapOnSegOpBody = KernelBody ExplicitMemory -> Either String (KernelBody Kernels)
unAllocKernelBody
                                         }

    unParam :: t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam t (MemInfo d u ret)
p = Either String (t (TypeBase (ShapeBase d) u))
-> (t (TypeBase (ShapeBase d) u)
    -> Either String (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either String (t (TypeBase (ShapeBase d) u))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String (t (TypeBase (ShapeBase d) u))
bad t (TypeBase (ShapeBase d) u)
-> Either String (t (TypeBase (ShapeBase d) u))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (t (TypeBase (ShapeBase d) u))
 -> Either String (t (TypeBase (ShapeBase d) u)))
-> Maybe (t (TypeBase (ShapeBase d) u))
-> Either String (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ (MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u))
-> t (MemInfo d u ret) -> Maybe (t (TypeBase (ShapeBase d) u))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr t (MemInfo d u ret)
p
      where bad :: Either String (t (TypeBase (ShapeBase d) u))
bad = String -> Either String (t (TypeBase (ShapeBase d) u))
forall a b. a -> Either a b
Left (String -> Either String (t (TypeBase (ShapeBase d) u)))
-> String -> Either String (t (TypeBase (ShapeBase d) u))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory-typed parameter '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ t (MemInfo d u ret) -> String
forall a. Pretty a => a -> String
pretty t (MemInfo d u ret)
p String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

    unT :: MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT MemInfo d u ret
t = Either String (TypeBase (ShapeBase d) u)
-> (TypeBase (ShapeBase d) u
    -> Either String (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either String (TypeBase (ShapeBase d) u)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String (TypeBase (ShapeBase d) u)
bad TypeBase (ShapeBase d) u
-> Either String (TypeBase (ShapeBase d) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TypeBase (ShapeBase d) u)
 -> Either String (TypeBase (ShapeBase d) u))
-> Maybe (TypeBase (ShapeBase d) u)
-> Either String (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr MemInfo d u ret
t
      where bad :: Either String (TypeBase (ShapeBase d) u)
bad = String -> Either String (TypeBase (ShapeBase d) u)
forall a b. a -> Either a b
Left (String -> Either String (TypeBase (ShapeBase d) u))
-> String -> Either String (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle memory type '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ MemInfo d u ret -> String
forall a. Pretty a => a -> String
pretty MemInfo d u ret
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

    unAlloc' :: Mapper ExplicitMemory Kernels (Either String)
unAlloc' = Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper { mapOnBody :: Scope Kernels
-> Body ExplicitMemory -> Either String (BodyT Kernels)
mapOnBody = (Body ExplicitMemory -> Either String (BodyT Kernels))
-> Scope Kernels
-> Body ExplicitMemory
-> Either String (BodyT Kernels)
forall a b. a -> b -> a
const Body ExplicitMemory -> Either String (BodyT Kernels)
unAllocBody
                      , mapOnRetType :: RetType ExplicitMemory -> Either String (RetType Kernels)
mapOnRetType = RetType ExplicitMemory -> Either String (RetType Kernels)
forall d u ret.
(Pretty d, Pretty u, Pretty ret,
 Pretty (TypeBase (ShapeBase d) u)) =>
MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT
                      , mapOnBranchType :: BranchType ExplicitMemory -> Either String (BranchType Kernels)
mapOnBranchType = BranchType ExplicitMemory -> Either String (BranchType Kernels)
forall d u ret.
(Pretty d, Pretty u, Pretty ret,
 Pretty (TypeBase (ShapeBase d) u)) =>
MemInfo d u ret -> Either String (TypeBase (ShapeBase d) u)
unT
                      , mapOnFParam :: FParam ExplicitMemory -> Either String (FParam Kernels)
mapOnFParam = FParam ExplicitMemory -> Either String (FParam Kernels)
forall (t :: * -> *) d u ret.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam
                      , mapOnLParam :: LParam ExplicitMemory -> Either String (LParam Kernels)
mapOnLParam = LParam ExplicitMemory -> Either String (LParam Kernels)
forall (t :: * -> *) d u ret.
(Pretty (t (MemInfo d u ret)), Traversable t) =>
t (MemInfo d u ret) -> Either String (t (TypeBase (ShapeBase d) u))
unParam
                      , mapOnOp :: Op ExplicitMemory -> Either String (Op Kernels)
mapOnOp = Op ExplicitMemory -> Either String (Op Kernels)
MemOp (HostOp ExplicitMemory ())
-> Either String (HostOp Kernels (SOAC Kernels))
unAllocOp
                      , mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = SubExp -> Either String SubExp
forall a b. b -> Either a b
Right
                      , mapOnVName :: VName -> Either String VName
mapOnVName = VName -> Either String VName
forall a b. b -> Either a b
Right
                      }

unAttr :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr :: MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr (MemPrim PrimType
pt) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unAttr (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u))
-> TypeBase (ShapeBase d) u -> Maybe (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unAttr MemMem{} = Maybe (TypeBase (ShapeBase d) u)
forall a. Maybe a
Nothing

unAllocScope :: Scope ExplicitMemory -> Scope Kernels.Kernels
unAllocScope :: Scope ExplicitMemory -> Scope Kernels
unAllocScope = (NameInfo ExplicitMemory -> Maybe (NameInfo Kernels))
-> Scope ExplicitMemory -> Scope Kernels
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe NameInfo ExplicitMemory -> Maybe (NameInfo Kernels)
forall lore d u ret lore d u d u ret ret.
(FParamAttr lore ~ MemInfo d u ret,
 LParamAttr lore ~ TypeBase (ShapeBase d) u,
 LetAttr lore ~ TypeBase (ShapeBase d) u,
 FParamAttr lore ~ TypeBase (ShapeBase d) u,
 LetAttr lore ~ MemInfo d u ret,
 LParamAttr lore ~ MemInfo d u ret) =>
NameInfo lore -> Maybe (NameInfo lore)
unInfo
  where unInfo :: NameInfo lore -> Maybe (NameInfo lore)
unInfo (LetInfo LetAttr lore
attr) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. LetAttr lore -> NameInfo lore
LetInfo (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr LetAttr lore
MemInfo d u ret
attr
        unInfo (FParamInfo FParamAttr lore
attr) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. FParamAttr lore -> NameInfo lore
FParamInfo (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr FParamAttr lore
MemInfo d u ret
attr
        unInfo (LParamInfo LParamAttr lore
attr) = TypeBase (ShapeBase d) u -> NameInfo lore
forall lore. LParamAttr lore -> NameInfo lore
LParamInfo (TypeBase (ShapeBase d) u -> NameInfo lore)
-> Maybe (TypeBase (ShapeBase d) u) -> Maybe (NameInfo lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
forall d u ret. MemInfo d u ret -> Maybe (TypeBase (ShapeBase d) u)
unAttr LParamAttr lore
MemInfo d u ret
attr
        unInfo (IndexInfo IntType
it) = NameInfo lore -> Maybe (NameInfo lore)
forall a. a -> Maybe a
Just (NameInfo lore -> Maybe (NameInfo lore))
-> NameInfo lore -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexInfo IntType
it

removeCommonSizes :: Extraction
                  -> [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
 -> (VName, (SubExp, Space)) -> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, (SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, (SubExp, Space)) -> Map SubExp [(VName, Space)]
forall k a b.
Ord k =>
Map k [(a, b)] -> (a, (k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, (SubExp, Space))] -> Map SubExp [(VName, Space)])
-> (Extraction -> [(VName, (SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction -> [(VName, (SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
  where comb :: Map k [(a, b)] -> (a, (k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

sliceKernelSizes :: SubExp -> [SubExp] -> SegSpace -> Stms ExplicitMemory
                 -> ExpandM (Stms Kernels.Kernels, [VName], [VName])
sliceKernelSizes :: SubExp
-> Result
-> SegSpace
-> Stms ExplicitMemory
-> ExpandM (Stms Kernels, [VName], [VName])
sliceKernelSizes SubExp
num_threads Result
sizes SegSpace
space Stms ExplicitMemory
kstms = do
  Stms Kernels
kstms' <- (String
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Stms Kernels))
-> (Stms Kernels
    -> ReaderT
         (Scope ExplicitMemory) (State VNameSource) (Stms Kernels))
-> Either String (Stms Kernels)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms Kernels)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms Kernels)
forall a. String -> a
compilerLimitationS Stms Kernels
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String (Stms Kernels)
 -> ReaderT
      (Scope ExplicitMemory) (State VNameSource) (Stms Kernels))
-> Either String (Stms Kernels)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory -> Either String (Stms Kernels)
unAllocKernelsStms Stms ExplicitMemory
kstms
  let num_sizes :: Int
num_sizes = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
sizes
      i64s :: [Type]
i64s = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
num_sizes (Type -> [Type]) -> Type -> [Type]
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope Kernels
kernels_scope <- (Scope ExplicitMemory -> Scope Kernels)
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Scope Kernels)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope ExplicitMemory -> Scope Kernels
unAllocScope

  (Lambda Kernels
max_lam, Stms Kernels
_) <- (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Lambda Kernels)
 -> Scope Kernels
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Lambda Kernels, Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Lambda Kernels, Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope ExplicitMemory) (State VNameSource))
  (Lambda Kernels)
-> Scope Kernels
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Lambda Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Lambda Kernels)
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Lambda Kernels, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Lambda Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
xs <- Int
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Param Type)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      [Param Type])
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
forall (m :: * -> *) attr.
MonadFreshNames m =>
String -> attr -> m (Param attr)
newParam String
"x" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param Type]
ys <- Int
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Param Type)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      [Param Type])
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
forall (m :: * -> *) attr.
MonadFreshNames m =>
String -> attr -> m (Param attr)
newParam String
"y" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms Kernels
stms) <- Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [Param Type]
xs [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
ys) (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Result, Stms Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Result, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ BinderT
  Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (BinderT
   Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Result,
       Stms
         (Lore
            (BinderT
               Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)))))
forall a b. (a -> b) -> a -> b
$
                  [(Param Type, Param Type)]
-> ((Param Type, Param Type)
    -> BinderT
         Kernels
         (ReaderT (Scope ExplicitMemory) (State VNameSource))
         SubExp)
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
xs [Param Type]
ys) (((Param Type, Param Type)
  -> BinderT
       Kernels
       (ReaderT (Scope ExplicitMemory) (State VNameSource))
       SubExp)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      Result)
-> ((Param Type, Param Type)
    -> BinderT
         Kernels
         (ReaderT (Scope ExplicitMemory) (State VNameSource))
         SubExp)
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
forall a b. (a -> b) -> a -> b
$ \(Param Type
x,Param Type
y) ->
      String
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"z" (Exp
   (Lore
      (BinderT
         Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      SubExp)
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp Kernels
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
y)
    Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Lambda Kernels))
-> Lambda Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$ [LParam Kernels] -> BodyT Kernels -> [Type] -> Lambda Kernels
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda ([Param Type]
xs [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
ys) (Stms Kernels -> Result -> BodyT Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
zs) [Type]
i64s

  Param Type
flat_gtid_lparam <- VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param (VName -> Type -> Param Type)
-> ReaderT (Scope ExplicitMemory) (State VNameSource) VName
-> ReaderT
     (Scope ExplicitMemory) (State VNameSource) (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> ReaderT (Scope ExplicitMemory) (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_gtid" ReaderT
  (Scope ExplicitMemory) (State VNameSource) (Type -> Param Type)
-> ReaderT (Scope ExplicitMemory) (State VNameSource) Type
-> ReaderT (Scope ExplicitMemory) (State VNameSource) (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> ReaderT (Scope ExplicitMemory) (State VNameSource) Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int32))

  (Lambda Kernels
size_lam', Stms Kernels
_) <- (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Lambda Kernels)
 -> Scope Kernels
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Lambda Kernels, Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Lambda Kernels, Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope ExplicitMemory) (State VNameSource))
  (Lambda Kernels)
-> Scope Kernels
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Lambda Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Lambda Kernels)
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (Lambda Kernels, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (Lambda Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [Param Type]
params <- Int
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Param Type)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      [Param Type])
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Param Type]
forall a b. (a -> b) -> a -> b
$ String
-> Type
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Param Type)
forall (m :: * -> *) attr.
MonadFreshNames m =>
String -> attr -> m (Param attr)
newParam String
"x" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms Kernels
stms) <- Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type]
params Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<>
                              [Param Type] -> Scope Kernels
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [Param Type
flat_gtid_lparam]) (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Result, Stms Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Result, Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ BinderT
  Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)))))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (BinderT
   Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Result,
       Stms
         (Lore
            (BinderT
               Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Result,
      Stms
        (Lore
           (BinderT
              Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)))))
forall a b. (a -> b) -> a -> b
$ do

      -- Even though this SegRed is one-dimensional, we need to
      -- provide indexes corresponding to the original potentially
      -- multi-dimensional construct.
      let ([VName]
kspace_gtids, Result
kspace_dims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], Result))
-> [(VName, SubExp)] -> ([VName], Result)
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
          new_inds :: [PrimExp VName]
new_inds = [PrimExp VName] -> PrimExp VName -> [PrimExp VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                     ((SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) Result
kspace_dims)
                     (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall attr. Param attr -> VName
paramName Param Type
flat_gtid_lparam)
      ([VName]
 -> ExpT Kernels
 -> BinderT
      Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ())
-> [[VName]]
-> [ExpT Kernels]
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> ExpT Kernels
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([ExpT Kernels]
 -> BinderT
      Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ())
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [ExpT Kernels]
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (PrimExp VName
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (ExpT Kernels))
-> [PrimExp VName]
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [ExpT Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (ExpT Kernels)
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp [PrimExp VName]
new_inds

      (Stm Kernels
 -> BinderT
      Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ())
-> Stms Kernels
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm Kernels
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stms Kernels
kstms'
      Result
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Result
forall (m :: * -> *) a. Monad m => a -> m a
return Result
sizes

    Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   (Lambda Kernels)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Lambda Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
forall a b. (a -> b) -> a -> b
$
      Lambda Kernels
-> [Maybe VName]
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Lambda Kernels)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
Lambda Kernels -> [Maybe VName] -> m (Lambda Kernels)
Kernels.simplifyLambda ([LParam Kernels] -> BodyT Kernels -> [Type] -> Lambda Kernels
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type
LParam Kernels
flat_gtid_lparam] (BodyAttr Kernels -> Stms Kernels -> Result -> BodyT Kernels
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms Kernels
stms Result
zs) [Type]
i64s) []

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms Kernels
slice_stms) <- (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   ([VName], [VName])
 -> Scope Kernels
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (([VName], [VName]), Stms Kernels))
-> Scope Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     ([VName], [VName])
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (([VName], [VName]), Stms Kernels)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT
  Kernels
  (ReaderT (Scope ExplicitMemory) (State VNameSource))
  ([VName], [VName])
-> Scope Kernels
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (([VName], [VName]), Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT Scope Kernels
kernels_scope (BinderT
   Kernels
   (ReaderT (Scope ExplicitMemory) (State VNameSource))
   ([VName], [VName])
 -> ReaderT
      (Scope ExplicitMemory)
      (State VNameSource)
      (([VName], [VName]), Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     ([VName], [VName])
-> ReaderT
     (Scope ExplicitMemory)
     (State VNameSource)
     (([VName], [VName]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    SubExp
num_threads_64 <- String
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp
   (Lore
      (BinderT
         Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      SubExp)
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) SubExp
forall a b. (a -> b) -> a -> b
$
                      BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp Kernels
forall lore. ConvOp -> SubExp -> BasicOp lore
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) SubExp
num_threads

    PatternT Type
pat <- [Ident] -> [Ident] -> PatternT Type
basicPattern [] ([Ident] -> PatternT Type)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Ident]
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (PatternT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Ident
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes
           (String
-> Type
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"max_per_thread" (Type
 -> BinderT
      Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Ident)
-> Type
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <- String
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"size_slice_w" (ExpT Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      SubExp)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (ExpT Kernels)
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
         BinOp
-> SubExp
-> Result
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Exp
        (Lore
           (BinderT
              Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> BinOp
Mul IntType
Int32) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) (SegSpace -> Result
segSpaceDims SegSpace
space)

    VName
thread_space_iota <- String
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"thread_space_iota" (Exp
   (Lore
      (BinderT
         Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
 -> BinderT
      Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName)
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$ BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
                         SubExp -> SubExp -> SubExp -> IntType -> BasicOp Kernels
forall lore. SubExp -> SubExp -> SubExp -> IntType -> BasicOp lore
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) IntType
Int32
    let red_op :: SegRedOp Kernels
red_op = Commutativity
-> Lambda Kernels -> Result -> ShapeBase SubExp -> SegRedOp Kernels
forall lore.
Commutativity
-> Lambda lore -> Result -> ShapeBase SubExp -> SegRedOp lore
SegRedOp Commutativity
Commutative Lambda Kernels
max_lam
                 (Int -> SubExp -> Result
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> Result) -> SubExp -> Result
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) ShapeBase SubExp
forall a. Monoid a => a
mempty
    SegLevel
lvl <- String
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     SegLevel
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> m SegLevel
segThread String
"segred"

    Stms Kernels
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels
 -> BinderT
      Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ())
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Stms Kernels)
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Stm Kernels))
-> Stms Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Stms Kernels)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm Kernels
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm (Stms Kernels
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      (Stms Kernels))
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Stms Kernels)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
      SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> Pattern Kernels
-> SubExp
-> [SegRedOp Kernels]
-> Lambda Kernels
-> [VName]
-> m (Stms Kernels)
nonSegRed SegLevel
lvl PatternT Type
Pattern Kernels
pat SubExp
w [SegRedOp Kernels
red_op] Lambda Kernels
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- [VName]
-> (VName
    -> BinderT
         Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
pat) ((VName
  -> BinderT
       Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName)
 -> BinderT
      Kernels
      (ReaderT (Scope ExplicitMemory) (State VNameSource))
      [VName])
-> (VName
    -> BinderT
         Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName)
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     [VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      String
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"size_sum" (Exp
   (Lore
      (BinderT
         Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
 -> BinderT
      Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName)
-> Exp
     (Lore
        (BinderT
           Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource))))
-> BinderT
     Kernels (ReaderT (Scope ExplicitMemory) (State VNameSource)) VName
forall a b. (a -> b) -> a -> b
$
      BasicOp Kernels -> ExpT Kernels
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp Kernels -> ExpT Kernels)
-> BasicOp Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp Kernels
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int64) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads_64

    ([VName], [VName])
-> BinderT
     Kernels
     (ReaderT (Scope ExplicitMemory) (State VNameSource))
     ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
pat, [VName]
size_sums)

  (Stms Kernels, [VName], [VName])
-> ExpandM (Stms Kernels, [VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)