{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    Allocable,
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    arraySizeInBytesExp,
    mkLetNamesB',
    mkLetNamesB'',

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.List (foldl', transpose, zip4)
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.SymbolTable (IndexOp)
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (maybeNth, splitAt3)

type Allocable fromrep torep inner =
  ( PrettyRep fromrep,
    PrettyRep torep,
    Mem torep inner,
    LetDec torep ~ LetDecMem,
    FParamInfo fromrep ~ DeclType,
    LParamInfo fromrep ~ Type,
    BranchType fromrep ~ ExtType,
    RetType fromrep ~ DeclExtType,
    BodyDec fromrep ~ (),
    BodyDec torep ~ (),
    ExpDec torep ~ (),
    SizeSubst inner,
    BuilderOps torep
  )

data AllocEnv fromrep torep = AllocEnv
  { -- | Aggressively try to reuse memory in do-loops -
    -- should be True inside kernels, False outside.
    forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep -> Bool
aggressiveReuse :: Bool,
    -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in local memory.
    forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
    forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
    forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromrep torep a
  = AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
  deriving
    ( forall a. a -> AllocM fromrep torep a
forall {k} {fromrep :: k} {k} {torep :: k}.
Functor (AllocM fromrep torep)
forall k (fromrep :: k) k (torep :: k) a.
a -> AllocM fromrep torep a
forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall k (fromrep :: k) k (torep :: k) a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
$c<* :: forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
*> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c*> :: forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$cliftA2 :: forall k (fromrep :: k) k (torep :: k) a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
<*> :: forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$c<*> :: forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
pure :: forall a. a -> AllocM fromrep torep a
$cpure :: forall k (fromrep :: k) k (torep :: k) a.
a -> AllocM fromrep torep a
Applicative,
      forall k (fromrep :: k) k (torep :: k) a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall k (fromrep :: k) k (torep :: k) a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
$c<$ :: forall k (fromrep :: k) k (torep :: k) a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
fmap :: forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$cfmap :: forall k (fromrep :: k) k (torep :: k) a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
Functor,
      forall a. a -> AllocM fromrep torep a
forall k (fromrep :: k) k (torep :: k).
Applicative (AllocM fromrep torep)
forall k (fromrep :: k) k (torep :: k) a.
a -> AllocM fromrep torep a
forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> AllocM fromrep torep a
$creturn :: forall k (fromrep :: k) k (torep :: k) a.
a -> AllocM fromrep torep a
>> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c>> :: forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>>= :: forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>>= :: forall k (fromrep :: k) k (torep :: k) a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
Monad,
      AllocM fromrep torep VNameSource
VNameSource -> AllocM fromrep torep ()
forall k (fromrep :: k) k (torep :: k).
Monad (AllocM fromrep torep)
forall k (fromrep :: k) k (torep :: k).
AllocM fromrep torep VNameSource
forall k (fromrep :: k) k (torep :: k).
VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromrep torep ()
$cputNameSource :: forall k (fromrep :: k) k (torep :: k).
VNameSource -> AllocM fromrep torep ()
getNameSource :: AllocM fromrep torep VNameSource
$cgetNameSource :: forall k (fromrep :: k) k (torep :: k).
AllocM fromrep torep VNameSource
MonadFreshNames,
      HasScope torep,
      LocalScope torep,
      MonadReader (AllocEnv fromrep torep)
    )

instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where
  type Rep (AllocM fromrep torep) = torep

  mkExpDecM :: Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pat (LetDec (Rep (AllocM fromrep torep)))
_ Exp (Rep (AllocM fromrep torep))
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
    Space
def_space <- forall {k} {k} (fromrep :: k) (torep :: k).
AllocM fromrep torep Space
askDefaultSpace
    [ExpHint]
hints <- forall {k} {k} (torep :: k) (fromrep :: k).
Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp (Rep (AllocM fromrep torep))
e
    Pat LParamMem
pat <- forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep (AllocM fromrep torep))
e [ExpHint]
hints
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ()) Exp (Rep (AllocM fromrep torep))
e

  mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep (AllocM fromrep torep))
stms Result
res

  addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = forall {k} {k} (fromrep :: k) (torep :: k) a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a.
AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = forall {k} {k} (fromrep :: k) (torep :: k) a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m

expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints :: forall {k} {k} (torep :: k) (fromrep :: k).
Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
  Exp torep -> AllocM fromrep torep [ExpHint]
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
  Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e

askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace :: forall {k} {k} (fromrep :: k) (torep :: k).
AllocM fromrep torep Space
askDefaultSpace = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep -> Space
allocSpace

runAllocM ::
  MonadFreshNames m =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  AllocM fromrep torep a ->
  m a
runAllocM :: forall {k} {k} (m :: * -> *) (fromrep :: k) (torep :: k) a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
  where
    env :: AllocEnv fromrep torep
env =
      AllocEnv
        { aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
          allocSpace :: Space
allocSpace = Space
DefaultSpace,
          envConsts :: Set VName
envConsts = forall a. Monoid a => a
mempty,
          allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
          envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
        }

elemSize :: Num a => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = forall a. Num a => PrimType -> a
primByteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Num a => a -> a -> a
(*) (forall a. Num a => Type -> a
elemSize Type
t) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

arraySizeInBytesExpM :: MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM :: forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM Type
t = do
  let dim_prod_i64 :: TPrimExp Int64 VName
dim_prod_i64 = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
      elm_size_i64 :: TPrimExp Int64 VName
elm_size_i64 = forall a. Num a => Type -> a
elemSize Type
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) (forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0) forall a b. (a -> b) -> a -> b
$
      forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
dim_prod_i64 forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size_i64

arraySizeInBytes :: MonadBuilder m => Type -> m SubExp
arraySizeInBytes :: forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM

allocForArray' ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
  Type ->
  Space ->
  m VName
allocForArray' :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space = do
  SubExp
size <- forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes Type
t
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space

-- | Allocate memory for a value of the given type.
allocForArray ::
  Allocable fromrep torep inner =>
  Type ->
  Space ->
  AllocM fromrep torep VName
allocForArray :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space = do
  forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space

allocsForStm ::
  (Allocable fromrep torep inner) =>
  [Ident] ->
  Exp torep ->
  AllocM fromrep torep (Stm torep)
allocsForStm :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm [Ident]
idents Exp torep
e = do
  Space
def_space <- forall {k} {k} (fromrep :: k) (torep :: k).
AllocM fromrep torep Space
askDefaultSpace
  [ExpHint]
hints <- forall {k} {k} (torep :: k) (fromrep :: k).
Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e
  [ExpReturns]
rts <- forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp torep
e
  [PatElem LParamMem]
pes <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
  ()
dec <- forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) Exp torep
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) (forall dec. dec -> StmAux dec
defAux ()
dec) Exp torep
e

patWithAllocations ::
  (MonadBuilder m, Mem (Rep m) inner) =>
  Space ->
  [VName] ->
  Exp (Rep m) ->
  [ExpHint] ->
  m (Pat LetDecMem)
patWithAllocations :: forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep m)
e [ExpHint]
hints = do
  [Type]
ts' <- forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
  let idents :: [Ident]
idents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
names [Type]
ts'
  [ExpReturns]
