{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    askDefaultSpace,
    Allocable,
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    arraySizeInBytesExp,
    mkLetNamesB',
    mkLetNamesB'',

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (first)
import Data.Either (partitionEithers)
import Data.Foldable (toList)
import Data.List (foldl', transpose, zip4)
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.SymbolTable (IndexOp)
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Prop.Aliases (AliasedOp)
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (maybeNth, splitAt3)

type Allocable fromrep torep inner =
  ( PrettyRep fromrep,
    PrettyRep torep,
    Mem torep inner,
    LetDec torep ~ LetDecMem,
    FParamInfo fromrep ~ DeclType,
    LParamInfo fromrep ~ Type,
    BranchType fromrep ~ ExtType,
    RetType fromrep ~ DeclExtType,
    BodyDec fromrep ~ (),
    BodyDec torep ~ (),
    ExpDec torep ~ (),
    SizeSubst (inner torep),
    BuilderOps torep
  )

data AllocEnv fromrep torep = AllocEnv
  { -- | Aggressively try to reuse memory in do-loops -
    -- should be True inside kernels, False outside.
    forall fromrep torep. AllocEnv fromrep torep -> Bool
aggressiveReuse :: Bool,
    -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in local memory.
    forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
    forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
    forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromrep torep a
  = AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
  deriving
    ( forall a. a -> AllocM fromrep torep a
forall {fromrep} {torep}. Functor (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
$c<* :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
*> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c*> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$cliftA2 :: forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
<*> :: forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$c<*> :: forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
pure :: forall a. a -> AllocM fromrep torep a
$cpure :: forall fromrep torep a. a -> AllocM fromrep torep a
Applicative,
      forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
$c<$ :: forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
fmap :: forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$cfmap :: forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
Functor,
      forall a. a -> AllocM fromrep torep a
forall fromrep torep. Applicative (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> AllocM fromrep torep a
$creturn :: forall fromrep torep a. a -> AllocM fromrep torep a
>> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c>> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>>= :: forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>>= :: forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
Monad,
      AllocM fromrep torep VNameSource
VNameSource -> AllocM fromrep torep ()
forall fromrep torep. Monad (AllocM fromrep torep)
forall fromrep torep. AllocM fromrep torep VNameSource
forall fromrep torep. VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromrep torep ()
$cputNameSource :: forall fromrep torep. VNameSource -> AllocM fromrep torep ()
getNameSource :: AllocM fromrep torep VNameSource
$cgetNameSource :: forall fromrep torep. AllocM fromrep torep VNameSource
MonadFreshNames,
      HasScope torep,
      LocalScope torep,
      MonadReader (AllocEnv fromrep torep)
    )

instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where
  type Rep (AllocM fromrep torep) = torep

  mkExpDecM :: Pat (LetDec (Rep (AllocM fromrep torep)))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pat (LetDec (Rep (AllocM fromrep torep)))
_ Exp (Rep (AllocM fromrep torep))
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
    Space
def_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
    [ExpHint]
hints <- forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp (Rep (AllocM fromrep torep))
e
    Pat LParamMem
pat <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep (AllocM fromrep torep))
e [ExpHint]
hints
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ()) Exp (Rep (AllocM fromrep torep))
e

  mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep (AllocM fromrep torep))
stms Result
res

  addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a.
AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m

expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints :: forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
  Exp torep -> AllocM fromrep torep [ExpHint]
f <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
  Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e

-- | The space in which we allocate memory if we have no other
-- preferences or constraints.
askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace :: forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace

runAllocM ::
  MonadFreshNames m =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  AllocM fromrep torep a ->
  m a
runAllocM :: forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
  where
    env :: AllocEnv fromrep torep
env =
      AllocEnv
        { aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
          allocSpace :: Space
allocSpace = Space
space,
          envConsts :: Set VName
envConsts = forall a. Monoid a => a
mempty,
          allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
          envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
        }

elemSize :: Num a => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = forall a. Num a => PrimType -> a
primByteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Num a => a -> a -> a
(*) (forall a. Num a => Type -> a
elemSize Type
t) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

arraySizeInBytesExpM :: MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM :: forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM Type
t = do
  let dim_prod_i64 :: TPrimExp Int64 VName
dim_prod_i64 = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
      elm_size_i64 :: TPrimExp Int64 VName
elm_size_i64 = forall a. Num a => Type -> a
elemSize Type
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) (forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0) forall a b. (a -> b) -> a -> b
$
      forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
dim_prod_i64 forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size_i64

arraySizeInBytes :: MonadBuilder m => Type -> m SubExp
arraySizeInBytes :: forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *). MonadBuilder m => Type -> m (PrimExp VName)
arraySizeInBytesExpM

allocForArray' ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Type ->
  Space ->
  m VName
allocForArray' :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space = do
  SubExp
size <- forall (m :: * -> *). MonadBuilder m => Type -> m SubExp
arraySizeInBytes Type
t
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size Space
space

-- | Allocate memory for a value of the given type.
allocForArray ::
  Allocable fromrep torep inner =>
  Type ->
  Space ->
  AllocM fromrep torep VName
allocForArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space = do
  forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
space

allocsForStm ::
  (Allocable fromrep torep inner) =>
  [Ident] ->
  Exp torep ->
  AllocM fromrep torep (Stm torep)
allocsForStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm [Ident]
idents Exp torep
e = do
  Space
def_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  [ExpHint]
hints <- forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e
  [ExpReturns]
rts <- forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp torep
e
  [PatElem LParamMem]
pes <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
  ()
dec <- forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) Exp torep
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LParamMem]
pes) (forall dec. dec -> StmAux dec
defAux ()
dec) Exp torep
e

patWithAllocations ::
  (MonadBuilder m, Mem (Rep m) inner) =>
  Space ->
  [VName] ->
  Exp (Rep m) ->
  [ExpHint] ->
  m (Pat LetDecMem)
patWithAllocations :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
def_space [VName]
names Exp (Rep m)
e [ExpHint]
hints = do
  [Type]
ts' <- forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
  let idents :: [Ident]
idents = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
names [Type]
ts'
  [ExpReturns]
rts <- forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp (Rep m)
e
  forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
idents [ExpReturns]
rts [ExpHint]
hints

mkMissingIdents :: MonadFreshNames m => [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents :: forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
idents [ExpReturns]
rts =
  forall a. [a] -> [a]
reverse forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *} {d} {u} {ret}.
MonadFreshNames f =>
MemInfo d u ret -> Maybe Ident -> f Ident
f (forall a. [a] -> [a]
reverse [ExpReturns]
rts) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Maybe a
Just (forall a. [a] -> [a]
reverse [Ident]
idents) forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat forall a. Maybe a
Nothing)
  where
    f :: MemInfo d u ret -> Maybe Ident -> f Ident
f MemInfo d u ret
_ (Just Ident
ident) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Ident
ident
    f (MemMem Space
space) Maybe Ident
Nothing = forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext_mem" forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
space
    f MemInfo d u ret
_ Maybe Ident
Nothing = forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

allocsForPat ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m [PatElem LetDecMem]
allocsForPat :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space
-> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LParamMem]
allocsForPat Space
def_space [Ident]
some_idents [ExpReturns]
rts [ExpHint]
hints = do
  [Ident]
idents <- forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
some_idents [ExpReturns]
rts

  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
idents [ExpReturns]
rts [ExpHint]
hints) forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
    let ident_shape :: Shape
ident_shape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
    case ExpReturns
rt of
      MemPrim PrimType
_ -> do
        LParamMem
summary <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemMem Space
space ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfun
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfn
      MemArray PrimType
_ ShapeBase (Ext SubExp)
extshape NoUniqueness
_ Maybe MemReturn
Nothing
        | Just [SubExp]
