{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    Allocable,
    Allocator (..),
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    arraySizeInBytesExp,
    mkLetNamesB',
    mkLetNamesB'',

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (foldl', partition, sort, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (splitAt3, splitFromEnd, takeLast)

data AllocStm
  = SizeComputation VName (PrimExp VName)
  | Allocation VName SubExp Space
  | ArrayCopy VName VName
  deriving (AllocStm -> AllocStm -> Bool
(AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool) -> Eq AllocStm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AllocStm -> AllocStm -> Bool
$c/= :: AllocStm -> AllocStm -> Bool
== :: AllocStm -> AllocStm -> Bool
$c== :: AllocStm -> AllocStm -> Bool
Eq, Eq AllocStm
Eq AllocStm
-> (AllocStm -> AllocStm -> Ordering)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> AllocStm)
-> (AllocStm -> AllocStm -> AllocStm)
-> Ord AllocStm
AllocStm -> AllocStm -> Bool
AllocStm -> AllocStm -> Ordering
AllocStm -> AllocStm -> AllocStm
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: AllocStm -> AllocStm -> AllocStm
$cmin :: AllocStm -> AllocStm -> AllocStm
max :: AllocStm -> AllocStm -> AllocStm
$cmax :: AllocStm -> AllocStm -> AllocStm
>= :: AllocStm -> AllocStm -> Bool
$c>= :: AllocStm -> AllocStm -> Bool
> :: AllocStm -> AllocStm -> Bool
$c> :: AllocStm -> AllocStm -> Bool
<= :: AllocStm -> AllocStm -> Bool
$c<= :: AllocStm -> AllocStm -> Bool
< :: AllocStm -> AllocStm -> Bool
$c< :: AllocStm -> AllocStm -> Bool
compare :: AllocStm -> AllocStm -> Ordering
$ccompare :: AllocStm -> AllocStm -> Ordering
Ord, Int -> AllocStm -> ShowS
[AllocStm] -> ShowS
AllocStm -> String
(Int -> AllocStm -> ShowS)
-> (AllocStm -> String) -> ([AllocStm] -> ShowS) -> Show AllocStm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AllocStm] -> ShowS
$cshowList :: [AllocStm] -> ShowS
show :: AllocStm -> String
$cshow :: AllocStm -> String
showsPrec :: Int -> AllocStm -> ShowS
$cshowsPrec :: Int -> AllocStm -> ShowS
Show)

bindAllocStm ::
  (MonadBinder m, Op (Rep m) ~ MemOp inner) =>
  AllocStm ->
  m ()