rts <- forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp (Rep m)
e
  forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints

mkMissingIdents :: MonadFreshNames m => [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents :: forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
idents [ExpReturns]
rts =
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *} {d} {u} {ret}.
MonadFreshNames f =>
MemInfo d u ret -> Maybe Ident -> f Ident
f (forall a. [a] -> [a]
reverse [ExpReturns]
rts) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Maybe a
Just (forall a. [a] -> [a]
reverse [Ident]
idents) forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat forall a. Maybe a
Nothing)
  where
    f :: MemInfo d u ret -> Maybe Ident -> f Ident
f MemInfo d u ret
_ (Just Ident
ident) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Ident
ident
    f (MemMem Space
space) Maybe Ident
Nothing = forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext_mem" forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
space
    f MemInfo d u ret
_ Maybe Ident
Nothing = forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

allocsForPat ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
  Space ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m [PatElem LetDecMem]
allocsForPat :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
some_idents [ExpReturns]
rts [ExpHint]
hints = do
  [Ident]
idents <- forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
some_idents [ExpReturns]
rts

  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
idents [ExpReturns]
rts [ExpHint]
hints) forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
    let ident_shape :: Shape
ident_shape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
    case ExpReturns
rt of
      MemPrim PrimType
_ -> do
        LParamMem
summary <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemMem Space
space ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfun
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfn
      MemArray PrimType
_ ShapeBase (Ext SubExp)
extshape NoUniqueness
_ Maybe MemReturn
Nothing
        | Just [SubExp]
_ <- forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ShapeBase (Ext SubExp)
extshape -> do
            LParamMem
summary <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfn
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$
          VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i) IxFun (TPrimExp Int64 VName)
ixfn
      MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      ExpReturns
_ -> forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPat!"
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a}. Ext a -> Maybe a
known forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = forall a. a -> Maybe a
Just a
v
    known Ext {} = forall a. Maybe a
Nothing

    getIdent :: [Ident] -> a -> VName
getIdent [Ident]
idents a
i =
      case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [Ident]
idents of
        Just Ident
ident -> Ident -> VName
identName Ident
ident
        Maybe Ident
Nothing ->
          forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getIdent: Ext " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
i forall a. Semigroup a => a -> a -> a
<> String
" but pattern has " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
idents) forall a. Semigroup a => a -> a -> a
<> String
" elements: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString [Ident]
idents

    instantiateExtIxFun :: [Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
inst
      where
        inst :: Ext VName -> VName
inst (Free VName
v) = VName
v
        inst (Ext Int
i) = forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i

instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun :: forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {f :: * -> *} {a}. Applicative f => Ext a -> f a
inst
  where
    inst :: Ext a -> f a
inst Ext {} = forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
    inst (Free a
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

summaryForBindage ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
  Space ->
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
_ (Prim PrimType
bt) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage Space
_ (Mem Space
space) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage Space
_ (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage Space
def_space t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Type -> Space -> m VName
allocForArray' Type
t Space
def_space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
summaryForBindage Space
_ t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun (TPrimExp Int64 VName)
ixfun Space
space) = do
  SubExp
bytes <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
        [ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun,
          forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
        ]
  VName
m <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
bytes Space
space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun

allocInFParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, Space)] ->
  ([FParam torep] -> AllocM fromrep torep a) ->
  AllocM fromrep torep a
allocInFParams :: forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
memparams, [Param FParamMem]
ctxparams)) <-
    forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam) [(FParam fromrep, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
memparams forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctxparams forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
  forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [Param FParamMem]
params'

allocInFParam ::
  (Allocable fromrep torep inner) =>
  FParam fromrep ->
  Space ->
  WriterT
    ([FParam torep], [FParam torep])
    (AllocM fromrep torep)
    (FParam torep)
allocInFParam :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
  case forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType FParam fromrep
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (forall dec. Param dec -> VName
paramName FParam fromrep
param) forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([forall dec. Attrs -> VName -> dec -> Param dec
Param (forall dec. Param dec -> Attrs
paramAttrs FParam fromrep
param) VName
mem forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace], [])
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun}
    Prim PrimType
pt ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
    Mem Space
space ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}

ensureRowMajorArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureRowMajorArray :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- forall {k} {k} (fromrep :: k) (torep :: k).
AllocM fromrep torep Space
askDefaultSpace
  let space :: Space
space = forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok
  if IxFun (TPrimExp Int64 VName) -> Int
numLMADs IxFun (TPrimExp Int64 VName)
ixfun forall a. Eq a => a -> a -> Bool
== Int
1
    Bool -> Bool -> Bool
&& IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm IxFun (TPrimExp Int64 VName)
ixfun forall a. Eq a => a -> a -> Bool
== [Int
0 .. forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun forall a. Num a => a -> a -> a
- Int
1]
    Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun) forall a. Eq a => a -> a -> Bool
== forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun
    Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    Bool -> Bool -> Bool
&& forall num. IxFun num -> Bool
IxFun.contiguous IxFun (TPrimExp Int64 VName)
ixfun
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

ensureArrayIn ::
  (Allocable fromrep torep inner) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString PrimValue
v forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (VName
mem', VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (forall a. a -> Maybe a
Just Space
space) VName
v
  (VName
_, IxFun (TPrimExp Int64 VName)
ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
  [SubExp]
ctx <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_arg" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun)
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem'], [SubExp]
ctx)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

allocInMergeParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, SubExp)] ->
  ( [(FParam torep, SubExp)] ->
    ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
    AllocM fromrep torep a
  ) ->
  AllocM fromrep torep a