_ <- forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ShapeBase (Ext SubExp)
extshape -> do
            LParamMem
summary <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
def_space (Ident -> Type
identType Ident
ident) ExpHint
hint
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
      MemArray PrimType
bt ShapeBase (Ext SubExp)
_ NoUniqueness
u (Just (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = forall {f :: * -> *} {f :: * -> *}.
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfn
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$
          VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i) IxFun (TPrimExp Int64 VName)
ixfn
      MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. VName -> dec -> PatElem dec
PatElem (Ident -> VName
identName Ident
ident) forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      ExpReturns
_ -> forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPat!"
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a}. Ext a -> Maybe a
known forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = forall a. a -> Maybe a
Just a
v
    known Ext {} = forall a. Maybe a
Nothing

    getIdent :: [Ident] -> a -> VName
getIdent [Ident]
idents a
i =
      case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [Ident]
idents of
        Just Ident
ident -> Ident -> VName
identName Ident
ident
        Maybe Ident
Nothing ->
          forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getIdent: Ext " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
i forall a. Semigroup a => a -> a -> a
<> String
" but pattern has " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
idents) forall a. Semigroup a => a -> a -> a
<> String
" elements: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> String
prettyString [Ident]
idents

    instantiateExtIxFun :: [Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
inst
      where
        inst :: Ext VName -> VName
inst (Free VName
v) = VName
v
        inst (Ext Int
i) = forall {a}. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i

instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun :: forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {f :: * -> *} {a}. Applicative f => Ext a -> f a
inst
  where
    inst :: Ext a -> f a
inst Ext {} = forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
    inst (Free a
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

summaryForBindage ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
  Space ->
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Space -> Type -> ExpHint -> m LParamMem
summaryForBindage Space
_ (Prim PrimType
bt) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage Space
_ (Mem Space
space) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage Space
_ (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage Space
def_space t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
t Space
def_space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
summaryForBindage Space
_ t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun (TPrimExp Int64 VName)
ixfun Space
space) = do
  SubExp
bytes <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$
      forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
        [ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun,
          forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
        ]
  VName
m <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
bytes Space
space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun

allocInFParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, Space)] ->
  ([FParam torep] -> AllocM fromrep torep a) ->
  AllocM fromrep torep a
allocInFParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
memparams, [Param FParamMem]
ctxparams)) <-
    forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam) [(FParam fromrep, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
memparams forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctxparams forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [Param FParamMem]
params'

allocInFParam ::
  (Allocable fromrep torep inner) =>
  FParam fromrep ->
  Space ->
  WriterT
    ([FParam torep], [FParam torep])
    (AllocM fromrep torep)
    (FParam torep)
allocInFParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
  case forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType FParam fromrep
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (forall dec. Param dec -> VName
paramName FParam fromrep
param) forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([forall dec. Attrs -> VName -> dec -> Param dec
Param (forall dec. Param dec -> Attrs
paramAttrs FParam fromrep
param) VName
mem forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace], [])
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun}
    Prim PrimType
pt ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
    Mem Space
space ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure FParam fromrep
param {paramDec :: FParamMem
paramDec = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}

ensureRowMajorArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureRowMajorArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  let space :: Space
space = forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok
  if IxFun (TPrimExp Int64 VName) -> Int
numLMADs IxFun (TPrimExp Int64 VName)
ixfun forall a. Eq a => a -> a -> Bool
== Int
1
    Bool -> Bool -> Bool
&& IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm IxFun (TPrimExp Int64 VName)
ixfun forall a. Eq a => a -> a -> Bool
== [Int
0 .. forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun forall a. Num a => a -> a -> a
- Int
1]
    Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun) forall a. Eq a => a -> a -> Bool
== forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
ixfun
    Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    Bool -> Bool -> Bool
&& forall num. IxFun num -> Bool
IxFun.contiguous IxFun (TPrimExp Int64 VName)
ixfun
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

ensureArrayIn ::
  (Allocable fromrep torep inner) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString PrimValue
v forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (VName
mem', VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (forall a. a -> Maybe a
Just Space
space) VName
v
  (VName
_, IxFun (TPrimExp Int64 VName)
ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
  [SubExp]
ctx <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_arg" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun)
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem'], [SubExp]
ctx)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

allocInMergeParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, SubExp)] ->
  ( [(FParam torep, SubExp)] ->
    ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
    AllocM fromrep torep a
  ) ->
  AllocM fromrep torep a
allocInMergeParams :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m = do
  (([Param FParamMem]
valparams, [SubExp]
valargs, [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
mem_params, [Param FParamMem]
ctx_params)) <-
    forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam [(FParam fromrep, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
mem_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
ses = do
        ([SubExp]
ses', ([SubExp]
memargs, [SubExp]
ctxargs)) <-
          forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall a b. (a -> b) -> a -> b
($) [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps [SubExp]
ses
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
memargs forall a. Semigroup a => a -> a -> a
<> [SubExp]
ctxargs, [SubExp]
ses')

  ([SubExp]
valctx_args, [SubExp]
valargs') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
valargs
  let merge' :: [(Param FParamMem, SubExp)]
merge' =
        forall a b. [a] -> [b] -> [(a, b)]
zip ([Param FParamMem]
mem_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
ctx_params forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams) ([SubExp]
valctx_args forall a. Semigroup a => a -> a -> a
<> [SubExp]
valargs')
  forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m [(Param FParamMem, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res
  where
    param_names :: Names
param_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam fromrep, SubExp)]
merge
    anyIsLoopParam :: Names -> Bool
anyIsLoopParam Names
names = Names
names Names -> Names -> Bool
`namesIntersect` Names
param_names

    scalarRes :: DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (Var VName
res) = do
      -- Try really hard to avoid copying needlessly, but the result
      -- _must_ be in ScalarSpace and have the right index function.
      (VName
res_mem, IxFun (TPrimExp Int64 VName)
res_ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
res
      Space
res_mem_space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
res_mem
      (VName
res_mem', VName
res') <-
        if (Space
res_mem_space, IxFun (TPrimExp Int64 VName)
res_ixfun) forall a. Eq a => a -> a -> Bool
== (Space
v_mem_space, IxFun (TPrimExp Int64 VName)
v_ixfun)
          then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
res_mem, VName
res)
          else forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
param_t) VName
res
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
res_mem'], [])
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
res'
    scalarRes DeclType
_ Space
_ IxFun (TPrimExp Int64 VName)
_ SubExp
se = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

    allocInMergeParam ::
      (Allocable fromrep torep inner) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam torep], [FParam torep])
        (AllocM fromrep torep)
        ( FParam torep,
          SubExp,
          SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
        )
    allocInMergeParam :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | param_t :: DeclType
param_t@(Array PrimType
pt Shape
shape Uniqueness
u) <- forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
          (VName
v_mem, IxFun (TPrimExp Int64 VName)
v_ixfun) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
          Space
v_mem_space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem

          -- Loop-invariant array parameters that are in scalar space
          -- are special - we do not wish to existentialise their index
          -- function at all (but the memory block is still existential).
          case Space
v_mem_space of
            ScalarSpace {} ->
              if Names -> Bool
anyIsLoopParam (forall a. FreeIn a => a -> Names
freeIn Shape
shape)
                then do
                  -- Arrays with loop-variant shape cannot be in scalar
                  -- space, so copy them elsewhere and try again.
                  Space
space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
                  (VName
_, VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
                  forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, VName -> SubExp
Var VName
v')
                else do
                  Param FParamMem
p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space
                  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
p], [])

                  forall (f :: * -> *) a. Applicative f => a -> f a
pure
                    ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall dec. Param dec -> VName
paramName Param FParamMem
p) IxFun (TPrimExp Int64 VName)
v_ixfun},
                      VName -> SubExp