bindAllocStm :: forall (m :: * -> *) inner.
(MonadBinder m, Op (Rep m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm (SizeComputation VName
name PrimExp VName
pe) =
  [VName] -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (ExpT (Rep m) -> m ()) -> m (ExpT (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (ExpT (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int64 PrimExp VName
pe)
bindAllocStm (Allocation VName
name SubExp
size Space
space) =
  [VName] -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> ExpT (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> ExpT (Rep m)) -> Op (Rep m) -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
bindAllocStm (ArrayCopy VName
name VName
src) =
  [VName] -> ExpT (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (ExpT (Rep m) -> m ()) -> ExpT (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT (Rep m)) -> BasicOp -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
src

class
  (MonadFreshNames m, LocalScope rep m, Mem rep) =>
  Allocator rep m
  where
  addAllocStm :: AllocStm -> m ()
  askDefaultSpace :: m Space

  default addAllocStm ::
    ( Allocable fromrep rep,
      m ~ AllocM fromrep rep
    ) =>
    AllocStm ->
    m ()
  addAllocStm (SizeComputation VName
name PrimExp VName
se) =
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (ExpT rep -> m ()) -> m (ExpT rep) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int64 PrimExp VName
se)
  addAllocStm (Allocation VName
name SubExp
size Space
space) =
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op rep -> ExpT rep
forall rep. Op rep -> ExpT rep
Op (Op rep -> ExpT rep) -> Op rep -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> Op rep
forall op. AllocOp op => SubExp -> Space -> op
allocOp SubExp
size Space
space
  addAllocStm (ArrayCopy VName
name VName
src) =
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
src

  -- | The subexpression giving the number of elements we should
  -- allocate space for.  See 'ChunkMap' comment.
  dimAllocationSize :: SubExp -> m SubExp
  default dimAllocationSize ::
    m ~ AllocM fromrep rep =>
    SubExp ->
    m SubExp
  dimAllocationSize (Var VName
v) =
    -- It is important to recurse here, as the substitution may itself
    -- be a chunk size.
    m SubExp -> (SubExp -> m SubExp) -> Maybe SubExp -> m SubExp
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v) SubExp -> m SubExp
forall rep (m :: * -> *). Allocator rep m => SubExp -> m SubExp
dimAllocationSize (Maybe SubExp -> m SubExp) -> m (Maybe SubExp) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (AllocEnv fromrep rep -> Maybe SubExp) -> m (Maybe SubExp)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName SubExp -> Maybe SubExp)
-> (AllocEnv fromrep rep -> Map VName SubExp)
-> AllocEnv fromrep rep
-> Maybe SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocEnv fromrep rep -> Map VName SubExp
forall fromrep torep. AllocEnv fromrep torep -> Map VName SubExp
chunkMap)
  dimAllocationSize SubExp
size =
    SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
size

  -- | Get those names that are known to be constants at run-time.
  askConsts :: m (S.Set VName)

  expHints :: Exp rep -> m [ExpHint]
  expHints = ExpT rep -> m [ExpHint]
forall (m :: * -> *) rep.
(Monad m, ASTRep rep) =>
Exp rep -> m [ExpHint]
defaultExpHints

allocateMemory ::
  Allocator rep m =>
  String ->
  SubExp ->
  Space ->
  m VName
allocateMemory :: forall rep (m :: * -> *).
Allocator rep m =>
String -> SubExp -> Space -> m VName
allocateMemory String
desc SubExp
size Space
space = do
  VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  AllocStm -> m ()
forall rep (m :: * -> *). Allocator rep m => AllocStm -> m ()
addAllocStm (AllocStm -> m ()) -> AllocStm -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp -> Space -> AllocStm
Allocation VName
v SubExp
size Space
space
  VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v

computeSize ::
  Allocator rep m =>
  String ->
  PrimExp VName ->
  m SubExp
computeSize :: forall rep (m :: * -> *).
Allocator rep m =>
String -> PrimExp VName -> m SubExp
computeSize String
desc PrimExp VName
se = do
  VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  AllocStm -> m ()
forall rep (m :: * -> *). Allocator rep m => AllocStm -> m ()
addAllocStm (AllocStm -> m ()) -> AllocStm -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> AllocStm
SizeComputation VName
v PrimExp VName
se
  SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

type Allocable fromrep torep =
  ( PrettyRep fromrep,
    PrettyRep torep,
    Mem torep,
    FParamInfo fromrep ~ DeclType,
    LParamInfo fromrep ~ Type,
    BranchType fromrep ~ ExtType,
    RetType fromrep ~ DeclExtType,
    BodyDec fromrep ~ (),
    BodyDec torep ~ (),
    ExpDec torep ~ (),
    SizeSubst (Op torep),
    BinderOps torep
  )

-- | A mapping from chunk names to their maximum size.  XXX FIXME
-- HACK: This is part of a hack to add loop-invariant allocations to
-- reduce kernels, because memory expansion does not use range
-- analysis yet (it should).
type ChunkMap = M.Map VName SubExp

data AllocEnv fromrep torep = AllocEnv
  { forall fromrep torep. AllocEnv fromrep torep -> Map VName SubExp
chunkMap :: ChunkMap,
    -- | Aggressively try to reuse memory in do-loops -
    -- should be True inside kernels, False outside.
    forall fromrep torep. AllocEnv fromrep torep -> Bool
aggressiveReuse :: Bool,
    -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in local memory.
    forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
    forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
    forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromrep torep a
  = AllocM (BinderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
  deriving
    ( Functor (AllocM fromrep torep)
Functor (AllocM fromrep torep)
-> (forall a. a -> AllocM fromrep torep a)
-> (forall a b.
    AllocM fromrep torep (a -> b)
    -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b c.
    (a -> b -> c)
    -> AllocM fromrep torep a
    -> AllocM fromrep torep b
    -> AllocM fromrep torep c)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Applicative (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall {fromrep} {torep}. Functor (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
$c<* :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
*> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c*> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
liftA2 :: forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$cliftA2 :: forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
<*> :: forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$c<*> :: forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
pure :: forall a. a -> AllocM fromrep torep a
$cpure :: forall fromrep torep a. a -> AllocM fromrep torep a
Applicative,
      (forall a b.
 (a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b.
    a -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Functor (AllocM fromrep torep)
forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
$c<$ :: forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
fmap :: forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$cfmap :: forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
Functor,
      Applicative (AllocM fromrep torep)
Applicative (AllocM fromrep torep)
-> (forall a b.
    AllocM fromrep torep a
    -> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a. a -> AllocM fromrep torep a)
-> Monad (AllocM fromrep torep)
forall a. a -> AllocM fromrep torep a
forall fromrep torep. Applicative (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> AllocM fromrep torep a
$creturn :: forall fromrep torep a. a -> AllocM fromrep torep a
>> :: forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c>> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>>= :: forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>>= :: forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
Monad,
      Monad (AllocM fromrep torep)
Applicative (AllocM fromrep torep)
AllocM fromrep torep VNameSource
Applicative (AllocM fromrep torep)
-> Monad (AllocM fromrep torep)
-> AllocM fromrep torep VNameSource
-> (VNameSource -> AllocM fromrep torep ())
-> MonadFreshNames (AllocM fromrep torep)
VNameSource -> AllocM fromrep torep ()
forall fromrep torep. Monad (AllocM fromrep torep)
forall fromrep torep. Applicative (AllocM fromrep torep)
forall fromrep torep. AllocM fromrep torep VNameSource
forall fromrep torep. VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromrep torep ()
$cputNameSource :: forall fromrep torep. VNameSource -> AllocM fromrep torep ()
getNameSource :: AllocM fromrep torep VNameSource
$cgetNameSource :: forall fromrep torep. AllocM fromrep torep VNameSource
MonadFreshNames,
      HasScope torep,
      LocalScope torep,
      MonadReader (AllocEnv fromrep torep)
    )

instance
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  MonadBinder (AllocM fromrep torep)
  where
  type Rep (AllocM fromrep torep) = torep

  mkExpDecM :: Pattern (Rep (AllocM fromrep torep))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pattern (Rep (AllocM fromrep torep))
_ Exp (Rep (AllocM fromrep torep))
_ = () -> AllocM fromrep torep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
    PatternT (LetDec torep)
pat <- [VName]
-> Exp torep -> AllocM fromrep torep (PatternT (LetDec torep))
forall rep (m :: * -> *).
(Allocator rep m, ExpDec rep ~ ()) =>
[VName] -> Exp rep -> m (Pattern rep)
patternWithAllocations [VName]
names Exp torep
Exp (Rep (AllocM fromrep torep))
e
    Stm torep -> AllocM fromrep torep (Stm torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec torep)
-> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatternT (LetDec torep)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp torep
Exp (Rep (AllocM fromrep torep))
e

  mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
bnds Result
res = BodyT torep -> AllocM fromrep torep (BodyT torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT torep -> AllocM fromrep torep (BodyT torep))
-> BodyT torep -> AllocM fromrep torep (BodyT torep)
forall a b. (a -> b) -> a -> b
$ BodyDec torep -> Stms torep -> Result -> BodyT torep
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () Stms torep
Stms (Rep (AllocM fromrep torep))
bnds Result
res

  addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
-> AllocM fromrep torep ()
forall fromrep torep a.
BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BinderT
   torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
 -> AllocM fromrep torep ())
-> (Stms torep
    -> BinderT
         torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ())
-> Stms torep
-> AllocM fromrep torep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> BinderT
     torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a.
AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = BinderT
  torep
  (ReaderT (AllocEnv fromrep torep) (State VNameSource))
  (a, Stms torep)
-> AllocM fromrep torep (a, Stms torep)
forall fromrep torep a.
BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BinderT
   torep
   (ReaderT (AllocEnv fromrep torep) (State VNameSource))
   (a, Stms torep)
 -> AllocM fromrep torep (a, Stms torep))
-> BinderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a, Stms torep)
-> AllocM fromrep torep (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> BinderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a,
      Stms
        (Rep
           (BinderT
              torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m

instance
  (Allocable fromrep torep) =>
  Allocator torep (AllocM fromrep torep)
  where
  expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
    Exp torep -> AllocM fromrep torep [ExpHint]
f <- (AllocEnv fromrep torep
 -> Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM
     fromrep torep (Exp torep -> AllocM fromrep torep [ExpHint])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
    Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e
  askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace = (AllocEnv fromrep torep -> Space) -> AllocM fromrep torep Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> Space
forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace

  askConsts :: AllocM fromrep torep (Set VName)
askConsts = (AllocEnv fromrep torep -> Set VName)
-> AllocM fromrep torep (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> Set VName
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts

runAllocM ::
  MonadFreshNames m =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  AllocM fromrep torep a ->
  m a
runAllocM :: forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
  ((a, Stms torep) -> a) -> m (a, Stms torep) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms torep) -> a
forall a b. (a, b) -> a
fst (m (a, Stms torep) -> m a) -> m (a, Stms torep) -> m a
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms torep), VNameSource))
 -> m (a, Stms torep))
-> (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms torep)
-> VNameSource -> ((a, Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms torep)
 -> VNameSource -> ((a, Stms torep), VNameSource))
-> State VNameSource (a, Stms torep)
-> VNameSource
-> ((a, Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
-> AllocEnv fromrep torep -> State VNameSource (a, Stms torep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> Scope torep
-> ReaderT
     (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BinderT rep m a -> Scope rep -> m (a, Stms rep)
runBinderT BinderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m Scope torep
forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
  where
    env :: AllocEnv fromrep torep
env =
      AllocEnv :: forall fromrep torep.
Map VName SubExp
-> Bool
-> Space
-> Set VName
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocEnv fromrep torep
AllocEnv
        { chunkMap :: Map VName SubExp
chunkMap = Map VName SubExp
forall a. Monoid a => a
mempty,
          aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
          allocSpace :: Space
allocSpace = Space
DefaultSpace,
          envConsts :: Set VName
envConsts = Set VName
forall a. Monoid a => a
mempty,
          allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
          envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
        }

-- | Monad for adding allocations to a single pattern.
newtype PatAllocM rep a
  = PatAllocM
      ( RWS
          (Scope rep)
          [AllocStm]
          VNameSource
          a
      )
  deriving
    ( Functor (PatAllocM rep)
Functor (PatAllocM rep)
-> (forall a. a -> PatAllocM rep a)
-> (forall a b.
    PatAllocM rep (a -> b) -> PatAllocM rep a -> PatAllocM rep b)
-> (forall a b c.
    (a -> b -> c)
    -> PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep c)
-> (forall a b.
    PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b)
-> (forall a b.
    PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep a)
-> Applicative (PatAllocM rep)
forall {rep}. Functor (PatAllocM rep)
forall a. a -> PatAllocM rep a
forall rep a. a -> PatAllocM rep a
forall a b. PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep a
forall a b. PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
forall a b.
PatAllocM rep (a -> b) -> PatAllocM rep a -> PatAllocM rep b
forall rep a b.
PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep a
forall rep a b.
PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
forall rep a b.
PatAllocM rep (a -> b) -> PatAllocM rep a -> PatAllocM rep b
forall a b c.
(a -> b -> c)
-> PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep c
forall rep a b c.
(a -> b -> c)
-> PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep a
$c<* :: forall rep a b.
PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep a
*> :: forall a b. PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
$c*> :: forall rep a b.
PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
liftA2 :: forall a b c.
(a -> b -> c)
-> PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep c
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep c
<*> :: forall a b.
PatAllocM rep (a -> b) -> PatAllocM rep a -> PatAllocM rep b
$c<*> :: forall rep a b.
PatAllocM rep (a -> b) -> PatAllocM rep a -> PatAllocM rep b
pure :: forall a. a -> PatAllocM rep a
$cpure :: forall rep a. a -> PatAllocM rep a
Applicative,
      (forall a b. (a -> b) -> PatAllocM rep a -> PatAllocM rep b)
-> (forall a b. a -> PatAllocM rep b -> PatAllocM rep a)
-> Functor (PatAllocM rep)
forall a b. a -> PatAllocM rep b -> PatAllocM rep a
forall a b. (a -> b) -> PatAllocM rep a -> PatAllocM rep b
forall rep a b. a -> PatAllocM rep b -> PatAllocM rep a
forall rep a b. (a -> b) -> PatAllocM rep a -> PatAllocM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> PatAllocM rep b -> PatAllocM rep a
$c<$ :: forall rep a b. a -> PatAllocM rep b -> PatAllocM rep a
fmap :: forall a b. (a -> b) -> PatAllocM rep a -> PatAllocM rep b
$cfmap :: forall rep a b. (a -> b) -> PatAllocM rep a -> PatAllocM rep b
Functor,
      Applicative (PatAllocM rep)
Applicative (PatAllocM rep)
-> (forall a b.
    PatAllocM rep a -> (a -> PatAllocM rep b) -> PatAllocM rep b)
-> (forall a b.
    PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b)
-> (forall a. a -> PatAllocM rep a)
-> Monad (PatAllocM rep)
forall rep. Applicative (PatAllocM rep)
forall a. a -> PatAllocM rep a
forall rep a. a -> PatAllocM rep a
forall a b. PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
forall a b.
PatAllocM rep a -> (a -> PatAllocM rep b) -> PatAllocM rep b
forall rep a b.
PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
forall rep a b.
PatAllocM rep a -> (a -> PatAllocM rep b) -> PatAllocM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> PatAllocM rep a
$creturn :: forall rep a. a -> PatAllocM rep a
>> :: forall a b. PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
$c>> :: forall rep a b.
PatAllocM rep a -> PatAllocM rep b -> PatAllocM rep b
>>= :: forall a b.
PatAllocM rep a -> (a -> PatAllocM rep b) -> PatAllocM rep b
$c>>= :: forall rep a b.
PatAllocM rep a -> (a -> PatAllocM rep b) -> PatAllocM rep b
Monad,
      HasScope rep,
      LocalScope rep,
      MonadWriter [AllocStm],
      Monad (PatAllocM rep)
Applicative (PatAllocM rep)
PatAllocM rep VNameSource
Applicative (PatAllocM rep)
-> Monad (PatAllocM rep)
-> PatAllocM rep VNameSource
-> (VNameSource -> PatAllocM rep ())
-> MonadFreshNames (PatAllocM rep)
VNameSource -> PatAllocM rep ()
forall rep. Monad (PatAllocM rep)
forall rep. Applicative (PatAllocM rep)
forall rep. PatAllocM rep VNameSource
forall rep. VNameSource -> PatAllocM rep ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> PatAllocM rep ()
$cputNameSource :: forall rep. VNameSource -> PatAllocM rep ()
getNameSource :: PatAllocM rep VNameSource
$cgetNameSource :: forall rep. PatAllocM rep VNameSource
MonadFreshNames
    )

instance Mem rep => Allocator rep (PatAllocM rep) where
  addAllocStm :: AllocStm -> PatAllocM rep ()
addAllocStm = [AllocStm] -> PatAllocM rep ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([AllocStm] -> PatAllocM rep ())
-> (AllocStm -> [AllocStm]) -> AllocStm -> PatAllocM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocStm -> [AllocStm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  dimAllocationSize :: SubExp -> PatAllocM rep SubExp
dimAllocationSize = SubExp -> PatAllocM rep SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
  askDefaultSpace :: PatAllocM rep Space
askDefaultSpace = Space -> PatAllocM rep Space
forall (m :: * -> *) a. Monad m => a -> m a
return Space
DefaultSpace
  askConsts :: PatAllocM rep (Set VName)
askConsts = Set VName -> PatAllocM rep (Set VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Set VName
forall a. Monoid a => a
mempty

runPatAllocM ::
  MonadFreshNames m =>
  PatAllocM rep a ->
  Scope rep ->
  m (a, [AllocStm])
runPatAllocM :: forall (m :: * -> *) rep a.
MonadFreshNames m =>
PatAllocM rep a -> Scope rep -> m (a, [AllocStm])
runPatAllocM (PatAllocM RWS (Scope rep) [AllocStm] VNameSource a
m) Scope rep
mems =
  (VNameSource -> ((a, [AllocStm]), VNameSource))
-> m (a, [AllocStm])
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, [AllocStm]), VNameSource))
 -> m (a, [AllocStm]))
-> (VNameSource -> ((a, [AllocStm]), VNameSource))
-> m (a, [AllocStm])
forall a b. (a -> b) -> a -> b
$ (a, VNameSource, [AllocStm]) -> ((a, [AllocStm]), VNameSource)
forall {a} {b} {b}. (a, b, b) -> ((a, b), b)
frob ((a, VNameSource, [AllocStm]) -> ((a, [AllocStm]), VNameSource))
-> (VNameSource -> (a, VNameSource, [AllocStm]))
-> VNameSource
-> ((a, [AllocStm]), VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RWS (Scope rep) [AllocStm] VNameSource a
-> Scope rep -> VNameSource -> (a, VNameSource, [AllocStm])
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope rep) [AllocStm] VNameSource a
m Scope rep
mems
  where
    frob :: (a, b, b) -> ((a, b), b)
frob (a
a, b
s, b
w) = ((a
a, b
w), b
s)

elemSize :: Num a => Type -> a
elemSize :: forall a. Num a => Type -> a
elemSize = PrimType -> a
forall a. Num a => PrimType -> a
primByteSize (PrimType -> a) -> (Type -> PrimType) -> Type -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) (Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)

arraySizeInBytesExpM :: Allocator rep m => Type -> m (PrimExp VName)
arraySizeInBytesExpM :: forall rep (m :: * -> *).
Allocator rep m =>
Type -> m (PrimExp VName)
arraySizeInBytesExpM Type
t = do
  Result
dims <- (SubExp -> m SubExp) -> Result -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall rep (m :: * -> *). Allocator rep m => SubExp -> m SubExp
dimAllocationSize (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)
  let dim_prod_i64 :: TPrimExp Int64 VName
dim_prod_i64 = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 Result
dims
      elm_size_i64 :: TPrimExp Int64 VName
elm_size_i64 = Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t
  PrimExp VName -> m (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> m (PrimExp VName))
-> PrimExp VName -> m (PrimExp VName)
forall a b. (a -> b) -> a -> b
$
    BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) (PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp VName) -> PrimValue -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0) (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
dim_prod_i64 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size_i64

arraySizeInBytes :: Allocator rep m => Type -> m SubExp
arraySizeInBytes :: forall rep (m :: * -> *). Allocator rep m => Type -> m SubExp
arraySizeInBytes = String -> PrimExp VName -> m SubExp
forall rep (m :: * -> *).
Allocator rep m =>
String -> PrimExp VName -> m SubExp
computeSize String
"bytes" (PrimExp VName -> m SubExp)
-> (Type -> m (PrimExp VName)) -> Type -> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Type -> m (PrimExp VName)
forall rep (m :: * -> *).
Allocator rep m =>
Type -> m (PrimExp VName)
arraySizeInBytesExpM

-- | Allocate memory for a value of the given type.
allocForArray ::
  Allocator rep m =>
  Type ->
  Space ->
  m VName
allocForArray :: forall rep (m :: * -> *).
Allocator rep m =>
Type -> Space -> m VName
allocForArray Type
t Space
space = do
  SubExp
size <- Type -> m SubExp
forall rep (m :: * -> *). Allocator rep m => Type -> m SubExp
arraySizeInBytes Type
t
  String -> SubExp -> Space -> m VName
forall rep (m :: * -> *).
Allocator rep m =>
String -> SubExp -> Space -> m VName
allocateMemory String
"mem" SubExp
size Space
space

allocsForStm ::
  (Allocator rep m, ExpDec rep ~ ()) =>
  [Ident] ->
  [Ident] ->
  Exp rep ->
  m (Stm rep)
allocsForStm :: forall rep (m :: * -> *).
(Allocator rep m, ExpDec rep ~ ()) =>
[Ident] -> [Ident] -> Exp rep -> m (Stm rep)
allocsForStm [Ident]
sizeidents [Ident]
validents Exp rep
e = do
  [ExpReturns]
rts <- Exp rep -> m [ExpReturns]
forall (m :: * -> *) rep.
(Monad m, LocalScope rep m, Mem rep) =>
Exp rep -> m [ExpReturns]
expReturns Exp rep
e
  [ExpHint]
hints <- Exp rep -> m [ExpHint]
forall rep (m :: * -> *).
Allocator rep m =>
ExpT rep -> m [ExpHint]
expHints Exp rep
e
  ([PatElemT LParamMem]
ctxElems, [PatElemT LParamMem]
valElems) <- [Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem rep], [PatElem rep])
forall rep (m :: * -> *).
Allocator rep m =>
[Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem rep], [PatElem rep])
allocsForPattern [Ident]
sizeidents [Ident]
validents [ExpReturns]
rts [ExpHint]
hints
  Stm rep -> m (Stm rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm rep -> m (Stm rep)) -> Stm rep -> m (Stm rep)
forall a b. (a -> b) -> a -> b
$ Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT LParamMem] -> [PatElemT LParamMem] -> PatternT LParamMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT LParamMem]
ctxElems [PatElemT LParamMem]
valElems) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp rep
e

patternWithAllocations ::
  (Allocator rep m, ExpDec rep ~ ()) =>
  [VName] ->
  Exp rep ->
  m (Pattern rep)
patternWithAllocations :: forall rep (m :: * -> *).
(Allocator rep m, ExpDec rep ~ ()) =>
[VName] -> Exp rep -> m (Pattern rep)
patternWithAllocations [VName]
names Exp rep
e = do
  ([Type]
ts', [Ident]
sizes) <- [ExtType] -> m ([Type], [Ident])
forall (m :: * -> *) u.
MonadFreshNames m =>
[TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' ([ExtType] -> m ([Type], [Ident]))
-> m [ExtType] -> m ([Type], [Ident])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp rep -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e
  let identForBindage :: VName -> Type -> f Ident
identForBindage VName
name Type
t =
        Ident -> f Ident
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ident -> f Ident) -> Ident -> f Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
  [Ident]
vals <- [m Ident] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [VName -> Type -> m Ident
forall {f :: * -> *}. Applicative f => VName -> Type -> f Ident
identForBindage VName
name Type
t | (VName
name, Type
t) <- [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Type]
ts']
  Stm rep -> PatternT LParamMem
forall rep. Stm rep -> Pattern rep
stmPattern (Stm rep -> PatternT LParamMem)
-> m (Stm rep) -> m (PatternT LParamMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Ident] -> [Ident] -> Exp rep -> m (Stm rep)
forall rep (m :: * -> *).
(Allocator rep m, ExpDec rep ~ ()) =>
[Ident] -> [Ident] -> Exp rep -> m (Stm rep)
allocsForStm [Ident]
sizes [Ident]
vals Exp rep
e

allocsForPattern ::
  Allocator rep m =>
  [Ident] ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m
    ( [PatElem rep],
      [PatElem rep]
    )
allocsForPattern :: forall rep (m :: * -> *).
Allocator rep m =>
[Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem rep], [PatElem rep])
allocsForPattern [Ident]
sizeidents [Ident]
validents [ExpReturns]
rts [ExpHint]
hints = do
  let sizes' :: [PatElemT LParamMem]
sizes' = [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
size (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64 | VName
size <- (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
sizeidents]
  ([PatElemT LParamMem]
vals, ([PatElemT LParamMem]
exts, [PatElemT LParamMem]
mems)) <-
    WriterT
  ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
-> m ([PatElemT LParamMem],
      ([PatElemT LParamMem], [PatElemT LParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
 -> m ([PatElemT LParamMem],
       ([PatElemT LParamMem], [PatElemT LParamMem])))
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
-> m ([PatElemT LParamMem],
      ([PatElemT LParamMem], [PatElemT LParamMem]))
forall a b. (a -> b) -> a -> b
$
      [(Ident, ExpReturns, ExpHint)]
-> ((Ident, ExpReturns, ExpHint)
    -> WriterT
         ([PatElemT LParamMem], [PatElemT LParamMem])
         m
         (PatElemT LParamMem))
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
-> [ExpReturns] -> [ExpHint] -> [(Ident, ExpReturns, ExpHint)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
validents [ExpReturns]
rts [ExpHint]
hints) (((Ident, ExpReturns, ExpHint)
  -> WriterT
       ([PatElemT LParamMem], [PatElemT LParamMem])
       m
       (PatElemT LParamMem))
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      [PatElemT LParamMem])
-> ((Ident, ExpReturns, ExpHint)
    -> WriterT
         ([PatElemT LParamMem], [PatElemT LParamMem])
         m
         (PatElemT LParamMem))
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
        let ident_shape :: Shape
ident_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
        case ExpReturns
rt of
          MemPrim PrimType
_ -> do
            LParamMem
summary <- m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem)
-> m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall a b. (a -> b) -> a -> b
$ Type -> ExpHint -> m LParamMem
forall rep (m :: * -> *).
Allocator rep m =>
Type -> ExpHint -> m LParamMem
summaryForBindage (Ident -> Type
identType Ident
ident) ExpHint
hint
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
          MemMem Space
space ->
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$
              VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$
                Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
          MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
            ([PatElemT LParamMem]
patels, IxFun
ixfn) <- Ident
-> ExtIxFun
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem])
     m
     ([PatElemT LParamMem], IxFun)
forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
ident ExtIxFun
extixfun
            ([PatElemT LParamMem], [PatElemT LParamMem])
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([PatElemT LParamMem]
patels, [])

            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$
              VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$
                PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$
                  VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfn
          MemArray PrimType
_ ExtShape
extshape NoUniqueness
_ Maybe MemReturn
Nothing
            | Just Result
_ <- ExtShape -> Maybe Result
forall {b}. ShapeBase (Ext b) -> Maybe [b]
knownShape ExtShape
extshape -> do
              LParamMem
summary <- m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem)
-> m LParamMem
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m LParamMem
forall a b. (a -> b) -> a -> b
$ Type -> ExpHint -> m LParamMem
forall rep (m :: * -> *).
Allocator rep m =>
Type -> ExpHint -> m LParamMem
summaryForBindage (Ident -> Type
identType Ident
ident) ExpHint
hint
              PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) LParamMem
summary
          MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsNewBlock Space
space Int
_ ExtIxFun
extixfn)) -> do
            -- treat existential index function first
            ([PatElemT LParamMem]
patels, IxFun
ixfn) <- Ident
-> ExtIxFun
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem])
     m
     ([PatElemT LParamMem], IxFun)
forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
ident ExtIxFun
extixfn
            ([PatElemT LParamMem], [PatElemT LParamMem])
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([PatElemT LParamMem]
patels, [])

            Ident
memid <- m Ident
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m Ident
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Ident
 -> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m Ident)
-> m Ident
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m Ident
forall a b. (a -> b) -> a -> b
$ Ident -> Space -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
Ident -> Space -> m Ident
mkMemIdent Ident
ident Space
space
            ([PatElemT LParamMem], [PatElemT LParamMem])
-> WriterT ([PatElemT LParamMem], [PatElemT LParamMem]) m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
memid) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space])
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$
              VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$
                PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$
                  VName -> IxFun -> MemBind
ArrayIn (Ident -> VName
identName Ident
memid) IxFun
ixfn
          MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
            PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT LParamMem
 -> WriterT
      ([PatElemT LParamMem], [PatElemT LParamMem])
      m
      (PatElemT LParamMem))
-> PatElemT LParamMem
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LParamMem -> PatElemT LParamMem)
-> LParamMem -> PatElemT LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          ExpReturns
_ -> String
-> WriterT
     ([PatElemT LParamMem], [PatElemT LParamMem]) m (PatElemT LParamMem)
forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPattern!"

  ([PatElemT LParamMem], [PatElemT LParamMem])