allocInMergeParams :: forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m = do
  (([Param FParamMem]
valparams, [SubExp]
valargs, [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
mem_params, [Param FParamMem]
ctx_params)) <-
    forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam [(FParam fromrep, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
mem_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
ses = do
        ([SubExp]
ses', ([SubExp]
memargs, [SubExp]
ctxargs)) <-
          forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall a b. (a -> b) -> a -> b
($) [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps [SubExp]
ses
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
memargs forall a. Semigroup a => a -> a -> a
<> [SubExp]
ctxargs, [SubExp]
ses')

  ([SubExp]
valctx_args, [SubExp]
valargs') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
valargs
  let merge' :: [(Param FParamMem, SubExp)]
merge' =
        forall a b. [a] -> [b] -> [(a, b)]
zip ([Param FParamMem]
mem_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams) ([SubExp]
valctx_args forall a. Semigroup a => a -> a -> a
<> [SubExp]
valargs')
  forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m [(Param FParamMem, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res
  where
    param_names :: Names
param_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam fromrep, SubExp)]
merge
    anyIsLoopParam :: Names -> Bool
anyIsLoopParam Names
names = Names
names Names -> Names -> Bool
`namesIntersect` Names
param_names

    scalarRes :: DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (Var VName
res) = do
      -- Try really hard to avoid copying needlessly, but the result
      -- _must_ be in ScalarSpace and have the right index function.
      (VName
res_mem, IxFun (TPrimExp Int64 VName)
res_ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
res
      Space
res_mem_space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
res_mem
      (VName
res_mem', VName
res') <-
        if (Space
res_mem_space, IxFun (TPrimExp Int64 VName)
res_ixfun) forall a. Eq a => a -> a -> Bool
== (Space
v_mem_space, IxFun (TPrimExp Int64 VName)
v_ixfun)
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
res_mem, VName
res)
          else forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner,
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
param_t) VName
res
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
res_mem'], [])
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
res'
    scalarRes DeclType
_ Space
_ IxFun (TPrimExp Int64 VName)
_ SubExp
se = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

    allocInMergeParam ::
      (Allocable fromrep torep inner) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam torep], [FParam torep])
        (AllocM fromrep torep)
        ( FParam torep,
          SubExp,
          SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
        )
    allocInMergeParam :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | param_t :: DeclType
param_t@(Array PrimType
pt Shape
shape Uniqueness
u) <- forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
          (VName
v_mem, IxFun (TPrimExp Int64 VName)
v_ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
          Space
v_mem_space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem

          -- Loop-invariant array parameters that are in scalar space
          -- are special - we do not wish to existentialise their index
          -- function at all (but the memory block is still existential).
          case Space
v_mem_space of
            ScalarSpace {} ->
              if Names -> Bool
anyIsLoopParam (forall a. FreeIn a => a -> Names
freeIn Shape
shape)
                then do
                  -- Arrays with loop-variant shape cannot be in scalar
                  -- space, so copy them elsewhere and try again.
                  (VName
_, VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
DefaultSpace (VName -> String
baseString VName
v) VName
v
                  forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, VName -> SubExp
Var VName
v')
                else do
                  Param FParamMem
p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space
                  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
p], [])

                  forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall dec. Param dec -> VName
paramName Param FParamMem
p) IxFun (TPrimExp Int64 VName)
v_ixfun},
                      VName -> SubExp
Var VName
v,
                      forall {m :: * -> *} {inner} {t :: (* -> *) -> * -> *} {a}.
(FParamInfo (Rep m) ~ FParamMem, LetDec (Rep m) ~ LParamMem,
 LParamInfo (Rep m) ~ LParamMem, RetType (Rep m) ~ RetTypeMem,
 BranchType (Rep m) ~ BranchTypeMem, Op (Rep m) ~ MemOp inner,
 MonadWriter ([SubExp], [a]) (t m), MonadBuilder m, MonadTrans t,
 OpReturns inner, HasLetDecMem (LetDec (Rep m))) =>
DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun
                    )
            Space
_ -> do
              (VName
v_mem', VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray forall a. Maybe a
Nothing VName
v
              (VName
_, IxFun (TPrimExp Int64 VName)
v_ixfun') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
              Space
v_mem_space' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem'

              [Param FParamMem]
ctx_params <-
                forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length IxFun (TPrimExp Int64 VName)
v_ixfun') forall a b. (a -> b) -> a -> b
$
                  forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ctx_param_ext" (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)

              IxFun (TPrimExp Int64 VName)
param_ixfun <-
                forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun forall a b. (a -> b) -> a -> b
$
                  forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
                    ( forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Int -> Ext a
Ext [Int
0 ..]) forall a b. (a -> b) -> a -> b
$
                        forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Ext a
Free forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param FParamMem]
ctx_params
                    )
                    (forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
IxFun.existentialize IxFun (TPrimExp Int64 VName)
v_ixfun')

              Param FParamMem
mem_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space'
              forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
mem_param], [Param FParamMem]
ctx_params)
              forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall dec. Param dec -> VName
paramName Param FParamMem
mem_param) IxFun (TPrimExp Int64 VName)
param_ixfun},
                  VName -> SubExp
Var VName
v',
                  forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
v_mem_space'
                )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
se) = forall {k} {k} {fromrep :: k} {fromrep :: k} {torep} {torep}
       {inner} {inner} {b}.
(BranchType fromrep ~ ExtType, BranchType fromrep ~ ExtType,
 BranchType torep ~ BranchTypeMem, BranchType torep ~ BranchTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo fromrep ~ Type,
 LParamInfo torep ~ LParamMem, LParamInfo torep ~ LParamMem,
 ExpDec torep ~ (), ExpDec torep ~ (), LetDec torep ~ LParamMem,
 LetDec torep ~ LParamMem, BodyDec fromrep ~ (),
 BodyDec fromrep ~ (), BodyDec torep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo fromrep ~ DeclType,
 FParamInfo torep ~ FParamMem, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, RetType torep ~ RetTypeMem,
 Op torep ~ MemOp inner, Op torep ~ MemOp inner, PrettyRep fromrep,
 PrettyRep fromrep, HasLetDecMem (LetDec torep),
 HasLetDecMem (LetDec torep), OpReturns inner, OpReturns inner,
 SizeSubst inner, SizeSubst inner, BuilderOps torep,
 BuilderOps torep) =>
Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam SubExp
se forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall {k} {k} (fromrep :: k) (torep :: k).
AllocM fromrep torep Space
askDefaultSpace

    doDefault :: Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param (FParamInfo fromrep)
mergeparam b
se Space
space = do
      Param (FParamInfo torep)
mergeparam' <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam Param (FParamInfo fromrep)
mergeparam Space
space
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (FParamInfo torep)
mergeparam', b
se, forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg (forall dec. Typed dec => Param dec -> Type
paramType Param (FParamInfo fromrep)
mergeparam) Space
space)

arrayWithIxFun ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner, LetDec (Rep m) ~ LetDecMem) =>
  Space ->
  IxFun ->
  Type ->
  VName ->
  m (VName, VName)
arrayWithIxFun :: forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner,
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
space IxFun (TPrimExp Int64 VName)
ixfun Type
v_t VName
v = do
  let Array PrimType
pt Shape
shape NoUniqueness
u = Type
v_t
  VName
mem <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Type -> Space -> m VName
allocForArray' Type
v_t Space
space
  VName
v_copy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v forall a. Semigroup a => a -> a -> a
<> String
"_scalcopy"
  forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun]) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v_copy)

ensureDirectArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureDirectArray :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- forall {k} {k} (fromrep :: k) (torep :: k).
AllocM fromrep torep Space
askDefaultSpace
  if forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun (TPrimExp Int64 VName)
ixfun Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space -> AllocM fromrep torep (VName, VName)
needCopy (forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromrep torep (VName, VName)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocPermArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  [Int] ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocPermArray :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v = do
  Type
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space
      VName
v' <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ String
s forall a. Semigroup a => a -> a -> a
<> String
"_desired_form"
      let info :: LParamMem
info =
            forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$
              forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) [Int]
perm
          pat :: Pat LParamMem
pat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
v' LParamMem
info]
      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v')
    Type
_ ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"allocPermArray: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Type
t

allocLinearArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocLinearArray :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
  let perm :: [Int]
perm = [Int
0 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t forall a. Num a => a -> a -> a
- Int
1]
  forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v

funcallArgs ::
  (Allocable fromrep torep inner) =>
  [(SubExp, Diet)] ->
  AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, ([SubExp]
ctx_args, [SubExp]
mem_and_size_args)) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall {k} {k} (fromrep :: k) (torep :: k).
AllocM fromrep torep Space
askDefaultSpace
      SubExp