Var VName
v,
                      forall {m :: * -> *} {inner :: * -> *} {t :: (* -> *) -> * -> *}
       {a}.
(BranchType (Rep m) ~ BranchTypeMem, LetDec (Rep m) ~ LParamMem,
 RetType (Rep m) ~ RetTypeMem, LParamInfo (Rep m) ~ LParamMem,
 FParamInfo (Rep m) ~ FParamMem, OpC (Rep m) ~ MemOp inner,
 RephraseOp inner, MonadWriter ([SubExp], [a]) (t m),
 MonadBuilder m, MonadTrans t, HasLetDecMem (LetDec (Rep m)),
 OpReturns (inner (Rep m))) =>
DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun
                    )
            Space
_ -> do
              (VName
v_mem', VName
v') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray forall a. Maybe a
Nothing VName
v
              (VName
_, IxFun (TPrimExp Int64 VName)
v_ixfun') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v'
              Space
v_mem_space' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem'

              [Param FParamMem]
ctx_params <-
                forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length IxFun (TPrimExp Int64 VName)
v_ixfun') forall a b. (a -> b) -> a -> b
$
                  forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ctx_param_ext" (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)

              IxFun (TPrimExp Int64 VName)
param_ixfun <-
                forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun forall a b. (a -> b) -> a -> b
$
                  forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
                    ( forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Int -> Ext a
Ext [Int
0 ..]) forall a b. (a -> b) -> a -> b
$
                        forall a b. (a -> b) -> [a] -> [b]
map (forall a. a -> TPrimExp Int64 a
le64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Ext a
Free forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param FParamMem]
ctx_params
                    )
                    (forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
IxFun.existentialize IxFun (TPrimExp Int64 VName)
v_ixfun')

              Param FParamMem
mem_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space'
              forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem
mem_param], [Param FParamMem]
ctx_params)
              forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (forall dec. Param dec -> VName
paramName Param FParamMem
mem_param) IxFun (TPrimExp Int64 VName)
param_ixfun},
                  VName -> SubExp
Var VName
v',
                  forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
v_mem_space'
                )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
se) = forall {fromrep} {fromrep} {torep} {torep} {inner :: * -> *}
       {inner :: * -> *} {b}.
(BranchType fromrep ~ ExtType, BranchType fromrep ~ ExtType,
 BranchType torep ~ BranchTypeMem, BranchType torep ~ BranchTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo torep ~ LParamMem,
 LParamInfo torep ~ LParamMem, LParamInfo fromrep ~ Type,
 ExpDec torep ~ (), ExpDec torep ~ (), LetDec torep ~ LParamMem,
 LetDec torep ~ LParamMem, BodyDec torep ~ (), BodyDec fromrep ~ (),
 BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo torep ~ FParamMem, FParamInfo fromrep ~ DeclType,
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, RetType torep ~ RetTypeMem,
 OpC torep ~ MemOp inner, OpC torep ~ MemOp inner,
 PrettyRep fromrep, PrettyRep fromrep, HasLetDecMem (LetDec torep),
 HasLetDecMem (LetDec torep), OpReturns (inner torep),
 OpReturns (inner torep), RephraseOp inner, RephraseOp inner,
 SizeSubst (inner torep), SizeSubst (inner torep), BuilderOps torep,
 BuilderOps torep) =>
Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam SubExp
se forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace

    doDefault :: Param (FParamInfo fromrep)
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param (FParamInfo fromrep)
mergeparam b
se Space
space = do
      Param (FParamInfo torep)
mergeparam' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam Param (FParamInfo fromrep)
mergeparam Space
space
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (FParamInfo torep)
mergeparam', b
se, forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg (forall dec. Typed dec => Param dec -> Type
paramType Param (FParamInfo fromrep)
mergeparam) Space
space)

arrayWithIxFun ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m), LetDec (Rep m) ~ LetDecMem) =>
  Space ->
  IxFun ->
  Type ->
  VName ->
  m (VName, VName)
arrayWithIxFun :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m),
 LetDec (Rep m) ~ LParamMem) =>
Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun Space
space IxFun (TPrimExp Int64 VName)
ixfun Type
v_t VName
v = do
  let Array PrimType
pt Shape
shape NoUniqueness
u = Type
v_t
  VName
mem <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) =>
Type -> Space -> m VName
allocForArray' Type
v_t Space
space
  VName
v_copy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v forall a. Semigroup a => a -> a -> a
<> String
"_scalcopy"
  forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
v_copy forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun]) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v_copy)

ensureDirectArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureDirectArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun (TPrimExp Int64 VName)
ixfun Bool -> Bool -> Bool
&& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v)
    else Space -> AllocM fromrep torep (VName, VName)
needCopy (forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromrep torep (VName, VName)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocPermArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  [Int] ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocPermArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v = do
  Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space
      VName
v' <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ String
s forall a. Semigroup a => a -> a -> a
<> String
"_desired_form"
      let info :: LParamMem
info =
            forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$
              forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) [Int]
perm
          pat :: Pat LParamMem
pat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
v' LParamMem
info]
      forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v')
    Type
_ ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"allocPermArray: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Type
t

allocLinearArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocLinearArray :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  let perm :: [Int]
perm = [Int
0 .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t forall a. Num a => a -> a -> a
- Int
1]
  forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName)
allocPermArray Space
space [Int]
perm String
s VName
v

funcallArgs ::
  (Allocable fromrep torep inner) =>
  [(SubExp, Diet)] ->
  AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, ([SubExp]
ctx_args, [SubExp]
mem_and_size_args)) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      SubExp
arg' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
arg', Diet
d)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) ([SubExp]
ctx_args forall a. Semigroup a => a -> a -> a
<> [SubExp]
mem_and_size_args) forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, VName
arg') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (forall a. a -> Maybe a
Just Space
space) VName
v
  forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> SubExp
Var VName
mem], [])
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
arg

explicitAllocationsGeneric ::
  (Allocable fromrep torep inner) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Pass fromrep torep
explicitAllocationsGeneric :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" forall a b. (a -> b) -> a -> b
$
    forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
  where
    onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
      forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType fromrep]
rettype [FParam fromrep]
params Body fromrep
fbody) =
      forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts forall a b. (a -> b) -> a -> b
$
        forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams (forall a b. [a] -> [b] -> [(a, b)]
zip [FParam fromrep]
params forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Space
space) forall a b. (a -> b) -> a -> b
$ \[FParam torep]
params' -> do
          (Body torep
fbody', [RetTypeMem]
mem_rets) <-
            forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Space
space) [RetType fromrep]
rettype) Body fromrep
fbody
          let rettype' :: [RetTypeMem]
rettype' = [RetTypeMem]
mem_rets forall a. [a] -> [a] -> [a]
++ Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space (forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets) [RetType fromrep]
rettype
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetTypeMem]
rettype' [FParam torep]
params' Body torep
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope torep m,
    Allocable fromrep torep inner
  ) =>
  Space ->
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Stms fromrep ->
  m (Stms torep)
explicitAllocationsInStmsGeneric :: forall (m :: * -> *) torep fromrep (inner :: * -> *).
(MonadFreshNames m, HasScope torep m,
 Allocable fromrep torep inner) =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
  Scope torep
scope <- forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
Space
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Space
space Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints forall a b. (a -> b) -> a -> b
$
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$
        forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
k [DeclExtType]
dets = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity RetTypeMem
addMem [DeclExtType]
dets) Int
0
  where
    addMem :: DeclExtType -> StateT Int Identity RetTypeMem
addMem (Prim PrimType
t) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u) = do
      Int
i <- forall s (m :: * -> *). MonadState s m => m s
get forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
+ Int
1)
      let shape' :: ShapeBase (Ext SubExp)