-> m ([PatElemT LParamMem], [PatElemT LParamMem])
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [PatElemT LParamMem]
sizes' [PatElemT LParamMem]
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Semigroup a => a -> a -> a
<> [PatElemT LParamMem]
exts [PatElemT LParamMem]
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Semigroup a => a -> a -> a
<> [PatElemT LParamMem]
mems,
      [PatElemT LParamMem]
vals
    )
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = (Ext b -> Maybe b) -> [Ext b] -> Maybe [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext b -> Maybe b
forall {a}. Ext a -> Maybe a
known ([Ext b] -> Maybe [b])
-> (ShapeBase (Ext b) -> [Ext b]) -> ShapeBase (Ext b) -> Maybe [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext b) -> [Ext b]
forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = a -> Maybe a
forall a. a -> Maybe a
Just a
v
    known Ext {} = Maybe a
forall a. Maybe a
Nothing

    mkMemIdent :: (MonadFreshNames m) => Ident -> Space -> m Ident
    mkMemIdent :: forall (m :: * -> *).
MonadFreshNames m =>
Ident -> Space -> m Ident
mkMemIdent Ident
ident Space
space = do
      let memname :: String
memname = VName -> String
baseString (Ident -> VName
identName Ident
ident) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
      String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
memname (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space

    instantiateExtIxFun ::
      MonadFreshNames m =>
      Ident ->
      ExtIxFun ->
      m ([PatElemT (MemInfo d u ret)], IxFun)
    instantiateExtIxFun :: forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
idd ExtIxFun
ext_ixfn = do
      let isAndPtps :: [(Int, PrimType)]
isAndPtps =
            Set (Int, PrimType) -> [(Int, PrimType)]
forall a. Set a -> [a]
S.toList (Set (Int, PrimType) -> [(Int, PrimType)])
-> Set (Int, PrimType) -> [(Int, PrimType)]
forall a b. (a -> b) -> a -> b
$
              ((Ext VName, PrimType) -> Set (Int, PrimType))
-> Set (Ext VName, PrimType) -> Set (Int, PrimType)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Ext VName, PrimType) -> Set (Int, PrimType)
forall a. (Ext a, PrimType) -> Set (Int, PrimType)
onlyExts (Set (Ext VName, PrimType) -> Set (Int, PrimType))
-> Set (Ext VName, PrimType) -> Set (Int, PrimType)
forall a b. (a -> b) -> a -> b
$
                (TPrimExp Int64 (Ext VName) -> Set (Ext VName, PrimType))
-> ExtIxFun -> Set (Ext VName, PrimType)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (PrimExp (Ext VName) -> Set (Ext VName, PrimType)
forall a. Ord a => PrimExp a -> Set (a, PrimType)
leafExpTypes (PrimExp (Ext VName) -> Set (Ext VName, PrimType))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> Set (Ext VName, PrimType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall t v. TPrimExp t v -> PrimExp v
untyped) ExtIxFun
ext_ixfn

      -- Find the existentials that reuse the sizeidents, and
      -- those that need new pattern elements.  Assumes that the
      -- Exts form a contiguous interval of integers.
      let ([(Int, PrimType)]
size_exts, [(Int, PrimType)]
new_exts) =
            ((Int, PrimType) -> Bool)
-> [(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
sizeidents) (Int -> Bool)
-> ((Int, PrimType) -> Int) -> (Int, PrimType) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, PrimType) -> Int
forall a b. (a, b) -> a
fst) ([(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)]))
-> [(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)])
forall a b. (a -> b) -> a -> b
$ [(Int, PrimType)] -> [(Int, PrimType)]
forall a. Ord a => [a] -> [a]
sort [(Int, PrimType)]
isAndPtps
      ([(Ext VName, PrimExp (Ext VName))]
new_substs, [PatElemT (MemInfo d u ret)]
patels) <-
        ([((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
 -> ([(Ext VName, PrimExp (Ext VName))],
     [PatElemT (MemInfo d u ret)]))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
-> m ([(Ext VName, PrimExp (Ext VName))],
      [PatElemT (MemInfo d u ret)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
-> ([(Ext VName, PrimExp (Ext VName))],
    [PatElemT (MemInfo d u ret)])
forall a b. [(a, b)] -> ([a], [b])
unzip (m [((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
 -> m ([(Ext VName, PrimExp (Ext VName))],
       [PatElemT (MemInfo d u ret)]))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
-> m ([(Ext VName, PrimExp (Ext VName))],
      [PatElemT (MemInfo d u ret)])
forall a b. (a -> b) -> a -> b
$
          [(Int, PrimType)]
-> ((Int, PrimType)
    -> m ((Ext VName, PrimExp (Ext VName)),
          PatElemT (MemInfo d u ret)))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Int, PrimType)]
new_exts (((Int, PrimType)
  -> m ((Ext VName, PrimExp (Ext VName)),
        PatElemT (MemInfo d u ret)))
 -> m [((Ext VName, PrimExp (Ext VName)),
        PatElemT (MemInfo d u ret))])
-> ((Int, PrimType)
    -> m ((Ext VName, PrimExp (Ext VName)),
          PatElemT (MemInfo d u ret)))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
forall a b. (a -> b) -> a -> b
$ \(Int
i, PrimType
t) -> do
            VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (Ident -> VName
identName Ident
idd) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_ixfn"
            ((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))
-> m ((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))
forall (m :: * -> *) a. Monad m => a -> m a
return
              ( (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i, Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
v) PrimType
t),
                VName -> MemInfo d u ret -> PatElemT (MemInfo d u ret)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v (MemInfo d u ret -> PatElemT (MemInfo d u ret))
-> MemInfo d u ret -> PatElemT (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
              )
      let size_substs :: [(Ext VName, PrimExp (Ext VName))]
size_substs =
            ((Int, PrimType) -> Ident -> (Ext VName, PrimExp (Ext VName)))
-> [(Int, PrimType)]
-> [Ident]
-> [(Ext VName, PrimExp (Ext VName))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \(Int
i, PrimType
t) Ident
ident ->
                  (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i, Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free (Ident -> VName
identName Ident
ident)) PrimType
t)
              )
              [(Int, PrimType)]
size_exts
              [Ident]
sizeidents
          substs :: Map (Ext VName) (PrimExp (Ext VName))
substs = [(Ext VName, PrimExp (Ext VName))]
-> Map (Ext VName) (PrimExp (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, PrimExp (Ext VName))]
 -> Map (Ext VName) (PrimExp (Ext VName)))
-> [(Ext VName, PrimExp (Ext VName))]
-> Map (Ext VName) (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ [(Ext VName, PrimExp (Ext VName))]
new_substs [(Ext VName, PrimExp (Ext VName))]
-> [(Ext VName, PrimExp (Ext VName))]
-> [(Ext VName, PrimExp (Ext VName))]
forall a. Semigroup a => a -> a -> a
<> [(Ext VName, PrimExp (Ext VName))]
size_substs
      IxFun
ixfn <- ExtIxFun -> m IxFun
forall (m :: * -> *). Monad m => ExtIxFun -> m IxFun
instantiateIxFun (ExtIxFun -> m IxFun) -> ExtIxFun -> m IxFun
forall a b. (a -> b) -> a -> b
$ Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun ((PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (PrimExp (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 Map (Ext VName) (PrimExp (Ext VName))
substs) ExtIxFun
ext_ixfn

      ([PatElemT (MemInfo d u ret)], IxFun)
-> m ([PatElemT (MemInfo d u ret)], IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (MemInfo d u ret)]
patels, IxFun
ixfn)

onlyExts :: (Ext a, PrimType) -> S.Set (Int, PrimType)
onlyExts :: forall a. (Ext a, PrimType) -> Set (Int, PrimType)
onlyExts (Free a
_, PrimType
_) = Set (Int, PrimType)
forall a. Set a
S.empty
onlyExts (Ext Int
i, PrimType
t) = (Int, PrimType) -> Set (Int, PrimType)
forall a. a -> Set a
S.singleton (Int
i, PrimType
t)

instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun :: forall (m :: * -> *). Monad m => ExtIxFun -> m IxFun
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun -> m IxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
 -> ExtIxFun -> m IxFun)
-> (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun
-> m IxFun
forall a b. (a -> b) -> a -> b
$ (Ext VName -> m VName)
-> TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> m VName
forall {m :: * -> *} {a}. Monad m => Ext a -> m a
inst
  where
    inst :: Ext a -> m a
inst Ext {} = String -> m a
forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
    inst (Free a
x) = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

summaryForBindage ::
  Allocator rep m =>
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: forall rep (m :: * -> *).
Allocator rep m =>
Type -> ExpHint -> m LParamMem
summaryForBindage (Prim PrimType
bt) ExpHint
_ =
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage (Mem Space
space) ExpHint
_ =
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- Type -> Space -> m VName
forall rep (m :: * -> *).
Allocator rep m =>
Type -> Space -> m VName
allocForArray Type
t (Space -> m VName) -> m Space -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m Space
forall rep (m :: * -> *). Allocator rep m => m Space
askDefaultSpace
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> VName -> Type -> LParamMem
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
pt Shape
shape NoUniqueness
u VName
m Type
t
summaryForBindage t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun
ixfun Space
space) = do
  SubExp
bytes <-
    String -> PrimExp VName -> m SubExp
forall rep (m :: * -> *).
Allocator rep m =>
String -> PrimExp VName -> m SubExp
computeSize String
"bytes" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
        [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
          [ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ IxFun -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun,
            Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
          ]
  VName
m <- String -> SubExp -> Space -> m VName
forall rep (m :: * -> *).
Allocator rep m =>
String -> SubExp -> Space -> m VName
allocateMemory String
"mem" SubExp
bytes Space
space
  LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
m IxFun
ixfun

lookupMemSpace :: (HasScope rep m, Monad m) => VName -> m Space
lookupMemSpace :: forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v = do
  Type
t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Mem Space
space -> Space -> m Space
forall (m :: * -> *) a. Monad m => a -> m a
return Space
space
    Type
_ -> String -> m Space
forall a. HasCallStack => String -> a
error (String -> m Space) -> String -> m Space
forall a b. (a -> b) -> a -> b
$ String
"lookupMemSpace: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a memory block."

directIxFun :: PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun :: forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
bt Shape
shape u
u VName
mem Type
t =
  let ixf :: IxFun
ixf = [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t
   in PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> MemBind -> MemInfo SubExp u MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixf

allocInFParams ::
  (Allocable fromrep torep) =>
  [(FParam fromrep, Space)] ->
  ([FParam torep] -> AllocM fromrep torep a) ->
  AllocM fromrep torep a
allocInFParams :: forall fromrep torep a.
Allocable fromrep torep =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
ctxparams, [Param FParamMem]
memparams)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   [Param FParamMem]
 -> AllocM
      fromrep
      torep
      ([Param FParamMem], ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ ((FParam fromrep, Space)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> [(FParam fromrep, Space)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParam fromrep
 -> Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> (FParam fromrep, Space)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry FParam fromrep
-> Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall fromrep torep.
Allocable fromrep torep =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam) [(FParam fromrep, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
ctxparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
memparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [FParam torep]
[Param FParamMem]
params'

allocInFParam ::
  (Allocable fromrep torep) =>
  FParam fromrep ->
  Space ->
  WriterT
    ([FParam torep], [FParam torep])
    (AllocM fromrep torep)
    (FParam torep)
allocInFParam :: forall fromrep torep.
Allocable fromrep torep =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
  case Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
FParam fromrep
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
FParam fromrep
param) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun
ixfun = [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      VName)
-> AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
mem (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace])
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun}
    Prim PrimType
pt ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
    Mem Space
space ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = VName -> Shape -> [Type] -> Uniqueness -> FParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}

allocInMergeParams ::
  ( Allocable fromrep torep,
    Allocator torep (AllocM fromrep torep)
  ) =>
  [(FParam fromrep, SubExp)] ->
  ( [FParam torep] ->
    [FParam torep] ->
    ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
    AllocM fromrep torep a
  ) ->
  AllocM fromrep torep a
allocInMergeParams :: forall fromrep torep a.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[(FParam fromrep, SubExp)]
-> ([FParam torep]
    -> [FParam torep]
    -> (Result -> AllocM fromrep torep (Result, Result))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge [FParam torep]
-> [FParam torep]
-> (Result -> AllocM fromrep torep (Result, Result))
-> AllocM fromrep torep a
m = do
  (([Param FParamMem]
valparams, [SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
ctx_params, [Param FParamMem]
mem_params)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  ([Param FParamMem],
   [SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem],
       [SubExp
        -> WriterT (Result, Result) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   ([Param FParamMem],
    [SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp])
 -> AllocM
      fromrep
      torep
      (([Param FParamMem],
        [SubExp
         -> WriterT (Result, Result) (AllocM fromrep torep) SubExp]),
       ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem],
      [SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem],
       [SubExp
        -> WriterT (Result, Result) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ [(Param FParamMem,
  SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)]
-> ([Param FParamMem],
    [SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param FParamMem,
   SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)]
 -> ([Param FParamMem],
     [SubExp
      -> WriterT (Result, Result) (AllocM fromrep torep) SubExp]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem,
       SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem],
      [SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem,
       SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp))
-> [(Param DeclType, SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem,
       SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem,
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep,
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
allocInMergeParam [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: Result -> AllocM fromrep torep (Result, Result)
mk_loop_res Result
ses = do
        (Result
valargs, (Result
ctxargs, Result
memargs)) <-
          WriterT (Result, Result) (AllocM fromrep torep) Result
-> AllocM fromrep torep (Result, (Result, Result))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (Result, Result) (AllocM fromrep torep) Result
 -> AllocM fromrep torep (Result, (Result, Result)))
-> WriterT (Result, Result) (AllocM fromrep torep) Result
-> AllocM fromrep torep (Result, (Result, Result))
forall a b. (a -> b) -> a -> b
$ ((SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
 -> SubExp
 -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
-> [SubExp
    -> WriterT (Result, Result) (AllocM fromrep torep) SubExp]
-> Result
-> WriterT (Result, Result) (AllocM fromrep torep) Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
-> SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
($) [SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp]
handle_loop_subexps Result
ses
        (Result, Result) -> AllocM fromrep torep (Result, Result)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
ctxargs Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
memargs, Result
valargs)

  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [FParam torep]
-> [FParam torep]
-> (Result -> AllocM fromrep torep (Result, Result))
-> AllocM fromrep torep a
m ([Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
mem_params) [FParam torep]
[Param FParamMem]
valparams Result -> AllocM fromrep torep (Result, Result)
mk_loop_res
  where
    allocInMergeParam ::
      (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam torep], [FParam torep])
        (AllocM fromrep torep)
        (FParam torep, SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
    allocInMergeParam :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep,
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | Array PrimType
pt Shape
shape Uniqueness
u <- Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
        (VName
mem', IxFun
_) <- AllocM fromrep torep (VName, IxFun)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, IxFun))
-> AllocM fromrep torep (VName, IxFun)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun)
forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
        Space
mem_space <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      Space)
-> AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem'

        (SubExp
_, ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs, VName
_) <- AllocM
  fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM
   fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName))
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall a b. (a -> b) -> a -> b
$ Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray Space
mem_space VName
v

        ([Param FParamMem]
ctx_params, [TPrimExp Int64 (Ext VName)]
param_ixfun_substs) <-
          [(Param FParamMem, TPrimExp Int64 (Ext VName))]
-> ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip
            ([(Param FParamMem, TPrimExp Int64 (Ext VName))]
 -> ([Param FParamMem], [TPrimExp Int64 (Ext VName)]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, TPrimExp Int64 (Ext VName))]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem, TPrimExp Int64 (Ext VName)))
-> [TPrimExp Int64 VName]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, TPrimExp Int64 (Ext VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
              ( \TPrimExp Int64 VName
e -> do
                  let e_t :: PrimType
e_t = PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType (PrimExp VName -> PrimType) -> PrimExp VName -> PrimType
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
e
                  VName
vname <- AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      VName)
-> AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ctx_param_ext"
                  (Param FParamMem, TPrimExp Int64 (Ext VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, TPrimExp Int64 (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return
                    ( VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
vname (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
e_t,
                      (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vname
                    )
              )
              [TPrimExp Int64 VName]
substs

        ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem]
ctx_params, [])

        IxFun
param_ixfun <-
          ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) IxFun
forall (m :: * -> *). Monad m => ExtIxFun -> m IxFun
instantiateIxFun (ExtIxFun
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      IxFun)
-> ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) IxFun
forall a b. (a -> b) -> a -> b
$
            Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
              ([(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ [Ext VName]
-> [TPrimExp Int64 (Ext VName)]
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Int -> Ext VName) -> [Int] -> [Ext VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext [Int
0 ..]) [TPrimExp Int64 (Ext VName)]
param_ixfun_substs)
              ExtIxFun
ext_ixfun

        VName
mem_name <- String
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem_param"
        ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> FParamMem -> Param FParamMem
forall dec. VName -> dec -> Param dec
Param VName
mem_name (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
mem_space])

        (Param FParamMem,
 SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem,
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem_name IxFun
param_ixfun},
            Space
-> SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space
-> SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
ensureArrayIn Space
mem_space
          )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
_) = Param DeclType
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep),
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
forall {torep} {fromrep} {torep} {fromrep}.
(PrettyRep fromrep, PrettyRep fromrep, AllocOp (Op torep),
 AllocOp (Op torep), OpReturns torep, OpReturns torep,
 SizeSubst (Op torep), SizeSubst (Op torep), BinderOps torep,
 BinderOps torep, LetDec torep ~ LParamMem, BodyDec torep ~ (),
 LParamInfo fromrep ~ Type, LetDec torep ~ LParamMem,
 BodyDec torep ~ (), LParamInfo fromrep ~ Type,
 BranchType fromrep ~ ExtType, ExpDec torep ~ (),
 RetType torep ~ RetTypeMem, BranchType fromrep ~ ExtType,
 ExpDec torep ~ (), RetType torep ~ RetTypeMem,
 LParamInfo torep ~ LParamMem, BodyDec fromrep ~ (),
 FParamInfo fromrep ~ DeclType, LParamInfo torep ~ LParamMem,
 BodyDec fromrep ~ (), FParamInfo fromrep ~ DeclType,
 RetType fromrep ~ DeclExtType, FParamInfo torep ~ FParamMem,
 BranchType torep ~ BranchTypeMem, RetType fromrep ~ DeclExtType,
 FParamInfo torep ~ FParamMem, BranchType torep ~ BranchTypeMem) =>
Param DeclType
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep),
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam (Space
 -> WriterT
      ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
      (AllocM fromrep torep)
      (Param FParamMem,
       SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp))
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param FParamMem,
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AllocM fromrep torep Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall rep (m :: * -> *). Allocator rep m => m Space
askDefaultSpace

    doDefault :: Param DeclType
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep),
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam Space
space = do
      Param (FParamInfo torep)
mergeparam' <- FParam fromrep
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep))
forall fromrep torep.
Allocable fromrep torep =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam Param DeclType
FParam fromrep
mergeparam Space
space
      (Param (FParamInfo torep),
 SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep),
      SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (FParamInfo torep)
mergeparam', Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromrep torep) SubExp
linearFuncallArg (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
mergeparam) Space
space)

-- Returns the existentialized index function, the list of substituted values and the memory location.
existentializeArray ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  Space ->
  VName ->
  AllocM fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray ScalarSpace {} VName
v = do
  (VName
mem', IxFun
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun)
forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
v, (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun
ixfun, [TPrimExp Int64 VName]
forall a. Monoid a => a
mempty, VName
mem')
existentializeArray Space
space VName
v = do
  (VName
mem', IxFun
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun)
forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  Space
sp <- VName -> AllocM fromrep torep Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem'

  let (Maybe ExtIxFun
ext_ixfun', [TPrimExp Int64 VName]
substs') = State [TPrimExp Int64 VName] (Maybe ExtIxFun)
-> [TPrimExp Int64 VName]
-> (Maybe ExtIxFun, [TPrimExp Int64 VName])
forall s a. State s a -> s -> (a, s)
runState (IxFun -> State [TPrimExp Int64 VName] (Maybe ExtIxFun)
forall t v.
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
IxFun.existentialize IxFun
ixfun) []

  case (Maybe ExtIxFun
ext_ixfun', Space
sp Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
space) of
    (Just ExtIxFun
x, Bool
True) -> (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
v, ExtIxFun
x, [TPrimExp Int64 VName]
substs', VName
mem')
    (Maybe ExtIxFun, Bool)
_ -> do
      (VName
mem, SubExp
subexp) <- Space -> String -> VName -> AllocM fromrep torep (VName, SubExp)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space -> String -> VName -> AllocM fromrep torep (VName, SubExp)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
      IxFun
ixfun' <- Maybe IxFun -> IxFun
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe IxFun -> IxFun)
-> AllocM fromrep torep (Maybe IxFun) -> AllocM fromrep torep IxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> AllocM fromrep torep (Maybe IxFun)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
SubExp -> AllocM fromrep torep (Maybe IxFun)
subExpIxFun SubExp
subexp
      let (Maybe ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs) = State [TPrimExp Int64 VName] (Maybe ExtIxFun)
-> [TPrimExp Int64 VName]
-> (Maybe ExtIxFun, [TPrimExp Int64 VName])
forall s a. State s a -> s -> (a, s)
runState (IxFun -> State [TPrimExp Int64 VName] (Maybe ExtIxFun)
forall t v.
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
IxFun.existentialize IxFun
ixfun') []
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
subexp, Maybe ExtIxFun -> ExtIxFun
forall a. HasCallStack => Maybe a -> a
fromJust Maybe ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs, VName
mem)

ensureArrayIn ::
  ( Allocable fromrep torep,
    Allocator torep (AllocM fromrep torep)
  ) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space
-> SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  String -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall a. HasCallStack => String -> a
error (String -> WriterT (Result, Result) (AllocM fromrep torep) SubExp)
-> String -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (SubExp
sub_exp, ExtIxFun
_, [TPrimExp Int64 VName]
substs, VName
mem) <- AllocM
  fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     (Result, Result)
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM
   fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
 -> WriterT
      (Result, Result)
      (AllocM fromrep torep)
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName))
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     (Result, Result)
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall a b. (a -> b) -> a -> b
$ Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray Space
space VName
v
  (Result
ctx_vals, [PrimExp (Ext VName)]
_) <-
    [(SubExp, PrimExp (Ext VName))] -> (Result, [PrimExp (Ext VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip
      ([(SubExp, PrimExp (Ext VName))]
 -> (Result, [PrimExp (Ext VName)]))
-> WriterT
     (Result, Result)
     (AllocM fromrep torep)
     [(SubExp, PrimExp (Ext VName))]
-> WriterT
     (Result, Result)
     (AllocM fromrep torep)
     (Result, [PrimExp (Ext VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName
 -> WriterT
      (Result, Result)
      (AllocM fromrep torep)
      (SubExp, PrimExp (Ext VName)))
-> [TPrimExp Int64 VName]
-> WriterT
     (Result, Result)
     (AllocM fromrep torep)
     [(SubExp, PrimExp (Ext VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
        ( \TPrimExp Int64 VName
s -> do
            VName
vname <- AllocM fromrep torep VName
-> WriterT (Result, Result) (AllocM fromrep torep) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
 -> WriterT (Result, Result) (AllocM fromrep torep) VName)
-> AllocM fromrep torep VName
-> WriterT (Result, Result) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String
-> Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"ctx_val" (ExpT torep -> AllocM fromrep torep VName)
-> AllocM fromrep torep (ExpT torep) -> AllocM fromrep torep VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> AllocM fromrep torep (Exp (Rep (AllocM fromrep torep)))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
s
            (SubExp, PrimExp (Ext VName))
-> WriterT
     (Result, Result)
     (AllocM fromrep torep)
     (SubExp, PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
vname, (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free (PrimExp VName -> PrimExp (Ext VName))
-> PrimExp VName -> PrimExp (Ext VName)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vname)
        )
        [TPrimExp Int64 VName]
substs

  (Result, Result)
-> WriterT (Result, Result) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Result
ctx_vals, [VName -> SubExp
Var VName
mem])

  SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
sub_exp

ensureDirectArray ::
  ( Allocable fromrep torep,
    Allocator torep (AllocM fromrep torep)
  ) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, SubExp)
ensureDirectArray :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Maybe Space -> VName -> AllocM fromrep torep (VName, SubExp)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun)
forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall rep (m :: * -> *). Allocator rep m => m Space
askDefaultSpace
  if IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, SubExp) -> AllocM fromrep torep (VName, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName -> SubExp
Var VName
v)
    else Space -> AllocM fromrep torep (VName, SubExp)
needCopy (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromrep torep (VName, SubExp)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      Space -> String -> VName -> AllocM fromrep torep (VName, SubExp)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space -> String -> VName -> AllocM fromrep torep (VName, SubExp)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocLinearArray ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  Space ->
  String ->
  VName ->
  AllocM fromrep torep (VName, SubExp)
allocLinearArray :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Space -> String -> VName -> AllocM fromrep torep (VName, SubExp)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- Type -> Space -> AllocM fromrep torep VName
forall rep (m :: * -> *).
Allocator rep m =>
Type -> Space -> m VName
allocForArray Type
t Space
space
      Ident
v' <- String -> Type -> AllocM fromrep torep Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (String
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_linear") Type
t
      let ixfun :: LParamMem
ixfun = PrimType -> Shape -> NoUniqueness -> VName -> Type -> LParamMem
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
pt Shape
shape NoUniqueness
u VName
mem Type
t
          pat :: PatternT LParamMem
pat = [PatElemT LParamMem] -> [PatElemT LParamMem] -> PatternT LParamMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LParamMem -> PatElemT LParamMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
v') LParamMem
ixfun]
      Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBinder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ())
-> Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Pattern torep -> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern torep
PatternT LParamMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp torep -> Stm torep) -> Exp torep -> Stm torep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp torep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp torep) -> BasicOp -> Exp torep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      (VName, SubExp) -> AllocM fromrep torep (VName, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v')
    Type
_ ->
      String -> AllocM fromrep torep (VName, SubExp)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (VName, SubExp))
-> String -> AllocM fromrep torep (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ String
"allocLinearArray: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t

funcallArgs ::
  ( Allocable fromrep torep,
    Allocator torep (AllocM fromrep torep)
  ) =>
  [(SubExp, Diet)] ->
  AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, (Result
ctx_args, Result
mem_and_size_args)) <- WriterT (Result, Result) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], (Result, Result))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT (Result, Result) (AllocM fromrep torep) [(SubExp, Diet)]
 -> AllocM fromrep torep ([(SubExp, Diet)], (Result, Result)))