arg' <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
arg', Diet
d)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) ([SubExp]
ctx_args forall a. Semigroup a => a -> a -> a
<> [SubExp]
mem_and_size_args) forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, VName
arg') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (forall a. a -> Maybe a
Just Space
space) VName
v
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem], [])
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
arg

explicitAllocationsGeneric ::
  (Allocable fromrep torep inner) =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Pass fromrep torep
explicitAllocationsGeneric :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
  forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" forall a b. (a -> b) -> a -> b
$
    forall {k1} {k2} (fromrep :: k1) (torep :: k2).
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
  where
    onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
      forall {k} {k} (m :: * -> *) (fromrep :: k) (torep :: k) a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType fromrep]
rettype [FParam fromrep]
params Body fromrep
fbody) =
      forall {k} {k} (m :: * -> *) (fromrep :: k) (torep :: k) a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts forall a b. (a -> b) -> a -> b
$
        forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams (forall a b. [a] -> [b] -> [(a, b)]
zip [FParam fromrep]
params forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Space
DefaultSpace) forall a b. (a -> b) -> a -> b
$ \[FParam torep]
params' -> do
          (Body torep
fbody', [RetTypeMem]
mem_rets) <-
            forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Space
DefaultSpace) [RetType fromrep]
rettype) Body fromrep
fbody
          let rettype' :: [RetTypeMem]
rettype' = [RetTypeMem]
mem_rets forall a. [a] -> [a] -> [a]
++ Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType (forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets) [RetType fromrep]
rettype
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetTypeMem]
rettype' [FParam torep]
params' Body torep
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope torep m,
    Allocable fromrep torep inner
  ) =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Stms fromrep ->
  m (Stms torep)
explicitAllocationsInStmsGeneric :: forall {k} (m :: * -> *) torep (fromrep :: k) inner.
(MonadFreshNames m, HasScope torep m,
 Allocable fromrep torep inner) =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
  Scope torep
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  forall {k} {k} (m :: * -> *) (fromrep :: k) (torep :: k) a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$
        forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: Int -> [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Int
k [DeclExtType]
dets = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity RetTypeMem
addMem [DeclExtType]
dets) Int
0
  where
    addMem :: DeclExtType -> StateT Int Identity RetTypeMem
addMem (Prim PrimType
t) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u) = do
      Int
i <- forall s (m :: * -> *). MonadState s m => m s
get forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
+ Int
1)
      let shape' :: ShapeBase (Ext SubExp)
shape' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ShapeBase (Ext SubExp)
shape
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i forall a b. (a -> b) -> a -> b
$
        forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert forall a b. (a -> b) -> a -> b
$
            forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape'
    addMem (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = forall a. a -> TPrimExp Int64 a
le64 forall a b. (a -> b) -> a -> b
$ forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = forall a. a -> Ext a
Free forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = forall a. Int -> Ext a
Ext (Int
i forall a. Num a => a -> a -> a
+ Int
k)
    shift (Free SubExp
x) = forall a. a -> Ext a
Free SubExp
x

bodyReturnMemCtx ::
  (Allocable fromrep torep inner) =>
  SubExpRes ->
  AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)]
bodyReturnMemCtx :: forall {k} (fromrep :: k) torep inner u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx (SubExpRes Certs
_ Constant {}) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure []
bodyReturnMemCtx (SubExpRes Certs
_ (Var VName
v)) = do
  LParamMem
info <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemPrim {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemAcc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemMem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_) -> do
      LParamMem
mem_info <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
mem
      case LParamMem
mem_info of
        MemMem Space
space ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [(SubExp -> SubExpRes
subExpRes forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem, forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)]
        LParamMem
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"bodyReturnMemCtx: not a memory block: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
mem

allocInFunBody ::
  (Allocable fromrep torep inner) =>
  [Maybe Space] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [FunReturns])
allocInFunBody :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect [Maybe Space]
space_oks' Result
res
    (Result
mem_ctx_res, [RetTypeMem]
mem_ctx_rets) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (fromrep :: k) torep inner u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx Result
res'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
mem_ctx_res forall a. Semigroup a => a -> a -> a
<> Result
res', [RetTypeMem]
mem_ctx_rets)
  where
    num_vals :: Int
num_vals = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res forall a. Num a => a -> a -> a
- Int
num_vals) forall a. Maybe a
Nothing forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  SubExpRes ->
  AllocM fromrep torep SubExpRes
ensureDirect :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
space_ok (SubExpRes Certs
cs SubExp
se) = do
  LParamMem
se_info <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (LParamMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, VName
v') <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
    (LParamMem, SubExp)
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

allocInStms ::
  (Allocable fromrep torep inner) =>
  Stms fromrep ->
  AllocM fromrep torep a ->
  AllocM fromrep torep a
allocInStms :: forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
  where
    allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
    allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
      Seq (Stm torep)
allocstms <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) forall a b. (a -> b) -> a -> b
$ forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
allocstms
      let stms_consts :: Set VName
stms_consts = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} (rep :: k). SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
          f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env = AllocEnv fromrep torep
env {envConsts :: Set VName
envConsts = Set VName
stms_consts forall a. Semigroup a => a -> a -> a
<> forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep -> Set VName
envConsts AllocEnv fromrep torep
env}
      forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms

allocInStm ::
  (Allocable fromrep torep inner) =>
  Stm fromrep ->
  AllocM fromrep torep ()
allocInStm :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pat [PatElem (LetDec fromrep)]
pes) StmAux (ExpDec fromrep)
_ Exp fromrep
e) =
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem (LetDec fromrep)]
pes) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e

allocInLambda ::
  Allocable fromrep torep inner =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body fromrep
body) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body fromrep
body

numLMADs :: IxFun -> Int
numLMADs :: IxFun (TPrimExp Int64 VName) -> Int
numLMADs = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs

ixFunPerm :: IxFun -> [Int]
ixFunPerm :: IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
IxFun.ldPerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> a
NE.head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs

ixFunMon :: IxFun -> [IxFun.Monotonicity]
ixFunMon :: IxFun (TPrimExp Int64 VName) -> [Monotonicity]
ixFunMon = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Monotonicity
IxFun.ldMon forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> a
NE.head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs

data MemReq
  = MemReq Space [Int] [IxFun.Monotonicity] Rank Bool
  | NeedsLinearisation Space
  deriving (MemReq -> MemReq -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemReq -> MemReq -> Bool
$c/= :: MemReq -> MemReq -> Bool
== :: MemReq -> MemReq -> Bool
$c== :: MemReq -> MemReq -> Bool
Eq, Int -> MemReq -> ShowS
[MemReq] -> ShowS
MemReq -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemReq] -> ShowS
$cshowList :: [MemReq] -> ShowS
show :: MemReq -> String
$cshow :: MemReq -> String
showsPrec :: Int -> MemReq -> ShowS
$cshowsPrec :: Int -> MemReq -> ShowS
Show)

combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs x :: MemReq
x@NeedsLinearisation {} MemReq
_ = MemReq
x
combMemReqs MemReq
_ y :: MemReq
y@NeedsLinearisation {} = MemReq
y
combMemReqs x :: MemReq
x@(MemReq Space
x_space [Int]
_ [Monotonicity]
_ Rank
_ Bool
_) y :: MemReq
y@MemReq {} =
  if MemReq