shape' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ShapeBase (Ext SubExp)
shape
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i forall a b. (a -> b) -> a -> b
$
        forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert forall a b. (a -> b) -> a -> b
$
            forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape'
    addMem (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = forall a. a -> TPrimExp Int64 a
le64 forall a b. (a -> b) -> a -> b
$ forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = forall a. a -> Ext a
Free forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = forall a. Int -> Ext a
Ext (Int
i forall a. Num a => a -> a -> a
+ Int
k)
    shift (Free SubExp
x) = forall a. a -> Ext a
Free SubExp
x

bodyReturnMemCtx ::
  (Allocable fromrep torep inner) =>
  SubExpRes ->
  AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)]
bodyReturnMemCtx :: forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx (SubExpRes Certs
_ Constant {}) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure []
bodyReturnMemCtx (SubExpRes Certs
_ (Var VName
v)) = do
  LParamMem
info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemPrim {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemAcc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    MemMem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_) -> do
      LParamMem
mem_info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
mem
      case LParamMem
mem_info of
        MemMem Space
space ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [(SubExp -> SubExpRes
subExpRes forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem, forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)]
        LParamMem
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"bodyReturnMemCtx: not a memory block: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString VName
mem

allocInFunBody ::
  (Allocable fromrep torep inner) =>
  [Maybe Space] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [FunReturns])
allocInFunBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect [Maybe Space]
space_oks' Result
res
    (Result
mem_ctx_res, [RetTypeMem]
mem_ctx_rets) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall fromrep torep (inner :: * -> *) u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx Result
res'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
mem_ctx_res forall a. Semigroup a => a -> a -> a
<> Result
res', [RetTypeMem]
mem_ctx_rets)
  where
    num_vals :: Int
num_vals = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res forall a. Num a => a -> a -> a
- Int
num_vals) forall a. Maybe a
Nothing forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  SubExpRes ->
  AllocM fromrep torep SubExpRes
ensureDirect :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
space_ok (SubExpRes Certs
cs SubExp
se) = do
  LParamMem
se_info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (LParamMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, VName
v') <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
    (LParamMem, SubExp)
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

allocInStms ::
  (Allocable fromrep torep inner) =>
  Stms fromrep ->
  AllocM fromrep torep a ->
  AllocM fromrep torep a
allocInStms :: forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
  where
    allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
    allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
      Seq (Stm torep)
allocstms <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
      forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
allocstms
      let stms_consts :: Set VName
stms_consts = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
          f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env = AllocEnv fromrep torep
env {envConsts :: Set VName
envConsts = Set VName
stms_consts forall a. Semigroup a => a -> a -> a
<> forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts AllocEnv fromrep torep
env}
      forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms

allocInStm ::
  (Allocable fromrep torep inner) =>
  Stm fromrep ->
  AllocM fromrep torep ()
allocInStm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pat [PatElem (LetDec fromrep)]
pes) StmAux (ExpDec fromrep)
_ Exp fromrep
e) =
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem (LetDec fromrep)]
pes) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e

allocInLambda ::
  Allocable fromrep torep inner =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (forall rep. Body rep -> Stms rep
bodyStms Body fromrep
body) forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body fromrep
body

numLMADs :: IxFun -> Int
numLMADs :: IxFun (TPrimExp Int64 VName) -> Int
numLMADs = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs

ixFunPerm :: IxFun -> [Int]
ixFunPerm :: IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
IxFun.ldPerm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> a
NE.head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs

ixFunMon :: IxFun -> [IxFun.Monotonicity]
ixFunMon :: IxFun (TPrimExp Int64 VName) -> [Monotonicity]
ixFunMon = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Monotonicity
IxFun.ldMon forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. LMAD num -> [LMADDim num]
IxFun.lmadDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> a
NE.head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num. IxFun num -> NonEmpty (LMAD num)
IxFun.ixfunLMADs

data MemReq
  = MemReq Space [Int] [IxFun.Monotonicity] Rank Bool
  | NeedsLinearisation Space
  deriving (MemReq -> MemReq -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemReq -> MemReq -> Bool
$c/= :: MemReq -> MemReq -> Bool
== :: MemReq -> MemReq -> Bool
$c== :: MemReq -> MemReq -> Bool
Eq, Int -> MemReq -> ShowS
[MemReq] -> ShowS
MemReq -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemReq] -> ShowS
$cshowList :: [MemReq] -> ShowS
show :: MemReq -> String
$cshow :: MemReq -> String
showsPrec :: Int -> MemReq -> ShowS
$cshowsPrec :: Int -> MemReq -> ShowS
Show)

combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs :: MemReq -> MemReq -> MemReq
combMemReqs x :: MemReq
x@NeedsLinearisation {} MemReq
_ = MemReq
x
combMemReqs MemReq
_ y :: MemReq
y@NeedsLinearisation {} = MemReq
y
combMemReqs x :: MemReq
x@(MemReq Space
x_space [Int]
_ [Monotonicity]
_ Rank
_ Bool
_) y :: MemReq
y@MemReq {} =
  if MemReq
x forall a. Eq a => a -> a -> Bool
== MemReq
y then MemReq
x else Space -> MemReq
NeedsLinearisation Space
x_space

type MemReqType = MemInfo (Ext SubExp) NoUniqueness MemReq

combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes :: MemReqType -> MemReqType -> MemReqType
combMemReqTypes (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
x) (MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ MemReq
y) =
  forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ MemReq -> MemReq -> MemReq
combMemReqs MemReq
x MemReq
y
combMemReqTypes MemReqType
x MemReqType
_ = MemReqType
x

