{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | A generic transformation for adding memory allocations to a
-- Futhark program.  Specialised by specific representations in
-- submodules.
module Futhark.Pass.ExplicitAllocations
  ( explicitAllocationsGeneric,
    explicitAllocationsInStmsGeneric,
    ExpHint (..),
    defaultExpHints,
    Allocable,
    AllocM,
    AllocEnv (..),
    SizeSubst (..),
    allocInStms,
    allocForArray,
    simplifiable,
    arraySizeInBytesExp,
    mkLetNamesB',
    mkLetNamesB'',
    dimAllocationSize,
    ChunkMap,

    -- * Module re-exports

    --
    -- These are highly likely to be needed by any downstream
    -- users.
    module Control.Monad.Reader,
    module Futhark.MonadFreshNames,
    module Futhark.Pass,
    module Futhark.Tools,
  )
where

import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (foldl', partition, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep (mkWiseBody)
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (maybeNth, splitAt3, splitFromEnd, takeLast)

-- | The subexpression giving the number of elements we should
-- allocate space for.  See 'ChunkMap' comment.
dimAllocationSize :: ChunkMap -> SubExp -> SubExp
dimAllocationSize :: ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap (Var VName
v) =
  -- It is important to recurse here, as the substitution may itself
  -- be a chunk size.
  SubExp -> (SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (VName -> SubExp
Var VName
v) (ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> ChunkMap -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ChunkMap
chunkmap
dimAllocationSize ChunkMap
_ SubExp
size =
  SubExp
size

type Allocable fromrep torep inner =
  ( PrettyRep fromrep,
    PrettyRep torep,
    Mem torep inner,
    LetDec torep ~ LetDecMem,
    FParamInfo fromrep ~ DeclType,
    LParamInfo fromrep ~ Type,
    BranchType fromrep ~ ExtType,
    RetType fromrep ~ DeclExtType,
    BodyDec fromrep ~ (),
    BodyDec torep ~ (),
    ExpDec torep ~ (),
    SizeSubst inner,
    BuilderOps torep
  )

-- | A mapping from chunk names to their maximum size.  XXX FIXME
-- HACK: This is part of a hack to add loop-invariant allocations to
-- reduce kernels, because memory expansion does not use range
-- analysis yet (it should).
type ChunkMap = M.Map VName SubExp

data AllocEnv fromrep torep = AllocEnv
  { AllocEnv fromrep torep -> ChunkMap
chunkMap :: ChunkMap,
    -- | Aggressively try to reuse memory in do-loops -
    -- should be True inside kernels, False outside.
    AllocEnv fromrep torep -> Bool
aggressiveReuse :: Bool,
    -- | When allocating memory, put it in this memory space.
    -- This is primarily used to ensure that group-wide
    -- statements store their results in local memory.
    AllocEnv fromrep torep -> Space
allocSpace :: Space,
    -- | The set of names that are known to be constants at
    -- kernel compile time.
    AllocEnv fromrep torep -> Set VName
envConsts :: S.Set VName,
    AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep),
    AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
  }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromrep torep a
  = AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a)
  deriving
    ( Functor (AllocM fromrep torep)
a -> AllocM fromrep torep a
Functor (AllocM fromrep torep)
-> (forall a. a -> AllocM fromrep torep a)
-> (forall a b.
    AllocM fromrep torep (a -> b)
    -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b c.
    (a -> b -> c)
    -> AllocM fromrep torep a
    -> AllocM fromrep torep b
    -> AllocM fromrep torep c)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Applicative (AllocM fromrep torep)
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall a. a -> AllocM fromrep torep a
forall fromrep torep. Functor (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
$c<* :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep a
*> :: AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c*> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
liftA2 :: (a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
$cliftA2 :: forall fromrep torep a b c.
(a -> b -> c)
-> AllocM fromrep torep a
-> AllocM fromrep torep b
-> AllocM fromrep torep c
<*> :: AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
$c<*> :: forall fromrep torep a b.
AllocM fromrep torep (a -> b)
-> AllocM fromrep torep a -> AllocM fromrep torep b
pure :: a -> AllocM fromrep torep a
$cpure :: forall fromrep torep a. a -> AllocM fromrep torep a
$cp1Applicative :: forall fromrep torep. Functor (AllocM fromrep torep)
Applicative,
      a -> AllocM fromrep torep b -> AllocM fromrep torep a
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
(forall a b.
 (a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b)
-> (forall a b.
    a -> AllocM fromrep torep b -> AllocM fromrep torep a)
-> Functor (AllocM fromrep torep)
forall a b. a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> AllocM fromrep torep b -> AllocM fromrep torep a
$c<$ :: forall fromrep torep a b.
a -> AllocM fromrep torep b -> AllocM fromrep torep a
fmap :: (a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
$cfmap :: forall fromrep torep a b.
(a -> b) -> AllocM fromrep torep a -> AllocM fromrep torep b
Functor,
      Applicative (AllocM fromrep torep)
a -> AllocM fromrep torep a
Applicative (AllocM fromrep torep)
-> (forall a b.
    AllocM fromrep torep a
    -> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b)
-> (forall a b.
    AllocM fromrep torep a
    -> AllocM fromrep torep b -> AllocM fromrep torep b)
-> (forall a. a -> AllocM fromrep torep a)
-> Monad (AllocM fromrep torep)
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a. a -> AllocM fromrep torep a
forall fromrep torep. Applicative (AllocM fromrep torep)
forall a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall fromrep torep a. a -> AllocM fromrep torep a
forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> AllocM fromrep torep a
$creturn :: forall fromrep torep a. a -> AllocM fromrep torep a
>> :: AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
$c>> :: forall fromrep torep a b.
AllocM fromrep torep a
-> AllocM fromrep torep b -> AllocM fromrep torep b
>>= :: AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$c>>= :: forall fromrep torep a b.
AllocM fromrep torep a
-> (a -> AllocM fromrep torep b) -> AllocM fromrep torep b
$cp1Monad :: forall fromrep torep. Applicative (AllocM fromrep torep)
Monad,
      Monad (AllocM fromrep torep)
Applicative (AllocM fromrep torep)
AllocM fromrep torep VNameSource
Applicative (AllocM fromrep torep)
-> Monad (AllocM fromrep torep)
-> AllocM fromrep torep VNameSource
-> (VNameSource -> AllocM fromrep torep ())
-> MonadFreshNames (AllocM fromrep torep)
VNameSource -> AllocM fromrep torep ()
forall fromrep torep. Monad (AllocM fromrep torep)
forall fromrep torep. Applicative (AllocM fromrep torep)
forall fromrep torep. AllocM fromrep torep VNameSource
forall fromrep torep. VNameSource -> AllocM fromrep torep ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromrep torep ()
$cputNameSource :: forall fromrep torep. VNameSource -> AllocM fromrep torep ()
getNameSource :: AllocM fromrep torep VNameSource
$cgetNameSource :: forall fromrep torep. AllocM fromrep torep VNameSource
$cp2MonadFreshNames :: forall fromrep torep. Monad (AllocM fromrep torep)
$cp1MonadFreshNames :: forall fromrep torep. Applicative (AllocM fromrep torep)
MonadFreshNames,
      HasScope torep,
      LocalScope torep,
      MonadReader (AllocEnv fromrep torep)
    )

instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where
  type Rep (AllocM fromrep torep) = torep

  mkExpDecM :: Pat (Rep (AllocM fromrep torep))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
mkExpDecM Pat (Rep (AllocM fromrep torep))
_ Exp (Rep (AllocM fromrep torep))
_ = () -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  mkLetNamesM :: [VName]
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (Stm (Rep (AllocM fromrep torep)))
mkLetNamesM [VName]
names Exp (Rep (AllocM fromrep torep))
e = do
    Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
    ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap)
-> AllocM fromrep torep ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
    [ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
Exp (Rep (AllocM fromrep torep))
e
    PatT LetDecMem
pat <- Space
-> ChunkMap
-> [VName]
-> Exp (Rep (AllocM fromrep torep))
-> [ExpHint]
-> AllocM fromrep torep (PatT LetDecMem)
forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (PatT LetDecMem)
patWithAllocations Space
def_space ChunkMap
chunkmap [VName]
names Exp (Rep (AllocM fromrep torep))
e [ExpHint]
hints
    Stm torep -> AllocM fromrep torep (Stm torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat torep -> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat torep
PatT LetDecMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp torep
Exp (Rep (AllocM fromrep torep))
e

  mkBodyM :: Stms (Rep (AllocM fromrep torep))
-> Result
-> AllocM fromrep torep (Body (Rep (AllocM fromrep torep)))
mkBodyM Stms (Rep (AllocM fromrep torep))
stms Result
res = BodyT torep -> AllocM fromrep torep (BodyT torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT torep -> AllocM fromrep torep (BodyT torep))
-> BodyT torep -> AllocM fromrep torep (BodyT torep)
forall a b. (a -> b) -> a -> b
$ BodyDec torep -> Stms torep -> Result -> BodyT torep
forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body () Stms torep
Stms (Rep (AllocM fromrep torep))
stms Result
res

  addStms :: Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
addStms = BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
-> AllocM fromrep torep ()
forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
   torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
 -> AllocM fromrep torep ())
-> (Stms torep
    -> BuilderT
         torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ())
-> Stms torep
-> AllocM fromrep torep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> BuilderT
     torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: AllocM fromrep torep a
-> AllocM fromrep torep (a, Stms (Rep (AllocM fromrep torep)))
collectStms (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) = BuilderT
  torep
  (ReaderT (AllocEnv fromrep torep) (State VNameSource))
  (a, Stms torep)
-> AllocM fromrep torep (a, Stms torep)
forall fromrep torep a.
BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> AllocM fromrep torep a
AllocM (BuilderT
   torep
   (ReaderT (AllocEnv fromrep torep) (State VNameSource))
   (a, Stms torep)
 -> AllocM fromrep torep (a, Stms torep))
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a, Stms torep)
-> AllocM fromrep torep (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> BuilderT
     torep
     (ReaderT (AllocEnv fromrep torep) (State VNameSource))
     (a,
      Stms
        (Rep
           (BuilderT
              torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m

expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints :: Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e = do
  Exp torep -> AllocM fromrep torep [ExpHint]
f <- (AllocEnv fromrep torep
 -> Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM
     fromrep torep (Exp torep -> AllocM fromrep torep [ExpHint])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
forall fromrep torep.
AllocEnv fromrep torep
-> Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints
  Exp torep -> AllocM fromrep torep [ExpHint]
f Exp torep
e

askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace :: AllocM fromrep torep Space
askDefaultSpace = (AllocEnv fromrep torep -> Space) -> AllocM fromrep torep Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> Space
forall fromrep torep. AllocEnv fromrep torep -> Space
allocSpace

runAllocM ::
  MonadFreshNames m =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  AllocM fromrep torep a ->
  m a
runAllocM :: (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m) =
  ((a, Stms torep) -> a) -> m (a, Stms torep) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms torep) -> a
forall a b. (a, b) -> a
fst (m (a, Stms torep) -> m a) -> m (a, Stms torep) -> m a
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms torep), VNameSource))
 -> m (a, Stms torep))
-> (VNameSource -> ((a, Stms torep), VNameSource))
-> m (a, Stms torep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms torep)
-> VNameSource -> ((a, Stms torep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms torep)
 -> VNameSource -> ((a, Stms torep), VNameSource))
-> State VNameSource (a, Stms torep)
-> VNameSource
-> ((a, Stms torep), VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
-> AllocEnv fromrep torep -> State VNameSource (a, Stms torep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
-> Scope torep
-> ReaderT
     (AllocEnv fromrep torep) (State VNameSource) (a, Stms torep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT
  torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a
m Scope torep
forall a. Monoid a => a
mempty) AllocEnv fromrep torep
env
  where
    env :: AllocEnv fromrep torep
env =
      AllocEnv :: forall fromrep torep.
ChunkMap
-> Bool
-> Space
-> Set VName
-> (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocEnv fromrep torep
AllocEnv
        { chunkMap :: ChunkMap
chunkMap = ChunkMap
forall a. Monoid a => a
mempty,
          aggressiveReuse :: Bool
aggressiveReuse = Bool
False,
          allocSpace :: Space
allocSpace = Space
DefaultSpace,
          envConsts :: Set VName
envConsts = Set VName
forall a. Monoid a => a
mempty,
          allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp = Op fromrep -> AllocM fromrep torep (Op torep)
handleOp,
          envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint]
envExpHints = Exp torep -> AllocM fromrep torep [ExpHint]
hints
        }

elemSize :: Num a => Type -> a
elemSize :: Type -> a
elemSize = PrimType -> a
forall a. Num a => PrimType -> a
primByteSize (PrimType -> a) -> (Type -> PrimType) -> Type -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) (Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t) ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

arraySizeInBytesExpM :: MonadBuilder m => ChunkMap -> Type -> m (PrimExp VName)
arraySizeInBytesExpM :: ChunkMap -> Type -> m (PrimExp VName)
arraySizeInBytesExpM ChunkMap
chunkmap Type
t = do
  let dim_prod_i64 :: TPrimExp Int64 VName
dim_prod_i64 = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (SubExp -> SubExp) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ChunkMap -> SubExp -> SubExp
dimAllocationSize ChunkMap
chunkmap) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
      elm_size_i64 :: TPrimExp Int64 VName
elm_size_i64 = Type -> TPrimExp Int64 VName
forall a. Num a => Type -> a
elemSize Type
t
  PrimExp VName -> m (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> m (PrimExp VName))
-> PrimExp VName -> m (PrimExp VName)
forall a b. (a -> b) -> a -> b
$
    BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> BinOp
SMax IntType
Int64) (PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp VName) -> PrimValue -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0) (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
dim_prod_i64 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
elm_size_i64

arraySizeInBytes :: MonadBuilder m => ChunkMap -> Type -> m SubExp
arraySizeInBytes :: ChunkMap -> Type -> m SubExp
arraySizeInBytes ChunkMap
chunkmap = String -> ExpT (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (ExpT (Rep m) -> m SubExp)
-> (Type -> m (ExpT (Rep m))) -> Type -> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (ExpT (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> m (ExpT (Rep m)))
-> (Type -> m (PrimExp VName)) -> Type -> m (ExpT (Rep m))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ChunkMap -> Type -> m (PrimExp VName)
forall (m :: * -> *).
MonadBuilder m =>
ChunkMap -> Type -> m (PrimExp VName)
arraySizeInBytesExpM ChunkMap
chunkmap

allocForArray' ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
  ChunkMap ->
  Type ->
  Space ->
  m VName
allocForArray' :: ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
t Space
space = do
  SubExp
size <- ChunkMap -> Type -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
ChunkMap -> Type -> m SubExp
arraySizeInBytes ChunkMap
chunkmap Type
t
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space

-- | Allocate memory for a value of the given type.
allocForArray ::
  Allocable fromrep torep inner =>
  Type ->
  Space ->
  AllocM fromrep torep VName
allocForArray :: Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space = do
  ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap)
-> AllocM fromrep torep ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
  ChunkMap -> Type -> Space -> AllocM fromrep torep VName
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
t Space
space

allocsForStm ::
  (Allocable fromrep torep inner) =>
  [Ident] ->
  Exp torep ->
  AllocM fromrep torep (Stm torep)
allocsForStm :: [Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm [Ident]
idents Exp torep
e = do
  Space
def_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap)
-> AllocM fromrep torep ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
  [ExpHint]
hints <- Exp torep -> AllocM fromrep torep [ExpHint]
forall torep fromrep. Exp torep -> AllocM fromrep torep [ExpHint]
expHints Exp torep
e
  [ExpReturns]
rts <- Exp torep -> AllocM fromrep torep [ExpReturns]
forall (m :: * -> *) rep inner.
(Monad m, LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp torep
e
  [PatElemT LetDecMem]
pes <- Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> AllocM fromrep torep [PatElemT LetDecMem]
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElemT LetDecMem]
allocsForPat Space
def_space ChunkMap
chunkmap [Ident]
idents [ExpReturns]
rts [ExpHint]
hints
  ()
dec <- Pat (Rep (AllocM fromrep torep))
-> Exp (Rep (AllocM fromrep torep))
-> AllocM fromrep torep (ExpDec (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM ([PatElemT LetDecMem] -> PatT LetDecMem
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT LetDecMem]
pes) Exp torep
Exp (Rep (AllocM fromrep torep))
e
  Stm torep -> AllocM fromrep torep (Stm torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm torep -> AllocM fromrep torep (Stm torep))
-> Stm torep -> AllocM fromrep torep (Stm torep)
forall a b. (a -> b) -> a -> b
$ Pat torep -> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT LetDecMem] -> PatT LetDecMem
forall dec. [PatElemT dec] -> PatT dec
Pat [PatElemT LetDecMem]
pes) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
dec) Exp torep
e

patWithAllocations ::
  (MonadBuilder m, Mem (Rep m) inner) =>
  Space ->
  ChunkMap ->
  [VName] ->
  Exp (Rep m) ->
  [ExpHint] ->
  m (PatT LetDecMem)
patWithAllocations :: Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (PatT LetDecMem)
patWithAllocations Space
def_space ChunkMap
chunkmap [VName]
names Exp (Rep m)
e [ExpHint]
hints = do
  [Type]
ts' <- [VName] -> [TypeBase ExtShape NoUniqueness] -> [Type]
forall u. [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names ([TypeBase ExtShape NoUniqueness] -> [Type])
-> m [TypeBase ExtShape NoUniqueness] -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Rep m) -> m [TypeBase ExtShape NoUniqueness]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [TypeBase ExtShape NoUniqueness]
expExtType Exp (Rep m)
e
  let idents :: [Ident]
idents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
names [Type]
ts'
  [ExpReturns]
rts <- Exp (Rep m) -> m [ExpReturns]
forall (m :: * -> *) rep inner.
(Monad m, LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp (Rep m)
e
  [PatElemT LetDecMem] -> PatT LetDecMem
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT LetDecMem] -> PatT LetDecMem)
-> m [PatElemT LetDecMem] -> m (PatT LetDecMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElemT LetDecMem]
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElemT LetDecMem]
allocsForPat Space
def_space ChunkMap
chunkmap [Ident]
idents [ExpReturns]
rts [ExpHint]
hints

mkMissingIdents :: MonadFreshNames m => [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents :: [Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
idents [ExpReturns]
rts =
  [Ident] -> [Ident]
forall a. [a] -> [a]
reverse ([Ident] -> [Ident]) -> m [Ident] -> m [Ident]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExpReturns -> Maybe Ident -> m Ident)
-> [ExpReturns] -> [Maybe Ident] -> m [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExpReturns -> Maybe Ident -> m Ident
forall (f :: * -> *) d u ret.
MonadFreshNames f =>
MemInfo d u ret -> Maybe Ident -> f Ident
f ([ExpReturns] -> [ExpReturns]
forall a. [a] -> [a]
reverse [ExpReturns]
rts) ((Ident -> Maybe Ident) -> [Ident] -> [Maybe Ident]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Maybe Ident
forall a. a -> Maybe a
Just ([Ident] -> [Ident]
forall a. [a] -> [a]
reverse [Ident]
idents) [Maybe Ident] -> [Maybe Ident] -> [Maybe Ident]
forall a. [a] -> [a] -> [a]
++ Maybe Ident -> [Maybe Ident]
forall a. a -> [a]
repeat Maybe Ident
forall a. Maybe a
Nothing)
  where
    f :: MemInfo d u ret -> Maybe Ident -> f Ident
f MemInfo d u ret
_ (Just Ident
ident) = Ident -> f Ident
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ident
ident
    f (MemMem Space
space) Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext_mem" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
    f MemInfo d u ret
_ Maybe Ident
Nothing = String -> Type -> f Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"ext" (Type -> f Ident) -> Type -> f Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

allocsForPat ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
  Space ->
  ChunkMap ->
  [Ident] ->
  [ExpReturns] ->
  [ExpHint] ->
  m [PatElemT LetDecMem]
allocsForPat :: Space
-> ChunkMap
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m [PatElemT LetDecMem]
allocsForPat Space
def_space ChunkMap
chunkmap [Ident]
some_idents [ExpReturns]
rts [ExpHint]
hints = do
  [Ident]
idents <- [Ident] -> [ExpReturns] -> m [Ident]
forall (m :: * -> *).
MonadFreshNames m =>
[Ident] -> [ExpReturns] -> m [Ident]
mkMissingIdents [Ident]
some_idents [ExpReturns]
rts

  [(Ident, ExpReturns, ExpHint)]
-> ((Ident, ExpReturns, ExpHint) -> m (PatElemT LetDecMem))
-> m [PatElemT LetDecMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
-> [ExpReturns] -> [ExpHint] -> [(Ident, ExpReturns, ExpHint)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
idents [ExpReturns]
rts [ExpHint]
hints) (((Ident, ExpReturns, ExpHint) -> m (PatElemT LetDecMem))
 -> m [PatElemT LetDecMem])
-> ((Ident, ExpReturns, ExpHint) -> m (PatElemT LetDecMem))
-> m [PatElemT LetDecMem]
forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
    let ident_shape :: Shape
ident_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
    case ExpReturns
rt of
      MemPrim PrimType
_ -> do
        LetDecMem
summary <- Space -> ChunkMap -> Type -> ExpHint -> m LetDecMem
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> ChunkMap -> Type -> ExpHint -> m LetDecMem
summaryForBindage Space
def_space ChunkMap
chunkmap (Ident -> Type
identType Ident
ident) ExpHint
hint
        PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElemT LetDecMem -> m (PatElemT LetDecMem))
-> PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall a b. (a -> b) -> a -> b
$ VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) LetDecMem
summary
      MemMem Space
space ->
        PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElemT LetDecMem -> m (PatElemT LetDecMem))
-> PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall a b. (a -> b) -> a -> b
$ VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LetDecMem -> PatElemT LetDecMem)
-> LetDecMem -> PatElemT LetDecMem
forall a b. (a -> b) -> a -> b
$ Space -> LetDecMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
      MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall (f :: * -> *) (f :: * -> *).
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfun
        PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElemT LetDecMem -> m (PatElemT LetDecMem))