x forall a. Eq a => a -> a -> Bool
== MemReq
y then MemReq
x else Space -> MemReq
NeedsLinearisation Space
x_space

type MemReqType = MemInfo (Ext SubExp) NoUniqueness MemReq

combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
x) (MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ MemReq
y) =
  forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ MemReq -> MemReq -> MemReq
combMemReqs MemReq
x MemReq
y
combMemReqTypes MemReqType
x MemReqType
_ = MemReqType
x

contextRets :: MemReqType -> [MemInfo d u r]
contextRets :: forall d u r. MemReqType -> [MemInfo d u r]
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (MemReq Space
space [Int]
_ [Monotonicity]
_ (Rank Int
base_rank) Bool
_)) =
  -- Memory + offset + base_rank + (stride,size)*rank.
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    forall a. a -> [a] -> [a]
: forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate Int
base_rank (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
    forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
2 forall a. Num a => a -> a -> a
* forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (NeedsLinearisation Space
space)) =
  -- Memory + offset + (base,stride,size)*rank.
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    forall a. a -> [a] -> [a]
: forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate (Int
3 forall a. Num a => a -> a -> a
* forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets MemReqType
_ = []

-- Add memory information to the body, but do not return memory/ixfun
-- information.  Instead, return restrictions on what the index
-- function should look like.  We will then (crudely) unify these
-- restrictions across all bodies.
allocInMatchBody ::
  (Allocable fromrep torep inner) =>
  [ExtType] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
rets (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ do
    [MemReqType]
restrictions <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {k} {rep :: k} {inner} {m :: * -> *} {d}.
(LParamInfo rep ~ LParamMem, BranchType rep ~ BranchTypeMem,
 RetType rep ~ RetTypeMem, FParamInfo rep ~ FParamMem,
 Op rep ~ MemOp inner, OpReturns inner, Monad m,
 HasLetDecMem (LetDec rep), ASTRep rep, HasScope rep m, Show d) =>
TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction [ExtType]
rets (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [MemReqType]
restrictions)
  where
    restriction :: TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction TypeBase (ShapeBase d) NoUniqueness
t SubExp
se = do
      LParamMem
v_info <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
      case (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info) of
        (Array PrimType
pt ShapeBase d
shape NoUniqueness
u, MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) -> do
          Space
space <- forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$
            if IxFun (TPrimExp Int64 VName) -> Int
numLMADs IxFun (TPrimExp Int64 VName)
ixfun forall a. Eq a => a -> a -> Bool
== Int
1
              then
                Space -> [Int] -> [Monotonicity] -> Rank -> Bool -> MemReq
MemReq
                  Space
space
                  (IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm IxFun (TPrimExp Int64 VName)
ixfun)
                  (IxFun (TPrimExp Int64 VName) -> [Monotonicity]
ixFunMon IxFun (TPrimExp Int64 VName)
ixfun)
                  (Int -> Rank
Rank forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun)
                  (forall num. IxFun num -> Bool
IxFun.contiguous IxFun (TPrimExp Int64 VName)
ixfun)
              else Space -> MemReq
NeedsLinearisation Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemMem Space
space) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemPrim PrimType
pt) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (TypeBase (ShapeBase d) NoUniqueness
_, MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (TypeBase (ShapeBase d) NoUniqueness, LParamMem)
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"allocInMatchBody: mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info)

mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs =
  let ([BranchTypeMem]
ctx_rets, [BranchTypeMem]
res_rets) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([], []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [MemReqType]
reqs [Int]
offsets
   in [BranchTypeMem]
ctx_rets forall a. [a] -> [a] -> [a]
++ [BranchTypeMem]
res_rets
  where
    numCtxNeeded :: MemReqType -> Int
numCtxNeeded = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u r. MemReqType -> [MemInfo d u r]
contextRets

    offsets :: [Int]
offsets = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(+) Int
0 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map MemReqType -> Int
numCtxNeeded [MemReqType]
reqs
    num_new_ctx :: Int
num_new_ctx = forall a. [a] -> a
last [Int]
offsets

    helper :: ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([BranchTypeMem]
ctx_rets_acc, [BranchTypeMem]
res_rets_acc) (MemReqType
req, Int
ctx_offset) =
      ( [BranchTypeMem]
ctx_rets_acc forall a. [a] -> [a] -> [a]
++ forall d u r. MemReqType -> [MemInfo d u r]
contextRets MemReqType
req,
        [BranchTypeMem]
res_rets_acc forall a. [a] -> [a] -> [a]
++ [Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset MemReqType
req]
      )

    arrayInfo :: Int -> MemReq -> (Space, [Int], [Monotonicity], Int, Bool)
arrayInfo Int
rank (NeedsLinearisation Space
space) =
      (Space
space, [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1], forall a. a -> [a]
repeat Monotonicity
IxFun.Inc, Int
rank, Bool
True)
    arrayInfo Int
_ (MemReq Space
space [Int]
perm [Monotonicity]
mon (Rank Int
base_rank) Bool
contig) =
      (Space
space, [Int]
perm, [Monotonicity]
mon, Int
base_rank, Bool
contig)

    inspect :: Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
req) =
      let shape' :: ShapeBase (Ext SubExp)
shape' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> Ext a -> Ext a
adjustExt Int
num_new_ctx) ShapeBase (Ext SubExp)
shape
          (Space
space, [Int]
perm, [Monotonicity]
mon, Int
base_rank, Bool
contig) = Int -> MemReq -> (Space, [Int], [Monotonicity], Int, Bool)
arrayInfo (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) MemReq
req
       in forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
ctx_offset forall a b. (a -> b) -> a -> b
$
            Ext SubExp -> TPrimExp Int64 (Ext VName)
convert
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a.
Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
IxFun.mkExistential Int
base_rank (forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
perm [Monotonicity]
mon) Bool
contig (Int
ctx_offset forall a. Num a => a -> a -> a
+ Int
1)
    inspect Int