contextRets :: MemReqType -> [MemInfo d u r]
contextRets :: forall d u r. MemReqType -> [MemInfo d u r]
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (MemReq Space
space [Int]
_ [Monotonicity]
_ (Rank Int
base_rank) Bool
_)) =
  -- Memory + offset + base_rank + (stride,size)*rank.
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    forall a. a -> [a] -> [a]
: forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate Int
base_rank (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
    forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
2 forall a. Num a => a -> a -> a
* forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ (NeedsLinearisation Space
space)) =
  -- Memory + offset + (base,stride,size)*rank.
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    forall a. a -> [a] -> [a]
: forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64
    forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate (Int
3 forall a. Num a => a -> a -> a
* forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
contextRets MemReqType
_ = []

-- Add memory information to the body, but do not return memory/ixfun
-- information.  Instead, return restrictions on what the index
-- function should look like.  We will then (crudely) unify these
-- restrictions across all bodies.
allocInMatchBody ::
  (Allocable fromrep torep inner) =>
  [ExtType] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [ExtType]
rets (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms forall a b. (a -> b) -> a -> b
$ do
    [MemReqType]
restrictions <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {rep} {inner :: * -> *} {m :: * -> *} {d}.
(LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 BranchType rep ~ BranchTypeMem, RetType rep ~ RetTypeMem,
 OpC rep ~ MemOp inner, HasScope rep m, RephraseOp inner,
 ASTRep rep, HasLetDecMem (LetDec rep), Monad m,
 OpReturns (inner rep), Show d) =>
TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction [ExtType]
rets (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [MemReqType]
restrictions)
  where
    restriction :: TypeBase (ShapeBase d) NoUniqueness
-> SubExp -> m (MemInfo d NoUniqueness MemReq)
restriction TypeBase (ShapeBase d) NoUniqueness
t SubExp
se = do
      LParamMem
v_info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
      case (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info) of
        (Array PrimType
pt ShapeBase d
shape NoUniqueness
u, MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) -> do
          Space
space <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$
            if IxFun (TPrimExp Int64 VName) -> Int
numLMADs IxFun (TPrimExp Int64 VName)
ixfun forall a. Eq a => a -> a -> Bool
== Int
1
              then
                Space -> [Int] -> [Monotonicity] -> Rank -> Bool -> MemReq
MemReq
                  Space
space
                  (IxFun (TPrimExp Int64 VName) -> [Int]
ixFunPerm IxFun (TPrimExp Int64 VName)
ixfun)
                  (IxFun (TPrimExp Int64 VName) -> [Monotonicity]
ixFunMon IxFun (TPrimExp Int64 VName)
ixfun)
                  (Int -> Rank
Rank forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun)
                  (forall num. IxFun num -> Bool
IxFun.contiguous IxFun (TPrimExp Int64 VName)
ixfun)
              else Space -> MemReq
NeedsLinearisation Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemMem Space
space) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
        (TypeBase (ShapeBase d) NoUniqueness
_, MemPrim PrimType
pt) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (TypeBase (ShapeBase d) NoUniqueness
_, MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (TypeBase (ShapeBase d) NoUniqueness, LParamMem)
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"allocInMatchBody: mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (TypeBase (ShapeBase d) NoUniqueness
t, LParamMem
v_info)

mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet :: [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs =
  let ([BranchTypeMem]
ctx_rets, [BranchTypeMem]
res_rets) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([], []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [MemReqType]
reqs [Int]
offsets
   in [BranchTypeMem]
ctx_rets forall a. [a] -> [a] -> [a]
++ [BranchTypeMem]
res_rets
  where
    numCtxNeeded :: MemReqType -> Int
numCtxNeeded = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u r. MemReqType -> [MemInfo d u r]
contextRets

    offsets :: [Int]
offsets = forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(+) Int
0 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map MemReqType -> Int
numCtxNeeded [MemReqType]
reqs
    num_new_ctx :: Int
num_new_ctx = forall a. [a] -> a
last [Int]
offsets

    helper :: ([BranchTypeMem], [BranchTypeMem])
-> (MemReqType, Int) -> ([BranchTypeMem], [BranchTypeMem])
helper ([BranchTypeMem]
ctx_rets_acc, [BranchTypeMem]
res_rets_acc) (MemReqType
req, Int
ctx_offset) =
      ( [BranchTypeMem]
ctx_rets_acc forall a. [a] -> [a] -> [a]
++ forall d u r. MemReqType -> [MemInfo d u r]
contextRets MemReqType
req,
        [BranchTypeMem]
res_rets_acc forall a. [a] -> [a] -> [a]
++ [Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset MemReqType
req]
      )

    arrayInfo :: Int -> MemReq -> (Space, [Int], [Monotonicity], Int, Bool)
arrayInfo Int
rank (NeedsLinearisation Space
space) =
      (Space
space, [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1], forall a. a -> [a]
repeat Monotonicity
IxFun.Inc, Int
rank, Bool
True)
    arrayInfo Int
_ (MemReq Space
space [Int]
perm [Monotonicity]
mon (Rank Int
base_rank) Bool
contig) =
      (Space
space, [Int]
perm, [Monotonicity]
mon, Int
base_rank, Bool
contig)

    inspect :: Int -> MemReqType -> BranchTypeMem
inspect Int
ctx_offset (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u MemReq
req) =
      let shape' :: ShapeBase (Ext SubExp)
shape' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> Ext a -> Ext a
adjustExt Int
num_new_ctx) ShapeBase (Ext SubExp)
shape
          (Space
space, [Int]
perm, [Monotonicity]
mon, Int
base_rank, Bool
contig) = Int -> MemReq -> (Space, [Int], [Monotonicity], Int, Bool)
arrayInfo (forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
shape) MemReq
req
       in forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
ctx_offset forall a b. (a -> b) -> a -> b
$
            Ext SubExp -> TPrimExp Int64 (Ext VName)
convert
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a.
Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
IxFun.mkExistential Int
base_rank (forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
perm [Monotonicity]
mon) Bool
contig (Int
ctx_offset forall a. Num a => a -> a -> a
+ Int
1)
    inspect Int
_ (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    inspect Int
_ (MemPrim PrimType
pt) = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect Int
_ (MemMem Space
space) = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = forall a. a -> TPrimExp Int64 a
le64 (forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = forall a. a -> Ext a
Free forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExt :: Int -> Ext a -> Ext a
    adjustExt :: forall a. Int -> Ext a -> Ext a
adjustExt Int
_ (Free a
v) = forall a. a -> Ext a
Free a
v
    adjustExt Int
k (Ext Int
i) = forall a. Int -> Ext a
Ext (Int
k forall a. Num a => a -> a -> a
+ Int
i)

addCtxToMatchBody ::
  (Allocable fromrep torep inner) =>
  [MemReqType] ->
  Body torep ->
  AllocM fromrep torep (Body torep)
addCtxToMatchBody :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
body = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
  Result
res <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {torep} {fromrep} {inner :: * -> *} {d} {u}.
(ExpDec torep ~ (), BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType torep ~ RetTypeMem,
 LetDec torep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo torep ~ LParamMem, BranchType fromrep ~ ExtType,
 BranchType torep ~ BranchTypeMem, OpC torep ~ MemOp inner,
 PrettyRep fromrep, HasLetDecMem (LetDec torep),
 OpReturns (inner torep), RephraseOp inner, SizeSubst (inner torep),
 BuilderOps torep) =>
MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
linearIfNeeded [MemReqType]
reqs forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body torep
body
  Result
ctx <- forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *} {inner :: * -> *}.
(BranchType (Rep f) ~ BranchTypeMem,
 LParamInfo (Rep f) ~ LParamMem, FParamInfo (Rep f) ~ FParamMem,
 RetType (Rep f) ~ RetTypeMem, OpC (Rep f) ~ MemOp inner,
 MonadBuilder f, RephraseOp inner, HasLetDecMem (LetDec (Rep f)),
 OpReturns (inner (Rep f))) =>
SubExpRes -> f Result
resCtx Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Result
ctx forall a. [a] -> [a] -> [a]
++ Result
res
  where
    linearIfNeeded :: MemInfo d u MemReq -> SubExpRes -> AllocM fromrep torep SubExpRes
linearIfNeeded (MemArray PrimType
_ ShapeBase d
_ u
_ (NeedsLinearisation Space
space)) (SubExpRes Certs
cs (Var VName
v)) =
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureRowMajorArray (forall a. a -> Maybe a
Just Space
space) VName
v
    linearIfNeeded MemInfo d u MemReq
_ SubExpRes
res =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
res

    resCtx :: SubExpRes -> f Result
resCtx (SubExpRes Certs
_ Constant {}) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    resCtx (SubExpRes Certs
_ (Var VName
v)) = do
      LParamMem
info <- forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
v
      case LParamMem
info of
        MemPrim {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemAcc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        MemMem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [] -- should not happen
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun) -> do
          [SubExp]
ixfun_exts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ixfun_ext" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
toList IxFun (TPrimExp Int64 VName)
ixfun
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ SubExp -> SubExpRes
subExpRes (VName -> SubExp
Var VName
mem) forall a. a -> [a] -> [a]
: [SubExp] -> Result
subExpsRes [SubExp]
ixfun_exts

-- Do a a simple form of invariance analysis to simplify a Match.  It
-- is unfortunate that we have to do it here, but functions such as
-- scalarRes will look carefully at the index functions before the
-- simplifier has a chance to run.  In a perfect world we would
-- simplify away those copies afterwards. XXX; this should be fixed by
-- a more general copy-removal pass. See
-- Futhark.Optimise.EntryPointMem for a very specialised version of
-- the idea, but which could perhaps be generalised.
simplifyMatch ::
  Mem rep inner =>
  [Case (Body rep)] ->
  Body rep ->
  [BranchTypeMem] ->
  ( [Case (Body rep)],
    Body rep,
    [BranchTypeMem]
  )
simplifyMatch :: forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body rep)]
cases Body rep
defbody [BranchTypeMem]
ts =
  let case_reses :: [Result]
case_reses = forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
      defbody_res :: Result
defbody_res = forall rep. Body rep -> Result
bodyResult Body rep
defbody
      ([(Int, SubExp)]
ctx_fixes, [(Result, SubExpRes, BranchTypeMem)]
variant) =
        forall a b. [Either a b] -> ([a], [b])
partitionEithers forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant forall a b. (a -> b) -> a -> b
$
          forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [Int
0 ..] (forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res [BranchTypeMem]
ts
      ([Result]
cases_reses, Result
defbody_reses, [BranchTypeMem]
ts') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Result, SubExpRes, BranchTypeMem)]
variant
   in ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {f :: * -> *} {rep}.
Functor f =>
f (Body rep) -> Result -> f (Body rep)
onCase [Case (Body rep)]
cases (forall a. [[a]] -> [[a]]
transpose [Result]
cases_reses),
        forall {rep}. Body rep -> Result -> Body rep
onBody Body rep
defbody Result
defbody_reses,
        forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchTypeMem]
ts' [(Int, SubExp)]
ctx_fixes
      )
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall rep. Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body rep
defbody

    onCase :: f (Body rep) -> Result -> f (Body rep)
onCase f (Body rep)
c Result
res = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {rep}. Body rep -> Result -> Body rep
`onBody` Result
res) f (Body rep)
c
    onBody :: Body rep -> Result -> Body rep