-> (MemBind -> PatElemT LetDecMem)
-> MemBind
-> m (PatElemT LetDecMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LetDecMem -> PatElemT LetDecMem)
-> (MemBind -> LetDecMem) -> MemBind -> PatElemT LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElemT LetDecMem))
-> MemBind -> m (PatElemT LetDecMem)
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfn
      MemArray PrimType
_ ExtShape
extshape NoUniqueness
_ Maybe MemReturn
Nothing
        | Just [SubExp]
_ <- ExtShape -> Maybe [SubExp]
forall b. ShapeBase (Ext b) -> Maybe [b]
knownShape ExtShape
extshape -> do
          LetDecMem
summary <- Space -> ChunkMap -> Type -> ExpHint -> m LetDecMem
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
Space -> ChunkMap -> Type -> ExpHint -> m LetDecMem
summaryForBindage Space
def_space ChunkMap
chunkmap (Ident -> Type
identType Ident
ident) ExpHint
hint
          PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElemT LetDecMem -> m (PatElemT LetDecMem))
-> PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall a b. (a -> b) -> a -> b
$ VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) LetDecMem
summary
      MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn)) -> do
        let ixfn :: IxFun (TPrimExp Int64 VName)
ixfn = [Ident] -> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall (f :: * -> *) (f :: * -> *).
(Functor f, Functor f) =>
[Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents ExtIxFun
extixfn
        PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElemT LetDecMem -> m (PatElemT LetDecMem))
-> (MemBind -> PatElemT LetDecMem)
-> MemBind
-> m (PatElemT LetDecMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LetDecMem -> PatElemT LetDecMem)
-> (MemBind -> LetDecMem) -> MemBind -> PatElemT LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
ident_shape NoUniqueness
u (MemBind -> m (PatElemT LetDecMem))
-> MemBind -> m (PatElemT LetDecMem)
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn ([Ident] -> Int -> VName
forall a. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i) IxFun (TPrimExp Int64 VName)
ixfn
      MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
        PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElemT LetDecMem -> m (PatElemT LetDecMem))
-> PatElemT LetDecMem -> m (PatElemT LetDecMem)
forall a b. (a -> b) -> a -> b
$ VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem (Ident -> VName
identName Ident
ident) (LetDecMem -> PatElemT LetDecMem)
-> LetDecMem -> PatElemT LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LetDecMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      ExpReturns
_ -> String -> m (PatElemT LetDecMem)
forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPat!"
  where
    knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = (Ext b -> Maybe b) -> [Ext b] -> Maybe [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext b -> Maybe b
forall a. Ext a -> Maybe a
known ([Ext b] -> Maybe [b])
-> (ShapeBase (Ext b) -> [Ext b]) -> ShapeBase (Ext b) -> Maybe [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext b) -> [Ext b]
forall d. ShapeBase d -> [d]
shapeDims
    known :: Ext a -> Maybe a
known (Free a
v) = a -> Maybe a
forall a. a -> Maybe a
Just a
v
    known Ext {} = Maybe a
forall a. Maybe a
Nothing

    getIdent :: [Ident] -> a -> VName
getIdent [Ident]
idents a
i =
      case a -> [Ident] -> Maybe Ident
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [Ident]
idents of
        Just Ident
ident -> Ident -> VName
identName Ident
ident
        Maybe Ident
Nothing ->
          String -> VName
forall a. HasCallStack => String -> a
error (String -> VName) -> String -> VName
forall a b. (a -> b) -> a -> b
$ String
"getIdent: Ext " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" but pattern has " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
idents) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" elements: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [Ident] -> String
forall a. Pretty a => a -> String
pretty [Ident]
idents

    instantiateExtIxFun :: [Ident] -> f (f (Ext VName)) -> f (f VName)
instantiateExtIxFun [Ident]
idents = (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName))
-> (f (Ext VName) -> f VName) -> f (f (Ext VName)) -> f (f VName)
forall a b. (a -> b) -> a -> b
$ (Ext VName -> VName) -> f (Ext VName) -> f VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
inst
      where
        inst :: Ext VName -> VName
inst (Free VName
v) = VName
v
        inst (Ext Int
i) = [Ident] -> Int -> VName
forall a. (Integral a, Show a) => [Ident] -> a -> VName
getIdent [Ident]
idents Int
i

instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun :: ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
 -> ExtIxFun -> m (IxFun (TPrimExp Int64 VName)))