_ (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    inspect Int
_ (MemPrim PrimType
pt) = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect Int
_ (MemMem Space
space) = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = forall a. a -> TPrimExp Int64 a
le64 (forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = forall a. a -> Ext a
Free forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExt :: Int -> Ext a -> Ext a
    adjustExt :: forall a. Int -> Ext a -> Ext a
adjustExt Int
_ (Free a
v) = forall a. a -> Ext a
Free a
v
    adjustExt Int
k (Ext Int
i) = forall a. Int -> Ext a
Ext (Int
k forall a. Num a => a -> a -> a
+ Int
i)

addCtxToMatchBody ::
  (Allocable fromrep torep inner) =>
  [MemReqType] ->
  Body torep ->
  AllocM fromrep torep (Body torep)
addCtxToMatchBody :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
body = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
  Result
res <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {k} {torep} {fromrep :: k} {inner} {d} {u}.
(ExpDec torep ~ (), BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType torep ~ RetTypeMem,
 LetDec torep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo torep ~ LParamMem, BranchType fromrep ~ ExtType,
 BranchType torep ~ BranchTypeMem, Op torep ~ MemOp inner,
 PrettyRep fromrep, HasLetDecMem (LetDec torep), OpReturns inner,
 SizeSubst inner, BuilderOps torep) =>
MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
linearIfNeeded [MemReqType]
reqs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body torep
body
  Result
ctx <- forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *} {inner}.
(RetType (Rep f) ~ RetTypeMem, LParamInfo (Rep f) ~ LParamMem,
 BranchType (Rep f) ~ BranchTypeMem, FParamInfo (Rep f) ~ FParamMem,
 Op (Rep f) ~ MemOp inner, OpReturns inner, MonadBuilder f,
 HasLetDecMem (LetDec (Rep f))) =>
SubExpRes -> f Result
resCtx Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Result
ctx forall a. [a] -> [a] -> [a]
++ Result
res
  where
    linearIfNeeded :: MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
linearIfNeeded (MemArray PrimType
_ ShapeBase d
_ u
_ (NeedsLinearisation Space
space)) (SubExpRes Certs
cs (Var VName
v)) =
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (forall a. a -> Maybe a
Just Space
space) VName
v
    linearIfNeeded MemInfo d u MemReq
_ SubExpRes
res =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
res

    resCtx :: SubExpRes -> f Result
resCtx (SubExpRes Certs
_ Constant {}) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    resCtx (SubExpRes Certs
_ (Var VName
v)) = do
      LParamMem
info <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
      case LParamMem
info of
        MemPrim {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemAcc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemMem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun) -> do
          [SubExp]
ixfun_exts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_ext" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SubExp -> SubExpRes
subExpRes (VName -> SubExp
Var VName
mem) forall a. a -> [a] -> [a]
: [SubExp] -> Result
subExpsRes [SubExp]
ixfun_exts

-- Do a a simple form of invariance analysis to simplify a Match.  It
-- is unfortunate that we have to do it here, but functions such as
-- scalarRes will look carefully at the index functions before the
-- simplifier has a chance to run.  In a perfect world we would
-- simplify away those copies afterwards. XXX; this should be fixed by
-- a more general copy-removal pass. See
-- Futhark.Optimise.EntryPointMem for a very specialised version of
-- the idea, but which could perhaps be generalised.
simplifyMatch ::
  Mem rep inner =>
  [Case (Body rep)] ->
  Body rep ->
  [BranchTypeMem] ->
  ( [Case (Body rep)],
    Body rep,
    [BranchTypeMem]
  )
simplifyMatch :: forall {k} (rep :: k) inner.
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body rep)]
cases Body rep
defbody [BranchTypeMem]
ts =
  let case_reses :: [Result]
case_reses = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
      defbody_res :: Result
defbody_res = forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
defbody
      ([(Int, SubExp)]
ctx_fixes, [(Result, SubExpRes, BranchTypeMem)]
variant) =
        forall a b. [Either a b] -> ([a], [b])
partitionEithers forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant forall a b. (a -> b) -> a -> b
$
          forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] (forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res [BranchTypeMem]
ts
      ([Result]
cases_reses, Result
defbody_reses, [BranchTypeMem]
ts') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Result, SubExpRes, BranchTypeMem)]
variant
   in ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} {f :: * -> *} {rep :: k}.
Functor f =>
f (Body rep) -> Result -> f (Body rep)
onCase [Case (Body rep)]
cases (forall a. [[a]] -> [[a]]
transpose [Result]
cases_reses),
        forall {k} {rep :: k}. Body rep -> Result -> Body rep
onBody Body rep
defbody Result
defbody_reses,
        forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchTypeMem]
ts' [(Int, SubExp)]
ctx_fixes
      )
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
defbody

    onCase :: f (Body rep) -> Result -> f (Body rep)
onCase f (Body rep)
c Result
res = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} {rep :: k}. Body rep -> Result -> Body rep
`onBody` Result
res) f (Body rep)
c
    onBody :: Body rep -> Result -> Body rep
onBody Body rep
body Result
res = Body rep
body {bodyResult :: Result
bodyResult = Result
res}

    branchInvariant :: (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant (Int
i, Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- If even one branch has a variant result, then we give up.
      | Names -> Names -> Bool
namesIntersect Names
bound_in_branches forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ SubExpRes
defres forall a. a -> [a] -> [a]
: Result
case_reses =
          forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- Do all branches return the same value?
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses =
          forall a b. a -> Either a b
Left (Int
i, SubExpRes -> SubExp
resSubExp SubExpRes
defres)
      | Bool
otherwise =
          forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)

allocInExp ::
  (Allocable fromrep torep inner) =>
  Exp fromrep ->
  AllocM fromrep torep (Exp torep)
allocInExp :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (DoLoop [(FParam fromrep, SubExp)]
merge LoopForm fromrep
form (Body () Stms fromrep
bodystms Result
bodyres)) =
  forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge forall a b. (a -> b) -> a -> b
$ \[(FParam torep, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val -> do
    LoopForm torep
form' <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm LoopForm fromrep
form
    forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm torep
form') forall a b. (a -> b) -> a -> b
$ do
      Body torep
body' <-
        forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (fromrep :: k) torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodystms forall a b. (a -> b) -> a -> b
$ do
          ([SubExp]
valctx, [SubExp]
valres') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bodyres
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
valctx forall a. Semigroup a => a -> a -> a
<> forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
bodyres) [SubExp]
valres'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam torep, SubExp)]
merge' LoopForm torep
form' Body torep
body'
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [RetType fromrep]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  -- We assume that every array is going to be in its own memory.
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args' ([RetTypeMem]
mems forall a. [a] -> [a] -> [a]
++ Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Int
num_arrays [RetType fromrep]
rettype) (Safety, SrcLoc, [SrcLoc])
loc
  where
    mems :: [RetTypeMem]
mems = forall a. Int -> a -> [a]
replicate Int
num_arrays (forall d u ret. Space -> MemInfo d u ret
MemMem Space
DefaultSpace)
    num_arrays :: Int
num_arrays = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> a -> Bool
> Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf) [RetType fromrep]
rettype
allocInExp (Match [SubExp]
ses [Case (Body fromrep)]
cases Body fromrep
defbody (MatchDec [BranchType fromrep]
rets MatchSort
ifsort)) = do
  (Body torep
defbody', [MemReqType]
def_reqs) <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [BranchType fromrep]
rets Body fromrep
defbody
  ([Case (Body torep)]
cases', [[MemReqType]]
cases_reqs) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase [Case (Body fromrep)]
cases
  let reqs :: [MemReqType]
reqs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl MemReqType -> MemReqType -> MemReqType
combMemReqTypes) [MemReqType]
def_reqs (forall a. [[a]] -> [[a]]
transpose [[MemReqType]]
cases_reqs)
  Body torep
defbody'' <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
defbody'
  [Case (Body torep)]
cases'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs) [Case (Body torep)]
cases'
  let ([Case (Body torep)]
cases''', Body torep
defbody''', [BranchTypeMem]
rets') =
        forall {k} (rep :: k) inner.
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body torep)]
cases'' Body torep
defbody'' forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body torep)]
cases''' Body torep
defbody''' forall a b. (a -> b) -> a -> b
$ forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchTypeMem]
rets' MatchSort
ifsort
  where
    onCase :: Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase (Case [Maybe PrimValue]
vs Body fromrep
body) = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [BranchType fromrep]
rets Body fromrep
body
allocInExp (WithAcc [WithAccInput fromrep]
inputs Lambda fromrep
bodylam) =
  forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} {fromrep :: k} {rep} {inner} {t :: * -> *} {a} {b}.