onBody Body rep
body Result
res = Body rep
body {bodyResult :: Result
bodyResult = Result
res}

    branchInvariant :: (Int, Result, SubExpRes, BranchTypeMem)
-> Either (Int, SubExp) (Result, SubExpRes, BranchTypeMem)
branchInvariant (Int
i, Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- If even one branch has a variant result, then we give up.
      | Names -> Names -> Bool
namesIntersect Names
bound_in_branches forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ SubExpRes
defres forall a. a -> [a] -> [a]
: Result
case_reses =
          forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)
      -- Do all branches return the same value?
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses =
          forall a b. a -> Either a b
Left (Int
i, SubExpRes -> SubExp
resSubExp SubExpRes
defres)
      | Bool
otherwise =
          forall a b. b -> Either a b
Right (Result
case_reses, SubExpRes
defres, BranchTypeMem
t)

allocInExp ::
  (Allocable fromrep torep inner) =>
  Exp fromrep ->
  AllocM fromrep torep (Exp torep)
allocInExp :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (DoLoop [(FParam fromrep, SubExp)]
merge LoopForm fromrep
form (Body () Stms fromrep
bodystms Result
bodyres)) =
  forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge forall a b. (a -> b) -> a -> b
$ \[(FParam torep, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val -> do
    LoopForm torep
form' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm LoopForm fromrep
form
    forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm torep
form') forall a b. (a -> b) -> a -> b
$ do
      Body torep
body' <-
        forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall fromrep torep (inner :: * -> *) a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodystms forall a b. (a -> b) -> a -> b
$ do
          ([SubExp]
valctx, [SubExp]
valres') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bodyres
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
valctx forall a. Semigroup a => a -> a -> a
<> forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
bodyres) [SubExp]
valres'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam torep, SubExp)]
merge' LoopForm torep
form' Body torep
body'
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [RetType fromrep]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  Space
space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  -- We assume that every array is going to be in its own memory.
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
fname [(SubExp, Diet)]
args' (Space -> [RetTypeMem]
mems Space
space forall a. [a] -> [a] -> [a]
++ Space -> Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Space
space Int
num_arrays [RetType fromrep]
rettype) (Safety, SrcLoc, [SrcLoc])
loc
  where
    mems :: Space -> [RetTypeMem]
mems Space
space = forall a. Int -> a -> [a]
replicate Int
num_arrays (forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)
    num_arrays :: Int
num_arrays = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> a -> Bool
> Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf) [RetType fromrep]
rettype
allocInExp (Match [SubExp]
ses [Case (Body fromrep)]
cases Body fromrep
defbody (MatchDec [BranchType fromrep]
rets MatchSort
ifsort)) = do
  (Body torep
defbody', [MemReqType]
def_reqs) <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [BranchType fromrep]
rets Body fromrep
defbody
  ([Case (Body torep)]
cases', [[MemReqType]]
cases_reqs) <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase [Case (Body fromrep)]
cases
  let reqs :: [MemReqType]
reqs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl MemReqType -> MemReqType -> MemReqType
combMemReqTypes) [MemReqType]
def_reqs (forall a. [[a]] -> [[a]]
transpose [[MemReqType]]
cases_reqs)
  Body torep
defbody'' <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs Body torep
defbody'
  [Case (Body torep)]
cases'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[MemReqType] -> Body torep -> AllocM fromrep torep (Body torep)
addCtxToMatchBody [MemReqType]
reqs) [Case (Body torep)]
cases'
  let ([Case (Body torep)]
cases''', Body torep
defbody''', [BranchTypeMem]
rets') =
        forall rep (inner :: * -> *).
Mem rep inner =>
[Case (Body rep)]
-> Body rep
-> [BranchTypeMem]
-> ([Case (Body rep)], Body rep, [BranchTypeMem])
simplifyMatch [Case (Body torep)]
cases'' Body torep
defbody'' forall a b. (a -> b) -> a -> b
$ [MemReqType] -> [BranchTypeMem]
mkBranchRet [MemReqType]
reqs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body torep)]
cases''' Body torep
defbody''' forall a b. (a -> b) -> a -> b
$ forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchTypeMem]
rets' MatchSort
ifsort
  where
    onCase :: Case (Body fromrep)
-> AllocM fromrep torep (Case (Body torep), [MemReqType])
onCase (Case [Maybe PrimValue]
vs Body fromrep
body) = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[ExtType]
-> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType])
allocInMatchBody [BranchType fromrep]
rets Body fromrep
body
allocInExp (WithAcc [WithAccInput fromrep]
inputs Lambda fromrep
bodylam) =
  forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {rep} {fromrep} {inner :: * -> *} {t :: * -> *} {a} {b}.
(BodyDec rep ~ (), BodyDec fromrep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 LetDec rep ~ LParamMem, ExpDec rep ~ (), OpC rep ~ MemOp inner,
 Traversable t, ArrayShape a, HasLetDecMem (LetDec rep),
 BuilderOps rep, OpReturns (inner rep), RephraseOp inner,
 SizeSubst (inner rep), PrettyRep fromrep) =>
(a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [WithAccInput fromrep]
inputs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {torep} {fromrep} {inner :: * -> *}.
(LetDec torep ~ LParamMem, RetType fromrep ~ DeclExtType,
 RetType torep ~ RetTypeMem, ExpDec torep ~ (),
 BodyDec fromrep ~ (), BodyDec torep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo torep ~ FParamMem,
 LParamInfo fromrep ~ Type, LParamInfo torep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType torep ~ BranchTypeMem,
 OpC torep ~ MemOp inner, PrettyRep fromrep,
 HasLetDecMem (LetDec torep), OpReturns (inner torep),
 RephraseOp inner, SizeSubst (inner torep), BuilderOps torep) =>
Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
  where
    onLambda :: Lambda fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
lam = do
      [Param LParamMem]
params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam) forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString (forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv Type
t)
      forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [Param LParamMem]