-> (TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName))
-> ExtIxFun
-> m (IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ (Ext VName -> m VName)
-> TPrimExp Int64 (Ext VName) -> m (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> m VName
forall (m :: * -> *) a. Monad m => Ext a -> m a
inst
  where
    inst :: Ext a -> m a
inst Ext {} = String -> m a
forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
    inst (Free a
x) = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

summaryForBindage ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
  Space ->
  ChunkMap ->
  Type ->
  ExpHint ->
  m (MemBound NoUniqueness)
summaryForBindage :: Space -> ChunkMap -> Type -> ExpHint -> m LetDecMem
summaryForBindage Space
_ ChunkMap
_ (Prim PrimType
bt) ExpHint
_ =
  LetDecMem -> m LetDecMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage Space
_ ChunkMap
_ (Mem Space
space) ExpHint
_ =
  LetDecMem -> m LetDecMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$ Space -> LetDecMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage Space
_ ChunkMap
_ (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) ExpHint
_ =
  LetDecMem -> m LetDecMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LetDecMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
summaryForBindage Space
def_space ChunkMap
chunkmap t :: Type
t@(Array PrimType
pt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- ChunkMap -> Type -> Space -> m VName
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
t Space
def_space
  LetDecMem -> m LetDecMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> VName -> Type -> LetDecMem
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
pt Shape
shape NoUniqueness
u VName
m Type
t
summaryForBindage Space
_ ChunkMap
_ t :: Type
t@(Array PrimType
pt Shape
_ NoUniqueness
_) (Hint IxFun (TPrimExp Int64 VName)
ixfun Space
space) = do
  SubExp
bytes <-
    String -> ExpT (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bytes" (ExpT (Rep m) -> m SubExp)
-> (TPrimExp Int64 VName -> m (ExpT (Rep m)))
-> TPrimExp Int64 VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> m (ExpT (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> m (ExpT (Rep m)))
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> m (ExpT (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> m SubExp)
-> TPrimExp Int64 VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
      [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
        [ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun,
          Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt :: Int64)
        ]
  VName
m <- String -> ExpT (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"mem" (ExpT (Rep m) -> m VName) -> ExpT (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> ExpT (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> ExpT (Rep m)) -> Op (Rep m) -> ExpT (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
bytes Space
space
  LetDecMem -> m LetDecMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun

lookupMemSpace :: (HasScope rep m, Monad m) => VName -> m Space
lookupMemSpace :: VName -> m Space
lookupMemSpace VName
v = do
  Type
t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Mem Space
space -> Space -> m Space
forall (m :: * -> *) a. Monad m => a -> m a
return Space
space
    Type
_ -> String -> m Space
forall a. HasCallStack => String -> a
error (String -> m Space) -> String -> m Space
forall a b. (a -> b) -> a -> b
$ String
"lookupMemSpace: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not a memory block."

directIxFun :: PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun :: PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
bt Shape
shape u
u VName
mem Type
t =
  let ixf :: IxFun (TPrimExp Int64 VName)
ixf = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
   in PrimType -> Shape -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixf

allocInFParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, Space)] ->
  ([FParam torep] -> AllocM fromrep torep a) ->
  AllocM fromrep torep a
allocInFParams :: [(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams [(FParam fromrep, Space)]
params [FParam torep] -> AllocM fromrep torep a
m = do
  ([Param FParamMem]
valparams, ([Param FParamMem]
ctxparams, [Param FParamMem]
memparams)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   [Param FParamMem]
 -> AllocM
      fromrep
      torep
      ([Param FParamMem], ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
-> AllocM
     fromrep
     torep
     ([Param FParamMem], ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, Space)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> [(Param DeclType, Space)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [Param FParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Param DeclType
 -> Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> (Param DeclType, Space)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Param DeclType
-> Space
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall fromrep torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam) [(Param DeclType, Space)]
[(FParam fromrep, Space)]
params
  let params' :: [Param FParamMem]
params' = [Param FParamMem]
ctxparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
memparams [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
params'
  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [FParam torep] -> AllocM fromrep torep a
m [FParam torep]
[Param FParamMem]
params'

allocInFParam ::
  (Allocable fromrep torep inner) =>
  FParam fromrep ->
  Space ->
  WriterT
    ([FParam torep], [FParam torep])
    (AllocM fromrep torep)
    (FParam torep)
allocInFParam :: FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam FParam fromrep
param Space
pspace =
  case Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
FParam fromrep
param of
    Array PrimType
pt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
FParam fromrep
param) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      VName)
-> AllocM fromrep torep VName
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [Attrs -> VName -> FParamMem -> Param FParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param (Param DeclType -> Attrs
forall dec. Param dec -> Attrs
paramAttrs Param DeclType
FParam fromrep
param) VName
mem (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace])
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun}
    Prim PrimType
pt ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt}
    Mem Space
space ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
    Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u ->
      Param FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromrep
param {paramDec :: FParamMem
paramDec = VName -> Shape -> [Type] -> Uniqueness -> FParamMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u}

allocInMergeParams ::
  (Allocable fromrep torep inner) =>
  [(FParam fromrep, SubExp)] ->
  ( [(FParam torep, SubExp)] ->
    ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) ->
    AllocM fromrep torep a
  ) ->
  AllocM fromrep torep a
allocInMergeParams :: [(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m = do
  (([Param FParamMem]
valparams, [SubExp]
valargs, [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps), ([Param FParamMem]
ctx_params, [Param FParamMem]
mem_params)) <-
    WriterT
  ([Param FParamMem], [Param FParamMem])
  (AllocM fromrep torep)
  ([Param FParamMem], [SubExp],
   [SubExp
    -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem], [SubExp],
       [SubExp
        -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   ([Param FParamMem], [SubExp],
    [SubExp
     -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
 -> AllocM
      fromrep
      torep
      (([Param FParamMem], [SubExp],
        [SubExp
         -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
       ([Param FParamMem], [Param FParamMem])))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [SubExp],
      [SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
-> AllocM
     fromrep
     torep
     (([Param FParamMem], [SubExp],
       [SubExp
        -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]),
      ([Param FParamMem], [Param FParamMem]))
forall a b. (a -> b) -> a -> b
$ [(Param FParamMem, SubExp,
  SubExp
  -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> ([Param FParamMem], [SubExp],
    [SubExp
     -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Param FParamMem, SubExp,
   SubExp
   -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
 -> ([Param FParamMem], [SubExp],
     [SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [SubExp],
      [SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> [(Param DeclType, SubExp)]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
  let mergeparams' :: [Param FParamMem]
mergeparams' = [Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams
      summary :: Scope torep
summary = [Param FParamMem] -> Scope torep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param FParamMem]
mergeparams'

      mk_loop_res :: [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
ses = do
        ([SubExp]
ses', ([SubExp]
ctxargs, [SubExp]
memargs)) <-
          WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
 -> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp])))
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
-> AllocM fromrep torep ([SubExp], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$ ((SubExp
  -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
 -> SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> [SubExp
    -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
-> [SubExp]
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
($) [SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp]
handle_loop_subexps [SubExp]
ses
        ([SubExp], [SubExp]) -> AllocM fromrep torep ([SubExp], [SubExp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp]
ctxargs [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
memargs, [SubExp]
ses')

  ([SubExp]
valctx_args, [SubExp]
valargs') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res [SubExp]
valargs
  let merge' :: [(Param FParamMem, SubExp)]
merge' =
        [Param FParamMem] -> [SubExp] -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
          ([Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
mem_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem]
valparams)
          ([SubExp]
valctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
valargs')
  Scope torep -> AllocM fromrep torep a -> AllocM fromrep torep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
summary (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> AllocM fromrep torep a
m [(FParam torep, SubExp)]
[(Param FParamMem, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_res
  where
    param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam fromrep, SubExp)]
merge
    anyIsLoopParam :: Names -> Bool
anyIsLoopParam Names
names = Names
names Names -> Names -> Bool
`namesIntersect` Names
param_names

    scalarRes :: DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (Var VName
res) = do
      -- Try really hard to avoid copying needlessly, but the result
      -- _must_ be in ScalarSpace and have the right index function.
      (VName
res_mem, IxFun (TPrimExp Int64 VName)
res_ixfun) <- m (VName, IxFun (TPrimExp Int64 VName))
-> t m (VName, IxFun (TPrimExp Int64 VName))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun (TPrimExp Int64 VName))
 -> t m (VName, IxFun (TPrimExp Int64 VName)))
-> m (VName, IxFun (TPrimExp Int64 VName))
-> t m (VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
res
      Space
res_mem_space <- m Space -> t m Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Space -> t m Space) -> m Space -> t m Space
forall a b. (a -> b) -> a -> b
$ VName -> m Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
res_mem
      ChunkMap
chunkmap <- (AllocEnv fromrep torep -> ChunkMap) -> t m ChunkMap
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap
      (VName
res_mem', VName
res') <-
        if (Space
res_mem_space, IxFun (TPrimExp Int64 VName)
res_ixfun) (Space, IxFun (TPrimExp Int64 VName))
-> (Space, IxFun (TPrimExp Int64 VName)) -> Bool
forall a. Eq a => a -> a -> Bool
== (Space
v_mem_space, IxFun (TPrimExp Int64 VName)
v_ixfun)
          then (VName, VName) -> t m (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
res_mem, VName
res)
          else m (VName, VName) -> t m (VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, VName) -> t m (VName, VName))
-> m (VName, VName) -> t m (VName, VName)
forall a b. (a -> b) -> a -> b
$ ChunkMap
-> Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner,
 LetDec (Rep m) ~ LetDecMem) =>
ChunkMap
-> Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun ChunkMap
chunkmap Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun (DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
param_t) VName
res
      ([a], [SubExp]) -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> SubExp
Var VName
res_mem'])
      SubExp -> t m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> t m SubExp) -> SubExp -> t m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
res'
    scalarRes DeclType
_ Space
_ IxFun (TPrimExp Int64 VName)
_ SubExp
se = SubExp -> t m SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

    allocInMergeParam ::
      (Allocable fromrep torep inner) =>
      (Param DeclType, SubExp) ->
      WriterT
        ([FParam torep], [FParam torep])
        (AllocM fromrep torep)
        ( FParam torep,
          SubExp,
          SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
        )
    allocInMergeParam :: (Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
      | param_t :: DeclType
param_t@(Array PrimType
pt Shape
shape Uniqueness
u) <- Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
mergeparam = do
        (VName
v_mem, IxFun (TPrimExp Int64 VName)
v_ixfun) <- AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
        Space
v_mem_space <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      Space)
-> AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem

        -- Loop-invariant array parameters that are in scalar space
        -- are special - we do not wish to existentialise their index
        -- function at all (but the memory block is still existential).
        case Space
v_mem_space of
          ScalarSpace {} ->
            if Names -> Bool
anyIsLoopParam (Shape -> Names
forall a. FreeIn a => a -> Names
freeIn Shape
shape)
              then do
                -- Arrays with loop-variant shape cannot be in scalar
                -- space, so copy them elsewhere and try again.
                (VName
_, VName
v') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (VName, VName)
forall a b. (a -> b) -> a -> b
$ Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
DefaultSpace (VName -> String
baseString VName
v) VName
v
                (Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall fromrep torep inner.
Allocable fromrep torep inner =>
(Param DeclType, SubExp)
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
allocInMergeParam (Param DeclType
mergeparam, VName -> SubExp
Var VName
v')
              else do
                Param FParamMem
p <- String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space
                ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [Param FParamMem
p])

                (Param FParamMem, SubExp,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                  ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) IxFun (TPrimExp Int64 VName)
v_ixfun},
                    VName -> SubExp
Var VName
v,
                    DeclType
-> Space
-> IxFun (TPrimExp Int64 VName)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall (t :: (* -> *) -> * -> *) (m :: * -> *) fromrep torep a
       inner.
(MonadTrans t, MonadReader (AllocEnv fromrep torep) (t m),
 MonadBuilder m, MonadWriter ([a], [SubExp]) (t m),
 HasLetDecMem (LetDec (Rep m)), OpReturns inner,
 LParamInfo (Rep m) ~ LetDecMem, RetType (Rep m) ~ RetTypeMem,
 LetDec (Rep m) ~ LetDecMem, Op (Rep m) ~ MemOp inner,
 FParamInfo (Rep m) ~ FParamMem,
 BranchType (Rep m) ~ BranchTypeMem) =>
DeclType
-> Space -> IxFun (TPrimExp Int64 VName) -> SubExp -> t m SubExp
scalarRes DeclType
param_t Space
v_mem_space IxFun (TPrimExp Int64 VName)
v_ixfun
                  )
          Space
_ -> do
            (SubExp
v', ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs, VName
v_mem') <-
              AllocM
  fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM
   fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName))
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall a b. (a -> b) -> a -> b
$ Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray Space
v_mem_space VName
v
            Space
v_mem_space' <- AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Space
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      Space)
-> AllocM fromrep torep Space
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromrep torep Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
v_mem'

            ([Param FParamMem]
ctx_params, [TPrimExp Int64 (Ext VName)]
param_ixfun_substs) <-
              ([(Param FParamMem, TPrimExp Int64 (Ext VName))]
 -> ([Param FParamMem], [TPrimExp Int64 (Ext VName)]))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, TPrimExp Int64 (Ext VName))]
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param FParamMem, TPrimExp Int64 (Ext VName))]
-> ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip (WriterT
   ([Param FParamMem], [Param FParamMem])
   (AllocM fromrep torep)
   [(Param FParamMem, TPrimExp Int64 (Ext VName))]
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      ([Param FParamMem], [TPrimExp Int64 (Ext VName)]))
-> ((TPrimExp Int64 VName
     -> WriterT
          ([Param FParamMem], [Param FParamMem])
          (AllocM fromrep torep)
          (Param FParamMem, TPrimExp Int64 (Ext VName)))
    -> WriterT
         ([Param FParamMem], [Param FParamMem])
         (AllocM fromrep torep)
         [(Param FParamMem, TPrimExp Int64 (Ext VName))])
-> (TPrimExp Int64 VName
    -> WriterT
         ([Param FParamMem], [Param FParamMem])
         (AllocM fromrep torep)
         (Param FParamMem, TPrimExp Int64 (Ext VName)))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName
    -> WriterT
         ([Param FParamMem], [Param FParamMem])
         (AllocM fromrep torep)
         (Param FParamMem, TPrimExp Int64 (Ext VName)))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     [(Param FParamMem, TPrimExp Int64 (Ext VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TPrimExp Int64 VName]
substs ((TPrimExp Int64 VName
  -> WriterT
       ([Param FParamMem], [Param FParamMem])
       (AllocM fromrep torep)
       (Param FParamMem, TPrimExp Int64 (Ext VName)))
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      ([Param FParamMem], [TPrimExp Int64 (Ext VName)]))
-> (TPrimExp Int64 VName
    -> WriterT
         ([Param FParamMem], [Param FParamMem])
         (AllocM fromrep torep)
         (Param FParamMem, TPrimExp Int64 (Ext VName)))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     ([Param FParamMem], [TPrimExp Int64 (Ext VName)])
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
e -> do
                Param FParamMem
p <- String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ctx_param_ext" (FParamMem
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ PrimType -> FParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> FParamMem) -> PrimType -> FParamMem
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType (PrimExp VName -> PrimType) -> PrimExp VName -> PrimType
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
e
                (Param FParamMem, TPrimExp Int64 (Ext VName))
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, TPrimExp Int64 (Ext VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param FParamMem
p, (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p)

            ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Param FParamMem]
ctx_params, [])

            IxFun (TPrimExp Int64 VName)
param_ixfun <-
              ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (IxFun (TPrimExp Int64 VName))
forall (m :: * -> *).
Monad m =>
ExtIxFun -> m (IxFun (TPrimExp Int64 VName))
instantiateIxFun (ExtIxFun
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (IxFun (TPrimExp Int64 VName)))
-> ExtIxFun
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (IxFun (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
                Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun
                  ([(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ [Ext VName]
-> [TPrimExp Int64 (Ext VName)]
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Int -> Ext VName) -> [Int] -> [Ext VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext [Int
0 ..]) [TPrimExp Int64 (Ext VName)]
param_ixfun_substs)
                  ExtIxFun
ext_ixfun

            Param FParamMem
mem_param <- String
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"mem_param" (FParamMem
 -> WriterT
      ([Param FParamMem], [Param FParamMem])
      (AllocM fromrep torep)
      (Param FParamMem))
-> FParamMem
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem)
forall a b. (a -> b) -> a -> b
$ Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
v_mem_space'
            ([Param FParamMem], [Param FParamMem])
-> WriterT
     ([Param FParamMem], [Param FParamMem]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [Param FParamMem
mem_param])
            (Param FParamMem, SubExp,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param FParamMem], [Param FParamMem])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
              ( Param DeclType
mergeparam {paramDec :: FParamMem
paramDec = PrimType -> Shape -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape Uniqueness
u (MemBind -> FParamMem) -> MemBind -> FParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
mem_param) IxFun (TPrimExp Int64 VName)
param_ixfun},
                SubExp
v',
                Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
v_mem_space'
              )
    allocInMergeParam (Param DeclType
mergeparam, SubExp
se) = Param DeclType
-> SubExp
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall torep fromrep fromrep torep inner inner b.
(PrettyRep fromrep, PrettyRep fromrep, HasLetDecMem (LetDec torep),
 HasLetDecMem (LetDec torep), OpReturns inner, OpReturns inner,
 SizeSubst inner, SizeSubst inner, BuilderOps torep,
 BuilderOps torep, FParamInfo torep ~ FParamMem,
 LetDec torep ~ LetDecMem, BodyDec fromrep ~ (),
 LParamInfo fromrep ~ Type, RetType torep ~ RetTypeMem,
 BodyDec fromrep ~ (), BranchType torep ~ BranchTypeMem,
 BranchType fromrep ~ TypeBase ExtShape NoUniqueness,
 FParamInfo torep ~ FParamMem, FParamInfo fromrep ~ DeclType,
 ExpDec torep ~ (), LParamInfo fromrep ~ Type,
 RetType torep ~ RetTypeMem, BodyDec torep ~ (),
 BranchType torep ~ BranchTypeMem,
 BranchType fromrep ~ TypeBase ExtShape NoUniqueness,
 RetType fromrep ~ DeclExtType, Op torep ~ MemOp inner,
 ExpDec torep ~ (), LParamInfo torep ~ LetDecMem,
 FParamInfo fromrep ~ DeclType, BodyDec torep ~ (),
 LetDec torep ~ LetDecMem, LParamInfo torep ~ LetDecMem,
 RetType fromrep ~ DeclExtType, Op torep ~ MemOp inner) =>
Param DeclType
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam SubExp
se (Space
 -> WriterT
      ([FParam torep], [FParam torep])
      (AllocM fromrep torep)
      (Param FParamMem, SubExp,
       SubExp
       -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp))
-> WriterT
     ([FParam torep], [FParam torep]) (AllocM fromrep torep) Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (Param FParamMem, SubExp,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AllocM fromrep torep Space
-> WriterT
     ([FParam torep], [FParam torep]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace

    doDefault :: Param DeclType
-> b
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
doDefault Param DeclType
mergeparam b
se Space
space = do
      Param (FParamInfo torep)
mergeparam' <- FParam fromrep
-> Space
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep))
forall fromrep torep inner.
Allocable fromrep torep inner =>
FParam fromrep
-> Space
-> WriterT
     ([FParam torep], [FParam torep])
     (AllocM fromrep torep)
     (FParam torep)
allocInFParam Param DeclType
FParam fromrep
mergeparam Space
space
      (Param (FParamInfo torep), b,
 SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> WriterT
     ([Param (FParamInfo torep)], [Param (FParamInfo torep)])
     (AllocM fromrep torep)
     (Param (FParamInfo torep), b,
      SubExp
      -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (FParamInfo torep)
mergeparam', b
se, Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
mergeparam) Space
space)

-- Returns the existentialized index function, the list of substituted values and the memory location.
existentializeArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  VName ->
  AllocM fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray :: Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray Space
space VName
v = do
  (VName
mem', IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
sp <- VName -> AllocM fromrep torep Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem'

  let (Maybe ExtIxFun
ext_ixfun', [TPrimExp Int64 VName]
substs') = State [TPrimExp Int64 VName] (Maybe ExtIxFun)
-> [TPrimExp Int64 VName]
-> (Maybe ExtIxFun, [TPrimExp Int64 VName])
forall s a. State s a -> s -> (a, s)
runState (IxFun (TPrimExp Int64 VName)
-> State [TPrimExp Int64 VName] (Maybe ExtIxFun)
forall t v.
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
IxFun.existentialize IxFun (TPrimExp Int64 VName)
ixfun) []

  case (Maybe ExtIxFun
ext_ixfun', Space
sp Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
space) of
    (Just ExtIxFun
x, Bool
True) -> (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
v, ExtIxFun
x, [TPrimExp Int64 VName]
substs', VName
mem')
    (Maybe ExtIxFun, Bool)
_ -> do
      (VName
mem, VName
v') <- Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v
      IxFun (TPrimExp Int64 VName)
ixfun' <- Maybe (IxFun (TPrimExp Int64 VName))
-> IxFun (TPrimExp Int64 VName)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (IxFun (TPrimExp Int64 VName))
 -> IxFun (TPrimExp Int64 VName))
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
-> AllocM fromrep torep (IxFun (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall fromrep torep inner.
Allocable fromrep torep inner =>
VName
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
lookupIxFun VName
v'
      let (Maybe ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs) = State [TPrimExp Int64 VName] (Maybe ExtIxFun)
-> [TPrimExp Int64 VName]
-> (Maybe ExtIxFun, [TPrimExp Int64 VName])
forall s a. State s a -> s -> (a, s)
runState (IxFun (TPrimExp Int64 VName)
-> State [TPrimExp Int64 VName] (Maybe ExtIxFun)
forall t v.
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
IxFun.existentialize IxFun (TPrimExp Int64 VName)
ixfun') []
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
v', Maybe ExtIxFun -> ExtIxFun
forall a. HasCallStack => Maybe a -> a
fromJust Maybe ExtIxFun
ext_ixfun, [TPrimExp Int64 VName]
substs, VName
mem)

arrayWithIxFun ::
  (MonadBuilder m, Op (Rep m) ~ MemOp inner, LetDec (Rep m) ~ LetDecMem) =>
  ChunkMap ->
  Space ->
  IxFun ->
  Type ->
  VName ->
  m (VName, VName)
arrayWithIxFun :: ChunkMap
-> Space
-> IxFun (TPrimExp Int64 VName)
-> Type
-> VName
-> m (VName, VName)
arrayWithIxFun ChunkMap
chunkmap Space
space IxFun (TPrimExp Int64 VName)
ixfun Type
v_t VName
v = do
  let Array PrimType
pt Shape
shape NoUniqueness
u = Type
v_t
  VName
mem <- ChunkMap -> Type -> Space -> m VName
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ MemOp inner) =>
ChunkMap -> Type -> Space -> m VName
allocForArray' ChunkMap
chunkmap Type
v_t Space
space
  VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_scalcopy"
  Pat (Rep m) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT LetDecMem] -> PatT LetDecMem
forall dec. [PatElemT dec] -> PatT dec
Pat [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v_copy (LetDecMem -> PatElemT LetDecMem)
-> LetDecMem -> PatElemT LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun]) (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
  (VName, VName) -> m (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
v_copy)

ensureArrayIn ::
  (Allocable fromrep torep inner) =>
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn :: Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
ensureArrayIn Space
_ (Constant PrimValue
v) =
  String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a. HasCallStack => String -> a
error (String
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> String
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Space
space (Var VName
v) = do
  (SubExp
sub_exp, ExtIxFun
_, [TPrimExp Int64 VName]
substs, VName
mem) <- AllocM
  fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM
   fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
 -> WriterT
      ([SubExp], [SubExp])
      (AllocM fromrep torep)
      (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName))
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall a b. (a -> b) -> a -> b
$ Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space
-> VName
-> AllocM
     fromrep torep (SubExp, ExtIxFun, [TPrimExp Int64 VName], VName)
existentializeArray Space
space VName
v
  ([SubExp]
ctx_vals, [PrimExp (Ext VName)]
_) <-
    [(SubExp, PrimExp (Ext VName))]
-> ([SubExp], [PrimExp (Ext VName)])
forall a b. [(a, b)] -> ([a], [b])
unzip
      ([(SubExp, PrimExp (Ext VName))]
 -> ([SubExp], [PrimExp (Ext VName)]))
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     [(SubExp, PrimExp (Ext VName))]
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     ([SubExp], [PrimExp (Ext VName)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName
 -> WriterT
      ([SubExp], [SubExp])
      (AllocM fromrep torep)
      (SubExp, PrimExp (Ext VName)))
-> [TPrimExp Int64 VName]
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     [(SubExp, PrimExp (Ext VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
        ( \TPrimExp Int64 VName
s -> do
            VName
vname <- AllocM fromrep torep VName
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep VName
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) VName)
-> AllocM fromrep torep VName
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) VName
forall a b. (a -> b) -> a -> b
$ String
-> Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"ctx_val" (ExpT torep -> AllocM fromrep torep VName)
-> AllocM fromrep torep (ExpT torep) -> AllocM fromrep torep VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> AllocM fromrep torep (Exp (Rep (AllocM fromrep torep)))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp TPrimExp Int64 VName
s
            (SubExp, PrimExp (Ext VName))
-> WriterT
     ([SubExp], [SubExp])
     (AllocM fromrep torep)
     (SubExp, PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> SubExp
Var VName
vname, (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free (PrimExp VName -> PrimExp (Ext VName))
-> PrimExp VName -> PrimExp (Ext VName)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vname)
        )
        [TPrimExp Int64 VName]
substs

  ([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([SubExp]
ctx_vals, [VName -> SubExp
Var VName
mem])

  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
sub_exp

ensureDirectArray ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  VName ->
  AllocM fromrep torep (VName, VName)
ensureDirectArray :: Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromrep torep Space
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
  if IxFun (TPrimExp Int64 VName) -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun (TPrimExp Int64 VName)
ixfun Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
mem_space) Maybe Space
space_ok
    then (VName, VName) -> AllocM fromrep torep (VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName
v)
    else Space -> AllocM fromrep torep (VName, VName)
needCopy (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where
    needCopy :: Space -> AllocM fromrep torep (VName, VName)
needCopy Space
space =
      -- We need to do a new allocation, copy 'v', and make a new
      -- binding for the size of the memory block.
      Space -> String -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocLinearArray ::
  (Allocable fromrep torep inner) =>
  Space ->
  String ->
  VName ->
  AllocM fromrep torep (VName, VName)
allocLinearArray :: Space -> String -> VName -> AllocM fromrep torep (VName, VName)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Array PrimType
pt Shape
shape NoUniqueness
u -> do
      VName
mem <- Type -> Space -> AllocM fromrep torep VName
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
space
      VName
v' <- String -> AllocM fromrep torep VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> AllocM fromrep torep VName)
-> String -> AllocM fromrep torep VName
forall a b. (a -> b) -> a -> b
$ String
s String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_linear"
      let ixfun :: LetDecMem
ixfun = PrimType -> Shape -> NoUniqueness -> VName -> Type -> LetDecMem
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIxFun PrimType
pt Shape
shape NoUniqueness
u VName
mem Type
t
          pat :: PatT LetDecMem
pat = [PatElemT LetDecMem] -> PatT LetDecMem
forall dec. [PatElemT dec] -> PatT dec
Pat [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v' LetDecMem
ixfun]
      Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ())
-> Stm (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Pat torep -> StmAux (ExpDec torep) -> Exp torep -> Stm torep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat torep
PatT LetDecMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp torep -> Stm torep) -> Exp torep -> Stm torep
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp torep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp torep) -> BasicOp -> Exp torep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      (VName, VName) -> AllocM fromrep torep (VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName
v')
    Type
_ ->
      String -> AllocM fromrep torep (VName, VName)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (VName, VName))
-> String -> AllocM fromrep torep (VName, VName)
forall a b. (a -> b) -> a -> b
$ String
"allocLinearArray: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t

funcallArgs ::
  (Allocable fromrep torep inner) =>
  [(SubExp, Diet)] ->
  AllocM fromrep torep [(SubExp, Diet)]
funcallArgs :: [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, ([SubExp]
ctx_args, [SubExp]
mem_and_size_args)) <- WriterT
  ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
 -> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp])))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
-> AllocM fromrep torep ([(SubExp, Diet)], ([SubExp], [SubExp]))
forall a b. (a -> b) -> a -> b
$
    [(SubExp, Diet)]
-> ((SubExp, Diet)
    -> WriterT
         ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args (((SubExp, Diet)
  -> WriterT
       ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)])
-> ((SubExp, Diet)
    -> WriterT
         ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet))
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ \(SubExp
arg, Diet
d) -> do
      Type
t <- AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep Type
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type)
-> AllocM fromrep torep Type
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Type
forall a b. (a -> b) -> a -> b
$ SubExp -> AllocM fromrep torep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
      Space
space <- AllocM fromrep torep Space
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromrep torep Space
forall fromrep torep. AllocM fromrep torep Space
askDefaultSpace
      SubExp
arg' <- Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
      (SubExp, Diet)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (SubExp, Diet)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
arg', Diet
d)
  [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)])
-> [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> (SubExp, Diet)) -> [SubExp] -> [(SubExp, Diet)]
forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) ([SubExp]
ctx_args [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
mem_and_size_args) [(SubExp, Diet)] -> [(SubExp, Diet)] -> [(SubExp, Diet)]
forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg ::
  (Allocable fromrep torep inner) =>
  Type ->
  Space ->
  SubExp ->
  WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg :: Type
-> Space
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
linearFuncallArg Array {} Space
space (Var VName
v) = do
  (VName
mem, VName
arg') <- AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromrep torep (VName, VName)
 -> WriterT
      ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName))
-> AllocM fromrep torep (VName, VName)
-> WriterT
     ([SubExp], [SubExp]) (AllocM fromrep torep) (VName, VName)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  ([SubExp], [SubExp])
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName -> SubExp
Var VName
mem])
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
 -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp)
-> SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  SubExp
-> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
arg

explicitAllocationsGeneric ::
  (Allocable fromrep torep inner) =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Pass fromrep torep
explicitAllocationsGeneric :: (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Pass fromrep torep
explicitAllocationsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints =
  String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" ((Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep)
-> (Prog fromrep -> PassM (Prog torep)) -> Pass fromrep torep
forall a b. (a -> b) -> a -> b
$
    (Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms fromrep -> PassM (Stms torep)
onStms Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun
  where
    onStms :: Stms fromrep -> PassM (Stms torep)
onStms Stms fromrep
stms =
      (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> PassM (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> PassM (Stms torep))
-> AllocM fromrep torep (Stms torep) -> PassM (Stms torep)
forall a b. (a -> b) -> a -> b
$ AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    allocInFun :: Stms torep -> FunDef fromrep -> PassM (FunDef torep)
allocInFun Stms torep
consts (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType fromrep]
rettype [FParam fromrep]
params BodyT fromrep
fbody) =
      (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> (AllocM fromrep torep (FunDef torep)
    -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
-> PassM (FunDef torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms torep
-> AllocM fromrep torep (FunDef torep)
-> AllocM fromrep torep (FunDef torep)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
consts (AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep))
-> AllocM fromrep torep (FunDef torep) -> PassM (FunDef torep)
forall a b. (a -> b) -> a -> b
$
        [(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall fromrep torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, Space)]
-> ([FParam torep] -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInFParams ([Param DeclType] -> [Space] -> [(Param DeclType, Space)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
[FParam fromrep]
params ([Space] -> [(Param DeclType, Space)])
-> [Space] -> [(Param DeclType, Space)]
forall a b. (a -> b) -> a -> b
$ Space -> [Space]
forall a. a -> [a]
repeat Space
DefaultSpace) (([FParam torep] -> AllocM fromrep torep (FunDef torep))
 -> AllocM fromrep torep (FunDef torep))
-> ([FParam torep] -> AllocM fromrep torep (FunDef torep))
-> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ \[FParam torep]
params' -> do
          (Body torep
fbody', [RetTypeMem]
mem_rets) <-
            [Maybe Space]
-> BodyT fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
forall fromrep torep inner.
Allocable fromrep torep inner =>
[Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody ((DeclExtType -> Maybe Space) -> [DeclExtType] -> [Maybe Space]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Space -> DeclExtType -> Maybe Space
forall a b. a -> b -> a
const (Maybe Space -> DeclExtType -> Maybe Space)
-> Maybe Space -> DeclExtType -> Maybe Space
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
DefaultSpace) [DeclExtType]
[RetType fromrep]
rettype) BodyT fromrep
fbody
          let rettype' :: [RetTypeMem]
rettype' = [RetTypeMem]
mem_rets [RetTypeMem] -> [RetTypeMem] -> [RetTypeMem]
forall a. [a] -> [a] -> [a]
++ Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType ([RetTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
mem_rets) [DeclExtType]
[RetType fromrep]
rettype
          FunDef torep -> AllocM fromrep torep (FunDef torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef torep -> AllocM fromrep torep (FunDef torep))
-> FunDef torep -> AllocM fromrep torep (FunDef torep)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType torep]
-> [FParam torep]
-> Body torep
-> FunDef torep
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> BodyT rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType torep]
[RetTypeMem]
rettype' [FParam torep]
params' Body torep
fbody'

explicitAllocationsInStmsGeneric ::
  ( MonadFreshNames m,
    HasScope torep m,
    Allocable fromrep torep inner
  ) =>
  (Op fromrep -> AllocM fromrep torep (Op torep)) ->
  (Exp torep -> AllocM fromrep torep [ExpHint]) ->
  Stms fromrep ->
  m (Stms torep)
explicitAllocationsInStmsGeneric :: (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> Stms fromrep
-> m (Stms torep)
explicitAllocationsInStmsGeneric Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints Stms fromrep
stms = do
  Scope torep
scope <- m (Scope torep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  (Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep (Stms torep)
-> m (Stms torep)
forall (m :: * -> *) fromrep torep a.
MonadFreshNames m =>
(Op fromrep -> AllocM fromrep torep (Op torep))
-> (Exp torep -> AllocM fromrep torep [ExpHint])
-> AllocM fromrep torep a
-> m a
runAllocM Op fromrep -> AllocM fromrep torep (Op torep)
handleOp Exp torep -> AllocM fromrep torep [ExpHint]
hints (AllocM fromrep torep (Stms torep) -> m (Stms torep))
-> AllocM fromrep torep (Stms torep) -> m (Stms torep)
forall a b. (a -> b) -> a -> b
$
    Scope torep
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope torep
scope (AllocM fromrep torep (Stms torep)
 -> AllocM fromrep torep (Stms torep))
-> AllocM fromrep torep (Stms torep)
-> AllocM fromrep torep (Stms torep)
forall a b. (a -> b) -> a -> b
$ AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> AllocM fromrep torep () -> AllocM fromrep torep ()
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ () -> AllocM fromrep torep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

memoryInDeclExtType :: Int -> [DeclExtType] -> [FunReturns]
memoryInDeclExtType :: Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Int
k [DeclExtType]
dets = State Int [RetTypeMem] -> Int -> [RetTypeMem]
forall s a. State s a -> s -> a
evalState ((DeclExtType -> StateT Int Identity RetTypeMem)
-> [DeclExtType] -> State Int [RetTypeMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity RetTypeMem
addMem [DeclExtType]
dets) Int
0
  where
    addMem :: DeclExtType -> StateT Int Identity RetTypeMem
addMem (Prim PrimType
t) = RetTypeMem -> StateT Int Identity RetTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ PrimType -> RetTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    addMem Mem {} = String -> StateT Int Identity RetTypeMem
forall a. HasCallStack => String -> a
error String
"memoryInDeclExtType: too much memory"
    addMem (Array PrimType
pt ExtShape
shape Uniqueness
u) = do
      Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get StateT Int Identity Int
-> StateT Int Identity () -> StateT Int Identity Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      let shape' :: ExtShape
shape' = (Ext SubExp -> Ext SubExp) -> ExtShape -> ExtShape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ExtShape
shape
      RetTypeMem -> StateT Int Identity RetTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (RetTypeMem -> StateT Int Identity RetTypeMem)
-> (ExtIxFun -> RetTypeMem)
-> ExtIxFun
-> StateT Int Identity RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> ExtShape -> Uniqueness -> MemReturn -> RetTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape' Uniqueness
u (MemReturn -> RetTypeMem)
-> (ExtIxFun -> MemReturn) -> ExtIxFun -> RetTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtIxFun -> StateT Int Identity RetTypeMem)
-> ExtIxFun -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$
        [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape'
    addMem (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = RetTypeMem -> StateT Int Identity RetTypeMem
forall (m :: * -> *) a. Monad m => a -> m a
return (RetTypeMem -> StateT Int Identity RetTypeMem)
-> RetTypeMem -> StateT Int Identity RetTypeMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> Uniqueness -> RetTypeMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Ext VName -> TPrimExp Int64 (Ext VName))
-> Ext VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = Int -> Ext SubExp
forall a. Int -> Ext a
Ext (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
    shift (Free SubExp
x) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
x

bodyReturnMemCtx ::
  (Allocable fromrep torep inner) =>
  SubExpRes ->
  AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)]
bodyReturnMemCtx :: SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx (SubExpRes Certs
_ Constant {}) =
  [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (m :: * -> *) a. Monad m => a -> m a
return []
bodyReturnMemCtx (SubExpRes Certs
_ (Var VName
v)) = do
  LetDecMem
info <- VName -> AllocM fromrep torep LetDecMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
  case LetDecMem
info of
    MemPrim {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    MemAcc {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    MemMem {} -> [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (m :: * -> *) a. Monad m => a -> m a
return [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_) -> do
      LetDecMem
mem_info <- VName -> AllocM fromrep torep LetDecMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
mem
      case LetDecMem
mem_info of
        MemMem Space
space ->
          [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem, Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)]
        LetDecMem
_ -> String
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a. HasCallStack => String -> a
error (String
 -> AllocM
      fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)])
-> String
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall a b. (a -> b) -> a -> b
$ String
"bodyReturnMemCtx: not a memory block: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
mem

allocInFunBody ::
  (Allocable fromrep torep inner) =>
  [Maybe Space] ->
  Body fromrep ->
  AllocM fromrep torep (Body torep, [FunReturns])
allocInFunBody :: [Maybe Space]
-> Body fromrep -> AllocM fromrep torep (Body torep, [RetTypeMem])
allocInFunBody [Maybe Space]
space_oks (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
  AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [RetTypeMem])
 -> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> (AllocM fromrep torep (Result, [RetTypeMem])
    -> AllocM fromrep torep (Result, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM fromrep torep (Result, [RetTypeMem])
 -> AllocM fromrep torep (Body torep, [RetTypeMem]))
-> AllocM fromrep torep (Result, [RetTypeMem])
-> AllocM fromrep torep (Body torep, [RetTypeMem])
forall a b. (a -> b) -> a -> b
$ do
    Result
res' <- (Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes)
-> [Maybe Space] -> Result -> AllocM fromrep torep Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect [Maybe Space]
space_oks' Result
res
    (Result
mem_ctx_res, [RetTypeMem]
mem_ctx_rets) <- [(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, RetTypeMem)] -> (Result, [RetTypeMem]))
-> ([[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)])
-> [[(SubExpRes, RetTypeMem)]]
-> (Result, [RetTypeMem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(SubExpRes, RetTypeMem)]] -> [(SubExpRes, RetTypeMem)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(SubExpRes, RetTypeMem)]] -> (Result, [RetTypeMem]))
-> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
-> AllocM fromrep torep (Result, [RetTypeMem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)])
-> Result -> AllocM fromrep torep [[(SubExpRes, RetTypeMem)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> AllocM fromrep torep [(SubExpRes, RetTypeMem)]
forall fromrep torep inner u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx Result
res'
    (Result, [RetTypeMem])
-> AllocM fromrep torep (Result, [RetTypeMem])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
res', [RetTypeMem]
mem_ctx_rets)
  where
    num_vals :: Int
num_vals = [Maybe Space] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
    space_oks' :: [Maybe Space]
space_oks' = Int -> Maybe Space -> [Maybe Space]
forall a. Int -> a -> [a]
replicate (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_vals) Maybe Space
forall a. Maybe a
Nothing [Maybe Space] -> [Maybe Space] -> [Maybe Space]
forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect ::
  (Allocable fromrep torep inner) =>
  Maybe Space ->
  SubExpRes ->
  AllocM fromrep torep SubExpRes
ensureDirect :: Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
space_ok (SubExpRes Certs
cs SubExp
se) = do
  LetDecMem
se_info <- SubExp -> AllocM fromrep torep LetDecMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
SubExp -> m LetDecMem
subExpMemInfo SubExp
se
  Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs (SubExp -> SubExpRes)
-> AllocM fromrep torep SubExp -> AllocM fromrep torep SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case (LetDecMem
se_info, SubExp
se) of
    (MemArray {}, Var VName
v) -> do
      (VName
_, VName
v') <- Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> VName -> AllocM fromrep torep (VName, VName)
ensureDirectArray Maybe Space
space_ok VName
v
      SubExp -> AllocM fromrep torep SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> AllocM fromrep torep SubExp)
-> SubExp -> AllocM fromrep torep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
    (LetDecMem, SubExp)
_ ->
      SubExp -> AllocM fromrep torep SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

allocInStms ::
  (Allocable fromrep torep inner) =>
  Stms fromrep ->
  AllocM fromrep torep a ->
  AllocM fromrep torep a
allocInStms :: Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
origstms AllocM fromrep torep a
m = [Stm fromrep] -> AllocM fromrep torep a
allocInStms' ([Stm fromrep] -> AllocM fromrep torep a)
-> [Stm fromrep] -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ Stms fromrep -> [Stm fromrep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms fromrep
origstms
  where
    allocInStms' :: [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [] = AllocM fromrep torep a
m
    allocInStms' (Stm fromrep
stm : [Stm fromrep]
stms) = do
      Seq (Stm torep)
allocstms <- AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (AllocM fromrep torep ()
 -> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep))))
-> AllocM fromrep torep ()
-> AllocM fromrep torep (Stms (Rep (AllocM fromrep torep)))
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec fromrep)
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing (Stm fromrep -> StmAux (ExpDec fromrep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux Stm fromrep
stm) (AllocM fromrep torep () -> AllocM fromrep torep ())
-> AllocM fromrep torep () -> AllocM fromrep torep ()
forall a b. (a -> b) -> a -> b
$ Stm fromrep -> AllocM fromrep torep ()
forall fromrep torep inner.
Allocable fromrep torep inner =>
Stm fromrep -> AllocM fromrep torep ()
allocInStm Stm fromrep
stm
      Stms (Rep (AllocM fromrep torep)) -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm torep)
Stms (Rep (AllocM fromrep torep))
allocstms
      let stms_substs :: ChunkMap
stms_substs = (Stm torep -> ChunkMap) -> Seq (Stm torep) -> ChunkMap
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> ChunkMap
forall rep. SizeSubst (Op rep) => Stm rep -> ChunkMap
sizeSubst Seq (Stm torep)
allocstms
          stms_consts :: Set VName
stms_consts = (Stm torep -> Set VName) -> Seq (Stm torep) -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm torep -> Set VName
forall rep. SizeSubst (Op rep) => Stm rep -> Set VName
stmConsts Seq (Stm torep)
allocstms
          f :: AllocEnv fromrep torep -> AllocEnv fromrep torep
f AllocEnv fromrep torep
env =
            AllocEnv fromrep torep
env
              { chunkMap :: ChunkMap
chunkMap = ChunkMap
stms_substs ChunkMap -> ChunkMap -> ChunkMap
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromrep torep -> ChunkMap
forall fromrep torep. AllocEnv fromrep torep -> ChunkMap
chunkMap AllocEnv fromrep torep
env,
                envConsts :: Set VName
envConsts = Set VName
stms_consts Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromrep torep -> Set VName
forall fromrep torep. AllocEnv fromrep torep -> Set VName
envConsts AllocEnv fromrep torep
env
              }
      (AllocEnv fromrep torep -> AllocEnv fromrep torep)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv fromrep torep -> AllocEnv fromrep torep
f (AllocM fromrep torep a -> AllocM fromrep torep a)
-> AllocM fromrep torep a -> AllocM fromrep torep a
forall a b. (a -> b) -> a -> b
$ [Stm fromrep] -> AllocM fromrep torep a
allocInStms' [Stm fromrep]
stms

allocInStm ::
  (Allocable fromrep torep inner) =>
  Stm fromrep ->
  AllocM fromrep torep ()
allocInStm :: Stm fromrep -> AllocM fromrep torep ()
allocInStm (Let (Pat [PatElemT (LetDec fromrep)]
pes) StmAux (ExpDec fromrep)
_ Exp fromrep
e) =
  Stm torep -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm torep -> AllocM fromrep torep ())
-> AllocM fromrep torep (Stm torep) -> AllocM fromrep torep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[Ident] -> Exp torep -> AllocM fromrep torep (Stm torep)
allocsForStm ((PatElemT (LetDec fromrep) -> Ident)
-> [PatElemT (LetDec fromrep)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (LetDec fromrep) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent [PatElemT (LetDec fromrep)]
pes) (Exp torep -> AllocM fromrep torep (Stm torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Stm torep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp fromrep -> AllocM fromrep torep (Exp torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp Exp fromrep
e

allocInLambda ::
  Allocable fromrep torep inner =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: [LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  [LParam (Rep (AllocM fromrep torep))]
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
[LParam (Rep (AllocM fromrep torep))]
params (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (Body fromrep -> Stms fromrep
forall rep. BodyT rep -> Stms rep
bodyStms Body fromrep
body) (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall a b. (a -> b) -> a -> b
$ Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Body fromrep -> Result
forall rep. BodyT rep -> Result
bodyResult Body fromrep
body

allocInExp ::
  (Allocable fromrep torep inner) =>
  Exp fromrep ->
  AllocM fromrep torep (Exp torep)
allocInExp :: Exp fromrep -> AllocM fromrep torep (Exp torep)
allocInExp (DoLoop [(FParam fromrep, SubExp)]
merge LoopForm fromrep
form (Body () Stms fromrep
bodystms Result
bodyres)) =
  [(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall fromrep torep inner a.
Allocable fromrep torep inner =>
[(FParam fromrep, SubExp)]
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep a)
-> AllocM fromrep torep a
allocInMergeParams [(FParam fromrep, SubExp)]
merge (([(FParam torep, SubExp)]
  -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
  -> AllocM fromrep torep (Exp torep))
 -> AllocM fromrep torep (Exp torep))
-> ([(FParam torep, SubExp)]
    -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
    -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ \[(FParam torep, SubExp)]
merge' [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val -> do
    LoopForm torep
form' <- LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm LoopForm fromrep
form
    Scope torep
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm torep -> Scope torep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm torep
form') (AllocM fromrep torep (Exp torep)
 -> AllocM fromrep torep (Exp torep))
-> AllocM fromrep torep (Exp torep)
-> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ do
      BodyT torep
body' <-
        AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep)
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (BodyT torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
bodystms (AllocM fromrep torep Result -> AllocM fromrep torep (BodyT torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (BodyT torep)
forall a b. (a -> b) -> a -> b
$ do
          ([SubExp]
val_ses, [SubExp]
valres') <- [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
mk_loop_val ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp]))
-> [SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bodyres
          Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
val_ses Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> (Certs -> SubExp -> SubExpRes) -> [Certs] -> [SubExp] -> Result
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes ((SubExpRes -> Certs) -> Result -> [Certs]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
bodyres) [SubExp]
valres'
      Exp torep -> AllocM fromrep torep (Exp torep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ [(FParam torep, SubExp)]
-> LoopForm torep -> BodyT torep -> Exp torep
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam torep, SubExp)]
merge' LoopForm torep
form' BodyT torep
body'
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [RetType fromrep]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
forall fromrep torep inner.
Allocable fromrep torep inner =>
[(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  -- We assume that every array is going to be in its own memory.
  Exp torep -> AllocM fromrep torep (Exp torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp torep -> AllocM fromrep torep (Exp torep))
-> Exp torep -> AllocM fromrep torep (Exp torep)
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType torep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp torep
forall rep.
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT rep
Apply Name
fname [(SubExp, Diet)]
args' ([RetTypeMem]
mems [RetTypeMem] -> [RetTypeMem] -> [RetTypeMem]
forall a. [a] -> [a] -> [a]
++ Int -> [DeclExtType] -> [RetTypeMem]
memoryInDeclExtType Int
0 [DeclExtType]
[RetType fromrep]
rettype) (Safety, SrcLoc, [SrcLoc])
loc
  where
    mems :: [RetTypeMem]
mems = Int -> RetTypeMem -> [RetTypeMem]
forall a. Int -> a -> [a]
replicate Int
num_arrays (Space -> RetTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
DefaultSpace)
    num_arrays :: Int
num_arrays = [DeclExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([DeclExtType] -> Int) -> [DeclExtType] -> Int
forall a b. (a -> b) -> a -> b
$ (DeclExtType -> Bool) -> [DeclExtType] -> [DeclExtType]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Int -> Bool) -> (DeclExtType -> Int) -> DeclExtType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (DeclExtType -> Int)
-> (DeclExtType -> DeclExtType) -> DeclExtType -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf) [DeclExtType]
[RetType fromrep]
rettype
allocInExp (If SubExp
cond BodyT fromrep
tbranch0 BodyT fromrep
fbranch0 (IfDec [BranchType fromrep]
rets IfSort
ifsort)) = do
  let num_rets :: Int
num_rets = [TypeBase ExtShape NoUniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase ExtShape NoUniqueness]
[BranchType fromrep]
rets
  -- switch to the explicit-mem rep, but do nothing about results
  (BodyT torep
tbranch, [Maybe (IxFun (TPrimExp Int64 VName))]
tm_ixfs) <- Int
-> BodyT fromrep
-> AllocM
     fromrep torep (BodyT torep, [Maybe (IxFun (TPrimExp Int64 VName))])
forall fromrep torep inner.
Allocable fromrep torep inner =>
Int
-> Body fromrep
-> AllocM
     fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))])
allocInIfBody Int
num_rets BodyT fromrep
tbranch0
  (BodyT torep
fbranch, [Maybe (IxFun (TPrimExp Int64 VName))]
fm_ixfs) <- Int
-> BodyT fromrep
-> AllocM
     fromrep torep (BodyT torep, [Maybe (IxFun (TPrimExp Int64 VName))])
forall fromrep torep inner.
Allocable fromrep torep inner =>
Int
-> Body fromrep
-> AllocM
     fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))])
allocInIfBody Int
num_rets BodyT fromrep
fbranch0
  [Maybe Space]
tspaces <- Int -> BodyT torep -> AllocM fromrep torep [Maybe Space]
forall torep inner (m :: * -> *).
(Mem torep inner, LocalScope torep m) =>
Int -> Body torep -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT torep
tbranch
  [Maybe Space]
fspaces <- Int -> BodyT torep -> AllocM fromrep torep [Maybe Space]
forall torep inner (m :: * -> *).
(Mem torep inner, LocalScope torep m) =>
Int -> Body torep -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT torep
fbranch
  -- try to generalize (antiunify) the index functions of the then and else bodies
  let sp_substs :: [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
sp_substs = ((Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))
 -> (Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))
 -> (Maybe Space,
     Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])))
-> [(Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))]
-> [(Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))]
-> [(Maybe Space,
     Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))
-> (Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))
-> (Maybe Space,
    Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
generalize ([Maybe Space]
-> [Maybe (IxFun (TPrimExp Int64 VName))]
-> [(Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
tspaces [Maybe (IxFun (TPrimExp Int64 VName))]
tm_ixfs) ([Maybe Space]
-> [Maybe (IxFun (TPrimExp Int64 VName))]
-> [(Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
fspaces [Maybe (IxFun (TPrimExp Int64 VName))]
fm_ixfs)
      ([Maybe Space]
spaces, [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs) = [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
-> ([Maybe Space],
    [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Maybe Space,
  Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))]
sp_substs
      tsubs :: [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
tsubs = (Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
 -> Maybe (ExtIxFun, [TPrimExp Int64 VName]))
-> [Maybe
      (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> TPrimExp Int64 VName)
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [TPrimExp Int64 VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> TPrimExp Int64 VName
forall a b. (a, b) -> a
fst) [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs
      fsubs :: [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
fsubs = (Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
 -> Maybe (ExtIxFun, [TPrimExp Int64 VName]))
-> [Maybe
      (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> TPrimExp Int64 VName)
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [TPrimExp Int64 VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> TPrimExp Int64 VName
forall a b. (a, b) -> b
snd) [Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])]
subs
  (BodyT torep
tbranch', [BranchTypeMem]
trets) <- [TypeBase ExtShape NoUniqueness]
-> BodyT torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (BodyT torep, [BranchTypeMem])
forall fromrep torep inner.
Allocable fromrep torep inner =>
[TypeBase ExtShape NoUniqueness]
-> Body torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (Body torep, [BranchTypeMem])
addResCtxInIfBody [TypeBase ExtShape NoUniqueness]
[BranchType fromrep]
rets BodyT torep
tbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
tsubs
  (BodyT torep
fbranch', [BranchTypeMem]
frets) <- [TypeBase ExtShape NoUniqueness]
-> BodyT torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (BodyT torep, [BranchTypeMem])
forall fromrep torep inner.
Allocable fromrep torep inner =>
[TypeBase ExtShape NoUniqueness]
-> Body torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (Body torep, [BranchTypeMem])
addResCtxInIfBody [TypeBase ExtShape NoUniqueness]
[BranchType fromrep]
rets BodyT torep
fbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
fsubs
  if [BranchTypeMem]
frets [BranchTypeMem] -> [BranchTypeMem] -> Bool
forall a. Eq a => a -> a -> Bool
/= [BranchTypeMem]
trets
    then String -> AllocM fromrep torep (Exp torep)
forall a. HasCallStack => String -> a
error String
"In allocInExp, IF case: antiunification of then/else produce different ExtInFn!"
    else do
      -- above is a sanity check; implementation continues on else branch
      let res_then :: Result
res_then = BodyT torep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT torep
tbranch'
          res_else :: Result
res_else = BodyT torep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT torep
fbranch'
          size_ext :: Int
size_ext = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res_then Int -> Int -> Int
forall a. Num a => a -> a -> a
- [BranchTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchTypeMem]
trets
          ([(SubExpRes, SubExpRes, Int)]
ind_ses0, [(SubExpRes, SubExpRes, Int)]
r_then_else) =
            ((SubExpRes, SubExpRes, Int) -> Bool)
-> [(SubExpRes, SubExpRes, Int)]
-> ([(SubExpRes, SubExpRes, Int)], [(SubExpRes, SubExpRes, Int)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(SubExpRes
r_then, SubExpRes
r_else, Int
_) -> SubExpRes
r_then SubExpRes -> SubExpRes -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes
r_else) ([(SubExpRes, SubExpRes, Int)]
 -> ([(SubExpRes, SubExpRes, Int)], [(SubExpRes, SubExpRes, Int)]))
-> [(SubExpRes, SubExpRes, Int)]
-> ([(SubExpRes, SubExpRes, Int)], [(SubExpRes, SubExpRes, Int)])
forall a b. (a -> b) -> a -> b
$
              Result -> Result -> [Int] -> [(SubExpRes, SubExpRes, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res_then Result
res_else [Int
0 .. Int
size_ext Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          (Result
r_then_ext, Result
r_else_ext, [Int]
_) = [(SubExpRes, SubExpRes, Int)] -> (Result, Result, [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExpRes, SubExpRes, Int)]
r_then_else
          ind_ses :: [(Int, SubExpRes)]
ind_ses =
            ((SubExpRes, SubExpRes, Int) -> Int -> (Int, SubExpRes))
-> [(SubExpRes, SubExpRes, Int)] -> [Int] -> [(Int, SubExpRes)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              (\(SubExpRes
se, SubExpRes
_, Int
i) Int
k -> (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k, SubExpRes
se))
              [(SubExpRes, SubExpRes, Int)]
ind_ses0
              [Int
0 .. [(SubExpRes, SubExpRes, Int)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExpRes, SubExpRes, Int)]
ind_ses0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          rets'' :: [BranchTypeMem]
rets'' = ([BranchTypeMem] -> (Int, SubExpRes) -> [BranchTypeMem])
-> [BranchTypeMem] -> [(Int, SubExpRes)] -> [BranchTypeMem]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[BranchTypeMem]
acc (Int
i, SubExpRes Certs
_ SubExp
se) -> Int -> SubExp -> [BranchTypeMem] -> [BranchTypeMem]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se [BranchTypeMem]
acc) [BranchTypeMem]
trets [(Int, SubExpRes)]
ind_ses
          tbranch'' :: BodyT torep
tbranch'' = BodyT torep
tbranch' {bodyResult :: Result
bodyResult = Result
r_then_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_then}
          fbranch'' :: BodyT torep
fbranch'' = BodyT torep
fbranch' {bodyResult :: Result
bodyResult = Result
r_else_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_else}
          res_if_expr :: Exp torep
res_if_expr = SubExp
-> BodyT torep
-> BodyT torep
-> IfDec (BranchType torep)
-> Exp torep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond BodyT torep
tbranch'' BodyT torep
fbranch'' (IfDec (BranchType torep) -> Exp torep)
-> IfDec (BranchType torep) -> Exp torep
forall a b. (a -> b) -> a -> b
$ [BranchTypeMem] -> IfSort -> IfDec BranchTypeMem
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchTypeMem]
rets'' IfSort
ifsort
      Exp torep -> AllocM fromrep torep (Exp torep)
forall (m :: * -> *) a. Monad m => a -> m a
return Exp torep
res_if_expr
  where
    generalize ::
      (Maybe Space, Maybe IxFun) ->
      (Maybe Space, Maybe IxFun) ->
      (Maybe Space, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
    generalize :: (Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))
-> (Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))
-> (Maybe Space,
    Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)]))
generalize (Just Space
sp1, Just IxFun (TPrimExp Int64 VName)
ixf1) (Just Space
sp2, Just IxFun (TPrimExp Int64 VName)
ixf2) =
      if Space
sp1 Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= Space
sp2
        then (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)
        else case IxFun (PrimExp VName)
-> IxFun (PrimExp VName)
-> Maybe
     (IxFun (PrimExp (Ext VName)), [(PrimExp VName, PrimExp VName)])
forall v.
Eq v =>
IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
IxFun.leastGeneralGeneralization ((TPrimExp Int64 VName -> PrimExp VName)
-> IxFun (TPrimExp Int64 VName) -> IxFun (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped IxFun (TPrimExp Int64 VName)
ixf1) ((TPrimExp Int64 VName -> PrimExp VName)
-> IxFun (TPrimExp Int64 VName) -> IxFun (PrimExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped IxFun (TPrimExp Int64 VName)
ixf2) of
          Just (IxFun (PrimExp (Ext VName))
ixf, [(PrimExp VName, PrimExp VName)]
m) ->
            ( Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1,
              (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
-> Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. a -> Maybe a
Just
                ( (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> IxFun (PrimExp (Ext VName)) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall t v. PrimExp v -> TPrimExp t v
TPrimExp IxFun (PrimExp (Ext VName))
ixf,
                  [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((PrimExp VName, PrimExp VName) -> TPrimExp Int64 VName)
-> [(PrimExp VName, PrimExp VName)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> (PrimExp VName, PrimExp VName)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> a
fst) [(PrimExp VName, PrimExp VName)]
m) (((PrimExp VName, PrimExp VName) -> TPrimExp Int64 VName)
-> [(PrimExp VName, PrimExp VName)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> (PrimExp VName, PrimExp VName)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> b
snd) [(PrimExp VName, PrimExp VName)]
m)
                )
            )
          Maybe
  (IxFun (PrimExp (Ext VName)), [(PrimExp VName, PrimExp VName)])
Nothing -> (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)
    generalize (Maybe Space
mbsp1, Maybe (IxFun (TPrimExp Int64 VName))
_) (Maybe Space, Maybe (IxFun (TPrimExp Int64 VName)))
_ = (Maybe Space
mbsp1, Maybe (ExtIxFun, [(TPrimExp Int64 VName, TPrimExp Int64 VName)])
forall a. Maybe a
Nothing)

    selectSub ::
      ((a, a) -> a) ->
      Maybe (ExtIxFun, [(a, a)]) ->
      Maybe (ExtIxFun, [a])
    selectSub :: ((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (a, a) -> a
f (Just (ExtIxFun
ixfn, [(a, a)]
m)) = (ExtIxFun, [a]) -> Maybe (ExtIxFun, [a])
forall a. a -> Maybe a
Just (ExtIxFun
ixfn, ((a, a) -> a) -> [(a, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, a) -> a
f [(a, a)]
m)
    selectSub (a, a) -> a
_ Maybe (ExtIxFun, [(a, a)])
Nothing = Maybe (ExtIxFun, [a])
forall a. Maybe a
Nothing
    allocInIfBody ::
      (Allocable fromrep torep inner) =>
      Int ->
      Body fromrep ->
      AllocM fromrep torep (Body torep, [Maybe IxFun])
    allocInIfBody :: Int
-> Body fromrep
-> AllocM
     fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))])
allocInIfBody Int
num_vals (Body BodyDec fromrep
_ Stms fromrep
stms Result
res) =
      AllocM
  fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
-> AllocM
     fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM
   fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
 -> AllocM
      fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))]))
-> (AllocM
      fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
    -> AllocM
         fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))]))
-> AllocM
     fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
-> AllocM
     fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM
     fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
-> AllocM
     fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms (AllocM
   fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
 -> AllocM
      fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))]))
-> AllocM
     fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
-> AllocM
     fromrep torep (Body torep, [Maybe (IxFun (TPrimExp Int64 VName))])
forall a b. (a -> b) -> a -> b
$ do
        let (Result
_, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res
        [Maybe (IxFun (TPrimExp Int64 VName))]
mem_ixfs <- (SubExpRes
 -> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName))))