-> WriterT (Result, Result) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], (Result, Result))
forall a b. (a -> b) -> a -> b
$
    [(SubExp, Diet)]
-> ((SubExp, Diet)
    -> WriterT (Result, Result) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT (Result, Result) (AllocM fromrep torep) [(SubExp, Diet)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args (((SubExp, Diet)
  -> WriterT (Result, Result) (AllocM fromrep torep) (SubExp, Diet))
 -> WriterT
      (Result, Result) (AllocM fromrep torep) [(SubExp, Diet)])
-> ((SubExp, Diet)
    -> WriterT (Result, Result) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT (Result, Result) (AllocM fromrep torep) [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- AllocM fromrep torep Type
-> WriterT (Result, Result) (AllocM fromrep torep) Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Type
 -> WriterT (Result, Result) (AllocM fromrep torep) Type)
-> AllocM fromrep torep Type
-> WriterT (Result, Result) (AllocM fromrep torep) Type
forall a b. (a -> b) -> a -> b
$ SubExp -> AllocM fromrep torep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- AllocM fromrep torep Space
-> WriterT (Result, Result) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall rep (m :: * -> *). Allocator rep m => m Space
askDefaultSpace
      SubExp
arg' <- Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      (SubExp, Diet)
-> WriterT (Result, Result) (AllocM fromrep torep) (SubExp, Diet)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
arg', Diet
d)
  [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)])
-> [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> (SubExp, Diet)) -> Result -> [(SubExp, Diet)]
forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) (Result
ctx_args Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_and_size_args) [(SubExp, Diet)] -> [(SubExp, Diet)] -> [(SubExp, Diet)]
forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  ( Allocable fromrep torep,
    Allocator torep (AllocM fromrep torep)
  ) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Type
-> Space
-> SubExp
-> WriterT (Result, Result) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, SubExp
arg') <- AllocM fromrep torep (VName, SubExp)
-> WriterT (Result, Result) (AllocM fromrep torep) (VName, SubExp)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, SubExp)
 -> WriterT (Result, Result) (AllocM fromrep torep) (VName, SubExp))
-> AllocM fromrep torep (VName, SubExp)
-> WriterT (Result, Result) (AllocM fromrep torep) (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, SubExp)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Maybe Space -> VName -> AllocM fromrep torep (VName, SubExp)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  (Result, Result)
-> WriterT (Result, Result) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> SubExp
Var VName
mem])
  SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  SubExp -> WriterT (Result, Result) (AllocM fromrep torep) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
arg

explicitAllocationsGeneric ::
  ( Allocable fromrep torep,
    Allocator torep (AllocM fromrep torep)
  ) =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Pass fromrep torep
explicitAllocationsGeneric :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
  String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" ((Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep)