params (forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)

    onInput :: (a, [VName], t (Lambda fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (Lambda fromrep, b)
op) =
      (a
shape,[VName]
arrs,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {rep} {fromrep} {inner :: * -> *} {a} {b}.
(BodyDec rep ~ (), BodyDec fromrep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LParamInfo fromrep ~ Type, LParamInfo rep ~ LParamMem,
 BranchType fromrep ~ ExtType, BranchType rep ~ BranchTypeMem,
 ExpDec rep ~ (), LetDec rep ~ LParamMem, OpC rep ~ MemOp inner,
 ArrayShape a, HasLetDecMem (LetDec rep), BuilderOps rep,
 PrettyRep fromrep, OpReturns (inner rep), RephraseOp inner,
 SizeSubst (inner rep)) =>
a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (Lambda fromrep, b)
op

    onOp :: a
-> [VName]
-> (Lambda fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (Lambda fromrep
lam, b
nes) = do
      let num_vs :: Int
num_vs = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda fromrep
lam)
          num_is :: Int
num_is = forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda fromrep
lam
          i_params' :: [Param LParamMem]
i_params' = forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
attrs VName
v Type
_) -> forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
v forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
      [Param LParamMem]
x_params' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall {rep} {inner :: * -> *} {f :: * -> *} {u}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem,
 RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
 OpC rep ~ MemOp inner, Monad f, HasLetDecMem (LetDec rep),
 ASTRep rep, OpReturns (inner rep), RephraseOp inner,
 HasScope rep f, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LParamMem]
y_params' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall {rep} {fromrep} {inner :: * -> *} {u}.
(ExpDec rep ~ (), BodyDec fromrep ~ (), BodyDec rep ~ (),
 FParamInfo fromrep ~ DeclType, FParamInfo rep ~ FParamMem,
 RetType fromrep ~ DeclExtType, RetType rep ~ RetTypeMem,
 LetDec rep ~ LParamMem, LParamInfo fromrep ~ Type,
 LParamInfo rep ~ LParamMem, BranchType fromrep ~ ExtType,
 BranchType rep ~ BranchTypeMem, OpC rep ~ MemOp inner,
 PrettyRep fromrep, HasLetDecMem (LetDec rep),
 OpReturns (inner rep), RephraseOp inner, SizeSubst (inner rep),
 BuilderOps rep, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda rep
lam' <-
        forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
          ([Param LParamMem]
i_params' forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
          (forall rep. Lambda rep -> Body rep
lambdaBody Lambda fromrep
lam)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam', b
nes)

    mkP :: Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is =
      forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$
          forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
            [DimIndex SubExp]
is forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
arr
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    onYParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      Space
space <- forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      VName
mem <- forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
arr_t Space
space
      let base_dims :: [TPrimExp Int64 VName]
base_dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {u}.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Param (TypeBase Shape u)
p
allocInExp Exp fromrep
e = forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM forall {fromrep} {trep}. Mapper fromrep trep (AllocM fromrep trep)
alloc Exp fromrep
e
  where
    alloc :: Mapper fromrep trep (AllocM fromrep trep)
alloc =
      forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope trep -> Body fromrep -> AllocM fromrep trep (Body trep)
mapOnBody = forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
          mapOnRetType :: RetType fromrep -> AllocM fromrep trep (RetType trep)
mapOnRetType = forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
          mapOnBranchType :: BranchType fromrep -> AllocM fromrep trep (BranchType trep)
mapOnBranchType = forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
          mapOnFParam :: FParam fromrep -> AllocM fromrep trep (FParam trep)
mapOnFParam = forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
          mapOnLParam :: LParam fromrep -> AllocM fromrep trep (LParam trep)
mapOnLParam = forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
          mapOnOp :: Op fromrep -> AllocM fromrep trep (Op trep)
mapOnOp = \Op fromrep
op -> do
            Op fromrep -> AllocM fromrep trep (Op trep)
handle <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
            Op fromrep -> AllocM fromrep trep (Op trep)
handle Op fromrep
op
        }

allocInLoopForm ::
  (Allocable fromrep torep inner) =>
  LoopForm fromrep ->
  AllocM fromrep torep (LoopForm torep)
allocInLoopForm :: forall fromrep torep (inner :: * -> *).
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm (WhileLoop VName
v) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. VName -> LoopForm rep
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromrep, VName)]
loopvars) =
  forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
n forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar [(LParam fromrep, VName)]
loopvars
  where
    allocInLoopVar :: (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar (Param Type
p, VName
a) = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
a
      case forall dec. Typed dec => Param dec -> Type
paramType Param Type
p of
        Array PrimType
pt Shape
shape NoUniqueness
u -> do
          [TPrimExp Int64 VName]
dims <- forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
          let ixfun' :: IxFun (TPrimExp Int64 VName)
ixfun' = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun forall a b. (a -> b) -> a -> b
$ forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
dims [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i]
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun'}, VName
a)
        Prim PrimType
bt ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}, VName
a)
        Mem Space
space ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}, VName
a)
        Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
p {paramDec :: LParamMem
paramDec = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}, VName
a)

class SizeSubst op where
  opIsConst :: op -> Bool
  opIsConst = forall a b. a -> b -> a
const Bool
False

instance SizeSubst (NoOp rep)

instance SizeSubst (op rep) => SizeSubst (MemOp op rep) where
  opIsConst :: MemOp op rep -> Bool
opIsConst (Inner op rep
op) = forall op. SizeSubst op => op -> Bool
opIsConst op rep
op
  opIsConst MemOp op rep
_ = Bool
False

stmConsts :: SizeSubst (Op rep) => Stm rep -> S.Set VName
stmConsts :: forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
  | forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
stmConsts Stm rep
_ = forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( LetDec (Rep m) ~ LetDecMem,
    Mem (Rep m) inner,
    MonadBuilder m,
    ExpDec (Rep m) ~ ()
  ) =>
  Space ->
  ExpDec (Rep m) ->
  [VName] ->
  Exp (Rep m) ->
  m (Stm (Rep m))
mkLetNamesB' :: forall (m :: * -> *) (inner :: * -> *).
(LetDec (Rep m) ~ LParamMem, Mem (Rep m) inner, MonadBuilder m,
 ExpDec (Rep m) ~ ()) =>
Space
-> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' Space
space ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
  Pat LParamMem
pat <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Rep m)
e [ExpHint]
nohints
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat LParamMem
pat (forall dec. dec -> StmAux dec
defAux ExpDec (Rep m)
dec) Exp (Rep m)
e
  where
    nohints :: [ExpHint]
nohints = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

mkLetNamesB'' ::
  ( Mem rep inner,
    LetDec rep ~ LetDecMem,
    OpReturns (inner (Engine.Wise rep)),
    ExpDec rep ~ (),
    Rep m ~ Engine.Wise rep,
    HasScope (Engine.Wise rep) m,
    MonadBuilder m,
    AliasedOp (inner (Engine.Wise rep)),
    RephraseOp (MemOp inner),
    Engine.CanBeWise inner
  ) =>
  Space ->
  [VName] ->
  Exp (Engine.Wise rep) ->
  m (Stm (Engine.Wise rep))
mkLetNamesB'' :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, LetDec rep ~ LParamMem,
 OpReturns (inner (Wise rep)), ExpDec rep ~ (), Rep m ~ Wise rep,
 HasScope (Wise rep) m, MonadBuilder m,
 AliasedOp (inner (Wise rep)), RephraseOp (MemOp inner),
 CanBeWise inner) =>
Space -> [VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' Space
space [VName]
names Exp (Wise rep)
e = do
  Pat LParamMem
pat <- forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Mem (Rep m) inner) =>
Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LParamMem)
patWithAllocations Space
space [VName]
names Exp (Wise rep)
e [ExpHint]
nohints
  let pat' :: Pat (LetDec (Wise rep))
pat' = forall rep.
Informing rep =>
Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep))
Engine.addWisdomToPat Pat LParamMem
pat Exp (Wise rep)
e
      dec :: ExpDec (Wise rep)