(BodyDec fromrep ~ (), BodyDec rep ~ (),
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 LetDec rep ~ LParamMem, ExpDec rep ~ (), Op rep ~ MemOp inner,
 Traversable t, ArrayShape a, HasLetDecMem (LetDec rep),
 BuilderOps rep, OpReturns inner, SizeSubst inner,
 PrettyRep fromrep) =>
(a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [WithAccInput fromrep]
inputs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {torep} {fromrep :: k} {inner}.
(LetDec torep ~ LParamMem, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, ExpDec torep ~ (),
 BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 LParamInfo fromrep ~ Type, LParamInfo torep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType torep ~ BranchTypeMem,
 Op torep ~ MemOp inner, PrettyRep fromrep,
 HasLetDecMem (LetDec torep), OpReturns inner, SizeSubst inner,
 BuilderOps torep) =>
Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
  where
    onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
lam = do
      [Param LParamMem]
params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam) forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString (forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv Type
t)
      forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [Param LParamMem]
params (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)

    onInput :: (a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (Lambda fromrep, b)
op) =
      (a
shape,[VName]
arrs,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {k} {fromrep :: k} {rep} {inner} {a} {b}.
(BodyDec fromrep ~ (), BodyDec rep ~ (),
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 ExpDec rep ~ (), LetDec rep ~ LParamMem, Op rep ~ MemOp inner,
 ArrayShape a, HasLetDecMem (LetDec rep), BuilderOps rep,
 PrettyRep fromrep, OpReturns inner, SizeSubst inner) =>
a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (Lambda fromrep, b)
op

    onOp :: a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (Lambda fromrep
lam, b
nes) = do
      let num_vs :: Int
num_vs = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda fromrep
lam)
          num_is :: Int
num_is = forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam
          i_params' :: [Param LParamMem]
i_params' = forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
attrs VName
v Type
_) -> forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
v forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
      [Param LParamMem]
x_params' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall {k} {rep :: k} {inner} {f :: * -> *} {u}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
 Op rep ~ MemOp inner, Monad f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns inner, HasScope rep f, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LParamMem]
y_params' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall {k} {rep} {fromrep :: k} {inner} {u}.
(ExpDec rep ~ (), BodyDec fromrep ~ (), BodyDec rep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LetDec rep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo rep ~ LParamMem, BranchType fromrep ~ ExtType,
 BranchType rep ~ BranchTypeMem, Op rep ~ MemOp inner,
 PrettyRep fromrep, HasLetDecMem (LetDec rep), OpReturns inner,
 SizeSubst inner, BuilderOps rep, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda rep
lam' <-
        forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
          ([Param LParamMem]
i_params' forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
          (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam', b
nes)

    mkP :: Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is =
      forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
          forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
            [DimIndex SubExp]
is forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
arr
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    onYParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
      VName
mem <- forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
arr_t Space
DefaultSpace
      let base_dims :: [TPrimExp Int64 VName]
base_dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p
allocInExp Exp fromrep
e = forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper fromrep torep (AllocM fromrep torep)
alloc Exp fromrep
e
  where
    alloc :: Mapper fromrep torep (AllocM fromrep torep)
alloc =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope torep -> Body fromrep -> AllocM fromrep torep (Body torep)
mapOnBody = forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
          mapOnRetType :: RetType fromrep -> AllocM fromrep torep (RetType torep)
mapOnRetType = forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
          mapOnBranchType :: BranchType fromrep -> AllocM fromrep torep (BranchType torep)
mapOnBranchType = forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
          mapOnFParam :: FParam fromrep -> AllocM fromrep torep (FParam torep)
mapOnFParam = forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
          mapOnLParam :: LParam fromrep -> AllocM fromrep torep (LParam torep)
mapOnLParam = forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
          mapOnOp :: Op fromrep -> AllocM fromrep torep (Op torep)
mapOnOp = \Op fromrep
op -> do
            Op fromrep -> AllocM fromrep torep (Op torep)
handle <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} {k} (fromrep :: k) (torep :: k).
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
            Op fromrep -> AllocM fromrep torep (Op torep)
handle Op fromrep
op
        }

allocInLoopForm ::
  (Allocable fromrep torep inner) =>
  LoopForm fromrep ->
  AllocM fromrep torep (LoopForm torep)
allocInLoopForm :: forall {k} (fromrep :: k) torep inner.
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm (WhileLoop VName
v) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). VName -> LoopForm rep
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromrep, VName)]
loopvars) =
  forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
n forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar [(LParam fromrep, VName)]
loopvars
  where
    allocInLoopVar :: (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar (Param Type
p, VName
a) = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
a
      case forall dec. Typed dec => Param dec -> Type
paramType Param Type
p of
        Array PrimType
pt Shape
shape NoUniqueness
u -> do
          [TPrimExp Int64 VName]
dims <- forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
a
          let ixfun' :: IxFun (TPrimExp Int64 VName)
ixfun' = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun forall a b. (a -> b) -> a -> b
$ forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
dims [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i]
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun'}, VName
a)
        Prim PrimType
bt ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}, VName
a)
        Mem Space
space ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}, VName
a)
        Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}, VName
a)

class SizeSubst op where
  opIsConst :: op -> Bool
  opIsConst = forall a b. a -> b -> a
const Bool
False

instance SizeSubst ()

instance SizeSubst op => SizeSubst (MemOp op) where
  opIsConst :: MemOp op -> Bool
opIsConst (Inner op
op) = forall op. SizeSubst op => op -> Bool
opIsConst op
op
  opIsConst MemOp op
_ = Bool
False

stmConsts :: SizeSubst (Op rep) => Stm rep -> S.Set VName
stmConsts :: forall {k} (rep :: k). SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
  | forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
stmConsts Stm rep
_ = forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( LetDec (Rep m) ~ LetDecMem,
    Mem (Rep m) inner,
    MonadBuilder m,
    ExpDec (Rep m) ~ ()
  ) =>
  ExpDec (Rep m) ->
  [VName] ->
  Exp (Rep m) ->
  m (Stm (Rep m))
mkLetNamesB' :: forall (m :: * -> *) inner.
(LetDec (Rep m) ~ LParamMem, Mem (Rep m) inner, MonadBuilder m,
 ExpDec (Rep m) ~ ()) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
  Pat LParamMem
pat <- forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
DefaultSpace [VName]
names Exp (Rep m)
e [ExpHint]
nohints
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ExpDec (Rep m)
dec) Exp (Rep m)
e
  where
    nohints :: [ExpHint]
nohints = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

mkLetNamesB'' ::
  ( Mem rep inner,
    LetDec rep ~ LetDecMem,
    OpReturns (Engine.OpWithWisdom inner),
    ExpDec rep ~ (),
    Rep m ~ Engine.Wise rep,
    HasScope (Engine.Wise rep) m,
    MonadBuilder m,
    Engine.CanBeWise inner
  ) =>
  [VName] ->
  Exp (Engine.Wise rep) ->
  m (Stm (Engine.Wise rep))