-> Result
-> AllocM fromrep torep [Maybe (IxFun (TPrimExp Int64 VName))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall fromrep torep inner.
Allocable fromrep torep inner =>
SubExp
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
subExpIxFun (SubExp
 -> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName))))
-> (SubExpRes -> SubExp)
-> SubExpRes
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
val_res
        (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
-> AllocM
     fromrep torep (Result, [Maybe (IxFun (TPrimExp Int64 VName))])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [Maybe (IxFun (TPrimExp Int64 VName))]
mem_ixfs)
allocInExp (WithAcc [WithAccInput fromrep]
inputs Lambda fromrep
bodylam) =
  [WithAccInput torep] -> Lambda torep -> Exp torep
forall rep. [WithAccInput rep] -> Lambda rep -> ExpT rep
WithAcc ([WithAccInput torep] -> Lambda torep -> Exp torep)
-> AllocM fromrep torep [WithAccInput torep]
-> AllocM fromrep torep (Lambda torep -> Exp torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep))
-> [WithAccInput fromrep]
-> AllocM fromrep torep [WithAccInput torep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM WithAccInput fromrep -> AllocM fromrep torep (WithAccInput torep)
forall (t :: * -> *) a rep inner fromrep b.
(Traversable t, ArrayShape a, HasLetDecMem (LetDec rep),
 BuilderOps rep, OpReturns inner, SizeSubst inner,
 PrettyRep fromrep, BranchType rep ~ BranchTypeMem,
 BodyDec fromrep ~ (), RetType rep ~ RetTypeMem,
 RetType fromrep ~ DeclExtType, FParamInfo rep ~ FParamMem,
 BranchType fromrep ~ TypeBase ExtShape NoUniqueness,
 LParamInfo rep ~ LetDecMem, ExpDec rep ~ (), BodyDec rep ~ (),
 Op rep ~ MemOp inner, FParamInfo fromrep ~ DeclType,
 LParamInfo fromrep ~ Type, LetDec rep ~ LetDecMem) =>
(a, [VName], t (LambdaT fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput [WithAccInput fromrep]
inputs AllocM fromrep torep (Lambda torep -> Exp torep)
-> AllocM fromrep torep (Lambda torep)
-> AllocM fromrep torep (Exp torep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda fromrep -> AllocM fromrep torep (Lambda torep)
forall fromrep torep inner.
(PrettyRep fromrep, HasLetDecMem (LetDec torep), OpReturns inner,
 SizeSubst inner, BuilderOps torep, ExpDec torep ~ (),
 LParamInfo fromrep ~ Type, BodyDec torep ~ (),
 FParamInfo torep ~ FParamMem, LetDec torep ~ LetDecMem,
 BodyDec fromrep ~ (), RetType torep ~ RetTypeMem,
 FParamInfo fromrep ~ DeclType, RetType fromrep ~ DeclExtType,
 BranchType torep ~ BranchTypeMem, LParamInfo torep ~ LetDecMem,
 BranchType fromrep ~ TypeBase ExtShape NoUniqueness,
 Op torep ~ MemOp inner) =>
LambdaT fromrep -> AllocM fromrep torep (Lambda torep)
onLambda Lambda fromrep
bodylam
  where
    onLambda :: LambdaT fromrep -> AllocM fromrep torep (Lambda torep)
onLambda LambdaT fromrep
lam = do
      [Param LetDecMem]
params <- [Param Type]
-> (Param Type -> AllocM fromrep torep (Param LetDecMem))
-> AllocM fromrep torep [Param LetDecMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (LambdaT fromrep -> [LParam fromrep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT fromrep
lam) ((Param Type -> AllocM fromrep torep (Param LetDecMem))
 -> AllocM fromrep torep [Param LetDecMem])
-> (Param Type -> AllocM fromrep torep (Param LetDecMem))
-> AllocM fromrep torep [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
attrs VName
pv Type
t) ->
        case Type
t of
          Prim PrimType
Unit -> Param LetDecMem -> AllocM fromrep torep (Param LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LetDecMem -> AllocM fromrep torep (Param LetDecMem))
-> Param LetDecMem -> AllocM fromrep torep (Param LetDecMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LetDecMem -> Param LetDecMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LetDecMem -> Param LetDecMem) -> LetDecMem -> Param LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
Unit
          Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> Param LetDecMem -> AllocM fromrep torep (Param LetDecMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param LetDecMem -> AllocM fromrep torep (Param LetDecMem))
-> Param LetDecMem -> AllocM fromrep torep (Param LetDecMem)
forall a b. (a -> b) -> a -> b
$ Attrs -> VName -> LetDecMem -> Param LetDecMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv (LetDecMem -> Param LetDecMem) -> LetDecMem -> Param LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> LetDecMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
          Type
_ -> String -> AllocM fromrep torep (Param LetDecMem)
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep torep (Param LetDecMem))
-> String -> AllocM fromrep torep (Param LetDecMem)
forall a b. (a -> b) -> a -> b
$ String
"Unexpected WithAcc lambda param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param Type -> String
forall a. Pretty a => a -> String
pretty (Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
pv Type
t)
      [LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
[Param LetDecMem]
params (LambdaT fromrep -> Body fromrep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT fromrep
lam)

    onInput :: (a, [VName], t (LambdaT fromrep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
onInput (a
shape, [VName]
arrs, t (LambdaT fromrep, b)
op) =
      (a
shape,[VName]
arrs,) (t (Lambda rep, b) -> (a, [VName], t (Lambda rep, b)))
-> AllocM fromrep rep (t (Lambda rep, b))
-> AllocM fromrep rep (a, [VName], t (Lambda rep, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LambdaT fromrep, b) -> AllocM fromrep rep (Lambda rep, b))
-> t (LambdaT fromrep, b) -> AllocM fromrep rep (t (Lambda rep, b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (a
-> [VName]
-> (LambdaT fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
forall a rep fromrep inner b.
(ArrayShape a, HasLetDecMem (LetDec rep), BuilderOps rep,
 PrettyRep fromrep, OpReturns inner, SizeSubst inner,
 BodyDec fromrep ~ (), FParamInfo rep ~ FParamMem, ExpDec rep ~ (),
 FParamInfo fromrep ~ DeclType, LParamInfo fromrep ~ Type,
 BodyDec rep ~ (),
 BranchType fromrep ~ TypeBase ExtShape NoUniqueness,
 RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem,
 RetType fromrep ~ DeclExtType, LParamInfo rep ~ LetDecMem,
 Op rep ~ MemOp inner, LetDec rep ~ LetDecMem) =>
a
-> [VName]
-> (LambdaT fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
shape [VName]
arrs) t (LambdaT fromrep, b)
op

    onOp :: a
-> [VName]
-> (LambdaT fromrep, b)
-> AllocM fromrep rep (Lambda rep, b)
onOp a
accshape [VName]
arrs (LambdaT fromrep
lam, b
nes) = do
      let num_vs :: Int
num_vs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT fromrep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT fromrep
lam)
          num_is :: Int
num_is = a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
accshape
          ([Param Type]
i_params, [Param Type]
x_params, [Param Type]
y_params) =
            Int
-> Int
-> [Param Type]
-> ([Param Type], [Param Type], [Param Type])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 Int
num_is Int
num_vs ([Param Type] -> ([Param Type], [Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ LambdaT fromrep -> [LParam fromrep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams LambdaT fromrep
lam
          i_params' :: [Param LetDecMem]
i_params' = (Param Type -> Param LetDecMem)
-> [Param Type] -> [Param LetDecMem]
forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
attrs VName
v Type
_) -> Attrs -> VName -> LetDecMem -> Param LetDecMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
v (LetDecMem -> Param LetDecMem) -> LetDecMem -> Param LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [Param Type]
i_params
          is :: [DimIndex SubExp]
is = (Param LetDecMem -> DimIndex SubExp)
-> [Param LetDecMem] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param LetDecMem -> SubExp)
-> Param LetDecMem
-> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (Param LetDecMem -> VName) -> Param LetDecMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName) [Param LetDecMem]
i_params'
      [Param LetDecMem]
x_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LetDecMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LetDecMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LetDecMem)
forall (f :: * -> *) rep inner u.
(Monad f, HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner,
 HasScope rep f, Pretty u, FParamInfo rep ~ FParamMem,
 LParamInfo rep ~ LetDecMem, RetType rep ~ RetTypeMem,
 BranchType rep ~ BranchTypeMem, Op rep ~ MemOp inner) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
is) [Param Type]
x_params [VName]
arrs
      [Param LetDecMem]
y_params' <- (Param Type -> VName -> AllocM fromrep rep (Param LetDecMem))
-> [Param Type] -> [VName] -> AllocM fromrep rep [Param LetDecMem]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ([DimIndex SubExp]
-> Param Type -> VName -> AllocM fromrep rep (Param LetDecMem)
forall rep fromrep inner u.
(PrettyRep fromrep, HasLetDecMem (LetDec rep), OpReturns inner,
 SizeSubst inner, BuilderOps rep, Pretty u,
 FParamInfo rep ~ FParamMem, LetDec rep ~ LetDecMem,
 ExpDec rep ~ (), LParamInfo rep ~ LetDecMem,
 FParamInfo fromrep ~ DeclType, BodyDec rep ~ (),
 RetType rep ~ RetTypeMem, LParamInfo fromrep ~ Type,
 BodyDec fromrep ~ (), BranchType rep ~ BranchTypeMem,
 BranchType fromrep ~ TypeBase ExtShape NoUniqueness,
 RetType fromrep ~ DeclExtType, Op rep ~ MemOp inner) =>
[DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
is) [Param Type]
y_params [VName]
arrs
      Lambda rep
lam' <-
        [LParam rep] -> Body fromrep -> AllocM fromrep rep (Lambda rep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda
          ([Param LetDecMem]
i_params' [Param LetDecMem] -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Semigroup a => a -> a -> a
<> [Param LetDecMem]
x_params' [Param LetDecMem] -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Semigroup a => a -> a -> a
<> [Param LetDecMem]
y_params')
          (LambdaT fromrep -> Body fromrep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT fromrep
lam)
      (Lambda rep, b) -> AllocM fromrep rep (Lambda rep, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep
lam', b
nes)

    mkP :: Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is =
      Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> (Slice (TPrimExp Int64 VName) -> MemInfo SubExp u MemBind)
-> Slice (TPrimExp Int64 VName)
-> Param (MemInfo SubExp u MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Shape -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> (Slice (TPrimExp Int64 VName) -> MemBind)
-> Slice (TPrimExp Int64 VName)
-> MemInfo SubExp u MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
-> MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun (Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind))
-> Slice (TPrimExp Int64 VName) -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$
        (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 (Slice SubExp -> Slice (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
is [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)

    onXParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> f (Param (MemInfo SubExp u MemBind))
onXParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
    onXParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> f (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
arr
      Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> f (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall u.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onXParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> f (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> f (Param (MemInfo SubExp u MemBind)))
-> String -> f (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p

    onYParam :: [DimIndex SubExp]
-> Param (TypeBase Shape u)
-> VName
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
onYParam [DimIndex SubExp]
_ (Param Attrs
attrs VName
p (Prim PrimType
t)) VName
_ =
      Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> MemInfo SubExp u MemBind
-> Param (MemInfo SubExp u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
p (MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind))
-> MemInfo SubExp u MemBind -> Param (MemInfo SubExp u MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp u MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    onYParam [DimIndex SubExp]
is (Param Attrs
attrs VName
p (Array PrimType
pt Shape
shape u
u)) VName
arr = do
      Type
arr_t <- VName -> AllocM fromrep rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
      VName
mem <- Type -> Space -> AllocM fromrep rep VName
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
arr_t Space
DefaultSpace
      let base_dims :: [TPrimExp Int64 VName]
base_dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
          ixfun :: IxFun (TPrimExp Int64 VName)
ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
      Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (MemInfo SubExp u MemBind)
 -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> Param (MemInfo SubExp u MemBind)
-> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
forall u.
Attrs
-> VName
-> PrimType
-> Shape
-> u
-> VName
-> IxFun (TPrimExp Int64 VName)
-> [DimIndex SubExp]
-> Param (MemInfo SubExp u MemBind)
mkP Attrs
attrs VName
p PrimType
pt Shape
shape u
u VName
mem IxFun (TPrimExp Int64 VName)
ixfun [DimIndex SubExp]
is
    onYParam [DimIndex SubExp]
_ Param (TypeBase Shape u)
p VName
_ =
      String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a. HasCallStack => String -> a
error (String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind)))
-> String -> AllocM fromrep rep (Param (MemInfo SubExp u MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle MkAcc param: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Param (TypeBase Shape u) -> String
forall a. Pretty a => a -> String
pretty Param (TypeBase Shape u)
p
allocInExp Exp fromrep
e = Mapper fromrep torep (AllocM fromrep torep)
-> Exp fromrep -> AllocM fromrep torep (Exp torep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper fromrep torep (AllocM fromrep torep)
alloc Exp fromrep
e
  where
    alloc :: Mapper fromrep torep (AllocM fromrep torep)
alloc =
      Mapper Any Any (AllocM fromrep torep)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope torep -> BodyT fromrep -> AllocM fromrep torep (BodyT torep)
mapOnBody = String
-> Scope torep
-> BodyT fromrep
-> AllocM fromrep torep (BodyT torep)
forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations",
          mapOnRetType :: RetType fromrep -> AllocM fromrep torep (RetType torep)
mapOnRetType = String -> RetType fromrep -> AllocM fromrep torep (RetType torep)
forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations",
          mapOnBranchType :: BranchType fromrep -> AllocM fromrep torep (BranchType torep)
mapOnBranchType = String
-> BranchType fromrep -> AllocM fromrep torep (BranchType torep)
forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations",
          mapOnFParam :: FParam fromrep -> AllocM fromrep torep (FParam torep)
mapOnFParam = String -> FParam fromrep -> AllocM fromrep torep (FParam torep)
forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations",
          mapOnLParam :: LParam fromrep -> AllocM fromrep torep (LParam torep)
mapOnLParam = String -> LParam fromrep -> AllocM fromrep torep (LParam torep)
forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations",
          mapOnOp :: Op fromrep -> AllocM fromrep torep (Op torep)
mapOnOp = \Op fromrep
op -> do
            Op fromrep -> AllocM fromrep torep (Op torep)
handle <- (AllocEnv fromrep torep
 -> Op fromrep -> AllocM fromrep torep (Op torep))
-> AllocM
     fromrep torep (Op fromrep -> AllocM fromrep torep (Op torep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
forall fromrep torep.
AllocEnv fromrep torep
-> Op fromrep -> AllocM fromrep torep (Op torep)
allocInOp
            Op fromrep -> AllocM fromrep torep (Op torep)
handle Op fromrep
op
        }

lookupIxFun ::
  (Allocable fromrep torep inner) =>
  VName ->
  AllocM fromrep torep (Maybe IxFun)
lookupIxFun :: VName
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
lookupIxFun VName
v = do
  LetDecMem
info <- VName -> AllocM fromrep torep LetDecMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
  case LetDecMem
info of
    MemArray PrimType
_ptp Shape
_shp NoUniqueness
_u (ArrayIn VName
_ IxFun (TPrimExp Int64 VName)
ixf) -> Maybe (IxFun (TPrimExp Int64 VName))
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp Int64 VName))
 -> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName))))
-> Maybe (IxFun (TPrimExp Int64 VName))
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName)
-> Maybe (IxFun (TPrimExp Int64 VName))
forall a. a -> Maybe a
Just IxFun (TPrimExp Int64 VName)
ixf
    LetDecMem
_ -> Maybe (IxFun (TPrimExp Int64 VName))
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (IxFun (TPrimExp Int64 VName))
forall a. Maybe a
Nothing

subExpIxFun ::
  (Allocable fromrep torep inner) =>
  SubExp ->
  AllocM fromrep torep (Maybe IxFun)
subExpIxFun :: SubExp
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
subExpIxFun Constant {} = Maybe (IxFun (TPrimExp Int64 VName))
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (IxFun (TPrimExp Int64 VName))
forall a. Maybe a
Nothing
subExpIxFun (Var VName
v) = VName
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
forall fromrep torep inner.
Allocable fromrep torep inner =>
VName
-> AllocM fromrep torep (Maybe (IxFun (TPrimExp Int64 VName)))
lookupIxFun VName
v

shiftShapeExts :: Int -> MemInfo ExtSize u r -> MemInfo ExtSize u r
shiftShapeExts :: Int -> MemInfo (Ext SubExp) u r -> MemInfo (Ext SubExp) u r
shiftShapeExts Int
k (MemArray PrimType
pt ExtShape
shape u
u r
returns) =
  PrimType -> ExtShape -> u -> r -> MemInfo (Ext SubExp) u r
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ((Ext SubExp -> Ext SubExp) -> ExtShape -> ExtShape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
shift ExtShape
shape) u
u r
returns
  where
    shift :: Ext SubExp -> Ext SubExp
shift (Ext Int
i) = Int -> Ext SubExp
forall a. Int -> Ext a
Ext (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
    shift (Free SubExp
x) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
x
shiftShapeExts Int
_ MemInfo (Ext SubExp) u r
ret = MemInfo (Ext SubExp) u r
ret

addResCtxInIfBody ::
  (Allocable fromrep torep inner) =>
  [ExtType] ->
  Body torep ->
  [Maybe Space] ->
  [Maybe (ExtIxFun, [TPrimExp Int64 VName])] ->
  AllocM fromrep torep (Body torep, [BodyReturns])
addResCtxInIfBody :: [TypeBase ExtShape NoUniqueness]
-> Body torep
-> [Maybe Space]
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> AllocM fromrep torep (Body torep, [BranchTypeMem])
addResCtxInIfBody [TypeBase ExtShape NoUniqueness]
ifrets (Body BodyDec torep
_ Stms torep
stms Result
res) [Maybe Space]
spaces [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
substs = AllocM fromrep torep (Result, [BranchTypeMem])
-> AllocM
     fromrep torep (Body (Rep (AllocM fromrep torep)), [BranchTypeMem])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (AllocM fromrep torep (Result, [BranchTypeMem])
 -> AllocM
      fromrep torep (Body (Rep (AllocM fromrep torep)), [BranchTypeMem]))
-> AllocM fromrep torep (Result, [BranchTypeMem])
-> AllocM
     fromrep torep (Body (Rep (AllocM fromrep torep)), [BranchTypeMem])
forall a b. (a -> b) -> a -> b
$ do
  (Stm torep -> AllocM fromrep torep ())
-> Stms torep -> AllocM fromrep torep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm torep -> AllocM fromrep torep ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stms torep
stms
  (Result
ctx, [BranchTypeMem]
ctx_rets, Result
res', [BranchTypeMem]
res_rets, Int
total_existentials) <-
    ((Result, [BranchTypeMem], Result, [BranchTypeMem], Int)
 -> (TypeBase ExtShape NoUniqueness, SubExpRes,
     Maybe (ExtIxFun, [TPrimExp Int64 VName]), Maybe Space)
 -> AllocM
      fromrep
      torep
      (Result, [BranchTypeMem], Result, [BranchTypeMem], Int))
-> (Result, [BranchTypeMem], Result, [BranchTypeMem], Int)
-> [(TypeBase ExtShape NoUniqueness, SubExpRes,
     Maybe (ExtIxFun, [TPrimExp Int64 VName]), Maybe Space)]
-> AllocM
     fromrep
     torep
     (Result, [BranchTypeMem], Result, [BranchTypeMem], Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Result, [BranchTypeMem], Result, [BranchTypeMem], Int)
-> (TypeBase ExtShape NoUniqueness, SubExpRes,
    Maybe (ExtIxFun, [TPrimExp Int64 VName]), Maybe Space)
-> AllocM
     fromrep
     torep
     (Result, [BranchTypeMem], Result, [BranchTypeMem], Int)
forall torep fromrep a inner u u.
(PrettyRep fromrep, HasLetDecMem (LetDec torep), BuilderOps torep,
 ToExp a, OpReturns inner, SizeSubst inner,
 FParamInfo torep ~ FParamMem, LetDec torep ~ LetDecMem,
 ExpDec torep ~ (), LParamInfo torep ~ LetDecMem,
 FParamInfo fromrep ~ DeclType, BodyDec torep ~ (),
 RetType torep ~ RetTypeMem, LParamInfo fromrep ~ Type,
 BodyDec fromrep ~ (), BranchType torep ~ BranchTypeMem,
 BranchType fromrep ~ TypeBase ExtShape NoUniqueness,
 RetType fromrep ~ DeclExtType, Op torep ~ MemOp inner) =>
(Result, [MemInfo (Ext SubExp) u MemReturn], Result,
 [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExpRes, Maybe (ExtIxFun, [a]),
    Maybe Space)
-> AllocM
     fromrep
     torep
     (Result, [MemInfo (Ext SubExp) u MemReturn], Result,
      [MemInfo (Ext SubExp) u MemReturn], Int)
helper ([], [], [], [], Int
0) ([TypeBase ExtShape NoUniqueness]
-> Result
-> [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
-> [Maybe Space]
-> [(TypeBase ExtShape NoUniqueness, SubExpRes,
     Maybe (ExtIxFun, [TPrimExp Int64 VName]), Maybe Space)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [TypeBase ExtShape NoUniqueness]
ifrets Result
res [Maybe (ExtIxFun, [TPrimExp Int64 VName])]
substs [Maybe Space]
spaces)
  (Result, [BranchTypeMem])
-> AllocM fromrep torep (Result, [BranchTypeMem])
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Result
ctx Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
res',
      -- We need to adjust the existentials in shapes corresponding
      -- to the previous type, because we added more existentials in
      -- front.
      [BranchTypeMem]
ctx_rets [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. [a] -> [a] -> [a]
++ (BranchTypeMem -> BranchTypeMem)
-> [BranchTypeMem] -> [BranchTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> BranchTypeMem -> BranchTypeMem
forall u r.
Int -> MemInfo (Ext SubExp) u r -> MemInfo (Ext SubExp) u r
shiftShapeExts Int
total_existentials) [BranchTypeMem]
res_rets
    )
  where
    helper :: (Result, [MemInfo (Ext SubExp) u MemReturn], Result,
 [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExpRes, Maybe (ExtIxFun, [a]),
    Maybe Space)
-> AllocM
     fromrep
     torep
     (Result, [MemInfo (Ext SubExp) u MemReturn], Result,
      [MemInfo (Ext SubExp) u MemReturn], Int)
helper (Result
ctx_acc, [MemInfo (Ext SubExp) u MemReturn]
ctx_rets_acc, Result
res_acc, [MemInfo (Ext SubExp) u MemReturn]
res_rets_acc, Int
k) (TypeBase ExtShape u
ifr, SubExpRes
r, Maybe (ExtIxFun, [a])
mbixfsub, Maybe Space
sp) =
      case Maybe (ExtIxFun, [a])
mbixfsub of
        Maybe (ExtIxFun, [a])
Nothing -> do
          -- does NOT generalize/antiunify; ensure direct
          SubExpRes
r' <- Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
forall fromrep torep inner.
Allocable fromrep torep inner =>
Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes
ensureDirect Maybe Space
sp SubExpRes
r
          (Result
mem_ctx_ses, [MemInfo (Ext SubExp) u MemReturn]
mem_ctx_rets) <- [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> (Result, [MemInfo (Ext SubExp) u MemReturn])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
 -> (Result, [MemInfo (Ext SubExp) u MemReturn]))
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep (Result, [MemInfo (Ext SubExp) u MemReturn])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall fromrep torep inner u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx SubExpRes
r'
          let body_ret :: MemInfo (Ext SubExp) u MemReturn
body_ret = Int
-> TypeBase ExtShape u
-> Maybe Space
-> MemInfo (Ext SubExp) u MemReturn
forall u.
Int
-> TypeBase ExtShape u
-> Maybe Space
-> MemInfo (Ext SubExp) u MemReturn
inspect Int
k TypeBase ExtShape u
ifr Maybe Space
sp
          (Result, [MemInfo (Ext SubExp) u MemReturn], Result,
 [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromrep
     torep
     (Result, [MemInfo (Ext SubExp) u MemReturn], Result,
      [MemInfo (Ext SubExp) u MemReturn], Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_ses,
              [MemInfo (Ext SubExp) u MemReturn]
ctx_rets_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn]
mem_ctx_rets,
              Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExpRes
r'],
              [MemInfo (Ext SubExp) u MemReturn]
res_rets_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
body_ret],
              Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
mem_ctx_ses
            )
        Just (ExtIxFun
ixfn, [a]
m) -> do
          -- generalizes
          let i :: Int
i = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
m
          [SubExp]
ext_ses <- (a -> AllocM fromrep torep SubExp)
-> [a] -> AllocM fromrep torep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> a -> AllocM fromrep torep SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"ixfn_exist") [a]
m
          (Result
mem_ctx_ses, [MemInfo (Ext SubExp) u MemReturn]
mem_ctx_rets) <- [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> (Result, [MemInfo (Ext SubExp) u MemReturn])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
 -> (Result, [MemInfo (Ext SubExp) u MemReturn]))
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
-> AllocM
     fromrep torep (Result, [MemInfo (Ext SubExp) u MemReturn])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
forall fromrep torep inner u.
Allocable fromrep torep inner =>
SubExpRes
-> AllocM
     fromrep torep [(SubExpRes, MemInfo (Ext SubExp) u MemReturn)]
bodyReturnMemCtx SubExpRes
r
          let sp' :: Space
sp' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
sp
              ixfn' :: ExtIxFun
ixfn' = (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName)
forall t. Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
adjustExtPE Int
k) ExtIxFun
ixfn
              exttp :: MemInfo (Ext SubExp) u MemReturn
exttp = case TypeBase ExtShape u
ifr of
                Array PrimType
pt ExtShape
shp' u
u ->
                  PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shp' u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
sp' (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) ExtIxFun
ixfn'
                TypeBase ExtShape u
_ -> String -> MemInfo (Ext SubExp) u MemReturn
forall a. HasCallStack => String -> a
error String
"Impossible case reached in addResCtxInIfBody"
          (Result, [MemInfo (Ext SubExp) u MemReturn], Result,
 [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromrep
     torep
     (Result, [MemInfo (Ext SubExp) u MemReturn], Result,
      [MemInfo (Ext SubExp) u MemReturn], Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp] -> Result
subExpsRes [SubExp]
ext_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_ses,
              [MemInfo (Ext SubExp) u MemReturn]
ctx_rets_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ (SubExp -> MemInfo (Ext SubExp) u MemReturn)
-> [SubExp] -> [MemInfo (Ext SubExp) u MemReturn]
forall a b. (a -> b) -> [a] -> [b]
map (MemInfo (Ext SubExp) u MemReturn
-> SubExp -> MemInfo (Ext SubExp) u MemReturn
forall a b. a -> b -> a
const (PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)) [SubExp]
ext_ses [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn]
mem_ctx_rets,
              Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExpRes
r],
              [MemInfo (Ext SubExp) u MemReturn]
res_rets_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
exttp],
              Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            )

    inspect :: Int
-> TypeBase ExtShape u
-> Maybe Space
-> MemInfo (Ext SubExp) u MemReturn
inspect Int
k (Array PrimType
pt ExtShape
shape u
u) Maybe Space
space =
      let space' :: Space
space' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
space
          bodyret :: MemInfo (Ext SubExp) u MemReturn
bodyret =
            PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
              Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space' Int
k (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
                [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 (Ext VName)] -> ExtIxFun)
-> [TPrimExp Int64 (Ext VName)] -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape
       in MemInfo (Ext SubExp) u MemReturn
bodyret
    inspect Int
_ (Acc VName
acc Shape
ispace [Type]
ts u
u) Maybe Space
_ = VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u
    inspect Int
_ (Prim PrimType
pt) Maybe Space
_ = PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
    inspect Int
_ (Mem Space
space) Maybe Space
_ = Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

    adjustExtV :: Int -> Ext VName -> Ext VName
    adjustExtV :: Int -> Ext VName -> Ext VName
adjustExtV Int
_ (Free VName
v) = VName -> Ext VName
forall a. a -> Ext a
Free VName
v
    adjustExtV Int
k (Ext Int
i) = Int -> Ext VName
forall a. Int -> Ext a
Ext (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)

    adjustExtPE :: Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
    adjustExtPE :: Int -> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
adjustExtPE Int
k = (Ext VName -> Ext VName)
-> TPrimExp t (Ext VName) -> TPrimExp t (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Ext VName -> Ext VName
adjustExtV Int
k)

mkSpaceOks ::
  (Mem torep inner, LocalScope torep m) =>
  Int ->
  Body torep ->
  m [Maybe Space]
mkSpaceOks :: Int -> Body torep -> m [Maybe Space]
mkSpaceOks Int
num_vals (Body BodyDec torep
_ Stms torep
stms Result
res) =
  Stms torep -> m [Maybe Space] -> m [Maybe Space]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms torep
stms (m [Maybe Space] -> m [Maybe Space])
-> m [Maybe Space] -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> m (Maybe Space)) -> Result -> m [Maybe Space]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> m (Maybe Space)
forall rep (m :: * -> *) inner.
(HasScope rep m, HasLetDecMem (LetDec rep), Monad m, ASTRep rep,
 OpReturns inner, FParamInfo rep ~ FParamMem,
 RetType rep ~ RetTypeMem, LParamInfo rep ~ LetDecMem,
 Op rep ~ MemOp inner, BranchType rep ~ BranchTypeMem) =>
SubExp -> m (Maybe Space)
mkSpaceOK (SubExp -> m (Maybe Space))
-> (SubExpRes -> SubExp) -> SubExpRes -> m (Maybe Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> m [Maybe Space]) -> Result -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast Int
num_vals Result
res
  where
    mkSpaceOK :: SubExp -> m (Maybe Space)
mkSpaceOK (Var VName
v) = do
      LetDecMem
v_info <- VName -> m LetDecMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
      case LetDecMem
v_info of
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_) -> do
          LetDecMem
mem_info <- VName -> m LetDecMem
forall rep (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
mem
          case LetDecMem
mem_info of
            MemMem Space
space -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Space -> m (Maybe Space)) -> Maybe Space -> m (Maybe Space)
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space
            LetDecMem
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
        LetDecMem
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
    mkSpaceOK SubExp
_ = Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing

allocInLoopForm ::
  (Allocable fromrep torep inner) =>
  LoopForm fromrep ->
  AllocM fromrep torep (LoopForm torep)
allocInLoopForm :: LoopForm fromrep -> AllocM fromrep torep (LoopForm torep)
allocInLoopForm (WhileLoop VName
v) = LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopForm torep -> AllocM fromrep torep (LoopForm torep))
-> LoopForm torep -> AllocM fromrep torep (LoopForm torep)
forall a b. (a -> b) -> a -> b
$ VName -> LoopForm torep
forall rep. VName -> LoopForm rep
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromrep, VName)]
loopvars) =
  VName
-> IntType -> SubExp -> [(LParam torep, VName)] -> LoopForm torep
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
n ([(Param LetDecMem, VName)] -> LoopForm torep)
-> AllocM fromrep torep [(Param LetDecMem, VName)]
-> AllocM fromrep torep (LoopForm torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
 -> AllocM fromrep torep (Param LetDecMem, VName))
-> [(Param Type, VName)]
-> AllocM fromrep torep [(Param LetDecMem, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> AllocM fromrep torep (Param LetDecMem, VName)
allocInLoopVar [(Param Type, VName)]
[(LParam fromrep, VName)]
loopvars
  where
    allocInLoopVar :: (Param Type, VName)
-> AllocM fromrep torep (Param LetDecMem, VName)
allocInLoopVar (Param Type
p, VName
a) = do
      (VName
mem, IxFun (TPrimExp Int64 VName)
ixfun) <- VName -> AllocM fromrep torep (VName, IxFun (TPrimExp Int64 VName))
forall rep inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun (TPrimExp Int64 VName))
lookupArraySummary VName
a
      case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p of
        Array PrimType
pt Shape
shape NoUniqueness
u -> do
          [TPrimExp Int64 VName]
dims <- (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (Type -> [SubExp]) -> Type -> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [TPrimExp Int64 VName])
-> AllocM fromrep torep Type
-> AllocM fromrep torep [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromrep torep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
          let ixfun' :: IxFun (TPrimExp Int64 VName)
ixfun' =
                IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                  [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i]
          (Param LetDecMem, VName)
-> AllocM fromrep torep (Param LetDecMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LetDecMem
paramDec = PrimType -> Shape -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun'}, VName
a)
        Prim PrimType
bt ->
          (Param LetDecMem, VName)
-> AllocM fromrep torep (Param LetDecMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LetDecMem
paramDec = PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}, VName
a)
        Mem Space
space ->
          (Param LetDecMem, VName)
-> AllocM fromrep torep (Param LetDecMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LetDecMem
paramDec = Space -> LetDecMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}, VName
a)
        Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
          (Param LetDecMem, VName)
-> AllocM fromrep torep (Param LetDecMem, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p {paramDec :: LetDecMem
paramDec = VName -> Shape -> [Type] -> NoUniqueness -> LetDecMem
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}, VName
a)

class SizeSubst op where
  opSizeSubst :: PatT dec -> op -> ChunkMap
  opIsConst :: op -> Bool
  opIsConst = Bool -> op -> Bool
forall a b. a -> b -> a
const Bool
False

instance SizeSubst () where
  opSizeSubst :: PatT dec -> () -> ChunkMap
opSizeSubst PatT dec
_ ()
_ = ChunkMap
forall a. Monoid a => a
mempty

instance SizeSubst op => SizeSubst (MemOp op) where
  opSizeSubst :: PatT dec -> MemOp op -> ChunkMap
opSizeSubst PatT dec
pat (Inner op
op) = PatT dec -> op -> ChunkMap
forall op dec. SizeSubst op => PatT dec -> op -> ChunkMap
opSizeSubst PatT dec
pat op
op
  opSizeSubst PatT dec
_ MemOp op
_ = ChunkMap
forall a. Monoid a => a
mempty

  opIsConst :: MemOp op -> Bool
opIsConst (Inner op
op) = op -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst op
op
  opIsConst MemOp op
_ = Bool
False

sizeSubst :: SizeSubst (Op rep) => Stm rep -> ChunkMap
sizeSubst :: Stm rep -> ChunkMap
sizeSubst (Let Pat rep
pat StmAux (ExpDec rep)
_ (Op Op rep
op)) = Pat rep -> Op rep -> ChunkMap
forall op dec. SizeSubst op => PatT dec -> op -> ChunkMap
opSizeSubst Pat rep
pat Op rep
op
sizeSubst Stm rep
_ = ChunkMap
forall a. Monoid a => a
mempty

stmConsts :: SizeSubst (Op rep) => Stm rep -> S.Set VName
stmConsts :: Stm rep -> Set VName
stmConsts (Let Pat rep
pat StmAux (ExpDec rep)
_ (Op Op rep
op))
  | Op rep -> Bool
forall op. SizeSubst op => op -> Bool
opIsConst Op rep
op = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat
stmConsts Stm rep
_ = Set VName
forall a. Monoid a => a
mempty

mkLetNamesB' ::
  ( LetDec (Rep m) ~ LetDecMem,
    Mem (Rep m) inner,
    MonadBuilder m,
    ExpDec (Rep m) ~ ()
  ) =>
  ExpDec (Rep m) ->
  [VName] ->
  Exp (Rep m) ->
  m (Stm (Rep m))
mkLetNamesB' :: ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesB' ExpDec (Rep m)
dec [VName]
names Exp (Rep m)
e = do
  PatT LetDecMem
pat <- Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (PatT LetDecMem)
forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (PatT LetDecMem)
patWithAllocations Space
DefaultSpace ChunkMap
forall a. Monoid a => a
mempty [VName]
names Exp (Rep m)
e [ExpHint]
nohints
  Stm (Rep m) -> m (Stm (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Rep m) -> m (Stm (Rep m))) -> Stm (Rep m) -> m (Stm (Rep m))
forall a b. (a -> b) -> a -> b
$ Pat (Rep m)
-> StmAux (ExpDec (Rep m)) -> Exp (Rep m) -> Stm (Rep m)
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (Rep m)
PatT LetDecMem
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()
ExpDec (Rep m)
dec) Exp (Rep m)
e
  where
    nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

mkLetNamesB'' ::
  ( BuilderOps rep,
    Mem rep inner,
    LetDec rep ~ LetDecMem,
    OpReturns (Engine.OpWithWisdom inner),
    ExpDec rep ~ (),
    Rep m ~ Engine.Wise rep,
    HasScope (Engine.Wise rep) m,
    MonadBuilder m,
    Engine.CanBeWise inner
  ) =>
  [VName] ->
  Exp (Engine.Wise rep) ->
  m (Stm (Engine.Wise rep))
mkLetNamesB'' :: [VName] -> Exp (Wise rep) -> m (Stm (Wise rep))
mkLetNamesB'' [VName]
names Exp (Wise rep)
e = do
  PatT LetDecMem
pat <- Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (PatT LetDecMem)
forall (m :: * -> *) inner.
(MonadBuilder m, Mem (Rep m) inner) =>
Space
-> ChunkMap
-> [VName]
-> Exp (Rep m)
-> [ExpHint]
-> m (PatT LetDecMem)
patWithAllocations Space
DefaultSpace ChunkMap
forall a. Monoid a => a
mempty [VName]
names Exp (Rep m)
Exp (Wise rep)
e [ExpHint]
nohints
  let pat' :: Pat (Wise rep)
pat' = Pat rep -> Exp (Wise rep) -> Pat (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat rep -> Exp (Wise rep) -> Pat (Wise rep)
Engine.addWisdomToPat Pat rep
PatT LetDecMem
pat Exp (Wise rep)
e
      dec :: ExpDec (Wise rep)
dec = Pat (Wise rep) -> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (Wise rep) -> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec Pat (Wise rep)
pat' () Exp (Wise rep)
e
  Stm (Wise rep) -> m (Stm (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm (Wise rep) -> m (Stm (Wise rep)))
-> Stm (Wise rep) -> m (Stm (Wise rep))
forall a b. (a -> b) -> a -> b
$ Pat (Wise rep)
-> StmAux (ExpDec (Wise rep)) -> Exp (Wise rep) -> Stm (Wise rep)
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (Wise rep)
pat' ((ExpWisdom, ()) -> StmAux (ExpWisdom, ())
forall dec. dec -> StmAux dec
defAux (ExpWisdom, ())
ExpDec (Wise rep)
dec) Exp (Wise rep)
e
  where
    nohints :: [ExpHint]
nohints = (VName -> ExpHint) -> [VName] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> VName -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [VName]
names

simplifiable ::
  ( Engine.SimplifiableRep rep,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    Mem rep inner
  ) =>
  (Engine.OpWithWisdom inner -> UT.UsageTable) ->
  (Engine.OpWithWisdom inner -> Engine.SimpleM rep (Engine.OpWithWisdom inner, Stms (Engine.Wise rep))) ->
  SimpleOps rep
simplifiable :: (OpWithWisdom inner -> UsageTable)
-> (OpWithWisdom inner
    -> SimpleM rep (OpWithWisdom inner, Stms (Wise rep)))
-> SimpleOps rep
simplifiable OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
-> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp =
  (SymbolTable (Wise rep)
 -> Pat (Wise rep)
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
forall rep.
(SymbolTable (Wise rep)
 -> Pat (Wise rep)
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op (Wise rep))
-> SimpleOps rep
SimpleOps SymbolTable (Wise rep)
-> Pat (Wise rep)
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
forall (m :: * -> *) rep p.
(Monad m, ASTRep rep, CanBeWise (Op rep), ExpDec rep ~ ()) =>
p
-> PatT (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> m (ExpWisdom, ExpDec rep)
mkExpDecS' SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall (m :: * -> *) rep p.
(Monad m, ASTRep rep, CanBeWise (Op rep), BodyDec rep ~ ()) =>
p -> Stms (Wise rep) -> Result -> m (Body (Wise rep))
mkBodyS' Protect (Builder (Wise rep))
forall (m :: * -> *) d u ret inner inner.
(MonadBuilder m, BranchType (Rep m) ~ MemInfo d u ret,
 Op (Rep m) ~ MemOp inner) =>
SubExp -> PatT (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp Op (Wise rep) -> UsageTable
MemOp (OpWithWisdom inner) -> UsageTable
opUsage SimplifyOp rep (Op (Wise rep))
MemOp (OpWithWisdom inner)
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp
  where
    mkExpDecS' :: p
-> PatT (VarWisdom, LetDec rep)
-> Exp (Wise rep)
-> m (ExpWisdom, ExpDec rep)
mkExpDecS' p
_ PatT (VarWisdom, LetDec rep)
pat Exp (Wise rep)
e =
      (ExpWisdom, ExpDec rep) -> m (ExpWisdom, ExpDec rep)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ExpDec rep) -> m (ExpWisdom, ExpDec rep))
-> (ExpWisdom, ExpDec rep) -> m (ExpWisdom, ExpDec rep)
forall a b. (a -> b) -> a -> b
$ Pat (Wise rep) -> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat (Wise rep) -> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep)
Engine.mkWiseExpDec PatT (VarWisdom, LetDec rep)
Pat (Wise rep)
pat () Exp (Wise rep)
e

    mkBodyS' :: p -> Stms (Wise rep) -> Result -> m (Body (Wise rep))
mkBodyS' p
_ Stms (Wise rep)
stms Result
res = Body (Wise rep) -> m (Body (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise rep) -> m (Body (Wise rep)))
-> Body (Wise rep) -> m (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep)
mkWiseBody () Stms (Wise rep)
stms Result
res

    protectOp :: SubExp -> PatT (LetDec (Rep m)) -> MemOp inner -> Maybe (m ())
protectOp SubExp
taken PatT (LetDec (Rep m))
pat (Alloc SubExp
size Space
space) = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
      BodyT (Rep m)
tbody <- [SubExp] -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
size]
      BodyT (Rep m)
fbody <- [SubExp] -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
      SubExp
size' <-
        String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"hoisted_alloc_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          SubExp
-> BodyT (Rep m)
-> BodyT (Rep m)
-> IfDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
taken BodyT (Rep m)
tbody BodyT (Rep m)
fbody (IfDec (BranchType (Rep m)) -> Exp (Rep m))
-> IfDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [MemInfo d u ret] -> IfSort -> IfDec (MemInfo d u ret)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] IfSort
IfFallback
      PatT (LetDec (Rep m)) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind PatT (LetDec (Rep m))
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
    protectOp SubExp
_ PatT (LetDec (Rep m))
_ MemOp inner
_ = Maybe (m ())
forall a. Maybe a
Nothing

    opUsage :: MemOp (OpWithWisdom inner) -> UsageTable
opUsage (Alloc (Var VName
size) Space
_) =
      VName -> UsageTable
UT.sizeUsage VName
size
    opUsage (Alloc SubExp
_ Space
_) =
      UsageTable
forall a. Monoid a => a
mempty
    opUsage (Inner OpWithWisdom inner
inner) =
      OpWithWisdom inner -> UsageTable
innerUsage OpWithWisdom inner
inner

    simplifyOp :: MemOp (OpWithWisdom inner)
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
simplifyOp (Alloc SubExp
size Space
space) =
      (,) (MemOp (OpWithWisdom inner)
 -> Stms (Wise rep)
 -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
-> SimpleM rep (MemOp (OpWithWisdom inner))
-> SimpleM
     rep
     (Stms (Wise rep) -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Space -> MemOp (OpWithWisdom inner)
forall inner. SubExp -> Space -> MemOp inner
Alloc (SubExp -> Space -> MemOp (OpWithWisdom inner))
-> SimpleM rep SubExp
-> SimpleM rep (Space -> MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
size SimpleM rep (Space -> MemOp (OpWithWisdom inner))
-> SimpleM rep Space -> SimpleM rep (MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> SimpleM rep Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) SimpleM
  rep
  (Stms (Wise rep) -> (MemOp (OpWithWisdom inner), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise rep)
forall a. Monoid a => a
mempty
    simplifyOp (Inner OpWithWisdom inner
k) = do
      (OpWithWisdom inner
k', Stms (Wise rep)
hoisted) <- OpWithWisdom inner
-> SimpleM rep (OpWithWisdom inner, Stms (Wise rep))
simplifyInnerOp OpWithWisdom inner
k
      (MemOp (OpWithWisdom inner), Stms (Wise rep))
-> SimpleM rep (MemOp (OpWithWisdom inner), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithWisdom inner -> MemOp (OpWithWisdom inner)
forall inner. inner -> MemOp inner
Inner OpWithWisdom inner
k', Stms (Wise rep)
hoisted)

data ExpHint
  = NoHint
  | Hint IxFun Space

defaultExpHints :: (Monad m, ASTRep rep) => Exp rep -> m [ExpHint]
defaultExpHints :: Exp rep -> m [ExpHint]
defaultExpHints Exp rep
e = [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp rep -> Int
forall rep. (RepTypes rep, TypedOp (Op rep)) => Exp rep -> Int
expExtTypeSize Exp rep
e) ExpHint
NoHint