dec = forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (LetDec (Wise rep))
pat' () Exp (Wise rep)
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec (Wise rep))
pat' (forall dec. dec -> StmAux dec
defAux ExpDec (Wise rep)
dec) Exp (Wise rep)
e
  where
    nohints :: [ExpHint]
nohints = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

simplifyMemOp ::
  Engine.SimplifiableRep rep =>
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  MemOp inner (Engine.Wise rep) ->
  Engine.SimpleM rep (MemOp inner (Engine.Wise rep), Stms (Engine.Wise rep))
simplifyMemOp :: forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
_ (Alloc SubExp
size Space
space) =
  (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner (Inner inner (Wise rep)
k) = do
  (inner (Wise rep)
k', Stms (Wise rep)
hoisted) <- inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
onInner inner (Wise rep)
k
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner inner (Wise rep)
k', Stms (Wise rep)
hoisted)

simplifiable ::
  ( Engine.SimplifiableRep rep,
    LetDec rep ~ LetDecMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    Mem (Engine.Wise rep) inner,
    Engine.CanBeWise inner,
    RephraseOp inner,
    IsOp (inner rep),
    OpReturns (inner (Engine.Wise rep)),
    AliasedOp (inner (Engine.Wise rep)),
    IndexOp (inner (Engine.Wise rep))
  ) =>
  (inner (Engine.Wise rep) -> UT.UsageTable) ->
  ( inner (Engine.Wise rep) ->
    Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep))
  ) ->
  SimpleOps rep
simplifiable :: forall rep (inner :: * -> *).
(SimplifiableRep rep, LetDec rep ~ LParamMem, ExpDec rep ~ (),
 BodyDec rep ~ (), Mem (Wise rep) inner, CanBeWise inner,
 RephraseOp inner, IsOp (inner rep), OpReturns (inner (Wise rep)),
 AliasedOp (inner (Wise rep)), IndexOp (inner (Wise rep))) =>
(inner (Wise rep) -> UsageTable)
-> (inner (Wise rep)
    -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> SimpleOps rep
simplifiable inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp =
  forall rep.
(SymbolTable (Wise rep)
 -> Pat (LetDec (Wise rep))
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> (Pat (LetDec rep)
    -> Exp (Wise rep) -> SimpleM rep (Pat (LetDec rep)))
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps forall {rep} {f :: * -> *} {p}.
(ExpDec rep ~ (), Applicative f, ASTRep rep,
 AliasedOp (OpC rep (Wise rep)), CanBeWise (OpC rep)) =>
p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' forall {rep} {f :: * -> *} {p}.
(BodyDec rep ~ (), Applicative f, ASTRep rep,
 AliasedOp (OpC rep (Wise rep)), CanBeWise (OpC rep)) =>
p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' forall {m :: * -> *} {d} {u} {ret} {inner :: * -> *}
       {inner :: * -> *} {rep}.
(BranchType (Rep m) ~ MemInfo d u ret, OpC (Rep m) ~ MemOp inner,
 MonadBuilder m) =>
SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp MemOp inner (Wise rep) -> UsageTable
opUsage forall {rep} {inner :: * -> *} {d} {a}.
(BranchType rep ~ BranchTypeMem, LParamInfo rep ~ LParamMem,
 FParamInfo rep ~ FParamMem, RetType rep ~ RetTypeMem,
 OpC rep ~ MemOp inner, BuilderOps (Wise rep),
 TraverseOpStms (Wise rep), ASTRep rep, RephraseOp inner,
 CanBeWise (OpC rep), Simplifiable d, Simplifiable (LetDec rep),
 Simplifiable (FParamInfo rep), Simplifiable (BranchType rep),
 Simplifiable (LParamInfo rep), Simplifiable (RetType rep),
 HasLetDecMem (LetDec rep), OpReturns (inner (Wise rep)),
 AliasedOp (OpC rep (Wise rep)), IndexOp (OpC rep (Wise rep))) =>
Pat (MemInfo d a MemBind)
-> Exp (Wise rep) -> SimpleM rep (Pat (MemInfo d a MemBind))
simplifyPat (forall rep (inner :: * -> *).
SimplifiableRep rep =>
(inner (Wise rep)
 -> SimpleM rep (inner (Wise rep), Stms (Wise rep)))
-> MemOp inner (Wise rep)
-> SimpleM rep (MemOp inner (Wise rep), Stms (Wise rep))
simplifyMemOp inner (Wise rep) -> SimpleM rep (inner (Wise rep), Stms (Wise rep))
simplifyInnerOp)
  where
    mkExpDecS' :: p
-> Pat (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> f (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ Pat (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Informing rep =>
Pat (LetDec (Wise rep))
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (VarWisdom, LetDec rep)
pat () Exp (Wise rep)
e

    mkBodyS' :: p -> Stms (Wise rep) -> Result -> f (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
stms Result
res = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Informing rep =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
stms Result
res

    protectOp :: SubExp -> Pat (LetDec (Rep m)) -> MemOp inner rep -> Maybe (m ())
protectOp SubExp
taken Pat (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
      Body (Rep m)
tbody <- forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
size]
      Body (Rep m)
fbody <- forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" forall a b. (a -> b) -> a -> b
$
          forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
taken] [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
tbody] Body (Rep m)
fbody forall a b. (a -> b) -> a -> b
$
            forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] MatchSort
MatchFallback
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep m))
pat forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size' Space
space
    protectOp SubExp
_ Pat (LetDec (Rep m))
_ MemOp inner rep
_ = forall a. Maybe a
Nothing

    opUsage :: MemOp inner (Wise rep) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      forall a. Monoid a => a
mempty
    opUsage (Inner inner (Wise rep)
inner) =
      inner (Wise rep) -> UsageTable
innerUsage inner (Wise rep)
inner

    simplifyPat :: Pat (MemInfo d a MemBind)
-> Exp (Wise rep) -> SimpleM rep (Pat (MemInfo d a MemBind))
simplifyPat (Pat [PatElem (MemInfo d a MemBind)]
pes) Exp (Wise rep)
e = do
      [ExpReturns]
rets <- forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp (Wise rep)
e
      forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (MemInfo d a MemBind)
-> ExpReturns -> SimpleM rep (PatElem (MemInfo d a MemBind))
update [PatElem (MemInfo d a MemBind)]
pes [ExpReturns]
rets
      where
        names :: [VName]
names = forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (MemInfo d a MemBind)]
pes
        update :: PatElem (MemInfo d a MemBind)
-> ExpReturns -> SimpleM rep (PatElem (MemInfo d a MemBind))
update
          (PatElem VName
pe_v (MemArray PrimType
pt ShapeBase d
shape a
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_)))
          (MemArray PrimType
_ ShapeBase (Ext SubExp)
_ NoUniqueness
_ (Just (ReturnsInBlock VName
_ ExtIxFun
ixfun)))
            | Just IxFun (TPrimExp Int64 VName)
ixfun' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
inst) ExtIxFun
ixfun =
                forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_v
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt
                          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase d
shape
                          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
u
                          forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
mem forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure IxFun (TPrimExp Int64 VName)
ixfun')
                      )
            where
              inst :: Ext VName -> Maybe VName
inst (Ext Int
i) = forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [VName]
names
              inst (Free VName
v) = forall a. a -> Maybe a
Just VName
v
        update PatElem (MemInfo d a MemBind)
pe ExpReturns
_ = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify PatElem (MemInfo d a MemBind)
pe

data ExpHint
  = NoHint
  | Hint IxFun Space

defaultExpHints :: (ASTRep rep, HasScope rep m) => Exp rep -> m [ExpHint]
defaultExpHints :: forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const ExpHint
NoHint) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e