-> (Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep
forall a b. (a -> b) -> a -> b
$
    (Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
  where
    onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
      (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> PassM (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> PassM (Stms torep))
-> AllocM fromrep torep (Stms torep) -> PassM (Stms torep)
forall a b. (a -> b) -> a -> b
$ AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep a.
Allocable fromrep torep =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType fromrep]
rettype [FParam fromrep]
params BodyT fromrep
fbody) =
      (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep)
forall a b. (a -> b) -> a -> b
$
        Stms torep
-> AllocM fromrep torep (FunDef torep)
-> AllocM fromrep torep (FunDef torep)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts (AllocM fromrep torep (FunDef torep)
 -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
-> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$
          [(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall fromrep torep a.
Allocable fromrep torep =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams ([Param DeclType] -> [Space] -> [(Param DeclType, Space)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
[FParam fromrep]
params ([Space] -> [(Param DeclType, Space)])
-> [Space] -> [(Param DeclType, Space)]
forall a b. (a -> b) -> a -> b
$ Space -> [Space]
forall a. a -> [a]
repeat Space
DefaultSpace) (([FParam torep] -> AllocM fromrep torep (FunDef torep))
 -> AllocM fromrep torep (FunDef torep))
-> ([FParam torep] -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ \[FParam torep]
params' -> do
            Body torep
fbody' <-
              [Maybe Space] -> BodyT fromrep -> AllocM fromrep torep (Body torep)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[Maybe Space] -> Body fromrep -> AllocM fromrep torep (Body torep)
allocInFunBody
                ((DeclExtType -> Maybe Space) -> [DeclExtType] -> [Maybe Space]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Space -> DeclExtType -> Maybe Space
forall a b. a -> b -> a
const (Maybe Space -> DeclExtType -> Maybe Space)
-> Maybe Space -> DeclExtType -> Maybe Space
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
DefaultSpace) [DeclExtType]
[RetType fromrep]
rettype)
                BodyT fromrep
fbody
            FunDef torep -> AllocM fromrep torep (FunDef torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef torep -> AllocM fromrep torep (FunDef torep))
-> FunDef torep -> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType torep]
-> [FParam torep]
-> Body torep
-> FunDef torep
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> BodyT rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname ([DeclExtType] -> [RetTypeMem]
memoryInDeclExtType [DeclExtType]
[RetType fromrep]
rettype) [FParam torep]
params' Body torep
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope torep m,
    Allocable fromrep torep
  ) =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Stms fromrep ->
  m (Stms torep)
explicitAllocationsInStmsGeneric :: forall (m :: * -> *) torep fromrep.
(MonadFreshNames m, HasScope torep m, Allocable fromrep torep) =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
  Scope torep
scope <- m (Scope torep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> m (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> m (Stms torep))
-> AllocM fromrep torep (Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> a -> b
$
    Scope torep
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope (AllocM fromrep torep (Stms torep)
 -> AllocM fromrep torep (Stms torep))
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall a b. (a -> b) -> a -> b
$ AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep a.
Allocable fromrep torep =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType [DeclExtType]
dets = State Int [RetTypeMem] -> Int -> [RetTypeMem]
forall s a. State s a -> s -> a
evalState ((DeclExtType -> StateT Int Identity RetTypeMem)
-> [DeclExtType] -> State Int [RetTypeMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity RetTypeMem
forall {m :: * -> *} {u}.
MonadState Int m =>
TypeBase ExtShape u -> m (MemInfo (Ext SubExp) u MemReturn)
addMem [DeclExtType]
dets) (Int -> [RetTypeMem]) -> Int -> [RetTypeMem]
forall a b. (a -> b) -> a -> b
$ [DeclExtType] -> Int
forall u. [TypeBase ExtShape u] -> Int
startOfFreeIDRange [DeclExtType]
dets
  where
    addMem :: TypeBase ExtShape u -> m (MemInfo (Ext SubExp) u MemReturn)
addMem (Prim PrimType
t) = MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = String -> m (MemInfo (Ext SubExp) u MemReturn)
forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ExtShape
shape u
u) = do
      Int
i <- m Int
forall s (m :: * -> *). MonadState s m => m s
get m Int -> m () -> m Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$
        PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
          Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
            [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape
    addMem (Acc VName
acc Shape
ispace [Type]
ts u
u) = MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> Ext VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

startOfFreeIDRange :: [TypeBase ExtShape u] -> Int
startOfFreeIDRange :: forall u. [TypeBase ExtShape u] -> Int
startOfFreeIDRange = Set Int -> Int
forall a. Set a -> Int
S.size (Set Int -> Int)
-> ([TypeBase ExtShape u] -> Set Int)
-> [TypeBase ExtShape u]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TypeBase ExtShape u] -> Set Int
forall u. [TypeBase ExtShape u] -> Set Int
shapeContext

bodyReturnMemCtx ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  SubExp ->
  AllocM fromrep torep [SubExp]
bodyReturnMemCtx :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
SubExp -> AllocM fromrep torep Result
bodyReturnMemCtx Constant {} =
  Result -> AllocM fromrep torep Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
bodyReturnMemCtx (Var VName
v) = do
  LParamMem
info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemPrim {} -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
    MemAcc {} -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
    MemMem {} -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a. Monad m => a -> m a
return [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> SubExp
Var VName
mem]

allocInFunBody ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  [Maybe Space] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep)
allocInFunBody :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[Maybe Space] -> Body fromrep -> AllocM fromrep torep (Body torep)
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
bnds Result
res) =
  AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep)
forall (m :: * -> *). MonadBinder m => m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (BodyT torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep a.
Allocable fromrep torep =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bnds (AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (BodyT torep)
forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- (Maybe Space -> SubExp -> AllocM fromrep torep SubExp)
-> [Maybe Space] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Maybe Space -> SubExp -> AllocM fromrep torep SubExp
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Maybe Space -> SubExp -> AllocM fromrep torep SubExp
ensureDirect [Maybe Space]
space_oks' Result
res
    let (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res'
    Result
mem_ctx_res <- [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result)
-> AllocM fromrep torep [Result] -> AllocM fromrep torep Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> AllocM fromrep torep Result
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
SubExp -> AllocM fromrep torep Result
bodyReturnMemCtx Result
val_res
    Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Result
ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res
  where
    num_vals :: Int
num_vals = [Maybe Space] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = Int -> Maybe Space -> [Maybe Space]
forall a. Int -> a -> [a]
replicate (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_vals) Maybe Space
forall a. Maybe a
Nothing [Maybe Space] -> [Maybe Space] -> [Maybe Space]
forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  Maybe Space ->
  SubExp ->
  AllocM fromrep torep SubExp
ensureDirect :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Maybe Space -> SubExp -> AllocM fromrep torep SubExp
ensureDirect Maybe Space
space_ok SubExp
se = do
  LParamMem
se_info <- SubExp -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
SubExp -> m LParamMem
subExpMemInfo SubExp
se
  case (LParamMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, SubExp
v') <- Maybe Space -> VName -> AllocM fromrep torep (VName, SubExp)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Maybe Space -> VName -> AllocM fromrep torep (VName, SubExp)
ensureDirectArray Maybe Space
space_ok VName
v
      SubExp -> AllocM fromrep torep SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
v'
    (LParamMem, SubExp)
_ ->
      SubExp -> AllocM fromrep torep SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

allocInStms ::
  (Allocable fromrep torep) =>
  Stms fromrep ->
  AllocM fromrep torep a ->
  AllocM fromrep torep a
allocInStms :: forall fromrep torep a.
Allocable fromrep torep =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' ([Stm fromrep] -> AllocM fromrep torep a)
-> [Stm fromrep] -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> [Stm fromrep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
  where
    allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
    allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
      Seq (Stm torep)
allocstms <- AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec fromrep)
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing (Stm fromrep -> StmAux (ExpDec fromrep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Stm fromrep -> AllocM fromrep torep ()
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
      Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
Stms (Rep (AllocM fromrep torep))
allocstms
      let stms_substs :: Map VName SubExp
stms_substs = (Stm torep -> Map VName SubExp)
-> Seq (Stm torep) -> Map VName SubExp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> Map VName SubExp
forall rep. SizeSubst (Op rep) => Stm rep -> Map VName SubExp
sizeSubst Seq (Stm torep)
allocstms
          stms_consts :: Set VName
stms_consts = (Stm torep -> Set VName) -> Seq (Stm torep) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> Set VName
forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
          f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env =
            AllocEnv fromrep torep
env
              { chunkMap :: Map VName SubExp
chunkMap = Map VName SubExp
stms_substs Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromrep torep -> Map VName SubExp
forall fromrep torep. AllocEnv fromrep torep -> Map VName SubExp
chunkMap AllocEnv fromrep torep
env,
                envConsts :: Set VName
envConsts = Set VName
stms_consts Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromrep torep -> Set VName
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts AllocEnv fromrep torep
env
              }
      (AllocEnv fromrep torep -> AllocEnv fromrep torep)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms

allocInStm ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  Stm fromrep ->
  AllocM fromrep torep ()
allocInStm :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pattern [PatElemT (LetDec fromrep)]
sizeElems [PatElemT (LetDec fromrep)]
valElems) StmAux (ExpDec fromrep)
_ Exp fromrep
e) = do
  Exp torep
e' <- Exp fromrep -> AllocM fromrep torep (Exp torep)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e
  let sizeidents :: [Ident]
sizeidents = (PatElemT (LetDec fromrep) -> Ident)
-> [PatElemT (LetDec fromrep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetDec fromrep) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent [PatElemT (LetDec fromrep)]
sizeElems
      validents :: [Ident]
validents = (PatElemT (LetDec fromrep) -> Ident)
-> [PatElemT (LetDec fromrep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetDec fromrep) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent [PatElemT (LetDec fromrep)]
valElems
  Stm torep
bnd <- [Ident] -> [Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
forall rep (m :: * -> *).
(Allocator rep m, ExpDec rep ~ ()) =>
[Ident] -> [Ident] -> Exp rep -> m (Stm rep)
allocsForStm [Ident]
sizeidents [Ident]
validents Exp torep
e'
  Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBinder m => Stm (Rep m) -> m ()
addStm Stm torep
Stm (Rep (AllocM fromrep torep))
bnd

allocInLambda ::
  Allocable fromrep torep =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: forall fromrep torep.
Allocable fromrep torep =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  [LParam (Rep (AllocM fromrep torep))]
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBinder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
[LParam (Rep (AllocM fromrep torep))]
params (AllocM fromrep torep Result
 -> AllocM fromrep torep (LambdaT torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (LambdaT torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep a.
Allocable fromrep torep =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (Body fromrep -> Stms fromrep
forall rep. BodyT rep -> Stms rep
bodyStms Body fromrep
body) (AllocM fromrep torep Result
 -> AllocM fromrep torep (LambdaT torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (LambdaT torep)
forall a b. (a -> b) -> a -> b
$
    Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Body fromrep -> Result
forall rep. BodyT rep -> Result
bodyResult Body fromrep
body

allocInExp ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  Exp fromrep ->
  AllocM fromrep torep (Exp torep)
allocInExp :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (DoLoop [(FParam fromrep, SubExp)]
ctx [(FParam fromrep, SubExp)]
val LoopForm fromrep
form (Body () Stms fromrep
bodybnds Result
bodyres)) =
  [(FParam fromrep, SubExp)]
-> ([FParam torep]
    -> [FParam torep]
    -> (Result -> AllocM fromrep torep (Result, Result))
    -> AllocM fromrep torep (ExpT torep))
-> AllocM fromrep torep (ExpT torep)
forall fromrep torep a.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[(FParam fromrep, SubExp)]
-> ([FParam torep]
    -> [FParam torep]
    -> (Result -> AllocM fromrep torep (Result, Result))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
ctx (([FParam torep]
  -> [FParam torep]
  -> (Result -> AllocM fromrep torep (Result, Result))
  -> AllocM fromrep torep (ExpT torep))
 -> AllocM fromrep torep (ExpT torep))
-> ([FParam torep]
    -> [FParam torep]
    -> (Result -> AllocM fromrep torep (Result, Result))
    -> AllocM fromrep torep (ExpT torep))
-> AllocM fromrep torep (ExpT torep)
forall a b. (a -> b) -> a -> b
$ \[FParam torep]
_ [FParam torep]
ctxparams' Result -> AllocM fromrep torep (Result, Result)
_ ->
    [(FParam fromrep, SubExp)]
-> ([FParam torep]
    -> [FParam torep]
    -> (Result -> AllocM fromrep torep (Result, Result))
    -> AllocM fromrep torep (ExpT torep))
-> AllocM fromrep torep (ExpT torep)
forall fromrep torep a.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[(FParam fromrep, SubExp)]
-> ([FParam torep]
    -> [FParam torep]
    -> (Result -> AllocM fromrep torep (Result, Result))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
val (([FParam torep]
  -> [FParam torep]
  -> (Result -> AllocM fromrep torep (Result, Result))
  -> AllocM fromrep torep (ExpT torep))
 -> AllocM fromrep torep (ExpT torep))
-> ([FParam torep]
    -> [FParam torep]
    -> (Result -> AllocM fromrep torep (Result, Result))
    -> AllocM fromrep torep (ExpT torep))
-> AllocM fromrep torep (ExpT torep)
forall a b. (a -> b) -> a -> b
$
      \[FParam torep]
new_ctx_params [FParam torep]
valparams' Result -> AllocM fromrep torep (Result, Result)
mk_loop_val -> do
        LoopForm torep
form' <- LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm LoopForm fromrep
form
        Scope torep
-> AllocM fromrep torep (ExpT torep)
-> AllocM fromrep torep (ExpT torep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm torep -> Scope torep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm torep
form') (AllocM fromrep torep (ExpT torep)
 -> AllocM fromrep torep (ExpT torep))
-> AllocM fromrep torep (ExpT torep)
-> AllocM fromrep torep (ExpT torep)
forall a b. (a -> b) -> a -> b
$ do
          (Result
valinit_ctx, Result
valinit') <- Result -> AllocM fromrep torep (Result, Result)
mk_loop_val Result
valinit
          BodyT torep
body' <-
            AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep)
forall (m :: * -> *). MonadBinder m => m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (BodyT torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep a.
Allocable fromrep torep =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodybnds (AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (BodyT torep)
forall a b. (a -> b) -> a -> b
$ do
              (Result
val_ses, Result
valres') <- Result -> AllocM fromrep torep (Result, Result)
mk_loop_val Result
valres
              Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Result
ctxres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
valres'
          ExpT torep -> AllocM fromrep torep (ExpT torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT torep -> AllocM fromrep torep (ExpT torep))
-> ExpT torep -> AllocM fromrep torep (ExpT torep)
forall a b. (a -> b) -> a -> b
$
            [(FParam torep, SubExp)]
-> [(FParam torep, SubExp)]
-> LoopForm torep
-> BodyT torep
-> ExpT torep
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop
              ([Param FParamMem] -> Result -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([FParam torep]
[Param FParamMem]
ctxparams' [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. [a] -> [a] -> [a]
++ [FParam torep]
[Param FParamMem]
new_ctx_params) (Result
ctxinit Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
valinit_ctx))
              ([Param FParamMem] -> Result -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam torep]
[Param FParamMem]
valparams' Result
valinit')
              LoopForm torep
form'
              BodyT torep
body'
  where
    ([Param DeclType]
_ctxparams, Result
ctxinit) = [(Param DeclType, SubExp)] -> ([Param DeclType], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
ctx
    ([Param DeclType]
_valparams, Result
valinit) = [(Param DeclType, SubExp)] -> ([Param DeclType], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
val
    (Result
ctxres, Result
valres) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param DeclType, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
ctx) Result
bodyres
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [RetType fromrep]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  ExpT torep -> AllocM fromrep torep (ExpT torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT torep -> AllocM fromrep torep (ExpT torep))
-> ExpT torep -> AllocM fromrep torep (ExpT torep)
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType torep]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT torep
forall rep.
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT rep
Apply Name
fname [(SubExp, Diet)]
args' ([DeclExtType] -> [RetTypeMem]
memoryInDeclExtType [DeclExtType]
[RetType fromrep]
rettype) (Safety, SrcLoc, [SrcLoc])
loc
allocInExp (If SubExp
cond BodyT fromrep
tbranch0 BodyT fromrep
fbranch0 (IfDec [BranchType fromrep]
rets IfSort
ifsort)) = do
  let num_rets :: Int
num_rets = [ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
[BranchType fromrep]
rets
  -- switch to the explicit-mem rep, but do nothing about results
  (BodyT torep
tbranch, [Maybe IxFun]
tm_ixfs) <- Int
-> BodyT fromrep
-> AllocM fromrep torep (BodyT torep, [Maybe IxFun])
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Int
-> Body fromrep -> AllocM fromrep torep (Body torep, [Maybe IxFun])
allocInIfBody Int
num_rets BodyT fromrep
tbranch0
  (BodyT torep
fbranch, [Maybe IxFun]
fm_ixfs) <- Int
-> BodyT fromrep
-> AllocM fromrep torep (BodyT torep, [Maybe IxFun])
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Int
-> Body fromrep -> AllocM fromrep torep (Body torep, [Maybe IxFun])
allocInIfBody Int
num_rets BodyT fromrep
fbranch0
  [Maybe Space]
tspaces <- Int -> BodyT torep -> AllocM fromrep torep [Maybe Space]
forall torep (m :: * -> *).
(Mem torep, LocalScope torep m) =>
Int -> Body torep -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT torep
tbranch
  [Maybe Space]
fspaces <- Int -> BodyT torep -> AllocM fromrep torep [Maybe Space]
forall torep (m :: * -> *).
(Mem torep, LocalScope torep m) =>
Int -> Body torep -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT torep
fbranch
  -- try to generalize (antiunify) the index functions of the then and else bodies
  let sp_substs :: [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
sp_substs = ((Maybe Space, Maybe IxFun)
 -> (Maybe Space, Maybe IxFun)
 -> (Maybe Space,
     Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])))
-> [(Maybe Space, Maybe IxFun)]
-> [(Maybe Space, Maybe IxFun)]
-> [(Maybe Space,
     Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Maybe Space, Maybe IxFun)
-> (Maybe Space, Maybe IxFun)
-> (Maybe Space,
    Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
generalize ([Maybe Space] -> [Maybe IxFun] -> [(Maybe Space, Maybe IxFun)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
tspaces [Maybe IxFun]
tm_ixfs) ([Maybe Space] -> [Maybe IxFun] -> [(Maybe Space, Maybe IxFun)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
fspaces [Maybe IxFun]
fm_ixfs)
      ([Maybe Space]
spaces, [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs) = [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
-> ([Maybe Space],
    [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
sp_substs
      tsubs :: [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
tsubs = (Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
 -> Maybe (ExtIxFun, [TPrimExp Int64 VName]))
-> [Maybe
      (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> TPrimExp Int64 VName)
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [TPrimExp Int64 VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> TPrimExp Int64 VName
forall a b. (a, b) -> a
fst) [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs
      fsubs :: [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
fsubs = (Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
 -> Maybe (ExtIxFun, [TPrimExp Int64 VName]))
-> [Maybe
      (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> TPrimExp Int64 VName)
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [TPrimExp Int64 VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> TPrimExp Int64 VName
forall a b. (a, b) -> b
snd) [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs
  (BodyT torep
tbranch', [BranchTypeMem]
trets) <- [ExtType]
-> BodyT torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (BodyT torep, [BranchTypeMem])
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[ExtType]
-> Body torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (Body torep, [BranchTypeMem])
addResCtxInIfBody [ExtType]
[BranchType fromrep]
rets BodyT torep
tbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
tsubs
  (BodyT torep
fbranch', [BranchTypeMem]
frets) <- [ExtType]
-> BodyT torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (BodyT torep, [BranchTypeMem])
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[ExtType]
-> Body torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (Body torep, [BranchTypeMem])
addResCtxInIfBody [ExtType]
[BranchType fromrep]
rets BodyT torep
fbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
fsubs
  if [BranchTypeMem]
frets [BranchTypeMem] -> [BranchTypeMem] -> Bool
forall a. Eq a => a -> a -> Bool
/= [BranchTypeMem]
trets
    then String -> AllocM fromrep torep (ExpT torep)
forall a. HasCallStack => String -> a
error String
"In allocInExp, IF case: antiunification of then/else produce different ExtInFn!"
    else do
      -- above is a sanity check; implementation continues on else branch
      let res_then :: Result
res_then = BodyT torep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT torep
tbranch'
          res_else :: Result
res_else = BodyT torep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT torep
fbranch'
          size_ext :: Int
size_ext = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res_then Int -> Int -> Int
forall a. Num a => a -> a -> a
- [BranchTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchTypeMem]
trets
          ([(SubExp, SubExp, Int)]
ind_ses0, [(SubExp, SubExp, Int)]
r_then_else) =
            ((SubExp, SubExp, Int) -> Bool)
-> [(SubExp, SubExp, Int)]
-> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(SubExp
r_then, SubExp
r_else, Int
_) -> SubExp
r_then SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
r_else) ([(SubExp, SubExp, Int)]
 -> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)]))
-> [(SubExp, SubExp, Int)]
-> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)])
forall a b. (a -> b) -> a -> b
$
              Result -> Result -> [Int] -> [(SubExp, SubExp, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res_then Result
res_else [Int
0 .. Int
size_ext Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          (Result
r_then_ext, Result
r_else_ext, [Int]
_) = [(SubExp, SubExp, Int)] -> (Result, Result, [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, Int)]
r_then_else
          ind_ses :: [(Int, SubExp)]
ind_ses =
            ((SubExp, SubExp, Int) -> Int -> (Int, SubExp))
-> [(SubExp, SubExp, Int)] -> [Int] -> [(Int, SubExp)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              (\(SubExp
se, SubExp
_, Int
i) Int
k -> (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k, SubExp
se))
              [(SubExp, SubExp, Int)]
ind_ses0
              [Int
0 .. [(SubExp, SubExp, Int)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, SubExp, Int)]
ind_ses0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          rets'' :: [BranchTypeMem]
rets'' = ([BranchTypeMem] -> (Int, SubExp) -> [BranchTypeMem])
-> [BranchTypeMem] -> [(Int, SubExp)] -> [BranchTypeMem]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[BranchTypeMem]
acc (Int
i, SubExp
se) -> Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se [BranchTypeMem]
acc) [BranchTypeMem]
trets [(Int, SubExp)]
ind_ses
          tbranch'' :: BodyT torep
tbranch'' = BodyT torep
tbranch' {bodyResult :: Result
bodyResult = Result
r_then_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_then}
          fbranch'' :: BodyT torep
fbranch'' = BodyT torep
fbranch' {bodyResult :: Result
bodyResult = Result
r_else_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_else}
          res_if_expr :: ExpT torep
res_if_expr = SubExp
-> BodyT torep
-> BodyT torep
-> IfDec (BranchType torep)
-> ExpT torep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond BodyT torep
tbranch'' BodyT torep
fbranch'' (IfDec (BranchType torep) -> ExpT torep)
-> IfDec (BranchType torep) -> ExpT torep
forall a b. (a -> b) -> a -> b
$ [BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchTypeMem]
rets'' IfSort
ifsort
      ExpT torep -> AllocM fromrep torep (ExpT torep)
forall (m :: * -> *) a. Monad m => a -> m a
return ExpT torep
res_if_expr
  where
    generalize ::
      (Maybe Space, Maybe IxFun) ->
      (Maybe Space, Maybe IxFun) ->
      (Maybe Space, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
    generalize :: (Maybe Space, Maybe IxFun)
-> (Maybe Space, Maybe IxFun)
-> (Maybe Space,
    Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
generalize (Just Space
sp1, Just IxFun
ixf1) (Just Space
sp2, Just IxFun
ixf2) =
      if Space
sp1 Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= Space
sp2
        then (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)
        else case IxFun (PrimExp VName)
-> IxFun (PrimExp VName)
-> Maybe
     (IxFun (PrimExp (Ext VName)), [(PrimExp VName, PrimExp VName)])
forall v.
Eq v =>
IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
IxFun.leastGeneralGeneralization ((TPrimExp Int64 VName -> PrimExp VName)
-> IxFun -> IxFun (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped IxFun
ixf1) ((TPrimExp Int64 VName -> PrimExp VName)
-> IxFun -> IxFun (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped IxFun
ixf2) of
          Just (IxFun (PrimExp (Ext VName))
ixf, [(PrimExp VName, PrimExp VName)]
m) ->
            ( Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1,
              (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. a -> Maybe a
Just
                ( (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> IxFun (PrimExp (Ext VName)) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall t v. PrimExp v -> TPrimExp t v
TPrimExp IxFun (PrimExp (Ext VName))
ixf,
                  [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((PrimExp VName, PrimExp VName) -> TPrimExp Int64 VName)
-> [(PrimExp VName, PrimExp VName)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> (PrimExp VName, PrimExp VName)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> a
fst) [(PrimExp VName, PrimExp VName)]
m) (((PrimExp VName, PrimExp VName) -> TPrimExp Int64 VName)
-> [(PrimExp VName, PrimExp VName)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> (PrimExp VName, PrimExp VName)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> b
snd) [(PrimExp VName, PrimExp VName)]
m)
                )
            )
          Maybe
  (IxFun (PrimExp (Ext VName)), [(PrimExp VName, PrimExp VName)])
Nothing -> (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)
    generalize (Maybe Space
mbsp1, Maybe IxFun
_) (Maybe Space, Maybe IxFun)
_ = (Maybe Space
mbsp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)

    selectSub ::
      ((a, a) -> a) ->
      Maybe (ExtIxFun, [(a, a)]) ->
      Maybe (ExtIxFun, [a])
    selectSub :: forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (a, a) -> a
f (Just (ExtIxFun
ixfn, [(a, a)]
m)) = (ExtIxFun, [a]) -> Maybe (ExtIxFun, [a])
forall a. a -> Maybe a
Just (ExtIxFun
ixfn, ((a, a) -> a) -> [(a, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, a) -> a
f [(a, a)]
m)
    selectSub (a, a) -> a
_ Maybe (ExtIxFun, [(a, a)])
Nothing = Maybe (ExtIxFun, [a])
forall a. Maybe a
Nothing
    allocInIfBody ::
      (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
      Int ->
      Body fromrep ->
      AllocM fromrep torep (Body torep, [Maybe IxFun])
    allocInIfBody :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Int
-> Body fromrep -> AllocM fromrep torep (Body torep, [Maybe IxFun])
allocInIfBody Int
num_vals (Body BodyDec fromrep
_ Stms fromrep
bnds Result
res) =
      AllocM fromrep torep (Result, [Maybe IxFun])
-> AllocM fromrep torep (BodyT torep, [Maybe IxFun])
forall (m :: * -> *) a.
MonadBinder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [Maybe IxFun])
 -> AllocM fromrep torep (BodyT torep, [Maybe IxFun]))
-> (AllocM fromrep torep (Result, [Maybe IxFun])
    -> AllocM fromrep torep (Result, [Maybe IxFun]))
-> AllocM fromrep torep (Result, [Maybe IxFun])
-> AllocM fromrep torep (BodyT torep, [Maybe IxFun])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [Maybe IxFun])
-> AllocM fromrep torep (Result, [Maybe IxFun])
forall fromrep torep a.
Allocable fromrep torep =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bnds (AllocM fromrep torep (Result, [Maybe IxFun])
 -> AllocM fromrep torep (BodyT torep, [Maybe IxFun]))
-> AllocM fromrep torep (Result, [Maybe IxFun])
-> AllocM fromrep torep (BodyT torep, [Maybe IxFun])
forall a b. (a -> b) -> a -> b
$ do
        let (Result
_, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res
        [Maybe IxFun]
mem_ixfs <- (SubExp -> AllocM fromrep torep (Maybe IxFun))
-> Result -> AllocM fromrep torep [Maybe IxFun]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> AllocM fromrep torep (Maybe IxFun)
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
SubExp -> AllocM fromrep torep (Maybe IxFun)
subExpIxFun Result
val_res
        (Result, [Maybe IxFun])
-> AllocM fromrep torep (Result, [Maybe IxFun])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [Maybe IxFun]
mem_ixfs)
allocInExp (WithAcc [(Shape, [VName], Maybe (Lambda fromrep, Result))]
inputs Lambda fromrep
bodylam) =
  [(Shape, [VName], Maybe (Lambda torep, Result))]
-> Lambda torep -> ExpT torep
forall rep.
[(Shape, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
WithAcc ([(Shape, [VName], Maybe (Lambda torep, Result))]
 -> Lambda torep -> ExpT torep)
-> AllocM
     fromrep torep [(Shape, [VName], Maybe (Lambda torep, Result))]
-> AllocM fromrep torep (Lambda torep -> ExpT torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Shape, [VName], Maybe (Lambda fromrep, Result))
 -> AllocM
      fromrep torep (Shape, [VName], Maybe (Lambda torep, Result)))
-> [(Shape, [VName], Maybe (Lambda fromrep, Result))]
-> AllocM
     fromrep torep [(Shape, [VName], Maybe (Lambda torep, Result))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape, [VName], Maybe (Lambda fromrep, Result))
-> AllocM
     fromrep torep (Shape, [VName], Maybe (Lambda torep, Result))
forall {t :: * -> *} {a} {rep} {fromrep} {b}.
(Traversable t, ArrayShape a, AllocOp (Op rep), SizeSubst (Op rep),
 OpReturns rep, BinderOps rep, PrettyRep fromrep,
 LParamInfo fromrep ~ Type, RetType rep ~ RetTypeMem,
 BodyDec fromrep ~ (), FParamInfo fromrep ~ DeclType,
 LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 BodyDec rep ~ (), ExpDec rep ~ (), LetDec rep ~ LParamMem,
 RetType fromrep ~ DeclExtType, BranchType fromrep ~ ExtType,
 BranchType rep ~ BranchTypeMem) =>
(a, [VName], t (LambdaT fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [(Shape, [VName], Maybe (Lambda fromrep, Result))]
inputs AllocM fromrep torep (Lambda torep -> ExpT torep)
-> AllocM fromrep torep (Lambda torep)
-> AllocM fromrep torep (ExpT torep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda fromrep -> AllocM fromrep torep (Lambda torep)
forall {fromrep} {torep}.
(PrettyRep fromrep, AllocOp (Op torep), OpReturns torep,
 SizeSubst (Op torep), BinderOps torep, LetDec torep ~ LParamMem,
 BodyDec fromrep ~ (), LParamInfo torep ~ LParamMem,
 LParamInfo fromrep ~ Type, BodyDec torep ~ (),
 RetType torep ~ RetTypeMem, FParamInfo torep ~ FParamMem,
 RetType fromrep ~ DeclExtType, FParamInfo fromrep ~ DeclType,
 BranchType fromrep ~ ExtType, ExpDec torep ~ (),
 BranchType torep ~ BranchTypeMem) =>
LambdaT fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
  where
    onLambda :: LambdaT fromrep -> AllocM fromrep torep (Lambda torep)
onLambda LambdaT fromrep
lam = do
      [Param LParamMem]
params <- [Param Type]
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (LambdaT fromrep -> [LParam fromrep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT fromrep
lam) ((Param Type -> AllocM fromrep torep (Param LParamMem))
 -> AllocM fromrep torep [Param LParamMem])
-> (Param Type -> AllocM fromrep torep (Param LParamMem))
-> AllocM fromrep torep [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ \(Param VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> Param LParamMem
forall dec. VName -> dec -> Param dec
Param VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LParamMem -> AllocM fromrep torep (Param LParamMem))
-> Param LParamMem -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ VName -> LParamMem -> Param LParamMem
forall dec. VName -> dec -> Param dec
Param VName
pv (LParamMem -> Param LParamMem) -> LParamMem -> Param LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> String -> AllocM fromrep torep (Param LParamMem)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (Param LParamMem))
-> String -> AllocM fromrep torep (Param LParamMem)
forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Param Type -> String
forall a. Pretty a => a -> String
pretty (VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
pv Type
t)
      [LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
forall fromrep torep.
Allocable fromrep torep =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
[Param LParamMem]
params (LambdaT fromrep -> Body fromrep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT fromrep
lam)

    onInput :: (a, [VName], t (LambdaT fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (LambdaT fromrep, b)
op) =
      (a
shape,[VName]
arrs,) (t (Lambda rep, b) -> (a, [VName], t (Lambda rep, b)))
-> AllocM fromrep rep (t (Lambda rep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LambdaT fromrep, b) -> AllocM fromrep rep (Lambda rep, b))
-> t (LambdaT fromrep, b) -> AllocM fromrep rep (t (Lambda rep, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (a
-> [VName]
-> (LambdaT fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
forall {a} {rep} {fromrep} {b}.
(ArrayShape a, AllocOp (Op rep), SizeSubst (Op rep), OpReturns rep,
 PrettyRep fromrep, BinderOps rep, LetDec rep ~ LParamMem,
 BodyDec fromrep ~ (), FParamInfo fromrep ~ DeclType,
 RetType fromrep ~ DeclExtType, ExpDec rep ~ (),
 RetType rep ~ RetTypeMem, LParamInfo rep ~ LParamMem,
 BranchType fromrep ~ ExtType, BodyDec rep ~ (),
 LParamInfo fromrep ~ Type, FParamInfo rep ~ FParamMem,
 BranchType rep ~ BranchTypeMem) =>
a
-> [VName]
-> (LambdaT fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (LambdaT fromrep, b)
op

    onOp :: a
-> [VName]
-> (LambdaT fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (LambdaT fromrep
lam, b
nes) = do
      let num_vs :: Int
num_vs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT fromrep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT fromrep
lam)
          num_is :: Int
num_is = a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            Int
-> Int
-> [Param Type]
-> ([Param Type], [Param Type], [Param Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs ([Param Type] -> ([Param Type], [Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT fromrep -> [LParam fromrep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT fromrep
lam
          i_params' :: [Param LParamMem]
i_params' = (Param Type -> Param LParamMem)
-> [Param Type] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> LParamMem -> Param LParamMem
forall dec. VName -> dec -> Param dec
`Param` PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) (VName -> Param LParamMem)
-> (Param Type -> VName) -> Param Type -> Param LParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = (Param LParamMem -> DimIndex SubExp)
-> [Param LParamMem] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param LParamMem -> SubExp)
-> Param LParamMem
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
i_params'
      [Param LParamMem]
x_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {m :: * -> *} {rep} {u}.
(Monad m, AllocOp (Op rep), ASTRep rep, OpReturns rep,
 HasScope rep m, Pretty u, LetDec rep ~ LParamMem,
 LParamInfo rep ~ LParamMem, RetType rep ~ RetTypeMem,
 FParamInfo rep ~ FParamMem, BranchType rep ~ BranchTypeMem) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LParamMem]
y_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LParamMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LParamMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LParamMem)
forall {m :: * -> *} {rep} {u}.
(Allocator rep m, Pretty u) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda rep
lam' <-
        [LParam rep] -> Body fromrep -> AllocM fromrep rep (Lambda rep)
forall fromrep torep.
Allocable fromrep torep =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
          ([Param LParamMem]
i_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
x_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param LParamMem]
y_params')
          (LambdaT fromrep -> Body fromrep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT fromrep
lam)
      (Lambda rep, b) -> AllocM fromrep rep (Lambda rep, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep
lam', b
nes)

    mkP :: VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP VName
p PrimType
pt Shape
shape u
u VName
mem IxFun
ixfun [DimIndex SubExp]
is =
      VName
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall dec. VName -> dec -> Param dec
Param VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> (Slice (TPrimExp Int64 VName) -> MemInfo SubExp u MemBind)
-> Slice (TPrimExp Int64 VName)
-> Param (MemInfo SubExp u MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> (Slice (TPrimExp Int64 VName) -> MemBind)
-> Slice (TPrimExp Int64 VName)
-> MemInfo SubExp u MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind)
-> (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName)
-> MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun (Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind))
-> Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$
        (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) ([DimIndex SubExp] -> Slice (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
is [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall dec. VName -> dec -> Param dec
Param VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, IxFun
ixfun) <- VName -> m (VName, IxFun)
forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
arr
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP VName
p PrimType
pt Shape
shape u
u VName
mem IxFun
ixfun [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> m (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> m (Param (MemInfo SubExp u MemBind)))
-> String -> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> m (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall dec. VName -> dec -> Param dec
Param VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onYParam [DimIndex SubExp]
is (Param VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      VName
mem <- Type -> Space -> m VName
forall rep (m :: * -> *).
Allocator rep m =>
Type -> Space -> m VName
allocForArray Type
arr_t Space
DefaultSpace
      let base_dims :: [TPrimExp Int64 VName]
base_dims = (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
          ixfun :: IxFun
ixfun = [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
      Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> m (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall {u}.
VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP VName
p PrimType
pt Shape
shape u
u VName
mem IxFun
ixfun [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> m (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> m (Param (MemInfo SubExp u MemBind)))
-> String -> m (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p
allocInExp ExpT fromrep
e = Mapper fromrep torep (AllocM fromrep torep)
-> ExpT fromrep -> AllocM fromrep torep (ExpT torep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper fromrep torep (AllocM fromrep torep)
alloc ExpT fromrep
e
  where
    alloc :: Mapper fromrep torep (AllocM fromrep torep)
alloc =
      Mapper Any Any (AllocM fromrep torep)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope torep -> BodyT fromrep -> AllocM fromrep torep (BodyT torep)
mapOnBody = String
-> Scope torep
-> BodyT fromrep
-> AllocM fromrep torep (BodyT torep)
forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
          mapOnRetType :: RetType fromrep -> AllocM fromrep torep (RetType torep)
mapOnRetType = String -> RetType fromrep -> AllocM fromrep torep (RetType torep)
forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
          mapOnBranchType :: BranchType fromrep -> AllocM fromrep torep (BranchType torep)
mapOnBranchType = String
-> BranchType fromrep -> AllocM fromrep torep (BranchType torep)
forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
          mapOnFParam :: FParam fromrep -> AllocM fromrep torep (FParam torep)
mapOnFParam = String -> FParam fromrep -> AllocM fromrep torep (FParam torep)
forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
          mapOnLParam :: LParam fromrep -> AllocM fromrep torep (LParam torep)
mapOnLParam = String -> LParam fromrep -> AllocM fromrep torep (LParam torep)
forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
          mapOnOp :: Op fromrep -> AllocM fromrep torep (Op torep)
mapOnOp = \Op fromrep
op -> do
            Op fromrep -> AllocM fromrep torep (Op torep)
handle <- (AllocEnv fromrep torep
 -> Op fromrep -> AllocM fromrep torep (Op torep))
-> AllocM
     fromrep torep (Op fromrep -> AllocM fromrep torep (Op torep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
            Op fromrep -> AllocM fromrep torep (Op torep)
handle Op fromrep
op
        }

subExpIxFun ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  SubExp ->
  AllocM fromrep torep (Maybe IxFun)
subExpIxFun :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
SubExp -> AllocM fromrep torep (Maybe IxFun)
subExpIxFun Constant {} = Maybe IxFun -> AllocM fromrep torep (Maybe IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe IxFun
forall a. Maybe a
Nothing
subExpIxFun (Var VName
v) = do
  LParamMem
info <- VName -> AllocM fromrep torep LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
info of
    MemArray PrimType
_ptp Shape
_shp NoUniqueness
_u (ArrayIn VName
_ IxFun
ixf) -> Maybe IxFun -> AllocM fromrep torep (Maybe IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe IxFun -> AllocM fromrep torep (Maybe IxFun))
-> Maybe IxFun -> AllocM fromrep torep (Maybe IxFun)
forall a b. (a -> b) -> a -> b
$ IxFun -> Maybe IxFun
forall a. a -> Maybe a
Just IxFun
ixf
    LParamMem
_ -> Maybe IxFun -> AllocM fromrep torep (Maybe IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe IxFun
forall a. Maybe a
Nothing

addResCtxInIfBody ::
  (Allocable fromrep torep, Allocator torep (AllocM fromrep torep)) =>
  [ExtType] ->
  Body torep ->
  [Maybe Space] ->
  [Maybe (ExtIxFun, [TPrimExp Int64 VName])] ->
  AllocM fromrep torep (Body torep, [BodyReturns])
addResCtxInIfBody :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
[ExtType]
-> Body torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (Body torep, [BranchTypeMem])
addResCtxInIfBody [ExtType]
ifrets (Body BodyDec torep
_ Stms torep
bnds Result
res) [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
substs = do
  let num_vals :: Int
num_vals = [ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
ifrets
      (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res
  ((Result
res', [BranchTypeMem]
bodyrets'), Stms torep
all_body_stms) <- AllocM fromrep torep (Result, [BranchTypeMem])
-> AllocM
     fromrep
     torep
     ((Result, [BranchTypeMem]), Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (a, Stms (Rep m))
collectStms (AllocM fromrep torep (Result, [BranchTypeMem])
 -> AllocM
      fromrep
      torep
      ((Result, [BranchTypeMem]), Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep (Result, [BranchTypeMem])
-> AllocM
     fromrep
     torep
     ((Result, [BranchTypeMem]), Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ do
    (Stm torep -> AllocM fromrep torep ())
-> Stms torep -> AllocM fromrep torep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm torep -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBinder m => Stm (Rep m) -> m ()
addStm Stms torep
bnds
    (Result
val_res', Result
ext_ses_res, Result
mem_ctx_res, [BranchTypeMem]
bodyrets, Int
total_existentials) <-
      ((Result, Result, Result, [BranchTypeMem], Int)
 -> (ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
     Maybe Space)
 -> AllocM
      fromrep torep (Result, Result, Result, [BranchTypeMem], Int))
-> (Result, Result, Result, [BranchTypeMem], Int)
-> [(ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
     Maybe Space)]
-> AllocM
     fromrep torep (Result, Result, Result, [BranchTypeMem], Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Result, Result, Result, [BranchTypeMem], Int)
-> (ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
    Maybe Space)
-> AllocM
     fromrep torep (Result, Result, Result, [BranchTypeMem], Int)
forall {torep} {fromrep} {a} {u}.
(PrettyRep fromrep, AllocOp (Op torep), OpReturns torep,
 SizeSubst (Op torep), BinderOps torep, ToExp a,
 LetDec torep ~ LParamMem, BodyDec torep ~ (),
 LParamInfo fromrep ~ Type, BranchType fromrep ~ ExtType,
 ExpDec torep ~ (), RetType torep ~ RetTypeMem,
 LParamInfo torep ~ LParamMem, BodyDec fromrep ~ (),
 FParamInfo fromrep ~ DeclType, RetType fromrep ~ DeclExtType,
 FParamInfo torep ~ FParamMem, BranchType torep ~ BranchTypeMem) =>
(Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExp, Maybe (ExtIxFun, [a]),
    Maybe Space)
-> AllocM
     fromrep
     torep
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
helper ([], [], [], [], Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
ctx_res) ([ExtType]
-> Result
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> [Maybe Space]
-> [(ExtType, SubExp, Maybe (ExtIxFun, [TPrimExp Int64 VName]),
     Maybe Space)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [ExtType]
ifrets Result
val_res [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
substs [Maybe Space]
spaces)
    (Result, [BranchTypeMem])
-> AllocM fromrep torep (Result, [BranchTypeMem])
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( Result
ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
ext_ses_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res',
        -- We need to adjust the ReturnsNewBlock existentials, because they
        -- should always be numbered _after_ all other existentials in the
        -- return values.
        [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a]
reverse ([BranchTypeMem] -> [BranchTypeMem])
-> [BranchTypeMem] -> [BranchTypeMem]
forall a b. (a -> b) -> a -> b
$ ([BranchTypeMem], Int) -> [BranchTypeMem]
forall a b. (a, b) -> a
fst (([BranchTypeMem], Int) -> [BranchTypeMem])
-> ([BranchTypeMem], Int) -> [BranchTypeMem]
forall a b. (a -> b) -> a -> b
$ (([BranchTypeMem], Int) -> BranchTypeMem -> ([BranchTypeMem], Int))
-> ([BranchTypeMem], Int)
-> [BranchTypeMem]
-> ([BranchTypeMem], Int)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BranchTypeMem], Int) -> BranchTypeMem -> ([BranchTypeMem], Int)
adjustNewBlockExistential ([], Int
total_existentials) [BranchTypeMem]
bodyrets
      )
  BodyT torep
body' <- Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms torep
Stms (Rep (AllocM fromrep torep))
all_body_stms Result
res'
  (BodyT torep, [BranchTypeMem])
-> AllocM fromrep torep (BodyT torep, [BranchTypeMem])
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT torep
body', [BranchTypeMem]
bodyrets')
  where
    helper :: (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExp, Maybe (ExtIxFun, [a]),
    Maybe Space)
-> AllocM
     fromrep
     torep
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
helper (Result
res_acc, Result
ext_acc, Result
ctx_acc, [MemInfo (Ext SubExp) u MemReturn]
br_acc, Int
k) (TypeBase ExtShape u
ifr, SubExp
r, Maybe (ExtIxFun, [a])
mbixfsub, Maybe Space
sp) =
      case Maybe (ExtIxFun, [a])
mbixfsub of
        Maybe (ExtIxFun, [a])
Nothing -> do
          -- does NOT generalize/antiunify; ensure direct
          SubExp
r' <- Maybe Space -> SubExp -> AllocM fromrep torep SubExp
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
Maybe Space -> SubExp -> AllocM fromrep torep SubExp
ensureDirect Maybe Space
sp SubExp
r
          Result
mem_ctx_r <- SubExp -> AllocM fromrep torep Result
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
SubExp -> AllocM fromrep torep Result
bodyReturnMemCtx SubExp
r'
          let body_ret :: MemInfo (Ext SubExp) u MemReturn
body_ret = TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
forall {u}.
TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
inspect TypeBase ExtShape u
ifr Maybe Space
sp
          (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromrep
     torep
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
r'],
              Result
ext_acc,
              Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_r,
              [MemInfo (Ext SubExp) u MemReturn]
br_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
body_ret],
              Int
k
            )
        Just (ExtIxFun
ixfn, [a]
m) -> do
          -- generalizes
          let i :: Int
i = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
m
          Result
ext_ses <- (a -> AllocM fromrep torep SubExp)
-> [a] -> AllocM fromrep torep Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> a -> AllocM fromrep torep SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"ixfn_exist") [a]
m
          Result
mem_ctx_r <- SubExp -> AllocM fromrep torep Result
forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
SubExp -> AllocM fromrep torep Result
bodyReturnMemCtx SubExp
r
          let sp' :: Space
sp' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
sp
              ixfn' :: ExtIxFun
ixfn' = (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName)
forall t. Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
adjustExtPE Int
k) ExtIxFun
ixfn
              exttp :: MemInfo (Ext SubExp) u MemReturn
exttp = case TypeBase ExtShape u
ifr of
                Array PrimType
pt ExtShape
shp' u
u ->
                  PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shp' u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
                    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
sp' Int
0 ExtIxFun
ixfn'
                TypeBase ExtShape u
_ -> String -> MemInfo (Ext SubExp) u MemReturn
forall a. HasCallStack => String -> a
error String
"Impossible case reached in addResCtxInIfBody"
          (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromrep
     torep
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
r],
              Result
ext_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
ext_ses,
              Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_r,
              [MemInfo (Ext SubExp) u MemReturn]
br_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
exttp],
              Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
            )

    adjustNewBlockExistential :: ([BodyReturns], Int) -> BodyReturns -> ([BodyReturns], Int)
    adjustNewBlockExistential :: ([BranchTypeMem], Int) -> BranchTypeMem -> ([BranchTypeMem], Int)
adjustNewBlockExistential ([BranchTypeMem]
acc, Int
k) (MemArray PrimType
pt ExtShape
shp NoUniqueness
u (ReturnsNewBlock Space
space Int
_ ExtIxFun
ixfun)) =
      (PrimType -> ExtShape -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shp NoUniqueness
u (Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
k ExtIxFun
ixfun) BranchTypeMem -> [BranchTypeMem] -> [BranchTypeMem]
forall a. a -> [a] -> [a]
: [BranchTypeMem]
acc, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    adjustNewBlockExistential ([BranchTypeMem]
acc, Int
k) BranchTypeMem
x = (BranchTypeMem
x BranchTypeMem -> [BranchTypeMem] -> [BranchTypeMem]
forall a. a -> [a] -> [a]
: [BranchTypeMem]
acc, Int
k)

    inspect :: TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
inspect (Array PrimType
pt ExtShape
shape u
u) Maybe Space
space =
      let space' :: Space
space' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
space
          bodyret :: MemInfo (Ext SubExp) u MemReturn
bodyret =
            PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
              Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space' Int
0 (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
                [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape
       in MemInfo (Ext SubExp) u MemReturn
bodyret
    inspect (Acc VName
acc Shape
ispace [Type]
ts u
u) Maybe Space
_ = VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u
    inspect (Prim PrimType
pt) Maybe Space
_ = PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect (Mem Space
space) Maybe Space
_ = Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExtV :: Int -> Ext VName -> Ext VName
    adjustExtV :: Int -> Ext VName -> Ext VName
adjustExtV Int
_ (Free VName
v) = VName -> Ext VName
forall a. a -> Ext a
Free VName
v
    adjustExtV Int
k (Ext Int
i) = Int -> Ext VName
forall a. Int -> Ext a
Ext (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)

    adjustExtPE :: Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
    adjustExtPE :: forall t. Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
adjustExtPE Int
k = (Ext VName -> Ext VName)
-> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Ext VName -> Ext VName
adjustExtV Int
k)

mkSpaceOks ::
  (Mem torep, LocalScope torep m) =>
  Int ->
  Body torep ->
  m [Maybe Space]
mkSpaceOks :: forall torep (m :: * -> *).
(Mem torep, LocalScope torep m) =>
Int -> Body torep -> m [Maybe Space]
mkSpaceOks Int
num_vals (Body BodyDec torep
_ Stms torep
stms Result
res) =
  Stms torep -> m [Maybe Space] -> m [Maybe Space]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
stms (m [Maybe Space] -> m [Maybe Space])
-> m [Maybe Space] -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$
    (SubExp -> m (Maybe Space)) -> Result -> m [Maybe Space]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m (Maybe Space)
forall {rep} {m :: * -> *}.
(HasScope rep m, AllocOp (Op rep), Monad m, ASTRep rep,
 OpReturns rep, LParamInfo rep ~ LParamMem, LetDec rep ~ LParamMem,
 BranchType rep ~ BranchTypeMem, FParamInfo rep ~ FParamMem,
 RetType rep ~ RetTypeMem) =>
SubExp -> m (Maybe Space)
mkSpaceOK (Result -> m [Maybe Space]) -> Result -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast Int
num_vals Result
res
  where
    mkSpaceOK :: SubExp -> m (Maybe Space)
mkSpaceOK (Var VName
v) = do
      LParamMem
v_info <- VName -> m LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
v
      case LParamMem
v_info of
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
          LParamMem
mem_info <- VName -> m LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
mem
          case LParamMem
mem_info of
            MemMem Space
space -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Space -> m (Maybe Space)) -> Maybe Space -> m (Maybe Space)
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space
            LParamMem
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
        LParamMem
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
    mkSpaceOK SubExp
_ = Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing

allocInLoopForm ::
  ( Allocable fromrep torep,
    Allocator torep (AllocM fromrep torep)
  ) =>
  LoopForm fromrep ->
  AllocM fromrep torep (LoopForm torep)
allocInLoopForm :: forall fromrep torep.
(Allocable fromrep torep,
 Allocator torep (AllocM fromrep torep)) =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm (WhileLoop VName
v) = LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopForm torep -> AllocM fromrep torep (LoopForm torep))
-> LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall a b. (a -> b) -> a -> b
$ VName -> LoopForm torep
forall rep. VName -> LoopForm rep
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromrep, VName)]
loopvars) =
  VName
-> IntType -> SubExp -> [(LParam torep, VName)] -> LoopForm torep
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
n ([(Param LParamMem, VName)] -> LoopForm torep)
-> AllocM fromrep torep [(Param LParamMem, VName)]
-> AllocM fromrep torep (LoopForm torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
 -> AllocM fromrep torep (Param LParamMem, VName))
-> [(Param Type, VName)]
-> AllocM fromrep torep [(Param LParamMem, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar [(Param Type, VName)]
[(LParam fromrep, VName)]
loopvars
  where
    allocInLoopVar :: (Param Type, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
allocInLoopVar (Param Type
p, VName
a) = do
      (VName
mem, IxFun
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun)
forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
a
      case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p of
        Array PrimType
pt Shape
shape NoUniqueness
u -> do
          [TPrimExp Int64 VName]
dims <- (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> (Type -> Result) -> Type -> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> [TPrimExp Int64 VName])
-> AllocM fromrep torep Type
-> AllocM fromrep torep [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
          let ixfun' :: IxFun
ixfun' =
                IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
                  [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i]
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun'}, VName
a)
        Prim PrimType
bt ->
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}, VName
a)
        Mem Space
space ->
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}, VName
a)
        Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
          (Param LParamMem, VName)
-> AllocM fromrep torep (Param LParamMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LParamMem
paramDec = VName -> Shape -> [Type] -> NoUniqueness -> LParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}, VName
a)

class SizeSubst op where
  opSizeSubst :: PatternT dec -> op -> ChunkMap
  opIsConst :: op -> Bool
  opIsConst = Bool -> op -> Bool
forall a b. a -> b -> a
const Bool
False

instance SizeSubst () where
  opSizeSubst :: forall dec. PatternT dec -> () -> Map VName SubExp
opSizeSubst PatternT dec
_ ()
_ = Map VName SubExp
forall a. Monoid a => a
mempty

instance SizeSubst op => SizeSubst (MemOp op) where
  opSizeSubst :: forall dec. PatternT dec -> MemOp op -> Map VName SubExp
opSizeSubst PatternT dec
pat (Inner op
op) = PatternT dec -> op -> Map VName SubExp
forall op dec.
SizeSubst op =>
PatternT dec -> op -> Map VName SubExp
opSizeSubst PatternT dec
pat op
op
  opSizeSubst PatternT dec
_ MemOp op
_ = Map VName SubExp
forall a. Monoid a => a
mempty

  opIsConst :: MemOp op -> Bool
opIsConst (Inner op
op) = op -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst op
op
  opIsConst MemOp op
_ = Bool
False

sizeSubst :: SizeSubst (Op rep) => Stm rep -> ChunkMap
sizeSubst :: forall rep. SizeSubst (Op rep) => Stm rep -> Map VName SubExp
sizeSubst (Let Pattern rep
pat StmAux (ExpDec rep)
_ (Op Op rep
op)) = Pattern rep -> Op rep -> Map VName SubExp
forall op dec.
SizeSubst op =>
PatternT dec -> op -> Map VName SubExp
opSizeSubst Pattern rep
pat Op rep
op
sizeSubst Stm rep
_ = Map VName SubExp
forall a. Monoid a => a
mempty

stmConsts :: SizeSubst (Op rep) => Stm rep -> S.Set VName
stmConsts :: forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts (Let Pattern rep
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
  | Op rep -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat
stmConsts Stm rep
_ = Set VName
forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( Op (Rep m) ~ MemOp inner,
    MonadBinder m,
    ExpDec (Rep m) ~ (),
    Allocator (Rep m) (PatAllocM (Rep m))
  ) =>
  ExpDec (Rep m) ->
  [VName] ->
  Exp (Rep m) ->
  m (Stm (Rep m))
mkLetNamesB' :: forall (m :: * -> *) inner.
(Op (Rep m) ~ MemOp inner, MonadBinder m, ExpDec (Rep m) ~ (),
 Allocator (Rep m) (PatAllocM (Rep m))) =>
ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
  Scope (Rep m)
scope <- m (Scope (Rep m))
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  PatternT (LetDec (Rep m))
pat <- Scope (Rep m)
-> [VName] -> Exp (Rep m) -> m (PatternT (LetDec (Rep m)))
forall (m :: * -> *) rep inner.
(MonadBinder m, ExpDec rep ~ (), Op (Rep m) ~ MemOp inner,
 Allocator rep (PatAllocM rep)) =>
Scope rep -> [VName] -> Exp rep -> m (Pattern rep)
bindPatternWithAllocations Scope (Rep m)
scope [VName]
names Exp (Rep m)
e
  Stm (Rep m) -> m (Stm (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Rep m) -> m (Stm (Rep m))) -> Stm (Rep m) -> m (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec (Rep m))
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let PatternT (LetDec (Rep m))
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
ExpDec (Rep m)
dec) Exp (Rep m)
e

mkLetNamesB'' ::
  ( Op (Rep m) ~ MemOp inner,
    ExpDec rep ~ (),
    HasScope (Engine.Wise rep) m,
    Allocator rep (PatAllocM rep),
    MonadBinder m,
    Engine.CanBeWise (Op rep)
  ) =>
  [VName] ->
  Exp (Engine.Wise rep) ->
  m (Stm (Engine.Wise rep))
mkLetNamesB'' :: forall (m :: * -> *) inner rep.
(Op (Rep m) ~ MemOp inner, ExpDec rep ~ (), HasScope (Wise rep) m,
 Allocator rep (PatAllocM rep), MonadBinder m,
 CanBeWise (Op rep)) =>
[VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' [VName]
names Exp (Wise rep)
e = do
  Scope rep
scope <- Scope (Wise rep) -> Scope rep
forall rep. Scope (Wise rep) -> Scope rep
Engine.removeScopeWisdom (Scope (Wise rep) -> Scope rep)
-> m (Scope (Wise rep)) -> m (Scope rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope (Wise rep))
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  (PatternT LParamMem
pat, [AllocStm]
prestms) <- PatAllocM rep (PatternT LParamMem)
-> Scope rep -> m (PatternT LParamMem, [AllocStm])
forall (m :: * -> *) rep a.
MonadFreshNames m =>
PatAllocM rep a -> Scope rep -> m (a, [AllocStm])
runPatAllocM ([VName] -> Exp rep -> PatAllocM rep (Pattern rep)
forall rep (m :: * -> *).
(Allocator rep m, ExpDec rep ~ ()) =>
[VName] -> Exp rep -> m (Pattern rep)
patternWithAllocations [VName]
names (Exp rep -> PatAllocM rep (Pattern rep))
-> Exp rep -> PatAllocM rep (Pattern rep)
forall a b. (a -> b) -> a -> b
$ Exp (Wise rep) -> Exp rep
forall rep. CanBeWise (Op rep) => Exp (Wise rep) -> Exp rep
Engine.removeExpWisdom Exp (Wise rep)
e) Scope rep
scope
  (AllocStm -> m ()) -> [AllocStm] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AllocStm -> m ()
forall (m :: * -> *) inner.
(MonadBinder m, Op (Rep m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm [AllocStm]
prestms
  let pat' :: Pattern (Wise rep)
pat' = Pattern rep -> Exp (Wise rep) -> Pattern (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pattern rep -> Exp (Wise rep) -> Pattern (Wise rep)
Engine.addWisdomToPattern Pattern rep
PatternT LParamMem
pat Exp (Wise rep)
e
      dec :: ExpDec (Wise rep)
dec = Pattern (Wise rep)
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pattern (Wise rep)
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pattern (Wise rep)
pat' () Exp (Wise rep)
e
  Stm (Wise rep) -> m (Stm (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Wise rep) -> m (Stm (Wise rep)))
-> Stm (Wise rep) -> m (Stm (Wise rep))
forall a b. (a -> b) -> a -> b
$ Pattern (Wise rep)
-> StmAux (ExpDec (Wise rep)) -> Exp (Wise rep) -> Stm (Wise rep)
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern (Wise rep)
pat' ((ExpWisdom, ()) -> StmAux (ExpWisdom, ())
forall dec. dec -> StmAux dec
defAux (ExpWisdom, ())
ExpDec (Wise rep)
dec) Exp (Wise rep)
e

simplifiable ::
  ( Engine.SimplifiableRep rep,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    Op rep ~ MemOp inner,
    Allocator rep (PatAllocM rep)
  ) =>
  (Engine.OpWithWisdom inner -> UT.UsageTable) ->
  (inner -> Engine.SimpleM rep (Engine.OpWithWisdom inner, Stms (Engine.Wise rep))) ->
  SimpleOps rep
simplifiable :: forall rep inner.
(SimplifiableRep rep, ExpDec rep ~ (), BodyDec rep ~ (),
 Op rep ~ MemOp inner, Allocator rep (PatAllocM rep)) =>
(OpWithWisdom inner -> UsageTable)
-> (inner -> SimpleM rep (OpWithWisdom inner, Stms (Wise rep)))
-> SimpleOps rep
simplifiable OpWithWisdom inner -> UsageTable
innerUsage inner -> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp =
  (SymbolTable (Wise rep)
 -> Pattern (Wise rep)
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (BinderT (Wise rep) (State VNameSource))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op rep)
-> SimpleOps rep
forall rep.
(SymbolTable (Wise rep)
 -> Pattern (Wise rep)
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Binder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op rep)
-> SimpleOps rep
SimpleOps SymbolTable (Wise rep)
-> Pattern (Wise rep)
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
forall {m :: * -> *} {rep} {p}.
(Monad m, ASTRep rep, CanBeWise (Op rep), ExpDec rep ~ ()) =>
p
-> PatternT (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> m (ExpWisdom, ExpDec rep)
mkExpDecS' SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall {m :: * -> *} {rep} {p}.
(Monad m, ASTRep rep, CanBeWise (Op rep), BodyDec rep ~ ()) =>
p -> Stms (Wise rep) -> Result -> m (Body (Wise rep))
mkBodyS' Protect (BinderT (Wise rep) (State VNameSource))
forall {m :: * -> *} {d} {u} {ret} {inner} {inner}.
(MonadBinder m, BranchType (Rep m) ~ MemInfo d u ret,
 Op (Rep m) ~ MemOp inner) =>
SubExp -> PatternT (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp Op (Wise rep) -> UsageTable
MemOp (OpWithWisdom inner) -> UsageTable
opUsage SimplifyOp rep (Op rep)
MemOp inner
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp
  where
    mkExpDecS' :: p
-> PatternT (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> m (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ PatternT (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
      (ExpWisdom, ExpDec rep) -> m (ExpWisdom, ExpDec rep)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ExpDec rep) -> m (ExpWisdom, ExpDec rep))
-> (ExpWisdom, ExpDec rep) -> m (ExpWisdom, ExpDec rep)
forall a b. (a -> b) -> a -> b
$ Pattern (Wise rep)
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pattern (Wise rep)
-> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec PatternT (VarWisdom, LetDec rep)
Pattern (Wise rep)
pat () Exp (Wise rep)
e

    mkBodyS' :: p -> Stms (Wise rep) -> Result -> m (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
bnds Result
res = Body (Wise rep) -> m (Body (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise rep) -> m (Body (Wise rep)))
-> Body (Wise rep) -> m (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
bnds Result
res

    protectOp :: SubExp -> PatternT (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp SubExp
taken PatternT (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
      BodyT (Rep m)
tbody <- Result -> m (BodyT (Rep m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Rep m))
resultBodyM [SubExp
size]
      BodyT (Rep m)
fbody <- Result -> m (BodyT (Rep m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          SubExp
-> BodyT (Rep m)
-> BodyT (Rep m)
-> IfDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
taken BodyT (Rep m)
tbody BodyT (Rep m)
fbody (IfDec (BranchType (Rep m)) -> Exp (Rep m))
-> IfDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [MemInfo d u ret] -> IfSort -> IfDec (MemInfo d u ret)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] IfSort
IfFallback
      PatternT (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind PatternT (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
    protectOp SubExp
_ PatternT (LetDec (Rep m))
_ MemOp inner
_ = Maybe (m ())
forall a. Maybe a
Nothing

    opUsage :: MemOp (OpWithWisdom inner) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      UsageTable
forall a. Monoid a => a
mempty
    opUsage (Inner OpWithWisdom inner
inner) =
      OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
inner

    simplifyOp :: MemOp inner
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp (Alloc SubExp
size Space
space) =
      (,) (MemOp (OpWithWisdom inner)
 -> Stms (Wise rep)
 -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
-> SimpleM rep (MemOp (OpWithWisdom inner))
-> SimpleM
     rep
     (Stms (Wise rep) -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Space -> MemOp (OpWithWisdom inner)
forall inner. SubExp -> Space -> MemOp inner
Alloc (SubExp -> Space -> MemOp (OpWithWisdom inner))
-> SimpleM rep SubExp
-> SimpleM rep (Space -> MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size SimpleM rep (Space -> MemOp (OpWithWisdom inner))
-> SimpleM rep Space -> SimpleM rep (MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> SimpleM rep Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) SimpleM
  rep
  (Stms (Wise rep) -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise rep)
forall a. Monoid a => a
mempty
    simplifyOp (Inner inner
k) = do
      (OpWithWisdom inner
k', Stms (Wise rep)
hoisted) <- inner -> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp inner
k
      (MemOp (OpWithWisdom inner), Stms (Wise rep))
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithWisdom inner -> MemOp (OpWithWisdom inner)
forall inner. inner -> MemOp inner
Inner OpWithWisdom inner
k', Stms (Wise rep)
hoisted)

bindPatternWithAllocations ::
  ( MonadBinder m,
    ExpDec rep ~ (),
    Op (Rep m) ~ MemOp inner,
    Allocator rep (PatAllocM rep)
  ) =>
  Scope rep ->
  [VName] ->
  Exp rep ->
  m (Pattern rep)
bindPatternWithAllocations :: forall (m :: * -> *) rep inner.
(MonadBinder m, ExpDec rep ~ (), Op (Rep m) ~ MemOp inner,
 Allocator rep (PatAllocM rep)) =>
Scope rep -> [VName] -> Exp rep -> m (Pattern rep)
bindPatternWithAllocations Scope rep
types [VName]
names Exp rep
e = do
  (PatternT LParamMem
pat, [AllocStm]
prebnds) <- PatAllocM rep (PatternT LParamMem)
-> Scope rep -> m (PatternT LParamMem, [AllocStm])
forall (m :: * -> *) rep a.
MonadFreshNames m =>
PatAllocM rep a -> Scope rep -> m (a, [AllocStm])
runPatAllocM ([VName] -> Exp rep -> PatAllocM rep (Pattern rep)
forall rep (m :: * -> *).
(Allocator rep m, ExpDec rep ~ ()) =>
[VName] -> Exp rep -> m (Pattern rep)
patternWithAllocations [VName]
names Exp rep
e) Scope rep
types
  (AllocStm -> m ()) -> [AllocStm] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AllocStm -> m ()
forall (m :: * -> *) inner.
(MonadBinder m, Op (Rep m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm [AllocStm]
prebnds
  PatternT LParamMem -> m (PatternT LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return PatternT LParamMem
pat

data ExpHint
  = NoHint
  | Hint IxFun Space

defaultExpHints :: (Monad m, ASTRep rep) => Exp rep -> m [ExpHint]
defaultExpHints :: forall (m :: * -> *) rep.
(Monad m, ASTRep rep) =>
Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp rep -> Int
forall rep. (RepTypes rep, TypedOp (Op rep)) => Exp rep -> Int
expExtTypeSize Exp rep
e) ExpHint
NoHint