mkLetNamesB'' :: forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, LetDec rep ~ LParamMem,
 OpReturns (OpWithWisdom inner), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m, CanBeWise inner) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' [VName]
names Exp (Wise rep)
e = do
  Pat LParamMem
pat <- forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
DefaultSpace [VName]
names Exp (Wise rep)
e [ExpHint]
nohints
  let pat' :: Pat (LetDec (Wise rep))
pat' = forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
Engine.addWisdomToPat Pat LParamMem
pat Exp (Wise rep)
e
      dec :: ExpDec (Wise rep)
dec = forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise rep))
pat' () Exp (Wise rep)
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise rep))
pat' (forall dec. dec -> StmAux dec
defAux ExpDec (Wise rep)
dec) Exp (Wise rep)
e
  where
    nohints :: [ExpHint]
nohints = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

simplifiable ::
  ( Engine.SimplifiableRep rep,
    LetDec rep ~ LetDecMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    OpReturns (Engine.OpWithWisdom inner),
    AliasedOp (Engine.OpWithWisdom inner),
    IndexOp (Engine.OpWithWisdom inner),
    Mem rep inner
  ) =>
  (Engine.OpWithWisdom inner -> UT.UsageTable) ->
  (Engine.OpWithWisdom inner -> Engine.SimpleM rep (Engine.OpWithWisdom inner, Stms (Engine.Wise rep))) ->
  SimpleOps rep
simplifiable :: forall {k} (rep :: k) inner.
(SimplifiableRep rep, LetDec rep ~ LParamMem, ExpDec rep ~ (),
 BodyDec rep ~ (), OpReturns (OpWithWisdom inner),
 AliasedOp (OpWithWisdom inner), IndexOp (OpWithWisdom inner),
 Mem rep inner) =>
(OpWithWisdom inner -> UsageTable)
-> (OpWithWisdom inner
    -> SimpleM rep (OpWithWisdom inner, Stms (Wise rep)))
-> SimpleOps rep
simplifiable OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
-> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp =
  forall {k} (rep :: k).
(SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> (Pat (LetDec rep)
    -> Exp (Wise rep) -> SimpleM rep (Pat (LetDec rep)))
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps forall {k} {rep :: k} {f :: * -> *} {p}.
(ExpDec rep ~ (), Applicative f, ASTRep rep, CanBeWise (Op rep)) =>
p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' forall {k} {rep :: k} {f :: * -> *} {p}.
(BodyDec rep ~ (), Applicative f, ASTRep rep,
 CanBeWise (Op rep)) =>
p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' forall {m :: * -> *} {d} {u} {ret} {inner} {inner}.
(BranchType (Rep m) ~ MemInfo d u ret, Op (Rep m) ~ MemOp inner,
 MonadBuilder m) =>
SubExp -> Pat (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp MemOp (OpWithWisdom inner) -> UsageTable
opUsage forall {k} {rep :: k} {inner} {d} {a}.
(RetType rep ~ RetTypeMem, LParamInfo rep ~ LParamMem,
 BranchType rep ~ BranchTypeMem, FParamInfo rep ~ FParamMem,
 OpWithWisdom (Op rep) ~ MemOp inner, BuilderOps (Wise rep),
 CanBeWise (Op rep), TraverseOpStms (Wise rep), ASTRep rep,
 OpReturns inner, IndexOp (OpWithWisdom (Op rep)), Simplifiable d,
 Simplifiable (LetDec rep), Simplifiable (BranchType rep),
 Simplifiable (RetType rep), Simplifiable (LParamInfo rep),
 Simplifiable (FParamInfo rep), HasLetDecMem (LetDec rep)) =>
Pat (MemInfo d a MemBind)
-> Exp (Wise rep) -> SimpleM rep (Pat (MemInfo d a MemBind))
simplifyPat MemOp (OpWithWisdom inner)
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp
  where
    mkExpDecS' :: p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ Pat (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (VarWisdom, LetDec rep)
pat () Exp (Wise rep)
e

    mkBodyS' :: p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
stms Result
res

    protectOp :: SubExp -> Pat (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp SubExp
taken Pat (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
      Body (Rep m)
tbody <- forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
size]
      Body (Rep m)
fbody <- forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
taken] [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
tbody] Body (Rep m)
fbody forall a b. (a -> b) -> a -> b
$
            forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] MatchSort
MatchFallback
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
    protectOp SubExp
_ Pat (LetDec (Rep m))
_ MemOp inner
_ = forall a. Maybe a
Nothing

    opUsage :: MemOp (OpWithWisdom inner) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      forall a. Monoid a => a
mempty
    opUsage (Inner OpWithWisdom inner
inner) =
      OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
inner

    simplifyOp :: MemOp (OpWithWisdom inner)
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp (Alloc SubExp
size Space
space) =
      (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall inner. SubExp -> Space -> MemOp inner
Alloc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
    simplifyOp (Inner OpWithWisdom inner
k) = do
      (OpWithWisdom inner
k', Stms (Wise rep)
hoisted) <- OpWithWisdom inner
-> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp OpWithWisdom inner
k
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall inner. inner -> MemOp inner
Inner OpWithWisdom inner
k', Stms (Wise rep)
hoisted)

    simplifyPat :: Pat (MemInfo d a MemBind)
-> Exp (Wise rep) -> SimpleM rep (Pat (MemInfo d a MemBind))
simplifyPat (Pat [PatElem (MemInfo d a MemBind)]
pes) Exp (Wise rep)
e = do
      [ExpReturns]
rets <- forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp (Wise rep)
e
      forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (MemInfo d a MemBind)
-> ExpReturns -> SimpleM rep (PatElem (MemInfo d a MemBind))
update [PatElem (MemInfo d a MemBind)]
pes [ExpReturns]
rets
      where
        names :: [VName]
names = forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (MemInfo d a MemBind)]
pes
        update :: PatElem (MemInfo d a MemBind)
-> ExpReturns -> SimpleM rep (PatElem (MemInfo d a MemBind))
update
          (PatElem VName
pe_v (MemArray PrimType
pt ShapeBase d
shape a
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_)))
          (MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ (Just (ReturnsInBlock VName
_ ExtIxFun
ixfun)))
            | Just IxFun (TPrimExp Int64 VName)
ixfun' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
inst) ExtIxFun
ixfun =
                forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_v
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt
                          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase d
shape
                          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
u
                          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
mem forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFun (TPrimExp Int64 VName)
ixfun')
                      )
            where
              inst :: Ext VName -> Maybe VName
inst (Ext Int
i) = forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [VName]
names
              inst (Free VName
v) = forall a. a -> Maybe a
Just VName
v
        update PatElem (MemInfo d a MemBind)
pe ExpReturns
_ = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify PatElem (MemInfo d a MemBind)
pe

data ExpHint
  = NoHint
  | Hint IxFun Space

defaultExpHints :: (Monad m, ASTRep rep) => Exp rep -> m [ExpHint]
defaultExpHints :: forall {k} (m :: * -> *) (rep :: k).
(Monad m, ASTRep rep) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate (forall {k} (rep :: k).
(RepTypes rep, TypedOp (Op rep)) =>
Exp rep -> Int
expExtTypeSize Exp rep
e) ExpHint
NoHint