{-# LANGUAGE GeneralizedNewtypeDeriving, TypeFamilies, FlexibleContexts, TupleSections, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DefaultSignatures #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.Pass.ExplicitAllocations
       ( explicitAllocations
       , explicitAllocationsInStms
       , simplifiable

       , arraySizeInBytesExp
       )
where

import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.RWS.Strict
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List (zip4, partition, sort)

import Futhark.Representation.Kernels
import Futhark.Optimise.Simplify.Lore
  (mkWiseBody, mkWiseLetStm, removeExpWisdom, removeScopeWisdom)
import Futhark.MonadFreshNames
import Futhark.Representation.ExplicitMemory
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Tools
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Optimise.Simplify.Engine (SimpleOps (..))
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Pass
import Futhark.Util (splitFromEnd, takeLast)

data AllocStm = SizeComputation VName (PrimExp VName)
              | Allocation VName SubExp Space
              | ArrayCopy VName VName
                    deriving (AllocStm -> AllocStm -> Bool
(AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool) -> Eq AllocStm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AllocStm -> AllocStm -> Bool
$c/= :: AllocStm -> AllocStm -> Bool
== :: AllocStm -> AllocStm -> Bool
$c== :: AllocStm -> AllocStm -> Bool
Eq, Eq AllocStm
Eq AllocStm
-> (AllocStm -> AllocStm -> Ordering)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> Bool)
-> (AllocStm -> AllocStm -> AllocStm)
-> (AllocStm -> AllocStm -> AllocStm)
-> Ord AllocStm
AllocStm -> AllocStm -> Bool
AllocStm -> AllocStm -> Ordering
AllocStm -> AllocStm -> AllocStm
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: AllocStm -> AllocStm -> AllocStm
$cmin :: AllocStm -> AllocStm -> AllocStm
max :: AllocStm -> AllocStm -> AllocStm
$cmax :: AllocStm -> AllocStm -> AllocStm
>= :: AllocStm -> AllocStm -> Bool
$c>= :: AllocStm -> AllocStm -> Bool
> :: AllocStm -> AllocStm -> Bool
$c> :: AllocStm -> AllocStm -> Bool
<= :: AllocStm -> AllocStm -> Bool
$c<= :: AllocStm -> AllocStm -> Bool
< :: AllocStm -> AllocStm -> Bool
$c< :: AllocStm -> AllocStm -> Bool
compare :: AllocStm -> AllocStm -> Ordering
$ccompare :: AllocStm -> AllocStm -> Ordering
$cp1Ord :: Eq AllocStm
Ord, Int -> AllocStm -> ShowS
[AllocStm] -> ShowS
AllocStm -> String
(Int -> AllocStm -> ShowS)
-> (AllocStm -> String) -> ([AllocStm] -> ShowS) -> Show AllocStm
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AllocStm] -> ShowS
$cshowList :: [AllocStm] -> ShowS
show :: AllocStm -> String
$cshow :: AllocStm -> String
showsPrec :: Int -> AllocStm -> ShowS
$cshowsPrec :: Int -> AllocStm -> ShowS
Show)

bindAllocStm :: (MonadBinder m, Op (Lore m) ~ MemOp inner) =>
                AllocStm -> m ()
bindAllocStm :: AllocStm -> m ()
bindAllocStm (SizeComputation VName
name PrimExp VName
pe) =
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (ExpT (Lore m) -> m ()) -> m (ExpT (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int64 PrimExp VName
pe)
bindAllocStm (Allocation VName
name SubExp
size Space
space) =
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> ExpT (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> ExpT (Lore m)) -> Op (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
bindAllocStm (ArrayCopy VName
name VName
src) =
  [VName] -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (ExpT (Lore m) -> m ()) -> ExpT (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> ExpT (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> ExpT (Lore m))
-> BasicOp (Lore m) -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp (Lore m)
forall lore. VName -> BasicOp lore
Copy VName
src

defaultExpHints :: (Monad m, Attributes lore) => Exp lore -> m [ExpHint]
defaultExpHints :: Exp lore -> m [ExpHint]
defaultExpHints Exp lore
e = [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp lore -> Int
forall lore.
(Annotations lore, TypedOp (Op lore)) =>
Exp lore -> Int
expExtTypeSize Exp lore
e) ExpHint
NoHint

class (MonadFreshNames m, HasScope lore m, ExplicitMemorish lore) =>
      Allocator lore m where
  addAllocStm :: AllocStm -> m ()
  askDefaultSpace :: m Space

  default addAllocStm :: (Allocable fromlore lore,
                          Op lore ~ MemOp inner,
                          m ~ AllocM fromlore lore)
                      => AllocStm -> m ()
  addAllocStm (SizeComputation VName
name PrimExp VName
se) =
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (ExpT lore -> m ()) -> m (ExpT lore) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> m (Exp (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int64 PrimExp VName
se)
  addAllocStm (Allocation VName
name SubExp
size Space
space) =
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
  addAllocStm (ArrayCopy VName
name VName
src) =
    [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames_ [VName
name] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> ExpT lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> ExpT lore) -> BasicOp lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp lore
forall lore. VName -> BasicOp lore
Copy VName
src

  -- | The subexpression giving the number of elements we should
  -- allocate space for.  See 'ChunkMap' comment.
  dimAllocationSize :: SubExp -> m SubExp

  default dimAllocationSize :: m ~ AllocM fromlore lore
                               => SubExp -> m SubExp
  dimAllocationSize (Var VName
v) =
    -- It is important to recurse here, as the substitution may itself
    -- be a chunk size.
    m SubExp -> (SubExp -> m SubExp) -> Maybe SubExp -> m SubExp
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v) SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize (Maybe SubExp -> m SubExp) -> m (Maybe SubExp) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (AllocEnv fromlore lore -> Maybe SubExp) -> m (Maybe SubExp)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName SubExp -> Maybe SubExp)
-> (AllocEnv fromlore lore -> Map VName SubExp)
-> AllocEnv fromlore lore
-> Maybe SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocEnv fromlore lore -> Map VName SubExp
forall fromlore tolore.
AllocEnv fromlore tolore -> Map VName SubExp
chunkMap)
  dimAllocationSize SubExp
size =
    SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
size

  expHints :: Exp lore -> m [ExpHint]
  expHints = ExpT lore -> m [ExpHint]
forall (m :: * -> *) lore.
(Monad m, Attributes lore) =>
Exp lore -> m [ExpHint]
defaultExpHints

allocateMemory :: Allocator lore m =>
                  String -> SubExp -> Space -> m VName
allocateMemory :: String -> SubExp -> Space -> m VName
allocateMemory String
desc SubExp
size Space
space = do
  VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  AllocStm -> m ()
forall lore (m :: * -> *). Allocator lore m => AllocStm -> m ()
addAllocStm (AllocStm -> m ()) -> AllocStm -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp -> Space -> AllocStm
Allocation VName
v SubExp
size Space
space
  VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v

computeSize :: Allocator lore m =>
               String -> PrimExp VName -> m SubExp
computeSize :: String -> PrimExp VName -> m SubExp
computeSize String
desc PrimExp VName
se = do
  VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  AllocStm -> m ()
forall lore (m :: * -> *). Allocator lore m => AllocStm -> m ()
addAllocStm (AllocStm -> m ()) -> AllocStm -> m ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> AllocStm
SizeComputation VName
v PrimExp VName
se
  SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

type Allocable fromlore tolore =
  (PrettyLore fromlore, PrettyLore tolore,
   ExplicitMemorish tolore,
   SameScope fromlore Kernels,
   RetType fromlore ~ RetType Kernels,
   BranchType fromlore ~ BranchType Kernels,
   BodyAttr fromlore ~ (),
   BodyAttr tolore ~ (),
   ExpAttr tolore ~ (),
   SizeSubst (Op tolore),
   BinderOps tolore)

-- | A mapping from chunk names to their maximum size.  XXX FIXME
-- HACK: This is part of a hack to add loop-invariant allocations to
-- reduce kernels, because memory expansion does not use range
-- analysis yet (it should).
type ChunkMap = M.Map VName SubExp

data AllocEnv fromlore tolore  =
  AllocEnv { AllocEnv fromlore tolore -> Map VName SubExp
chunkMap :: ChunkMap
           , AllocEnv fromlore tolore -> Bool
aggressiveReuse :: Bool
             -- ^ Aggressively try to reuse memory in do-loops -
             -- should be True inside kernels, False outside.
           , AllocEnv fromlore tolore -> Space
allocSpace :: Space
             -- ^ When allocating memory, put it in this memory space.
             -- This is primarily used to ensure that group-wide
             -- statements store their results in local memory.
           , AllocEnv fromlore tolore
-> Op fromlore -> AllocM fromlore tolore (Op tolore)
allocInOp :: Op fromlore -> AllocM fromlore tolore (Op tolore)
           , AllocEnv fromlore tolore
-> Exp tolore -> AllocM fromlore tolore [ExpHint]
envExpHints :: Exp tolore -> AllocM fromlore tolore [ExpHint]
           }

boundDims :: ChunkMap -> AllocEnv fromlore tolore
          -> AllocEnv fromlore tolore
boundDims :: Map VName SubExp
-> AllocEnv fromlore tolore -> AllocEnv fromlore tolore
boundDims Map VName SubExp
m AllocEnv fromlore tolore
env = AllocEnv fromlore tolore
env { chunkMap :: Map VName SubExp
chunkMap = Map VName SubExp
m Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> AllocEnv fromlore tolore -> Map VName SubExp
forall fromlore tolore.
AllocEnv fromlore tolore -> Map VName SubExp
chunkMap AllocEnv fromlore tolore
env }

-- | Monad for adding allocations to an entire program.
newtype AllocM fromlore tolore a =
  AllocM (BinderT tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a)
  deriving (Functor (AllocM fromlore tolore)
a -> AllocM fromlore tolore a
Functor (AllocM fromlore tolore)
-> (forall a. a -> AllocM fromlore tolore a)
-> (forall a b.
    AllocM fromlore tolore (a -> b)
    -> AllocM fromlore tolore a -> AllocM fromlore tolore b)
-> (forall a b c.
    (a -> b -> c)
    -> AllocM fromlore tolore a
    -> AllocM fromlore tolore b
    -> AllocM fromlore tolore c)
-> (forall a b.
    AllocM fromlore tolore a
    -> AllocM fromlore tolore b -> AllocM fromlore tolore b)
-> (forall a b.
    AllocM fromlore tolore a
    -> AllocM fromlore tolore b -> AllocM fromlore tolore a)
-> Applicative (AllocM fromlore tolore)
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
forall a. a -> AllocM fromlore tolore a
forall fromlore tolore. Functor (AllocM fromlore tolore)
forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall a b.
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall fromlore tolore a. a -> AllocM fromlore tolore a
forall a b c.
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall fromlore tolore a b.
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall fromlore tolore a b c.
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
$c<* :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore a
*> :: AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
$c*> :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
liftA2 :: (a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
$cliftA2 :: forall fromlore tolore a b c.
(a -> b -> c)
-> AllocM fromlore tolore a
-> AllocM fromlore tolore b
-> AllocM fromlore tolore c
<*> :: AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
$c<*> :: forall fromlore tolore a b.
AllocM fromlore tolore (a -> b)
-> AllocM fromlore tolore a -> AllocM fromlore tolore b
pure :: a -> AllocM fromlore tolore a
$cpure :: forall fromlore tolore a. a -> AllocM fromlore tolore a
$cp1Applicative :: forall fromlore tolore. Functor (AllocM fromlore tolore)
Applicative, a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
(forall a b.
 (a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b)
-> (forall a b.
    a -> AllocM fromlore tolore b -> AllocM fromlore tolore a)
-> Functor (AllocM fromlore tolore)
forall a b.
a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall a b.
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall fromlore tolore a b.
a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
forall fromlore tolore a b.
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
$c<$ :: forall fromlore tolore a b.
a -> AllocM fromlore tolore b -> AllocM fromlore tolore a
fmap :: (a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
$cfmap :: forall fromlore tolore a b.
(a -> b) -> AllocM fromlore tolore a -> AllocM fromlore tolore b
Functor, Applicative (AllocM fromlore tolore)
a -> AllocM fromlore tolore a
Applicative (AllocM fromlore tolore)
-> (forall a b.
    AllocM fromlore tolore a
    -> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b)
-> (forall a b.
    AllocM fromlore tolore a
    -> AllocM fromlore tolore b -> AllocM fromlore tolore b)
-> (forall a. a -> AllocM fromlore tolore a)
-> Monad (AllocM fromlore tolore)
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall a. a -> AllocM fromlore tolore a
forall fromlore tolore. Applicative (AllocM fromlore tolore)
forall a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall a b.
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
forall fromlore tolore a. a -> AllocM fromlore tolore a
forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
forall fromlore tolore a b.
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> AllocM fromlore tolore a
$creturn :: forall fromlore tolore a. a -> AllocM fromlore tolore a
>> :: AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
$c>> :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> AllocM fromlore tolore b -> AllocM fromlore tolore b
>>= :: AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
$c>>= :: forall fromlore tolore a b.
AllocM fromlore tolore a
-> (a -> AllocM fromlore tolore b) -> AllocM fromlore tolore b
$cp1Monad :: forall fromlore tolore. Applicative (AllocM fromlore tolore)
Monad,
             Monad (AllocM fromlore tolore)
Applicative (AllocM fromlore tolore)
AllocM fromlore tolore VNameSource
Applicative (AllocM fromlore tolore)
-> Monad (AllocM fromlore tolore)
-> AllocM fromlore tolore VNameSource
-> (VNameSource -> AllocM fromlore tolore ())
-> MonadFreshNames (AllocM fromlore tolore)
VNameSource -> AllocM fromlore tolore ()
forall fromlore tolore. Monad (AllocM fromlore tolore)
forall fromlore tolore. Applicative (AllocM fromlore tolore)
forall fromlore tolore. AllocM fromlore tolore VNameSource
forall fromlore tolore. VNameSource -> AllocM fromlore tolore ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> AllocM fromlore tolore ()
$cputNameSource :: forall fromlore tolore. VNameSource -> AllocM fromlore tolore ()
getNameSource :: AllocM fromlore tolore VNameSource
$cgetNameSource :: forall fromlore tolore. AllocM fromlore tolore VNameSource
$cp2MonadFreshNames :: forall fromlore tolore. Monad (AllocM fromlore tolore)
$cp1MonadFreshNames :: forall fromlore tolore. Applicative (AllocM fromlore tolore)
MonadFreshNames,
             HasScope tolore,
             LocalScope tolore,
             MonadReader (AllocEnv fromlore tolore))

instance (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
         MonadBinder (AllocM fromlore tolore) where
  type Lore (AllocM fromlore tolore) = tolore

  mkExpAttrM :: Pattern (Lore (AllocM fromlore tolore))
-> Exp (Lore (AllocM fromlore tolore))
-> AllocM fromlore tolore (ExpAttr (Lore (AllocM fromlore tolore)))
mkExpAttrM Pattern (Lore (AllocM fromlore tolore))
_ Exp (Lore (AllocM fromlore tolore))
_ = () -> AllocM fromlore tolore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

  mkLetNamesM :: [VName]
-> Exp (Lore (AllocM fromlore tolore))
-> AllocM fromlore tolore (Stm (Lore (AllocM fromlore tolore)))
mkLetNamesM [VName]
names Exp (Lore (AllocM fromlore tolore))
e = do
    PatternT (LetAttr tolore)
pat <- [VName]
-> Exp tolore -> AllocM fromlore tolore (PatternT (LetAttr tolore))
forall lore (m :: * -> *).
(Allocator lore m, ExpAttr lore ~ ()) =>
[VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names Exp tolore
Exp (Lore (AllocM fromlore tolore))
e
    Stm tolore -> AllocM fromlore tolore (Stm tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm tolore -> AllocM fromlore tolore (Stm tolore))
-> Stm tolore -> AllocM fromlore tolore (Stm tolore)
forall a b. (a -> b) -> a -> b
$ PatternT (LetAttr tolore)
-> StmAux (ExpAttr tolore) -> Exp tolore -> Stm tolore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT (LetAttr tolore)
pat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) Exp tolore
Exp (Lore (AllocM fromlore tolore))
e

  mkBodyM :: Stms (Lore (AllocM fromlore tolore))
-> Result
-> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
mkBodyM Stms (Lore (AllocM fromlore tolore))
bnds Result
res = BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT tolore -> AllocM fromlore tolore (BodyT tolore))
-> BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BodyAttr tolore -> Stms tolore -> Result -> BodyT tolore
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms tolore
Stms (Lore (AllocM fromlore tolore))
bnds Result
res

  addStms :: Stms (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
addStms Stms (Lore (AllocM fromlore tolore))
binding = BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ()
-> AllocM fromlore tolore ()
forall fromlore tolore a.
BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> AllocM fromlore tolore a
AllocM (BinderT
   tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ()
 -> AllocM fromlore tolore ())
-> BinderT
     tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ()
-> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ Stms tolore
-> BinderT
     tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> BinderT lore m ()
addBinderStms Stms tolore
Stms (Lore (AllocM fromlore tolore))
binding
  collectStms :: AllocM fromlore tolore a
-> AllocM fromlore tolore (a, Stms (Lore (AllocM fromlore tolore)))
collectStms (AllocM BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m) = BinderT
  tolore
  (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
  (a, Stms tolore)
-> AllocM fromlore tolore (a, Stms tolore)
forall fromlore tolore a.
BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> AllocM fromlore tolore a
AllocM (BinderT
   tolore
   (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
   (a, Stms tolore)
 -> AllocM fromlore tolore (a, Stms tolore))
-> BinderT
     tolore
     (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
     (a, Stms tolore)
-> AllocM fromlore tolore (a, Stms tolore)
forall a b. (a -> b) -> a -> b
$ BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> BinderT
     tolore
     (ReaderT (AllocEnv fromlore tolore) (State VNameSource))
     (a, Stms tolore)
forall (m :: * -> *) lore a.
Monad m =>
BinderT lore m a -> BinderT lore m (a, Stms lore)
collectBinderStms BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m
  certifying :: Certificates
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
certifying Certificates
cs (AllocM BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m) = BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> AllocM fromlore tolore a
forall fromlore tolore a.
BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> AllocM fromlore tolore a
AllocM (BinderT
   tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
 -> AllocM fromlore tolore a)
-> BinderT
     tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$ Certificates
-> BinderT
     tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> BinderT
     tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
forall (m :: * -> *) lore a.
(MonadFreshNames m, BinderOps lore) =>
Certificates -> BinderT lore m a -> BinderT lore m a
certifyingBinder Certificates
cs BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m

instance Allocable fromlore ExplicitMemory =>
         Allocator ExplicitMemory (AllocM fromlore ExplicitMemory) where
  expHints :: Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint]
expHints Exp ExplicitMemory
e = do
    Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint]
f <- (AllocEnv fromlore ExplicitMemory
 -> Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint])
-> AllocM
     fromlore
     ExplicitMemory
     (Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore ExplicitMemory
-> Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint]
forall fromlore tolore.
AllocEnv fromlore tolore
-> Exp tolore -> AllocM fromlore tolore [ExpHint]
envExpHints
    Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint]
f Exp ExplicitMemory
e
  askDefaultSpace :: AllocM fromlore ExplicitMemory Space
askDefaultSpace = (AllocEnv fromlore ExplicitMemory -> Space)
-> AllocM fromlore ExplicitMemory Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore ExplicitMemory -> Space
forall fromlore tolore. AllocEnv fromlore tolore -> Space
allocSpace

runAllocM :: MonadFreshNames m =>
             (Op fromlore -> AllocM fromlore tolore (Op tolore))
          -> (Exp tolore -> AllocM fromlore tolore [ExpHint])
          -> AllocM fromlore tolore a -> m a
runAllocM :: (Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints (AllocM BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m) =
  ((a, Stms tolore) -> a) -> m (a, Stms tolore) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Stms tolore) -> a
forall a b. (a, b) -> a
fst (m (a, Stms tolore) -> m a) -> m (a, Stms tolore) -> m a
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((a, Stms tolore), VNameSource))
-> m (a, Stms tolore)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Stms tolore), VNameSource))
 -> m (a, Stms tolore))
-> (VNameSource -> ((a, Stms tolore), VNameSource))
-> m (a, Stms tolore)
forall a b. (a -> b) -> a -> b
$ State VNameSource (a, Stms tolore)
-> VNameSource -> ((a, Stms tolore), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (a, Stms tolore)
 -> VNameSource -> ((a, Stms tolore), VNameSource))
-> State VNameSource (a, Stms tolore)
-> VNameSource
-> ((a, Stms tolore), VNameSource)
forall a b. (a -> b) -> a -> b
$ ReaderT
  (AllocEnv fromlore tolore) (State VNameSource) (a, Stms tolore)
-> AllocEnv fromlore tolore -> State VNameSource (a, Stms tolore)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
-> Scope tolore
-> ReaderT
     (AllocEnv fromlore tolore) (State VNameSource) (a, Stms tolore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT
  tolore (ReaderT (AllocEnv fromlore tolore) (State VNameSource)) a
m Scope tolore
forall a. Monoid a => a
mempty) AllocEnv fromlore tolore
env
  where env :: AllocEnv fromlore tolore
env = Map VName SubExp
-> Bool
-> Space
-> (Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocEnv fromlore tolore
forall fromlore tolore.
Map VName SubExp
-> Bool
-> Space
-> (Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocEnv fromlore tolore
AllocEnv Map VName SubExp
forall a. Monoid a => a
mempty Bool
False Space
DefaultSpace Op fromlore -> AllocM fromlore tolore (Op tolore)
handleOp Exp tolore -> AllocM fromlore tolore [ExpHint]
hints

-- | Monad for adding allocations to a single pattern.
newtype PatAllocM lore a = PatAllocM (RWS
                                      (Scope lore)
                                      [AllocStm]
                                      VNameSource
                                      a)
                    deriving (Functor (PatAllocM lore)
a -> PatAllocM lore a
Functor (PatAllocM lore)
-> (forall a. a -> PatAllocM lore a)
-> (forall a b.
    PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b)
-> (forall a b c.
    (a -> b -> c)
    -> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c)
-> (forall a b.
    PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b)
-> (forall a b.
    PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a)
-> Applicative (PatAllocM lore)
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
forall lore. Functor (PatAllocM lore)
forall a. a -> PatAllocM lore a
forall lore a. a -> PatAllocM lore a
forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall a b.
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall lore a b.
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall a b c.
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
forall lore a b c.
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
$c<* :: forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore a
*> :: PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
$c*> :: forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
liftA2 :: (a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c)
-> PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore c
<*> :: PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
$c<*> :: forall lore a b.
PatAllocM lore (a -> b) -> PatAllocM lore a -> PatAllocM lore b
pure :: a -> PatAllocM lore a
$cpure :: forall lore a. a -> PatAllocM lore a
$cp1Applicative :: forall lore. Functor (PatAllocM lore)
Applicative, a -> PatAllocM lore b -> PatAllocM lore a
(a -> b) -> PatAllocM lore a -> PatAllocM lore b
(forall a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b)
-> (forall a b. a -> PatAllocM lore b -> PatAllocM lore a)
-> Functor (PatAllocM lore)
forall a b. a -> PatAllocM lore b -> PatAllocM lore a
forall a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall lore a b. a -> PatAllocM lore b -> PatAllocM lore a
forall lore a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> PatAllocM lore b -> PatAllocM lore a
$c<$ :: forall lore a b. a -> PatAllocM lore b -> PatAllocM lore a
fmap :: (a -> b) -> PatAllocM lore a -> PatAllocM lore b
$cfmap :: forall lore a b. (a -> b) -> PatAllocM lore a -> PatAllocM lore b
Functor, Applicative (PatAllocM lore)
a -> PatAllocM lore a
Applicative (PatAllocM lore)
-> (forall a b.
    PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b)
-> (forall a b.
    PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b)
-> (forall a. a -> PatAllocM lore a)
-> Monad (PatAllocM lore)
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall lore. Applicative (PatAllocM lore)
forall a. a -> PatAllocM lore a
forall lore a. a -> PatAllocM lore a
forall a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall a b.
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
forall lore a b.
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> PatAllocM lore a
$creturn :: forall lore a. a -> PatAllocM lore a
>> :: PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
$c>> :: forall lore a b.
PatAllocM lore a -> PatAllocM lore b -> PatAllocM lore b
>>= :: PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
$c>>= :: forall lore a b.
PatAllocM lore a -> (a -> PatAllocM lore b) -> PatAllocM lore b
$cp1Monad :: forall lore. Applicative (PatAllocM lore)
Monad,
                              HasScope lore,
                              MonadWriter [AllocStm],
                              Monad (PatAllocM lore)
Applicative (PatAllocM lore)
PatAllocM lore VNameSource
Applicative (PatAllocM lore)
-> Monad (PatAllocM lore)
-> PatAllocM lore VNameSource
-> (VNameSource -> PatAllocM lore ())
-> MonadFreshNames (PatAllocM lore)
VNameSource -> PatAllocM lore ()
forall lore. Monad (PatAllocM lore)
forall lore. Applicative (PatAllocM lore)
forall lore. PatAllocM lore VNameSource
forall lore. VNameSource -> PatAllocM lore ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> PatAllocM lore ()
$cputNameSource :: forall lore. VNameSource -> PatAllocM lore ()
getNameSource :: PatAllocM lore VNameSource
$cgetNameSource :: forall lore. PatAllocM lore VNameSource
$cp2MonadFreshNames :: forall lore. Monad (PatAllocM lore)
$cp1MonadFreshNames :: forall lore. Applicative (PatAllocM lore)
MonadFreshNames)

instance Allocator ExplicitMemory (PatAllocM ExplicitMemory) where
  addAllocStm :: AllocStm -> PatAllocM ExplicitMemory ()
addAllocStm = [AllocStm] -> PatAllocM ExplicitMemory ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([AllocStm] -> PatAllocM ExplicitMemory ())
-> (AllocStm -> [AllocStm])
-> AllocStm
-> PatAllocM ExplicitMemory ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocStm -> [AllocStm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  dimAllocationSize :: SubExp -> PatAllocM ExplicitMemory SubExp
dimAllocationSize = SubExp -> PatAllocM ExplicitMemory SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
  askDefaultSpace :: PatAllocM ExplicitMemory Space
askDefaultSpace = Space -> PatAllocM ExplicitMemory Space
forall (m :: * -> *) a. Monad m => a -> m a
return Space
DefaultSpace

runPatAllocM :: MonadFreshNames m =>
                PatAllocM lore a -> Scope lore
             -> m (a, [AllocStm])
runPatAllocM :: PatAllocM lore a -> Scope lore -> m (a, [AllocStm])
runPatAllocM (PatAllocM RWS (Scope lore) [AllocStm] VNameSource a
m) Scope lore
mems =
  (VNameSource -> ((a, [AllocStm]), VNameSource))
-> m (a, [AllocStm])
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, [AllocStm]), VNameSource))
 -> m (a, [AllocStm]))
-> (VNameSource -> ((a, [AllocStm]), VNameSource))
-> m (a, [AllocStm])
forall a b. (a -> b) -> a -> b
$ (a, VNameSource, [AllocStm]) -> ((a, [AllocStm]), VNameSource)
forall a b b. (a, b, b) -> ((a, b), b)
frob ((a, VNameSource, [AllocStm]) -> ((a, [AllocStm]), VNameSource))
-> (VNameSource -> (a, VNameSource, [AllocStm]))
-> VNameSource
-> ((a, [AllocStm]), VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RWS (Scope lore) [AllocStm] VNameSource a
-> Scope lore -> VNameSource -> (a, VNameSource, [AllocStm])
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope lore) [AllocStm] VNameSource a
m Scope lore
mems
  where frob :: (a, b, b) -> ((a, b), b)
frob (a
a,b
s,b
w) = ((a
a,b
w),b
s)

arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp :: Type -> PrimExp VName
arraySizeInBytesExp Type
t =
  [PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
    [ PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v
toInt64 (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ [PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)
    , PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp VName) -> PrimValue -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (Int64 -> IntValue) -> Int64 -> IntValue
forall a b. (a -> b) -> a -> b
$ PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize (PrimType -> Int64) -> PrimType -> Int64
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t ]
  where toInt64 :: PrimExp v -> PrimExp v
toInt64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (ConvOp -> PrimExp v -> PrimExp v)
-> ConvOp -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64

arraySizeInBytesExpM :: Allocator lore m => Type -> m (PrimExp VName)
arraySizeInBytesExpM :: Type -> m (PrimExp VName)
arraySizeInBytesExpM Type
t = do
  Result
dims <- (SubExp -> m SubExp) -> Result -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)
  let dim_prod_i32 :: PrimExp VName
dim_prod_i32 = [PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v
toInt64 (PrimExp VName -> PrimExp VName)
-> (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) Result
dims
  let elm_size_i64 :: PrimExp VName
elm_size_i64 = PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp VName) -> PrimValue -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (Int64 -> IntValue) -> Int64 -> IntValue
forall a b. (a -> b) -> a -> b
$ PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize (PrimType -> Int64) -> PrimType -> Int64
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
  PrimExp VName -> m (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> m (PrimExp VName))
-> PrimExp VName -> m (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ [PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [ PrimExp VName
dim_prod_i32, PrimExp VName
elm_size_i64 ]
  where toInt64 :: PrimExp v -> PrimExp v
toInt64 = ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (ConvOp -> PrimExp v -> PrimExp v)
-> ConvOp -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64

arraySizeInBytes :: Allocator lore m => Type -> m SubExp
arraySizeInBytes :: Type -> m SubExp
arraySizeInBytes = String -> PrimExp VName -> m SubExp
forall lore (m :: * -> *).
Allocator lore m =>
String -> PrimExp VName -> m SubExp
computeSize String
"bytes" (PrimExp VName -> m SubExp)
-> (Type -> m (PrimExp VName)) -> Type -> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Type -> m (PrimExp VName)
forall lore (m :: * -> *).
Allocator lore m =>
Type -> m (PrimExp VName)
arraySizeInBytesExpM

allocForArray :: Allocator lore m =>
                 Type -> Space -> m VName
allocForArray :: Type -> Space -> m VName
allocForArray Type
t Space
space = do
  SubExp
size <- Type -> m SubExp
forall lore (m :: * -> *). Allocator lore m => Type -> m SubExp
arraySizeInBytes Type
t
  String -> SubExp -> Space -> m VName
forall lore (m :: * -> *).
Allocator lore m =>
String -> SubExp -> Space -> m VName
allocateMemory String
"mem" SubExp
size Space
space

allocsForStm :: (Allocator lore m, ExpAttr lore ~ ()) =>
                [Ident] -> [Ident] -> Exp lore
             -> m (Stm lore, [AllocStm])
allocsForStm :: [Ident] -> [Ident] -> Exp lore -> m (Stm lore, [AllocStm])
allocsForStm [Ident]
sizeidents [Ident]
validents Exp lore
e = do
  [ExpReturns]
rts <- Exp lore -> m [ExpReturns]
forall (m :: * -> *) lore.
(Monad m, HasScope lore m, ExplicitMemorish lore) =>
Exp lore -> m [ExpReturns]
expReturns Exp lore
e
  [ExpHint]
hints <- Exp lore -> m [ExpHint]
forall lore (m :: * -> *).
Allocator lore m =>
ExpT lore -> m [ExpHint]
expHints Exp lore
e
  ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctxElems, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
valElems, [AllocStm]
postbnds) <- [Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem ExplicitMemory], [PatElem ExplicitMemory],
      [AllocStm])
forall lore (m :: * -> *).
Allocator lore m =>
[Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem ExplicitMemory], [PatElem ExplicitMemory],
      [AllocStm])
allocsForPattern [Ident]
sizeidents [Ident]
validents [ExpReturns]
rts [ExpHint]
hints
  (Stm lore, [AllocStm]) -> m (Stm lore, [AllocStm])
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
ctxElems [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
valElems) (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) Exp lore
e,
          [AllocStm]
postbnds)

patternWithAllocations :: (Allocator lore m, ExpAttr lore ~ ()) =>
                          [VName]
                       -> Exp lore
                       -> m (Pattern lore)
patternWithAllocations :: [VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names Exp lore
e = do
  ([Type]
ts',[Ident]
sizes) <- [TypeBase ExtShape NoUniqueness] -> m ([Type], [Ident])
forall (m :: * -> *) u.
MonadFreshNames m =>
[TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' ([TypeBase ExtShape NoUniqueness] -> m ([Type], [Ident]))
-> m [TypeBase ExtShape NoUniqueness] -> m ([Type], [Ident])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp lore -> m [TypeBase ExtShape NoUniqueness]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [TypeBase ExtShape NoUniqueness]
expExtType Exp lore
e
  let identForBindage :: VName -> Type -> f Ident
identForBindage VName
name Type
t =
        Ident -> f Ident
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ident -> f Ident) -> Ident -> f Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
  [Ident]
vals <- [m Ident] -> m [Ident]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ VName -> Type -> m Ident
forall (f :: * -> *). Applicative f => VName -> Type -> f Ident
identForBindage VName
name Type
t | (VName
name, Type
t) <- [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Type]
ts' ]
  (Let Pattern lore
pat StmAux (ExpAttr lore)
_ Exp lore
_, [AllocStm]
extrabnds) <- [Ident] -> [Ident] -> Exp lore -> m (Stm lore, [AllocStm])
forall lore (m :: * -> *).
(Allocator lore m, ExpAttr lore ~ ()) =>
[Ident] -> [Ident] -> Exp lore -> m (Stm lore, [AllocStm])
allocsForStm [Ident]
sizes [Ident]
vals Exp lore
e
  case [AllocStm]
extrabnds of
    [] -> PatternT (MemInfo SubExp NoUniqueness MemBind)
-> m (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return Pattern lore
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
    [AllocStm]
_  -> String -> m (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall a. HasCallStack => String -> a
error (String -> m (PatternT (MemInfo SubExp NoUniqueness MemBind)))
-> String -> m (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ String
"Cannot make allocations for pattern of " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp lore -> String
forall a. Pretty a => a -> String
pretty Exp lore
e

allocsForPattern :: Allocator lore m =>
                    [Ident] -> [Ident] -> [ExpReturns] -> [ExpHint]
                 -> m ([PatElem ExplicitMemory],
                       [PatElem ExplicitMemory],
                       [AllocStm])
allocsForPattern :: [Ident]
-> [Ident]
-> [ExpReturns]
-> [ExpHint]
-> m ([PatElem ExplicitMemory], [PatElem ExplicitMemory],
      [AllocStm])
allocsForPattern [Ident]
sizeidents [Ident]
validents [ExpReturns]
rts [ExpHint]
hints = do
  let sizes' :: [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
sizes' = [ VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
size (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int32 | VName
size <- (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
sizeidents ]
  ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
vals, ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
exts, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
mems, [AllocStm]
postbnds)) <-
    WriterT
  ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
   [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
  m
  [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> m ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm]))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
    [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
   m
   [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
 -> m ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
        [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])))
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> m ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm]))
forall a b. (a -> b) -> a -> b
$ [(Ident, ExpReturns, ExpHint)]
-> ((Ident, ExpReturns, ExpHint)
    -> WriterT
         ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
          [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
         m
         (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
-> [ExpReturns] -> [ExpHint] -> [(Ident, ExpReturns, ExpHint)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Ident]
validents [ExpReturns]
rts [ExpHint]
hints) (((Ident, ExpReturns, ExpHint)
  -> WriterT
       ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
        [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
       m
       (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)])
-> ((Ident, ExpReturns, ExpHint)
    -> WriterT
         ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
          [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
         m
         (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ \(Ident
ident, ExpReturns
rt, ExpHint
hint) -> do
      let shape :: Shape
shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
ident
      case ExpReturns
rt of
        MemPrim PrimType
_ -> do
          MemInfo SubExp NoUniqueness MemBind
summary <- m (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (MemInfo SubExp NoUniqueness MemBind)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (MemInfo SubExp NoUniqueness MemBind)
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      (MemInfo SubExp NoUniqueness MemBind))
-> m (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Type -> ExpHint -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
Allocator lore m =>
Type -> ExpHint -> m (MemInfo SubExp NoUniqueness MemBind)
summaryForBindage (Ident -> Type
identType Ident
ident) ExpHint
hint
          PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
ident) MemInfo SubExp NoUniqueness MemBind
summary

        MemMem Space
space ->
          PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
ident) (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
          Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

        MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsInBlock VName
mem ExtIxFun
extixfun)) -> do
          ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
patels, IxFun
ixfn) <- Ident
-> ExtIxFun
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)], IxFun)
forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
ident ExtIxFun
extixfun
          ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
 [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
patels, [], [])

          PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
ident) (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
            PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$
            VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfn

        MemArray PrimType
_ ExtShape
extshape NoUniqueness
_ Maybe MemReturn
Nothing
          | Just Result
_ <- ExtShape -> Maybe Result
forall b. ShapeBase (Ext b) -> Maybe [b]
knownShape ExtShape
extshape -> do
            MemInfo SubExp NoUniqueness MemBind
summary <- m (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (MemInfo SubExp NoUniqueness MemBind)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (MemInfo SubExp NoUniqueness MemBind)
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      (MemInfo SubExp NoUniqueness MemBind))
-> m (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Type -> ExpHint -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
Allocator lore m =>
Type -> ExpHint -> m (MemInfo SubExp NoUniqueness MemBind)
summaryForBindage (Ident -> Type
identType Ident
ident) ExpHint
hint
            PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
ident) MemInfo SubExp NoUniqueness MemBind
summary

        MemArray PrimType
bt ExtShape
_ NoUniqueness
u (Just (ReturnsNewBlock Space
space Int
_ ExtIxFun
extixfn)) -> do
          -- treat existential index function first
          ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
patels, IxFun
ixfn) <- Ident
-> ExtIxFun
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)], IxFun)
forall (m :: * -> *) d u ret.
MonadFreshNames m =>
Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
ident ExtIxFun
extixfn
          ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
 [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
patels, [], [])

          Ident
memid <- m Ident
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     Ident
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Ident
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      Ident)
-> m Ident
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     Ident
forall a b. (a -> b) -> a -> b
$ Ident -> Space -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
Ident -> Space -> m Ident
mkMemIdent Ident
ident Space
space
          ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
 [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([], [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
memid) (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space], [])
          PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (MemInfo SubExp NoUniqueness MemBind)
 -> WriterT
      ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
       [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
      m
      (PatElemT (MemInfo SubExp NoUniqueness MemBind)))
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
ident) (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$
            VName -> IxFun -> MemBind
ArrayIn (Ident -> VName
identName Ident
memid) IxFun
ixfn

        ExpReturns
_ -> String
-> WriterT
     ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
     m
     (PatElemT (MemInfo SubExp NoUniqueness MemBind))
forall a. HasCallStack => String -> a
error String
"Impossible case reached in allocsForPattern!"

  ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
 [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
-> m ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
      [PatElemT (MemInfo SubExp NoUniqueness MemBind)], [AllocStm])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
sizes' [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a. Semigroup a => a -> a -> a
<> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
exts [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall a. Semigroup a => a -> a -> a
<> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
mems,
          [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
vals,
          [AllocStm]
postbnds)
  where knownShape :: ShapeBase (Ext b) -> Maybe [b]
knownShape = (Ext b -> Maybe b) -> [Ext b] -> Maybe [b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext b -> Maybe b
forall a. Ext a -> Maybe a
known ([Ext b] -> Maybe [b])
-> (ShapeBase (Ext b) -> [Ext b]) -> ShapeBase (Ext b) -> Maybe [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext b) -> [Ext b]
forall d. ShapeBase d -> [d]
shapeDims
        known :: Ext a -> Maybe a
known (Free a
v) = a -> Maybe a
forall a. a -> Maybe a
Just a
v
        known Ext{} = Maybe a
forall a. Maybe a
Nothing

        mkMemIdent :: (MonadFreshNames m) => Ident -> Space -> m Ident
        mkMemIdent :: Ident -> Space -> m Ident
mkMemIdent Ident
ident Space
space = do
          let memname :: String
memname = VName -> String
baseString (Ident -> VName
identName Ident
ident) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
memname (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space

        instantiateExtIxFun :: MonadFreshNames m =>
                               Ident -> ExtIxFun ->
                               m ([PatElemT (MemInfo d u ret)], IxFun)
        instantiateExtIxFun :: Ident -> ExtIxFun -> m ([PatElemT (MemInfo d u ret)], IxFun)
instantiateExtIxFun Ident
idd ExtIxFun
ext_ixfn = do
          let isAndPtps :: [(Int, PrimType)]
isAndPtps = Set (Int, PrimType) -> [(Int, PrimType)]
forall a. Set a -> [a]
S.toList (Set (Int, PrimType) -> [(Int, PrimType)])
-> Set (Int, PrimType) -> [(Int, PrimType)]
forall a b. (a -> b) -> a -> b
$
                          ((Ext VName, PrimType) -> Set (Int, PrimType))
-> Set (Ext VName, PrimType) -> Set (Int, PrimType)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Ext VName, PrimType) -> Set (Int, PrimType)
forall a. (Ext a, PrimType) -> Set (Int, PrimType)
onlyExts (Set (Ext VName, PrimType) -> Set (Int, PrimType))
-> Set (Ext VName, PrimType) -> Set (Int, PrimType)
forall a b. (a -> b) -> a -> b
$
                          (PrimExp (Ext VName) -> Set (Ext VName, PrimType))
-> ExtIxFun -> Set (Ext VName, PrimType)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap PrimExp (Ext VName) -> Set (Ext VName, PrimType)
forall a. Ord a => PrimExp a -> Set (a, PrimType)
leafExpTypes ExtIxFun
ext_ixfn

          -- Find the existentials that reuse the sizeidents, and
          -- those that need new pattern elements.  Assumes that the
          -- Exts form a contiguous interval of integers.
          let ([(Int, PrimType)]
size_exts, [(Int, PrimType)]
new_exts) =
                ((Int, PrimType) -> Bool)
-> [(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<[Ident] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ident]
sizeidents) (Int -> Bool)
-> ((Int, PrimType) -> Int) -> (Int, PrimType) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, PrimType) -> Int
forall a b. (a, b) -> a
fst) ([(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)]))
-> [(Int, PrimType)] -> ([(Int, PrimType)], [(Int, PrimType)])
forall a b. (a -> b) -> a -> b
$ [(Int, PrimType)] -> [(Int, PrimType)]
forall a. Ord a => [a] -> [a]
sort [(Int, PrimType)]
isAndPtps
          ([(Ext VName, PrimExp (Ext VName))]
new_substs, [PatElemT (MemInfo d u ret)]
patels) <-
            ([((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
 -> ([(Ext VName, PrimExp (Ext VName))],
     [PatElemT (MemInfo d u ret)]))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
-> m ([(Ext VName, PrimExp (Ext VName))],
      [PatElemT (MemInfo d u ret)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
-> ([(Ext VName, PrimExp (Ext VName))],
    [PatElemT (MemInfo d u ret)])
forall a b. [(a, b)] -> ([a], [b])
unzip (m [((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))]
 -> m ([(Ext VName, PrimExp (Ext VName))],
       [PatElemT (MemInfo d u ret)]))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
-> m ([(Ext VName, PrimExp (Ext VName))],
      [PatElemT (MemInfo d u ret)])
forall a b. (a -> b) -> a -> b
$ [(Int, PrimType)]
-> ((Int, PrimType)
    -> m ((Ext VName, PrimExp (Ext VName)),
          PatElemT (MemInfo d u ret)))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Int, PrimType)]
new_exts (((Int, PrimType)
  -> m ((Ext VName, PrimExp (Ext VName)),
        PatElemT (MemInfo d u ret)))
 -> m [((Ext VName, PrimExp (Ext VName)),
        PatElemT (MemInfo d u ret))])
-> ((Int, PrimType)
    -> m ((Ext VName, PrimExp (Ext VName)),
          PatElemT (MemInfo d u ret)))
-> m [((Ext VName, PrimExp (Ext VName)),
       PatElemT (MemInfo d u ret))]
forall a b. (a -> b) -> a -> b
$ \(Int
i, PrimType
t) -> do
            VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (Ident -> VName
identName Ident
idd) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_ixfn"
            ((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))
-> m ((Ext VName, PrimExp (Ext VName)), PatElemT (MemInfo d u ret))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i, Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
v) PrimType
t),
                    VName -> MemInfo d u ret -> PatElemT (MemInfo d u ret)
forall attr. VName -> attr -> PatElemT attr
PatElem VName
v (MemInfo d u ret -> PatElemT (MemInfo d u ret))
-> MemInfo d u ret -> PatElemT (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t)
          let size_substs :: [(Ext VName, PrimExp (Ext VName))]
size_substs = ((Int, PrimType) -> Ident -> (Ext VName, PrimExp (Ext VName)))
-> [(Int, PrimType)]
-> [Ident]
-> [(Ext VName, PrimExp (Ext VName))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(Int
i, PrimType
t) Ident
ident ->
                                    (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i, Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free (Ident -> VName
identName Ident
ident)) PrimType
t))
                            [(Int, PrimType)]
size_exts [Ident]
sizeidents
              substs :: Map (Ext VName) (PrimExp (Ext VName))
substs = [(Ext VName, PrimExp (Ext VName))]
-> Map (Ext VName) (PrimExp (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, PrimExp (Ext VName))]
 -> Map (Ext VName) (PrimExp (Ext VName)))
-> [(Ext VName, PrimExp (Ext VName))]
-> Map (Ext VName) (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ [(Ext VName, PrimExp (Ext VName))]
new_substs [(Ext VName, PrimExp (Ext VName))]
-> [(Ext VName, PrimExp (Ext VName))]
-> [(Ext VName, PrimExp (Ext VName))]
forall a. Semigroup a => a -> a -> a
<> [(Ext VName, PrimExp (Ext VName))]
size_substs
          IxFun
ixfn <- ExtIxFun -> m IxFun
forall (m :: * -> *). Monad m => ExtIxFun -> m IxFun
instantiateIxFun (ExtIxFun -> m IxFun) -> ExtIxFun -> m IxFun
forall a b. (a -> b) -> a -> b
$ Map (Ext VName) (PrimExp (Ext VName)) -> ExtIxFun -> ExtIxFun
forall a.
Ord a =>
Map a (PrimExp a) -> IxFun (PrimExp a) -> IxFun (PrimExp a)
IxFun.substituteInIxFun Map (Ext VName) (PrimExp (Ext VName))
substs ExtIxFun
ext_ixfn

          ([PatElemT (MemInfo d u ret)], IxFun)
-> m ([PatElemT (MemInfo d u ret)], IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return ([PatElemT (MemInfo d u ret)]
patels, IxFun
ixfn)

onlyExts :: (Ext a, PrimType) -> S.Set (Int, PrimType)
onlyExts :: (Ext a, PrimType) -> Set (Int, PrimType)
onlyExts (Free a
_, PrimType
_) = Set (Int, PrimType)
forall a. Set a
S.empty
onlyExts (Ext Int
i, PrimType
t) = (Int, PrimType) -> Set (Int, PrimType)
forall a. a -> Set a
S.singleton (Int
i, PrimType
t)


instantiateIxFun :: Monad m => ExtIxFun -> m IxFun
instantiateIxFun :: ExtIxFun -> m IxFun
instantiateIxFun = (PrimExp (Ext VName) -> m (PrimExp VName)) -> ExtIxFun -> m IxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((PrimExp (Ext VName) -> m (PrimExp VName)) -> ExtIxFun -> m IxFun)
-> (PrimExp (Ext VName) -> m (PrimExp VName))
-> ExtIxFun
-> m IxFun
forall a b. (a -> b) -> a -> b
$ (Ext VName -> m VName) -> PrimExp (Ext VName) -> m (PrimExp VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> m VName
forall (m :: * -> *) a. Monad m => Ext a -> m a
inst
  where inst :: Ext a -> m a
inst Ext{} = String -> m a
forall a. HasCallStack => String -> a
error String
"instantiateIxFun: not yet"
        inst (Free a
x) = a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

summaryForBindage :: Allocator lore m =>
                     Type -> ExpHint
                  -> m (MemBound NoUniqueness)
summaryForBindage :: Type -> ExpHint -> m (MemInfo SubExp NoUniqueness MemBind)
summaryForBindage (Prim PrimType
bt) ExpHint
_ =
  MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
 -> m (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
summaryForBindage (Mem Space
space) ExpHint
_ =
  MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
 -> m (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
summaryForBindage t :: Type
t@(Array PrimType
bt Shape
shape NoUniqueness
u) ExpHint
NoHint = do
  VName
m <- Type -> Space -> m VName
forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
t (Space -> m VName) -> m Space -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace
  MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
 -> m (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType
-> Shape
-> NoUniqueness
-> VName
-> Type
-> MemInfo SubExp NoUniqueness MemBind
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIndexFunction PrimType
bt Shape
shape NoUniqueness
u VName
m Type
t
summaryForBindage Type
t (Hint IxFun
ixfun Space
space) = do
  let bt :: PrimType
bt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
  SubExp
bytes <- String -> PrimExp VName -> m SubExp
forall lore (m :: * -> *).
Allocator lore m =>
String -> PrimExp VName -> m SubExp
computeSize String
"bytes" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
           [PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int64) ([PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (IxFun -> [PrimExp VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun
ixfun)),
                    Int64 -> PrimExp VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PrimType -> Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)::Int64)]
  VName
m <- String -> SubExp -> Space -> m VName
forall lore (m :: * -> *).
Allocator lore m =>
String -> SubExp -> Space -> m VName
allocateMemory String
"mem" SubExp
bytes Space
space
  MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo SubExp NoUniqueness MemBind
 -> m (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> m (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
m IxFun
ixfun

lookupMemSpace :: (HasScope lore m, Monad m) => VName -> m Space
lookupMemSpace :: VName -> m Space
lookupMemSpace VName
v = do
  Type
t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  case Type
t of
    Mem Space
space -> Space -> m Space
forall (m :: * -> *) a. Monad m => a -> m a
return Space
space
    Type
_ -> String -> m Space
forall a. HasCallStack => String -> a
error (String -> m Space) -> String -> m Space
forall a b. (a -> b) -> a -> b
$ String
"lookupMemSpace: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not a memory block."

directIndexFunction :: PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIndexFunction :: PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIndexFunction PrimType
bt Shape
shape u
u VName
mem Type
t =
  PrimType -> Shape -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
  [PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([PrimExp VName] -> IxFun) -> [PrimExp VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Result -> [PrimExp VName]) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t

allocInFParams :: (Allocable fromlore tolore) =>
                  [(FParam fromlore, Space)] ->
                  ([FParam tolore] -> AllocM fromlore tolore a)
               -> AllocM fromlore tolore a
allocInFParams :: [(FParam fromlore, Space)]
-> ([FParam tolore] -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInFParams [(FParam fromlore, Space)]
params [FParam tolore] -> AllocM fromlore tolore a
m = do
  ([Param (MemInfo SubExp Uniqueness MemBind)]
valparams, [Param (MemInfo SubExp Uniqueness MemBind)]
memparams) <-
    WriterT
  [Param (MemInfo SubExp Uniqueness MemBind)]
  (AllocM fromlore tolore)
  [Param (MemInfo SubExp Uniqueness MemBind)]
-> AllocM
     fromlore
     tolore
     ([Param (MemInfo SubExp Uniqueness MemBind)],
      [Param (MemInfo SubExp Uniqueness MemBind)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   [Param (MemInfo SubExp Uniqueness MemBind)]
   (AllocM fromlore tolore)
   [Param (MemInfo SubExp Uniqueness MemBind)]
 -> AllocM
      fromlore
      tolore
      ([Param (MemInfo SubExp Uniqueness MemBind)],
       [Param (MemInfo SubExp Uniqueness MemBind)]))
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     [Param (MemInfo SubExp Uniqueness MemBind)]
-> AllocM
     fromlore
     tolore
     ([Param (MemInfo SubExp Uniqueness MemBind)],
      [Param (MemInfo SubExp Uniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, Space)
 -> WriterT
      [Param (MemInfo SubExp Uniqueness MemBind)]
      (AllocM fromlore tolore)
      (Param (MemInfo SubExp Uniqueness MemBind)))
-> [(Param DeclType, Space)]
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Param DeclType
 -> Space
 -> WriterT
      [Param (MemInfo SubExp Uniqueness MemBind)]
      (AllocM fromlore tolore)
      (Param (MemInfo SubExp Uniqueness MemBind)))
-> (Param DeclType, Space)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Param DeclType
-> Space
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind))
forall fromlore tolore.
Allocable fromlore tolore =>
FParam fromlore
-> Space
-> WriterT [FParam tolore] (AllocM fromlore tolore) (FParam tolore)
allocInFParam) [(Param DeclType, Space)]
[(FParam fromlore, Space)]
params
  let params' :: [Param (MemInfo SubExp Uniqueness MemBind)]
params' = [Param (MemInfo SubExp Uniqueness MemBind)]
memparams [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)]
valparams
      summary :: Scope tolore
summary = [Param (MemInfo SubExp Uniqueness MemBind)] -> Scope tolore
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
params'
  Scope tolore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope tolore
summary (AllocM fromlore tolore a -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$ [FParam tolore] -> AllocM fromlore tolore a
m [FParam tolore]
[Param (MemInfo SubExp Uniqueness MemBind)]
params'

allocInFParam :: (Allocable fromlore tolore) =>
                 FParam fromlore
              -> Space
              -> WriterT [FParam tolore]
                 (AllocM fromlore tolore) (FParam tolore)
allocInFParam :: FParam fromlore
-> Space
-> WriterT [FParam tolore] (AllocM fromlore tolore) (FParam tolore)
allocInFParam FParam fromlore
param Space
pspace =
  case Param DeclType -> DeclType
forall attr. DeclTyped attr => Param attr -> DeclType
paramDeclType Param DeclType
FParam fromlore
param of
    Array PrimType
bt Shape
shape Uniqueness
u -> do
      let memname :: String
memname = VName -> String
baseString (Param DeclType -> VName
forall attr. Param attr -> VName
paramName Param DeclType
FParam fromlore
param) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_mem"
          ixfun :: IxFun
ixfun = [PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([PrimExp VName] -> IxFun) -> [PrimExp VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Result -> [PrimExp VName]) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
mem <- AllocM fromlore tolore VName
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore VName
 -> WriterT
      [Param (MemInfo SubExp Uniqueness MemBind)]
      (AllocM fromlore tolore)
      VName)
-> AllocM fromlore tolore VName
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     VName
forall a b. (a -> b) -> a -> b
$ String -> AllocM fromlore tolore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
memname
      [Param (MemInfo SubExp Uniqueness MemBind)]
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [VName
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall attr. VName -> attr -> Param attr
Param VName
mem (MemInfo SubExp Uniqueness MemBind
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> MemInfo SubExp Uniqueness MemBind
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp Uniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
pspace]
      Param (MemInfo SubExp Uniqueness MemBind)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromlore
param { paramAttr :: MemInfo SubExp Uniqueness MemBind
paramAttr =  PrimType
-> Shape
-> Uniqueness
-> MemBind
-> MemInfo SubExp Uniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape Uniqueness
u (MemBind -> MemInfo SubExp Uniqueness MemBind)
-> MemBind -> MemInfo SubExp Uniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun }
    Prim PrimType
bt ->
      Param (MemInfo SubExp Uniqueness MemBind)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromlore
param { paramAttr :: MemInfo SubExp Uniqueness MemBind
paramAttr = PrimType -> MemInfo SubExp Uniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt }
    Mem Space
space ->
      Param (MemInfo SubExp Uniqueness MemBind)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return Param DeclType
FParam fromlore
param { paramAttr :: MemInfo SubExp Uniqueness MemBind
paramAttr = Space -> MemInfo SubExp Uniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space }

allocInMergeParams :: (Allocable fromlore tolore,
                       Allocator tolore (AllocM fromlore tolore)) =>
                      [VName]
                   -> [(FParam fromlore,SubExp)]
                   -> ([FParam tolore]
                       -> [FParam tolore]
                       -> ([SubExp] -> AllocM fromlore tolore ([SubExp], [SubExp]))
                       -> AllocM fromlore tolore a)
                   -> AllocM fromlore tolore a
allocInMergeParams :: [VName]
-> [(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInMergeParams [VName]
variant [(FParam fromlore, SubExp)]
merge [FParam tolore]
-> [FParam tolore]
-> (Result -> AllocM fromlore tolore (Result, Result))
-> AllocM fromlore tolore a
m = do
  (([Param (MemInfo SubExp Uniqueness MemBind)]
valparams, [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp]
handle_loop_subexps), [Param (MemInfo SubExp Uniqueness MemBind)]
mem_params) <-
    WriterT
  [Param (MemInfo SubExp Uniqueness MemBind)]
  (AllocM fromlore tolore)
  ([Param (MemInfo SubExp Uniqueness MemBind)],
   [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp])
-> AllocM
     fromlore
     tolore
     (([Param (MemInfo SubExp Uniqueness MemBind)],
       [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp]),
      [Param (MemInfo SubExp Uniqueness MemBind)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   [Param (MemInfo SubExp Uniqueness MemBind)]
   (AllocM fromlore tolore)
   ([Param (MemInfo SubExp Uniqueness MemBind)],
    [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp])
 -> AllocM
      fromlore
      tolore
      (([Param (MemInfo SubExp Uniqueness MemBind)],
        [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp]),
       [Param (MemInfo SubExp Uniqueness MemBind)]))
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     ([Param (MemInfo SubExp Uniqueness MemBind)],
      [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp])
-> AllocM
     fromlore
     tolore
     (([Param (MemInfo SubExp Uniqueness MemBind)],
       [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp]),
      [Param (MemInfo SubExp Uniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ [(Param (MemInfo SubExp Uniqueness MemBind),
  SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)],
    [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param (MemInfo SubExp Uniqueness MemBind),
   SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)]
 -> ([Param (MemInfo SubExp Uniqueness MemBind)],
     [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp]))
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     [(Param (MemInfo SubExp Uniqueness MemBind),
       SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)]
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     ([Param (MemInfo SubExp Uniqueness MemBind)],
      [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param DeclType, SubExp)
 -> WriterT
      [Param (MemInfo SubExp Uniqueness MemBind)]
      (AllocM fromlore tolore)
      (Param (MemInfo SubExp Uniqueness MemBind),
       SubExp -> WriterT Result (AllocM fromlore tolore) SubExp))
-> [(Param DeclType, SubExp)]
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     [(Param (MemInfo SubExp Uniqueness MemBind),
       SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
allocInMergeParam [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
merge
  let mergeparams' :: [Param (MemInfo SubExp Uniqueness MemBind)]
mergeparams' = [Param (MemInfo SubExp Uniqueness MemBind)]
mem_params [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)]
valparams
      summary :: Scope tolore
summary = [Param (MemInfo SubExp Uniqueness MemBind)] -> Scope tolore
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
mergeparams'

      mk_loop_res :: Result -> AllocM fromlore tolore (Result, Result)
mk_loop_res Result
ses = do
        (Result
valargs, Result
memargs) <-
          WriterT Result (AllocM fromlore tolore) Result
-> AllocM fromlore tolore (Result, Result)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Result (AllocM fromlore tolore) Result
 -> AllocM fromlore tolore (Result, Result))
-> WriterT Result (AllocM fromlore tolore) Result
-> AllocM fromlore tolore (Result, Result)
forall a b. (a -> b) -> a -> b
$ ((SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
 -> SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
-> [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp]
-> Result
-> WriterT Result (AllocM fromlore tolore) Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
-> SubExp -> WriterT Result (AllocM fromlore tolore) SubExp
forall a b. (a -> b) -> a -> b
($) [SubExp -> WriterT Result (AllocM fromlore tolore) SubExp]
handle_loop_subexps Result
ses
        (Result, Result) -> AllocM fromlore tolore (Result, Result)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
memargs, Result
valargs)

  Scope tolore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope tolore
summary (AllocM fromlore tolore a -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$ [FParam tolore]
-> [FParam tolore]
-> (Result -> AllocM fromlore tolore (Result, Result))
-> AllocM fromlore tolore a
m [FParam tolore]
[Param (MemInfo SubExp Uniqueness MemBind)]
mem_params [FParam tolore]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams Result -> AllocM fromlore tolore (Result, Result)
mk_loop_res
  where allocInMergeParam :: (Param DeclType, SubExp)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
allocInMergeParam (Param DeclType
mergeparam, Var VName
v)
          | Array PrimType
bt Shape
shape Uniqueness
u <- Param DeclType -> DeclType
forall attr. DeclTyped attr => Param attr -> DeclType
paramDeclType Param DeclType
mergeparam = do
              (VName
mem, IxFun
ixfun) <- AllocM fromlore tolore (VName, IxFun)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore (VName, IxFun)
 -> WriterT
      [Param (MemInfo SubExp Uniqueness MemBind)]
      (AllocM fromlore tolore)
      (VName, IxFun))
-> AllocM fromlore tolore (VName, IxFun)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(ExplicitMemorish lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
              Space
space <- AllocM fromlore tolore Space
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore Space
 -> WriterT
      [Param (MemInfo SubExp Uniqueness MemBind)]
      (AllocM fromlore tolore)
      Space)
-> AllocM fromlore tolore Space
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     Space
forall a b. (a -> b) -> a -> b
$ VName -> AllocM fromlore tolore Space
forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
              Bool
reuse <- (AllocEnv fromlore tolore -> Bool)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore tolore -> Bool
forall fromlore tolore. AllocEnv fromlore tolore -> Bool
aggressiveReuse
              if Space
space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= String -> Space
Space String
"local" Bool -> Bool -> Bool
&&
                 Bool
reuse Bool -> Bool -> Bool
&&
                 Uniqueness
u Uniqueness -> Uniqueness -> Bool
forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique Bool -> Bool -> Bool
&&
                 Param DeclType -> Bool
loopInvariantShape Param DeclType
mergeparam
                then (Param (MemInfo SubExp Uniqueness MemBind),
 SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param DeclType
mergeparam { paramAttr :: MemInfo SubExp Uniqueness MemBind
paramAttr = PrimType
-> Shape
-> Uniqueness
-> MemBind
-> MemInfo SubExp Uniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape Uniqueness
Unique (MemBind -> MemInfo SubExp Uniqueness MemBind)
-> MemBind -> MemInfo SubExp Uniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun },
                             AllocM fromlore tolore SubExp
-> WriterT Result (AllocM fromlore tolore) SubExp
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore SubExp
 -> WriterT Result (AllocM fromlore tolore) SubExp)
-> (SubExp -> AllocM fromlore tolore SubExp)
-> SubExp
-> WriterT Result (AllocM fromlore tolore) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> VName -> IxFun -> SubExp -> AllocM fromlore tolore SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Type -> VName -> IxFun -> SubExp -> AllocM fromlore tolore SubExp
ensureArrayIn (Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param DeclType
mergeparam) VName
mem IxFun
ixfun)
                else do Space
def_space <- (AllocEnv fromlore tolore -> Space)
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore tolore -> Space
forall fromlore tolore. AllocEnv fromlore tolore -> Space
allocSpace
                        Param DeclType
-> Space
-> WriterT
     [FParam tolore]
     (AllocM fromlore tolore)
     (FParam tolore,
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
forall tolore tolore fromlore fromlore.
(OpReturns tolore, Checkable tolore, PrettyLore fromlore,
 PrettyLore fromlore, SizeSubst (Op tolore), SizeSubst (Op tolore),
 BinderOps tolore, BinderOps tolore,
 Allocator tolore (AllocM fromlore tolore), LetAttr fromlore ~ Type,
 BodyAttr fromlore ~ (), LetAttr fromlore ~ Type,
 BranchType tolore ~ BodyReturns, BodyAttr fromlore ~ (),
 FParamAttr fromlore ~ DeclType, BodyAttr tolore ~ (),
 LetAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 LParamAttr fromlore ~ Type, BodyAttr tolore ~ (),
 LetAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 FParamAttr fromlore ~ DeclType,
 LParamAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 BranchType fromlore ~ TypeBase ExtShape NoUniqueness,
 LParamAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 RetType tolore ~ FunReturns,
 BranchType fromlore ~ TypeBase ExtShape NoUniqueness,
 RetType fromlore ~ DeclExtType, ExpAttr tolore ~ (),
 FParamAttr tolore ~ MemInfo SubExp Uniqueness MemBind,
 RetType fromlore ~ DeclExtType, ExpAttr tolore ~ (),
 FParamAttr tolore ~ MemInfo SubExp Uniqueness MemBind,
 LParamAttr fromlore ~ Type) =>
Param DeclType
-> Space
-> WriterT
     [Param (FParamAttr tolore)]
     (AllocM fromlore tolore)
     (Param (FParamAttr tolore),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
doDefault Param DeclType
mergeparam Space
def_space

        allocInMergeParam (Param DeclType
mergeparam, SubExp
_) = Param DeclType
-> Space
-> WriterT
     [FParam tolore]
     (AllocM fromlore tolore)
     (FParam tolore,
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
forall tolore tolore fromlore fromlore.
(OpReturns tolore, Checkable tolore, PrettyLore fromlore,
 PrettyLore fromlore, SizeSubst (Op tolore), SizeSubst (Op tolore),
 BinderOps tolore, BinderOps tolore,
 Allocator tolore (AllocM fromlore tolore), LetAttr fromlore ~ Type,
 BodyAttr fromlore ~ (), LetAttr fromlore ~ Type,
 BranchType tolore ~ BodyReturns, BodyAttr fromlore ~ (),
 FParamAttr fromlore ~ DeclType, BodyAttr tolore ~ (),
 LetAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 LParamAttr fromlore ~ Type, BodyAttr tolore ~ (),
 LetAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 FParamAttr fromlore ~ DeclType,
 LParamAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 BranchType fromlore ~ TypeBase ExtShape NoUniqueness,
 LParamAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 RetType tolore ~ FunReturns,
 BranchType fromlore ~ TypeBase ExtShape NoUniqueness,
 RetType fromlore ~ DeclExtType, ExpAttr tolore ~ (),
 FParamAttr tolore ~ MemInfo SubExp Uniqueness MemBind,
 RetType fromlore ~ DeclExtType, ExpAttr tolore ~ (),
 FParamAttr tolore ~ MemInfo SubExp Uniqueness MemBind,
 LParamAttr fromlore ~ Type) =>
Param DeclType
-> Space
-> WriterT
     [Param (FParamAttr tolore)]
     (AllocM fromlore tolore)
     (Param (FParamAttr tolore),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
doDefault Param DeclType
mergeparam (Space
 -> WriterT
      [Param (MemInfo SubExp Uniqueness MemBind)]
      (AllocM fromlore tolore)
      (Param (MemInfo SubExp Uniqueness MemBind),
       SubExp -> WriterT Result (AllocM fromlore tolore) SubExp))
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     Space
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     (Param (MemInfo SubExp Uniqueness MemBind),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< AllocM fromlore tolore Space
-> WriterT
     [Param (MemInfo SubExp Uniqueness MemBind)]
     (AllocM fromlore tolore)
     Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromlore tolore Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace

        doDefault :: Param DeclType
-> Space
-> WriterT
     [Param (FParamAttr tolore)]
     (AllocM fromlore tolore)
     (Param (FParamAttr tolore),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
doDefault Param DeclType
mergeparam Space
space = do
          Param (FParamAttr tolore)
mergeparam' <- FParam fromlore
-> Space
-> WriterT
     [Param (FParamAttr tolore)]
     (AllocM fromlore tolore)
     (Param (FParamAttr tolore))
forall fromlore tolore.
Allocable fromlore tolore =>
FParam fromlore
-> Space
-> WriterT [FParam tolore] (AllocM fromlore tolore) (FParam tolore)
allocInFParam Param DeclType
FParam fromlore
mergeparam Space
space
          (Param (FParamAttr tolore),
 SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
-> WriterT
     [Param (FParamAttr tolore)]
     (AllocM fromlore tolore)
     (Param (FParamAttr tolore),
      SubExp -> WriterT Result (AllocM fromlore tolore) SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (FParamAttr tolore)
mergeparam', Type
-> Space
-> SubExp
-> WriterT Result (AllocM fromlore tolore) SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Type
-> Space
-> SubExp
-> WriterT Result (AllocM fromlore tolore) SubExp
linearFuncallArg (Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param DeclType
mergeparam) Space
space)

        variant_names :: [VName]
variant_names = [VName]
variant [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall attr. Param attr -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
merge
        loopInvariantShape :: Param DeclType -> Bool
loopInvariantShape =
          Bool -> Bool
not (Bool -> Bool)
-> (Param DeclType -> Bool) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
variant_names) ([VName] -> Bool)
-> (Param DeclType -> [VName]) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Result -> [VName]
subExpVars (Result -> [VName])
-> (Param DeclType -> Result) -> Param DeclType -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> Result)
-> (Param DeclType -> Type) -> Param DeclType -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Type
forall attr. Typed attr => Param attr -> Type
paramType

ensureArrayIn :: (Allocable fromlore tolore,
                  Allocator tolore (AllocM fromlore tolore)) =>
                 Type -> VName -> IxFun -> SubExp
              -> AllocM fromlore tolore SubExp
ensureArrayIn :: Type -> VName -> IxFun -> SubExp -> AllocM fromlore tolore SubExp
ensureArrayIn Type
_ VName
_ IxFun
_ (Constant PrimValue
v) =
  String -> AllocM fromlore tolore SubExp
forall a. HasCallStack => String -> a
error (String -> AllocM fromlore tolore SubExp)
-> String -> AllocM fromlore tolore SubExp
forall a b. (a -> b) -> a -> b
$ String
"ensureArrayIn: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" cannot be an array."
ensureArrayIn Type
t VName
mem IxFun
ixfun (Var VName
v) = do
  (VName
src_mem, IxFun
src_ixfun) <- VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(ExplicitMemorish lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  if VName
src_mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
mem Bool -> Bool -> Bool
&& IxFun
src_ixfun IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun
ixfun
    then SubExp -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> AllocM fromlore tolore SubExp)
-> SubExp -> AllocM fromlore tolore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    else do Ident
copy <- String -> Type -> AllocM fromlore tolore Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (VName -> String
baseString VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_ensure_copy") Type
t
            let summary :: MemInfo SubExp NoUniqueness MemBind
summary = PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$
                          VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun
                pat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
pat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
copy) MemInfo SubExp NoUniqueness MemBind
summary]
            Pattern (Lore (AllocM fromlore tolore))
-> Exp (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore (AllocM fromlore tolore))
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat (Exp (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ())
-> Exp (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ BasicOp tolore -> ExpT tolore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp tolore -> ExpT tolore) -> BasicOp tolore -> ExpT tolore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp tolore
forall lore. VName -> BasicOp lore
Copy VName
v
            SubExp -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> AllocM fromlore tolore SubExp)
-> SubExp -> AllocM fromlore tolore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
copy

ensureDirectArray :: (Allocable fromlore tolore,
                      Allocator tolore (AllocM fromlore tolore)) =>
                     Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray :: Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray Maybe Space
space_ok VName
v = do
  (VName
mem, IxFun
ixfun) <- VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(ExplicitMemorish lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
  Space
mem_space <- VName -> AllocM fromlore tolore Space
forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
  Space
default_space <- AllocM fromlore tolore Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace
  if IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun Bool -> Bool -> Bool
&& Bool -> (Space -> Bool) -> Maybe Space -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
==Space
mem_space) Maybe Space
space_ok
    then (VName, SubExp) -> AllocM fromlore tolore (VName, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName -> SubExp
Var VName
v)
    else Space -> AllocM fromlore tolore (VName, SubExp)
needCopy (Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
default_space Maybe Space
space_ok)
  where needCopy :: Space -> AllocM fromlore tolore (VName, SubExp)
needCopy Space
space =
          -- We need to do a new allocation, copy 'v', and make a new
          -- binding for the size of the memory block.
          Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
allocLinearArray Space
space (VName -> String
baseString VName
v) VName
v

allocLinearArray :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                    Space -> String -> VName
                 -> AllocM fromlore tolore (VName, SubExp)
allocLinearArray :: Space -> String -> VName -> AllocM fromlore tolore (VName, SubExp)
allocLinearArray Space
space String
s VName
v = do
  Type
t <- VName -> AllocM fromlore tolore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  VName
mem <- Type -> Space -> AllocM fromlore tolore VName
forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
t Space
space
  Ident
v' <- String -> Type -> AllocM fromlore tolore Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (String
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_linear") Type
t
  let pat :: PatternT (MemInfo SubExp NoUniqueness MemBind)
pat = [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> PatternT (MemInfo SubExp NoUniqueness MemBind)
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall attr. VName -> attr -> PatElemT attr
PatElem (Ident -> VName
identName Ident
v') (MemInfo SubExp NoUniqueness MemBind
 -> PatElemT (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElemT (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
                        PrimType
-> Shape
-> NoUniqueness
-> VName
-> Type
-> MemInfo SubExp NoUniqueness MemBind
forall u. PrimType -> Shape -> u -> VName -> Type -> MemBound u
directIndexFunction (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t)
                        NoUniqueness
NoUniqueness VName
mem Type
t]
  Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ())
-> Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ Pattern tolore
-> StmAux (ExpAttr tolore) -> Exp tolore -> Stm tolore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern tolore
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) (Exp tolore -> Stm tolore) -> Exp tolore -> Stm tolore
forall a b. (a -> b) -> a -> b
$ BasicOp tolore -> Exp tolore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp tolore -> Exp tolore) -> BasicOp tolore -> Exp tolore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp tolore
forall lore. VName -> BasicOp lore
Copy VName
v
  (VName, SubExp) -> AllocM fromlore tolore (VName, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v')

funcallArgs :: (Allocable fromlore tolore,
                Allocator tolore (AllocM fromlore tolore)) =>
               [(SubExp,Diet)] -> AllocM fromlore tolore [(SubExp,Diet)]
funcallArgs :: [(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args = do
  ([(SubExp, Diet)]
valargs, Result
mem_and_size_args) <- WriterT Result (AllocM fromlore tolore) [(SubExp, Diet)]
-> AllocM fromlore tolore ([(SubExp, Diet)], Result)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Result (AllocM fromlore tolore) [(SubExp, Diet)]
 -> AllocM fromlore tolore ([(SubExp, Diet)], Result))
-> WriterT Result (AllocM fromlore tolore) [(SubExp, Diet)]
-> AllocM fromlore tolore ([(SubExp, Diet)], Result)
forall a b. (a -> b) -> a -> b
$ [(SubExp, Diet)]
-> ((SubExp, Diet)
    -> WriterT Result (AllocM fromlore tolore) (SubExp, Diet))
-> WriterT Result (AllocM fromlore tolore) [(SubExp, Diet)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(SubExp, Diet)]
args (((SubExp, Diet)
  -> WriterT Result (AllocM fromlore tolore) (SubExp, Diet))
 -> WriterT Result (AllocM fromlore tolore) [(SubExp, Diet)])
-> ((SubExp, Diet)
    -> WriterT Result (AllocM fromlore tolore) (SubExp, Diet))
-> WriterT Result (AllocM fromlore tolore) [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ \(SubExp
arg,Diet
d) -> do
    Type
t <- AllocM fromlore tolore Type
-> WriterT Result (AllocM fromlore tolore) Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore Type
 -> WriterT Result (AllocM fromlore tolore) Type)
-> AllocM fromlore tolore Type
-> WriterT Result (AllocM fromlore tolore) Type
forall a b. (a -> b) -> a -> b
$ SubExp -> AllocM fromlore tolore Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
arg
    Space
space <- AllocM fromlore tolore Space
-> WriterT Result (AllocM fromlore tolore) Space
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift AllocM fromlore tolore Space
forall lore (m :: * -> *). Allocator lore m => m Space
askDefaultSpace
    SubExp
arg' <- Type
-> Space
-> SubExp
-> WriterT Result (AllocM fromlore tolore) SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Type
-> Space
-> SubExp
-> WriterT Result (AllocM fromlore tolore) SubExp
linearFuncallArg Type
t Space
space SubExp
arg
    (SubExp, Diet)
-> WriterT Result (AllocM fromlore tolore) (SubExp, Diet)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
arg', Diet
d)
  [(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)])
-> [(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> (SubExp, Diet)) -> Result -> [(SubExp, Diet)]
forall a b. (a -> b) -> [a] -> [b]
map (,Diet
Observe) Result
mem_and_size_args [(SubExp, Diet)] -> [(SubExp, Diet)] -> [(SubExp, Diet)]
forall a. Semigroup a => a -> a -> a
<> [(SubExp, Diet)]
valargs

linearFuncallArg :: (Allocable fromlore tolore,
                     Allocator tolore (AllocM fromlore tolore)) =>
                    Type -> Space -> SubExp
                 -> WriterT [SubExp] (AllocM fromlore tolore) SubExp
linearFuncallArg :: Type
-> Space
-> SubExp
-> WriterT Result (AllocM fromlore tolore) SubExp
linearFuncallArg Array{} Space
space (Var VName
v) = do
  (VName
mem, SubExp
arg') <- AllocM fromlore tolore (VName, SubExp)
-> WriterT Result (AllocM fromlore tolore) (VName, SubExp)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AllocM fromlore tolore (VName, SubExp)
 -> WriterT Result (AllocM fromlore tolore) (VName, SubExp))
-> AllocM fromlore tolore (VName, SubExp)
-> WriterT Result (AllocM fromlore tolore) (VName, SubExp)
forall a b. (a -> b) -> a -> b
$ Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space) VName
v
  Result -> WriterT Result (AllocM fromlore tolore) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [VName -> SubExp
Var VName
mem]
  SubExp -> WriterT Result (AllocM fromlore tolore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
arg'
linearFuncallArg Type
_ Space
_ SubExp
arg =
  SubExp -> WriterT Result (AllocM fromlore tolore) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
arg

explicitAllocations :: Pass Kernels ExplicitMemory
explicitAllocations :: Pass Kernels ExplicitMemory
explicitAllocations =
  String
-> String
-> (Prog Kernels -> PassM (Prog ExplicitMemory))
-> Pass Kernels ExplicitMemory
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"explicit allocations" String
"Transform program to explicit memory representation" ((Prog Kernels -> PassM (Prog ExplicitMemory))
 -> Pass Kernels ExplicitMemory)
-> (Prog Kernels -> PassM (Prog ExplicitMemory))
-> Pass Kernels ExplicitMemory
forall a b. (a -> b) -> a -> b
$
  (Stms Kernels -> PassM (Stms ExplicitMemory))
-> (Stms ExplicitMemory
    -> FunDef Kernels -> PassM (FunDef ExplicitMemory))
-> Prog Kernels
-> PassM (Prog ExplicitMemory)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms Kernels -> PassM (Stms ExplicitMemory)
forall (m :: * -> *).
MonadFreshNames m =>
Stms Kernels -> m (Stms ExplicitMemory)
onStms Stms ExplicitMemory
-> FunDef Kernels -> PassM (FunDef ExplicitMemory)
forall (m :: * -> *).
MonadFreshNames m =>
Stms ExplicitMemory -> FunDef Kernels -> m (FunDef ExplicitMemory)
allocInFun
  where onStms :: Stms Kernels -> m (Stms ExplicitMemory)
onStms Stms Kernels
stms =
          (Op Kernels -> AllocM Kernels ExplicitMemory (Op ExplicitMemory))
-> (Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint])
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
-> m (Stms ExplicitMemory)
forall (m :: * -> *) fromlore tolore a.
MonadFreshNames m =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op Kernels -> AllocM Kernels ExplicitMemory (Op ExplicitMemory)
HostOp Kernels (SOAC Kernels)
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
handleHostOp Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
forall (m :: * -> *).
Allocator ExplicitMemory m =>
Exp ExplicitMemory -> m [ExpHint]
kernelExpHints (AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
 -> m (Stms ExplicitMemory))
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
-> m (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stms Kernels
-> (Stms ExplicitMemory
    -> AllocM Kernels ExplicitMemory (Stms ExplicitMemory))
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms Kernels
stms Stms ExplicitMemory
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
forall (f :: * -> *) a. Applicative f => a -> f a
pure

explicitAllocationsInStms :: (MonadFreshNames m, HasScope ExplicitMemory m) =>
                             Stms Kernels -> m (Stms ExplicitMemory)
explicitAllocationsInStms :: Stms Kernels -> m (Stms ExplicitMemory)
explicitAllocationsInStms Stms Kernels
stms = do
  Scope ExplicitMemory
scope <- m (Scope ExplicitMemory)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (Op Kernels -> AllocM Kernels ExplicitMemory (Op ExplicitMemory))
-> (Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint])
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
-> m (Stms ExplicitMemory)
forall (m :: * -> *) fromlore tolore a.
MonadFreshNames m =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op Kernels -> AllocM Kernels ExplicitMemory (Op ExplicitMemory)
HostOp Kernels (SOAC Kernels)
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
handleHostOp Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
forall (m :: * -> *).
Allocator ExplicitMemory m =>
Exp ExplicitMemory -> m [ExpHint]
kernelExpHints (AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
 -> m (Stms ExplicitMemory))
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
-> m (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Scope ExplicitMemory
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope (AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
 -> AllocM Kernels ExplicitMemory (Stms ExplicitMemory))
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stms Kernels
-> (Stms ExplicitMemory
    -> AllocM Kernels ExplicitMemory (Stms ExplicitMemory))
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms Kernels
stms Stms ExplicitMemory
-> AllocM Kernels ExplicitMemory (Stms ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return

memoryInRetType :: [RetType Kernels] -> [RetType ExplicitMemory]
memoryInRetType :: [RetType Kernels] -> [RetType ExplicitMemory]
memoryInRetType [RetType Kernels]
ts = State Int [FunReturns] -> Int -> [FunReturns]
forall s a. State s a -> s -> a
evalState ((DeclExtType -> StateT Int Identity FunReturns)
-> [DeclExtType] -> State Int [FunReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DeclExtType -> StateT Int Identity FunReturns
forall (m :: * -> *) u.
MonadState Int m =>
TypeBase ExtShape u -> m (MemInfo (Ext SubExp) u MemReturn)
addAttr [DeclExtType]
[RetType Kernels]
ts) (Int -> [FunReturns]) -> Int -> [FunReturns]
forall a b. (a -> b) -> a -> b
$ [DeclExtType] -> Int
forall u. [TypeBase ExtShape u] -> Int
startOfFreeIDRange [DeclExtType]
[RetType Kernels]
ts
  where addAttr :: TypeBase ExtShape u -> m (MemInfo (Ext SubExp) u MemReturn)
addAttr (Prim PrimType
t) = MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
        addAttr Mem{} = String -> m (MemInfo (Ext SubExp) u MemReturn)
forall a. HasCallStack => String -> a
error String
"memoryInRetType: too much memory"
        addAttr (Array PrimType
bt ExtShape
shape u
u) = do
          Int
i <- m Int
forall s (m :: * -> *). MonadState s m => m s
get m Int -> m () -> m Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
          MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo (Ext SubExp) u MemReturn
 -> m (MemInfo (Ext SubExp) u MemReturn))
-> MemInfo (Ext SubExp) u MemReturn
-> m (MemInfo (Ext SubExp) u MemReturn)
forall a b. (a -> b) -> a -> b
$ PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ExtShape
shape u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
            Shape (PrimExp (Ext VName)) -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp (Ext VName)) -> ExtIxFun)
-> Shape (PrimExp (Ext VName)) -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> PrimExp (Ext VName))
-> [Ext SubExp] -> Shape (PrimExp (Ext VName))
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> PrimExp (Ext VName)
convert ([Ext SubExp] -> Shape (PrimExp (Ext VName)))
-> [Ext SubExp] -> Shape (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape

        convert :: Ext SubExp -> PrimExp (Ext VName)
convert (Ext Int
i) = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
int32
        convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
v

startOfFreeIDRange :: [TypeBase ExtShape u] -> Int
startOfFreeIDRange :: [TypeBase ExtShape u] -> Int
startOfFreeIDRange = Set Int -> Int
forall a. Set a -> Int
S.size (Set Int -> Int)
-> ([TypeBase ExtShape u] -> Set Int)
-> [TypeBase ExtShape u]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TypeBase ExtShape u] -> Set Int
forall u. [TypeBase ExtShape u] -> Set Int
shapeContext

allocInFun :: MonadFreshNames m =>
              Stms ExplicitMemory -> FunDef Kernels -> m (FunDef ExplicitMemory)
allocInFun :: Stms ExplicitMemory -> FunDef Kernels -> m (FunDef ExplicitMemory)
allocInFun Stms ExplicitMemory
consts (FunDef Maybe EntryPoint
entry Name
fname [RetType Kernels]
rettype [FParam Kernels]
params BodyT Kernels
fbody) =
  (Op Kernels -> AllocM Kernels ExplicitMemory (Op ExplicitMemory))
-> (Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint])
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
-> m (FunDef ExplicitMemory)
forall (m :: * -> *) fromlore tolore a.
MonadFreshNames m =>
(Op fromlore -> AllocM fromlore tolore (Op tolore))
-> (Exp tolore -> AllocM fromlore tolore [ExpHint])
-> AllocM fromlore tolore a
-> m a
runAllocM Op Kernels -> AllocM Kernels ExplicitMemory (Op ExplicitMemory)
HostOp Kernels (SOAC Kernels)
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
handleHostOp Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
forall (m :: * -> *).
Allocator ExplicitMemory m =>
Exp ExplicitMemory -> m [ExpHint]
kernelExpHints (AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
 -> m (FunDef ExplicitMemory))
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
-> m (FunDef ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stms ExplicitMemory
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms ExplicitMemory
consts (AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
 -> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory))
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
  [(FParam Kernels, Space)]
-> ([FParam ExplicitMemory]
    -> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory))
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
forall fromlore tolore a.
Allocable fromlore tolore =>
[(FParam fromlore, Space)]
-> ([FParam tolore] -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInFParams ([Param DeclType] -> [Space] -> [(Param DeclType, Space)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
[FParam Kernels]
params ([Space] -> [(Param DeclType, Space)])
-> [Space] -> [(Param DeclType, Space)]
forall a b. (a -> b) -> a -> b
$ Space -> [Space]
forall a. a -> [a]
repeat Space
DefaultSpace) (([FParam ExplicitMemory]
  -> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory))
 -> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory))
-> ([FParam ExplicitMemory]
    -> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory))
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ \[FParam ExplicitMemory]
params' -> do
    BodyT ExplicitMemory
fbody' <- AllocM
  Kernels
  ExplicitMemory
  (Body (Lore (AllocM Kernels ExplicitMemory)))
-> AllocM
     Kernels
     ExplicitMemory
     (Body (Lore (AllocM Kernels ExplicitMemory)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (AllocM
   Kernels
   ExplicitMemory
   (Body (Lore (AllocM Kernels ExplicitMemory)))
 -> AllocM
      Kernels
      ExplicitMemory
      (Body (Lore (AllocM Kernels ExplicitMemory))))
-> AllocM
     Kernels
     ExplicitMemory
     (Body (Lore (AllocM Kernels ExplicitMemory)))
-> AllocM
     Kernels
     ExplicitMemory
     (Body (Lore (AllocM Kernels ExplicitMemory)))
forall a b. (a -> b) -> a -> b
$ [Maybe Space]
-> BodyT Kernels
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[Maybe Space]
-> Body fromlore -> AllocM fromlore tolore (Body tolore)
allocInFunBody
              ((DeclExtType -> Maybe Space) -> [DeclExtType] -> [Maybe Space]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe Space -> DeclExtType -> Maybe Space
forall a b. a -> b -> a
const (Maybe Space -> DeclExtType -> Maybe Space)
-> Maybe Space -> DeclExtType -> Maybe Space
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
DefaultSpace) [DeclExtType]
[RetType Kernels]
rettype) BodyT Kernels
fbody
    FunDef ExplicitMemory
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef ExplicitMemory
 -> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory))
-> FunDef ExplicitMemory
-> AllocM Kernels ExplicitMemory (FunDef ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Name
-> [RetType ExplicitMemory]
-> [FParam ExplicitMemory]
-> BodyT ExplicitMemory
-> FunDef ExplicitMemory
forall lore.
Maybe EntryPoint
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Name
fname ([RetType Kernels] -> [RetType ExplicitMemory]
memoryInRetType [RetType Kernels]
rettype) [FParam ExplicitMemory]
params' BodyT ExplicitMemory
fbody'

handleHostOp :: HostOp Kernels (SOAC Kernels)
             -> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
handleHostOp :: HostOp Kernels (SOAC Kernels)
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
handleHostOp (SizeOp SizeOp
op) =
  MemOp (HostOp ExplicitMemory ())
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (MemOp (HostOp ExplicitMemory ())
 -> AllocM
      Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ())))
-> MemOp (HostOp ExplicitMemory ())
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
forall a b. (a -> b) -> a -> b
$ HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp ExplicitMemory ()
forall lore op. SizeOp -> HostOp lore op
SizeOp SizeOp
op
handleHostOp (OtherOp SOAC Kernels
op) =
  String
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
forall a. HasCallStack => String -> a
error (String
 -> AllocM
      Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ())))
-> String
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate memory in SOAC: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SOAC Kernels -> String
forall a. Pretty a => a -> String
pretty SOAC Kernels
op
handleHostOp (SegOp SegOp Kernels
op) =
  HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ())
forall inner. inner -> MemOp inner
Inner (HostOp ExplicitMemory () -> MemOp (HostOp ExplicitMemory ()))
-> (SegOp ExplicitMemory -> HostOp ExplicitMemory ())
-> SegOp ExplicitMemory
-> MemOp (HostOp ExplicitMemory ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp ExplicitMemory -> HostOp ExplicitMemory ()
forall lore op. SegOp lore -> HostOp lore op
SegOp (SegOp ExplicitMemory -> MemOp (HostOp ExplicitMemory ()))
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
-> AllocM Kernels ExplicitMemory (MemOp (HostOp ExplicitMemory ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp Kernels
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
handleSegOp SegOp Kernels
op

handleSegOp :: SegOp Kernels
            -> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
handleSegOp :: SegOp Kernels
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
handleSegOp SegOp Kernels
op = SegLevel
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
forall fromlore tlore a.
SegLevel -> AllocM fromlore tlore a -> AllocM fromlore tlore a
allocAtLevel (SegOp Kernels -> SegLevel
forall lore. SegOp lore -> SegLevel
segLevel SegOp Kernels
op) (AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
 -> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory))
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ SegOpMapper Kernels ExplicitMemory (AllocM Kernels ExplicitMemory)
-> SegOp Kernels
-> AllocM Kernels ExplicitMemory (SegOp ExplicitMemory)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SegOpMapper flore tlore m -> SegOp flore -> m (SegOp tlore)
mapSegOpM SegOpMapper Kernels ExplicitMemory (AllocM Kernels ExplicitMemory)
mapper SegOp Kernels
op
  where scope :: Scope ExplicitMemory
scope = SegSpace -> Scope ExplicitMemory
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope ExplicitMemory)
-> SegSpace -> Scope ExplicitMemory
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> SegSpace
forall lore. SegOp lore -> SegSpace
segSpace SegOp Kernels
op
        mapper :: SegOpMapper Kernels ExplicitMemory (AllocM Kernels ExplicitMemory)
mapper = SegOpMapper Any Any (AllocM Kernels ExplicitMemory)
forall (m :: * -> *) lore. Monad m => SegOpMapper lore lore m
identitySegOpMapper
             { mapOnSegOpBody :: KernelBody Kernels
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
mapOnSegOpBody = Scope ExplicitMemory
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope ExplicitMemory
scope (AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
 -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
-> (KernelBody Kernels
    -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
-> KernelBody Kernels
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel
-> KernelBody Kernels
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
allocInKernelBody (SegOp Kernels -> SegLevel
forall lore. SegOp lore -> SegLevel
segLevel SegOp Kernels
op)
             , mapOnSegOpLambda :: Lambda Kernels
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
mapOnSegOpLambda = SegLevel
-> SegSpace
-> Lambda Kernels
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInBinOpLambda (SegOp Kernels -> SegLevel
forall lore. SegOp lore -> SegLevel
segLevel SegOp Kernels
op) (SegSpace
 -> Lambda Kernels
 -> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory))
-> SegSpace
-> Lambda Kernels
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> SegSpace
forall lore. SegOp lore -> SegSpace
segSpace SegOp Kernels
op
             }

allocAtLevel :: SegLevel -> AllocM fromlore tlore a -> AllocM fromlore tlore a
allocAtLevel :: SegLevel -> AllocM fromlore tlore a -> AllocM fromlore tlore a
allocAtLevel SegLevel
lvl = (AllocEnv fromlore tlore -> AllocEnv fromlore tlore)
-> AllocM fromlore tlore a -> AllocM fromlore tlore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((AllocEnv fromlore tlore -> AllocEnv fromlore tlore)
 -> AllocM fromlore tlore a -> AllocM fromlore tlore a)
-> (AllocEnv fromlore tlore -> AllocEnv fromlore tlore)
-> AllocM fromlore tlore a
-> AllocM fromlore tlore a
forall a b. (a -> b) -> a -> b
$ \AllocEnv fromlore tlore
env -> AllocEnv fromlore tlore
env { allocSpace :: Space
allocSpace = Space
space
                                       , aggressiveReuse :: Bool
aggressiveReuse = Bool
True
                                       }
  where space :: Space
space = case SegLevel
lvl of SegThread{} -> Space
DefaultSpace
                            SegGroup{} -> String -> Space
Space String
"local"

bodyReturnMemCtx :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                    SubExp -> AllocM fromlore tolore [SubExp]
bodyReturnMemCtx :: SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx Constant{} =
  Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
bodyReturnMemCtx (Var VName
v) = do
  MemInfo SubExp NoUniqueness MemBind
info <- VName
-> AllocM fromlore tolore (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, ExplicitMemorish lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
v
  case MemInfo SubExp NoUniqueness MemBind
info of
    MemPrim{} -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return []
    MemMem{} -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return [] -- should not happen
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> SubExp
Var VName
mem]

allocInFunBody :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                  [Maybe Space] -> Body fromlore -> AllocM fromlore tolore (Body tolore)
allocInFunBody :: [Maybe Space]
-> Body fromlore -> AllocM fromlore tolore (Body tolore)
allocInFunBody [Maybe Space]
space_oks (Body BodyAttr fromlore
_ Stms fromlore
bnds Result
res) =
  Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore (Body tolore))
-> AllocM fromlore tolore (Body tolore)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms fromlore
bnds ((Stms tolore -> AllocM fromlore tolore (Body tolore))
 -> AllocM fromlore tolore (Body tolore))
-> (Stms tolore -> AllocM fromlore tolore (Body tolore))
-> AllocM fromlore tolore (Body tolore)
forall a b. (a -> b) -> a -> b
$ \Stms tolore
bnds' -> do
    (Result
res'', Stms tolore
allocs) <- AllocM fromlore tolore Result
-> AllocM
     fromlore tolore (Result, Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (AllocM fromlore tolore Result
 -> AllocM
      fromlore tolore (Result, Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore Result
-> AllocM
     fromlore tolore (Result, Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ do
      Result
res' <- (Maybe Space -> SubExp -> AllocM fromlore tolore SubExp)
-> [Maybe Space] -> Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
ensureDirect [Maybe Space]
space_oks' Result
res
      let (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res'
      Result
mem_ctx_res <- [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result)
-> AllocM fromlore tolore [Result] -> AllocM fromlore tolore Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> AllocM fromlore tolore Result)
-> Result -> AllocM fromlore tolore [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> AllocM fromlore tolore Result
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx Result
val_res
      Result -> AllocM fromlore tolore Result
forall (m :: * -> *) a. Monad m => a -> m a
return (Result -> AllocM fromlore tolore Result)
-> Result -> AllocM fromlore tolore Result
forall a b. (a -> b) -> a -> b
$ Result
ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res
    Body tolore -> AllocM fromlore tolore (Body tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body tolore -> AllocM fromlore tolore (Body tolore))
-> Body tolore -> AllocM fromlore tolore (Body tolore)
forall a b. (a -> b) -> a -> b
$ BodyAttr tolore -> Stms tolore -> Result -> Body tolore
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () (Stms tolore
bnds'Stms tolore -> Stms tolore -> Stms tolore
forall a. Semigroup a => a -> a -> a
<>Stms tolore
allocs) Result
res''
  where num_vals :: Int
num_vals = [Maybe Space] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe Space]
space_oks
        space_oks' :: [Maybe Space]
space_oks' = Int -> Maybe Space -> [Maybe Space]
forall a. Int -> a -> [a]
replicate (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_vals) Maybe Space
forall a. Maybe a
Nothing [Maybe Space] -> [Maybe Space] -> [Maybe Space]
forall a. [a] -> [a] -> [a]
++ [Maybe Space]
space_oks

ensureDirect :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
ensureDirect :: Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
ensureDirect Maybe Space
_ se :: SubExp
se@Constant{} = SubExp -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
ensureDirect Maybe Space
space_ok (Var VName
v) = do
  Bool
bt <- Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool)
-> AllocM fromlore tolore Type -> AllocM fromlore tolore Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromlore tolore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  if Bool
bt
    then SubExp -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> AllocM fromlore tolore SubExp)
-> SubExp -> AllocM fromlore tolore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    else do (VName
_, SubExp
v') <- Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> VName -> AllocM fromlore tolore (VName, SubExp)
ensureDirectArray Maybe Space
space_ok VName
v
            SubExp -> AllocM fromlore tolore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
v'

allocInStms :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
               Stms fromlore -> (Stms tolore -> AllocM fromlore tolore a)
            -> AllocM fromlore tolore a
allocInStms :: Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms fromlore
origbnds Stms tolore -> AllocM fromlore tolore a
m = [Stm fromlore] -> Stms tolore -> AllocM fromlore tolore a
allocInStms' (Stms fromlore -> [Stm fromlore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms fromlore
origbnds) Stms tolore
forall a. Monoid a => a
mempty
  where allocInStms' :: [Stm fromlore] -> Stms tolore -> AllocM fromlore tolore a
allocInStms' [] Stms tolore
bnds' =
          Stms tolore -> AllocM fromlore tolore a
m Stms tolore
bnds'
        allocInStms' (Stm fromlore
x:[Stm fromlore]
xs) Stms tolore
bnds' = do
          Stms tolore
allocbnds <- Stm fromlore -> AllocM fromlore tolore (Stms tolore)
forall tolore fromlore.
(PrettyLore fromlore, SizeSubst (Op tolore), BinderOps tolore,
 Allocator tolore (AllocM fromlore tolore), LetAttr fromlore ~ Type,
 BodyAttr fromlore ~ (),
 BranchType fromlore ~ TypeBase ExtShape NoUniqueness,
 LParamAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 FParamAttr fromlore ~ DeclType,
 LetAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 BodyAttr tolore ~ (), RetType fromlore ~ DeclExtType,
 ExpAttr tolore ~ (),
 FParamAttr tolore ~ MemInfo SubExp Uniqueness MemBind,
 LParamAttr fromlore ~ Type) =>
Stm fromlore -> AllocM fromlore tolore (Stms tolore)
allocInStm' Stm fromlore
x
          let summaries :: Scope tolore
summaries = Stms tolore -> Scope tolore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms tolore
allocbnds
          Scope tolore
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope tolore
summaries (AllocM fromlore tolore a -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$
            (AllocEnv fromlore tolore -> AllocEnv fromlore tolore)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Map VName SubExp
-> AllocEnv fromlore tolore -> AllocEnv fromlore tolore
forall fromlore tolore.
Map VName SubExp
-> AllocEnv fromlore tolore -> AllocEnv fromlore tolore
boundDims (Map VName SubExp
 -> AllocEnv fromlore tolore -> AllocEnv fromlore tolore)
-> Map VName SubExp
-> AllocEnv fromlore tolore
-> AllocEnv fromlore tolore
forall a b. (a -> b) -> a -> b
$ [Map VName SubExp] -> Map VName SubExp
forall a. Monoid a => [a] -> a
mconcat ([Map VName SubExp] -> Map VName SubExp)
-> [Map VName SubExp] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ (Stm tolore -> Map VName SubExp)
-> [Stm tolore] -> [Map VName SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Stm tolore -> Map VName SubExp
forall lore. SizeSubst (Op lore) => Stm lore -> Map VName SubExp
sizeSubst ([Stm tolore] -> [Map VName SubExp])
-> [Stm tolore] -> [Map VName SubExp]
forall a b. (a -> b) -> a -> b
$ Stms tolore -> [Stm tolore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms tolore
allocbnds) (AllocM fromlore tolore a -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a -> AllocM fromlore tolore a
forall a b. (a -> b) -> a -> b
$
            [Stm fromlore] -> Stms tolore -> AllocM fromlore tolore a
allocInStms' [Stm fromlore]
xs (Stms tolore
bnds'Stms tolore -> Stms tolore -> Stms tolore
forall a. Semigroup a => a -> a -> a
<>Stms tolore
allocbnds)
        allocInStm' :: Stm fromlore -> AllocM fromlore tolore (Stms tolore)
allocInStm' Stm fromlore
bnd = do
          ((),Stms tolore
bnds') <- AllocM fromlore tolore ()
-> AllocM
     fromlore tolore ((), Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (AllocM fromlore tolore ()
 -> AllocM
      fromlore tolore ((), Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore ()
-> AllocM
     fromlore tolore ((), Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ Certificates
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Stm fromlore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm fromlore
bnd) (AllocM fromlore tolore () -> AllocM fromlore tolore ())
-> AllocM fromlore tolore () -> AllocM fromlore tolore ()
forall a b. (a -> b) -> a -> b
$ Stm fromlore -> AllocM fromlore tolore ()
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stm fromlore -> AllocM fromlore tolore ()
allocInStm Stm fromlore
bnd
          Stms tolore -> AllocM fromlore tolore (Stms tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms tolore
bnds'

allocInStm :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
              Stm fromlore -> AllocM fromlore tolore ()
allocInStm :: Stm fromlore -> AllocM fromlore tolore ()
allocInStm (Let (Pattern [PatElemT (LetAttr fromlore)]
sizeElems [PatElemT (LetAttr fromlore)]
valElems) StmAux (ExpAttr fromlore)
_ Exp fromlore
e) = do
  Exp tolore
e' <- Exp fromlore -> AllocM fromlore tolore (Exp tolore)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Exp fromlore -> AllocM fromlore tolore (Exp tolore)
allocInExp Exp fromlore
e
  let sizeidents :: [Ident]
sizeidents = (PatElemT Type -> Ident) -> [PatElemT Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> Ident
forall attr. Typed attr => PatElemT attr -> Ident
patElemIdent [PatElemT Type]
[PatElemT (LetAttr fromlore)]
sizeElems
      validents :: [Ident]
validents = (PatElemT Type -> Ident) -> [PatElemT Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> Ident
forall attr. Typed attr => PatElemT attr -> Ident
patElemIdent [PatElemT Type]
[PatElemT (LetAttr fromlore)]
valElems
  (Stm tolore
bnd, [AllocStm]
bnds) <- [Ident]
-> [Ident]
-> Exp tolore
-> AllocM fromlore tolore (Stm tolore, [AllocStm])
forall lore (m :: * -> *).
(Allocator lore m, ExpAttr lore ~ ()) =>
[Ident] -> [Ident] -> Exp lore -> m (Stm lore, [AllocStm])
allocsForStm [Ident]
sizeidents [Ident]
validents Exp tolore
e'
  Stm (Lore (AllocM fromlore tolore)) -> AllocM fromlore tolore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm tolore
Stm (Lore (AllocM fromlore tolore))
bnd
  (AllocStm -> AllocM fromlore tolore ())
-> [AllocStm] -> AllocM fromlore tolore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AllocStm -> AllocM fromlore tolore ()
forall lore (m :: * -> *). Allocator lore m => AllocStm -> m ()
addAllocStm [AllocStm]
bnds

allocInExp :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
              Exp fromlore -> AllocM fromlore tolore (Exp tolore)
allocInExp :: Exp fromlore -> AllocM fromlore tolore (Exp tolore)
allocInExp (DoLoop [(FParam fromlore, SubExp)]
ctx [(FParam fromlore, SubExp)]
val LoopForm fromlore
form (Body () Stms fromlore
bodybnds Result
bodyres)) =
  [VName]
-> [(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (Exp tolore))
-> AllocM fromlore tolore (Exp tolore)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[VName]
-> [(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInMergeParams [VName]
forall a. Monoid a => a
mempty [(FParam fromlore, SubExp)]
ctx (([FParam tolore]
  -> [FParam tolore]
  -> (Result -> AllocM fromlore tolore (Result, Result))
  -> AllocM fromlore tolore (Exp tolore))
 -> AllocM fromlore tolore (Exp tolore))
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (Exp tolore))
-> AllocM fromlore tolore (Exp tolore)
forall a b. (a -> b) -> a -> b
$ \[FParam tolore]
_ [FParam tolore]
ctxparams' Result -> AllocM fromlore tolore (Result, Result)
_ ->
  [VName]
-> [(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (Exp tolore))
-> AllocM fromlore tolore (Exp tolore)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[VName]
-> [(FParam fromlore, SubExp)]
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInMergeParams ((Param (MemInfo SubExp Uniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp Uniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName [FParam tolore]
[Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams') [(FParam fromlore, SubExp)]
val (([FParam tolore]
  -> [FParam tolore]
  -> (Result -> AllocM fromlore tolore (Result, Result))
  -> AllocM fromlore tolore (Exp tolore))
 -> AllocM fromlore tolore (Exp tolore))
-> ([FParam tolore]
    -> [FParam tolore]
    -> (Result -> AllocM fromlore tolore (Result, Result))
    -> AllocM fromlore tolore (Exp tolore))
-> AllocM fromlore tolore (Exp tolore)
forall a b. (a -> b) -> a -> b
$
  \[FParam tolore]
new_ctx_params [FParam tolore]
valparams' Result -> AllocM fromlore tolore (Result, Result)
mk_loop_val -> do
  LoopForm tolore
form' <- LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
allocInLoopForm LoopForm fromlore
form
  Scope tolore
-> AllocM fromlore tolore (Exp tolore)
-> AllocM fromlore tolore (Exp tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm tolore -> Scope tolore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm tolore
form') (AllocM fromlore tolore (Exp tolore)
 -> AllocM fromlore tolore (Exp tolore))
-> AllocM fromlore tolore (Exp tolore)
-> AllocM fromlore tolore (Exp tolore)
forall a b. (a -> b) -> a -> b
$ do
    (Result
valinit_ctx, Result
valinit') <- Result -> AllocM fromlore tolore (Result, Result)
mk_loop_val Result
valinit
    BodyT tolore
body' <- AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
-> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
 -> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
-> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore (BodyT tolore))
-> AllocM fromlore tolore (BodyT tolore)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms fromlore
bodybnds ((Stms tolore -> AllocM fromlore tolore (BodyT tolore))
 -> AllocM fromlore tolore (BodyT tolore))
-> (Stms tolore -> AllocM fromlore tolore (BodyT tolore))
-> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ \Stms tolore
bodybnds' -> do
      ((Result
val_ses,Result
valres'),Stms tolore
val_retbnds) <- AllocM fromlore tolore (Result, Result)
-> AllocM
     fromlore
     tolore
     ((Result, Result), Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (AllocM fromlore tolore (Result, Result)
 -> AllocM
      fromlore
      tolore
      ((Result, Result), Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore (Result, Result)
-> AllocM
     fromlore
     tolore
     ((Result, Result), Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ Result -> AllocM fromlore tolore (Result, Result)
mk_loop_val Result
valres
      BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT tolore -> AllocM fromlore tolore (BodyT tolore))
-> BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BodyAttr tolore -> Stms tolore -> Result -> BodyT tolore
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () (Stms tolore
bodybnds'Stms tolore -> Stms tolore -> Stms tolore
forall a. Semigroup a => a -> a -> a
<>Stms tolore
val_retbnds) (Result
ctxresResult -> Result -> Result
forall a. [a] -> [a] -> [a]
++Result
val_sesResult -> Result -> Result
forall a. [a] -> [a] -> [a]
++Result
valres')
    Exp tolore -> AllocM fromlore tolore (Exp tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp tolore -> AllocM fromlore tolore (Exp tolore))
-> Exp tolore -> AllocM fromlore tolore (Exp tolore)
forall a b. (a -> b) -> a -> b
$
      [(FParam tolore, SubExp)]
-> [(FParam tolore, SubExp)]
-> LoopForm tolore
-> BodyT tolore
-> Exp tolore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop
      ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([FParam tolore]
[Param (MemInfo SubExp Uniqueness MemBind)]
ctxparams'[Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++[FParam tolore]
[Param (MemInfo SubExp Uniqueness MemBind)]
new_ctx_params) (Result
ctxinitResult -> Result -> Result
forall a. [a] -> [a] -> [a]
++Result
valinit_ctx))
      ([Param (MemInfo SubExp Uniqueness MemBind)]
-> Result -> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam tolore]
[Param (MemInfo SubExp Uniqueness MemBind)]
valparams' Result
valinit')
      LoopForm tolore
form' BodyT tolore
body'
  where ([Param DeclType]
_ctxparams, Result
ctxinit) = [(Param DeclType, SubExp)] -> ([Param DeclType], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
ctx
        ([Param DeclType]
_valparams, Result
valinit) = [(Param DeclType, SubExp)] -> ([Param DeclType], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
val
        (Result
ctxres, Result
valres) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param DeclType, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param DeclType, SubExp)]
[(FParam fromlore, SubExp)]
ctx) Result
bodyres
allocInExp (Apply Name
fname [(SubExp, Diet)]
args [RetType fromlore]
rettype (Safety, SrcLoc, [SrcLoc])
loc) = do
  [(SubExp, Diet)]
args' <- [(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[(SubExp, Diet)] -> AllocM fromlore tolore [(SubExp, Diet)]
funcallArgs [(SubExp, Diet)]
args
  Exp tolore -> AllocM fromlore tolore (Exp tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp tolore -> AllocM fromlore tolore (Exp tolore))
-> Exp tolore -> AllocM fromlore tolore (Exp tolore)
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType tolore]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp tolore
forall lore.
Name
-> [(SubExp, Diet)]
-> [RetType lore]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT lore
Apply Name
fname [(SubExp, Diet)]
args' ([RetType Kernels] -> [RetType ExplicitMemory]
memoryInRetType [RetType fromlore]
[RetType Kernels]
rettype) (Safety, SrcLoc, [SrcLoc])
loc
allocInExp (If SubExp
cond BodyT fromlore
tbranch0 BodyT fromlore
fbranch0 (IfAttr [BranchType fromlore]
rets IfSort
ifsort)) = do
  let num_rets :: Int
num_rets = [TypeBase ExtShape NoUniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase ExtShape NoUniqueness]
[BranchType fromlore]
rets
  -- switch to the explicit-mem rep, but do nothing about results
  (BodyT tolore
tbranch, [Maybe MemBind]
tm_ixfs) <- Int
-> BodyT fromlore
-> AllocM fromlore tolore (BodyT tolore, [Maybe MemBind])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Int
-> Body fromlore
-> AllocM fromlore tolore (Body tolore, [Maybe MemBind])
allocInIfBody Int
num_rets BodyT fromlore
tbranch0
  (BodyT tolore
fbranch, [Maybe MemBind]
fm_ixfs) <- Int
-> BodyT fromlore
-> AllocM fromlore tolore (BodyT tolore, [Maybe MemBind])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Int
-> Body fromlore
-> AllocM fromlore tolore (Body tolore, [Maybe MemBind])
allocInIfBody Int
num_rets BodyT fromlore
fbranch0
  [Maybe Space]
tspaces <- Int -> BodyT tolore -> AllocM fromlore tolore [Maybe Space]
forall tolore (m :: * -> *).
(ExplicitMemorish tolore, LocalScope tolore m) =>
Int -> Body tolore -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT tolore
tbranch
  [Maybe Space]
fspaces <- Int -> BodyT tolore -> AllocM fromlore tolore [Maybe Space]
forall tolore (m :: * -> *).
(ExplicitMemorish tolore, LocalScope tolore m) =>
Int -> Body tolore -> m [Maybe Space]
mkSpaceOks Int
num_rets BodyT tolore
fbranch
  -- try to generalize (antiunify) the index functions of the then and else bodies
  let sp_substs :: [(Maybe Space, Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)]))]
sp_substs = ((Maybe Space, Maybe MemBind)
 -> (Maybe Space, Maybe MemBind)
 -> (Maybe Space,
     Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])))
-> [(Maybe Space, Maybe MemBind)]
-> [(Maybe Space, Maybe MemBind)]
-> [(Maybe Space,
     Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)]))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Maybe Space, Maybe MemBind)
-> (Maybe Space, Maybe MemBind)
-> (Maybe Space,
    Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)]))
generalize ([Maybe Space] -> [Maybe MemBind] -> [(Maybe Space, Maybe MemBind)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
tspaces [Maybe MemBind]
tm_ixfs) ([Maybe Space] -> [Maybe MemBind] -> [(Maybe Space, Maybe MemBind)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe Space]
fspaces [Maybe MemBind]
fm_ixfs)
      ([Maybe Space]
spaces, [Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])]
subs) = [(Maybe Space, Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)]))]
-> ([Maybe Space],
    [Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Maybe Space, Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)]))]
sp_substs
      tsubs :: [Maybe (ExtIxFun, [PrimExp VName])]
tsubs = (Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
 -> Maybe (ExtIxFun, [PrimExp VName]))
-> [Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])]
-> [Maybe (ExtIxFun, [PrimExp VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
-> Maybe (ExtIxFun, [PrimExp VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> a
fst) [Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])]
subs
      fsubs :: [Maybe (ExtIxFun, [PrimExp VName])]
fsubs = (Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
 -> Maybe (ExtIxFun, [PrimExp VName]))
-> [Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])]
-> [Maybe (ExtIxFun, [PrimExp VName])]
forall a b. (a -> b) -> [a] -> [b]
map (((PrimExp VName, PrimExp VName) -> PrimExp VName)
-> Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
-> Maybe (ExtIxFun, [PrimExp VName])
forall a.
((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (PrimExp VName, PrimExp VName) -> PrimExp VName
forall a b. (a, b) -> b
snd) [Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])]
subs
  (BodyT tolore
tbranch', [BodyReturns]
trets) <- [TypeBase ExtShape NoUniqueness]
-> BodyT tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [PrimExp VName])]
-> AllocM fromlore tolore (BodyT tolore, [BodyReturns])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[TypeBase ExtShape NoUniqueness]
-> Body tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [PrimExp VName])]
-> AllocM fromlore tolore (Body tolore, [BodyReturns])
addResCtxInIfBody [TypeBase ExtShape NoUniqueness]
[BranchType fromlore]
rets BodyT tolore
tbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [PrimExp VName])]
tsubs
  (BodyT tolore
fbranch', [BodyReturns]
frets) <- [TypeBase ExtShape NoUniqueness]
-> BodyT tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [PrimExp VName])]
-> AllocM fromlore tolore (BodyT tolore, [BodyReturns])
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
[TypeBase ExtShape NoUniqueness]
-> Body tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [PrimExp VName])]
-> AllocM fromlore tolore (Body tolore, [BodyReturns])
addResCtxInIfBody [TypeBase ExtShape NoUniqueness]
[BranchType fromlore]
rets BodyT tolore
fbranch [Maybe Space]
spaces [Maybe (ExtIxFun, [PrimExp VName])]
fsubs
  if [BodyReturns]
frets [BodyReturns] -> [BodyReturns] -> Bool
forall a. Eq a => a -> a -> Bool
/= [BodyReturns]
trets then String -> AllocM fromlore tolore (Exp tolore)
forall a. HasCallStack => String -> a
error String
"In allocInExp, IF case: antiunification of then/else produce different ExtInFn!"
    else do -- above is a sanity check; implementation continues on else branch
    let res_then :: Result
res_then = BodyT tolore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT tolore
tbranch'
        res_else :: Result
res_else = BodyT tolore -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT tolore
fbranch'
        size_ext :: Int
size_ext = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res_then Int -> Int -> Int
forall a. Num a => a -> a -> a
- [BodyReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BodyReturns]
trets
        ([(SubExp, SubExp, Int)]
ind_ses0, [(SubExp, SubExp, Int)]
r_then_else) =
            ((SubExp, SubExp, Int) -> Bool)
-> [(SubExp, SubExp, Int)]
-> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(SubExp
r_then, SubExp
r_else, Int
_) -> SubExp
r_then SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
r_else) ([(SubExp, SubExp, Int)]
 -> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)]))
-> [(SubExp, SubExp, Int)]
-> ([(SubExp, SubExp, Int)], [(SubExp, SubExp, Int)])
forall a b. (a -> b) -> a -> b
$
            Result -> Result -> [Int] -> [(SubExp, SubExp, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res_then Result
res_else [Int
0 .. Int
size_ext Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        (Result
r_then_ext, Result
r_else_ext, [Int]
_) = [(SubExp, SubExp, Int)] -> (Result, Result, [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, Int)]
r_then_else
        ind_ses :: [(Int, SubExp)]
ind_ses = ((SubExp, SubExp, Int) -> Int -> (Int, SubExp))
-> [(SubExp, SubExp, Int)] -> [Int] -> [(Int, SubExp)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(SubExp
se, SubExp
_, Int
i) Int
k -> (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
k, SubExp
se)) [(SubExp, SubExp, Int)]
ind_ses0
                  [Int
0 .. [(SubExp, SubExp, Int)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, SubExp, Int)]
ind_ses0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
        rets'' :: [BodyReturns]
rets'' = ([BodyReturns] -> (Int, SubExp) -> [BodyReturns])
-> [BodyReturns] -> [(Int, SubExp)] -> [BodyReturns]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[BodyReturns]
acc (Int
i, SubExp
se) -> Int -> SubExp -> [BodyReturns] -> [BodyReturns]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se [BodyReturns]
acc) [BodyReturns]
trets [(Int, SubExp)]
ind_ses
        tbranch'' :: BodyT tolore
tbranch'' = BodyT tolore
tbranch' { bodyResult :: Result
bodyResult = Result
r_then_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_then }
        fbranch'' :: BodyT tolore
fbranch'' = BodyT tolore
fbranch' { bodyResult :: Result
bodyResult = Result
r_else_ext Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
size_ext Result
res_else }
        res_if_expr :: Exp tolore
res_if_expr = SubExp
-> BodyT tolore
-> BodyT tolore
-> IfAttr (BranchType tolore)
-> Exp tolore
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond BodyT tolore
tbranch'' BodyT tolore
fbranch'' (IfAttr (BranchType tolore) -> Exp tolore)
-> IfAttr (BranchType tolore) -> Exp tolore
forall a b. (a -> b) -> a -> b
$ [BodyReturns] -> IfSort -> IfAttr BodyReturns
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [BodyReturns]
rets'' IfSort
ifsort
    Exp tolore -> AllocM fromlore tolore (Exp tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return Exp tolore
res_if_expr
      where generalize :: (Maybe Space, Maybe MemBind) -> (Maybe Space, Maybe MemBind)
                       -> (Maybe Space, Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)]))
            generalize :: (Maybe Space, Maybe MemBind)
-> (Maybe Space, Maybe MemBind)
-> (Maybe Space,
    Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)]))
generalize (Just Space
sp1, Just (ArrayIn VName
_ IxFun
ixf1)) (Just Space
sp2, Just (ArrayIn VName
_ IxFun
ixf2)) =
              if Space
sp1 Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= Space
sp2 then (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
forall a. Maybe a
Nothing)
              else case IxFun
-> IxFun -> Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
forall v.
Eq v =>
IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
IxFun.leastGeneralGeneralization IxFun
ixf1 IxFun
ixf2 of
                Just (ExtIxFun
ixf, [(PrimExp VName, PrimExp VName)]
m) -> (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, (ExtIxFun, [(PrimExp VName, PrimExp VName)])
-> Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
forall a. a -> Maybe a
Just (ExtIxFun
ixf, [(PrimExp VName, PrimExp VName)]
m))
                Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
Nothing -> (Space -> Maybe Space
forall a. a -> Maybe a
Just Space
sp1, Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
forall a. Maybe a
Nothing)
            generalize (Maybe Space
mbsp1, Maybe MemBind
_) (Maybe Space, Maybe MemBind)
_ = (Maybe Space
mbsp1, Maybe (ExtIxFun, [(PrimExp VName, PrimExp VName)])
forall a. Maybe a
Nothing)

            selectSub :: ((a, a) -> a) -> Maybe (ExtIxFun, [(a, a)]) ->
                         Maybe (ExtIxFun, [a])
            selectSub :: ((a, a) -> a)
-> Maybe (ExtIxFun, [(a, a)]) -> Maybe (ExtIxFun, [a])
selectSub (a, a) -> a
f (Just (ExtIxFun
ixfn, [(a, a)]
m)) = (ExtIxFun, [a]) -> Maybe (ExtIxFun, [a])
forall a. a -> Maybe a
Just (ExtIxFun
ixfn, ((a, a) -> a) -> [(a, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (a, a) -> a
f [(a, a)]
m)
            selectSub (a, a) -> a
_ Maybe (ExtIxFun, [(a, a)])
Nothing = Maybe (ExtIxFun, [a])
forall a. Maybe a
Nothing

            -- | Just introduces the new representation (index functions); but
            -- does not unify (e.g., does not ensures direct); implementation
            -- extends `allocInBodyNoDirect`, but also return `MemBind`
            allocInIfBody :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                             Int -> Body fromlore -> AllocM fromlore tolore (Body tolore, [Maybe MemBind])
            allocInIfBody :: Int
-> Body fromlore
-> AllocM fromlore tolore (Body tolore, [Maybe MemBind])
allocInIfBody Int
num_vals (Body BodyAttr fromlore
_ Stms fromlore
bnds Result
res) =
              Stms fromlore
-> (Stms tolore
    -> AllocM fromlore tolore (Body tolore, [Maybe MemBind]))
-> AllocM fromlore tolore (Body tolore, [Maybe MemBind])
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms fromlore
bnds ((Stms tolore
  -> AllocM fromlore tolore (Body tolore, [Maybe MemBind]))
 -> AllocM fromlore tolore (Body tolore, [Maybe MemBind]))
-> (Stms tolore
    -> AllocM fromlore tolore (Body tolore, [Maybe MemBind]))
-> AllocM fromlore tolore (Body tolore, [Maybe MemBind])
forall a b. (a -> b) -> a -> b
$ \Stms tolore
bnds' -> do
                let (Result
_, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res
                [Maybe MemBind]
mem_ixfs <- (SubExp -> AllocM fromlore tolore (Maybe MemBind))
-> Result -> AllocM fromlore tolore [Maybe MemBind]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> AllocM fromlore tolore (Maybe MemBind)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Checkable lore, OpReturns lore,
 LetAttr lore ~ MemInfo SubExp NoUniqueness MemBind,
 BranchType lore ~ BodyReturns,
 FParamAttr lore ~ MemInfo SubExp Uniqueness MemBind,
 RetType lore ~ FunReturns,
 LParamAttr lore ~ MemInfo SubExp NoUniqueness MemBind) =>
SubExp -> m (Maybe MemBind)
bodyReturnMIxf Result
val_res
                (Body tolore, [Maybe MemBind])
-> AllocM fromlore tolore (Body tolore, [Maybe MemBind])
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyAttr tolore -> Stms tolore -> Result -> Body tolore
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms tolore
bnds' Result
res, [Maybe MemBind]
mem_ixfs)
                  where
                    bodyReturnMIxf :: SubExp -> m (Maybe MemBind)
bodyReturnMIxf Constant{} = Maybe MemBind -> m (Maybe MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe MemBind
forall a. Maybe a
Nothing
                    bodyReturnMIxf (Var VName
v) = do
                      MemInfo SubExp NoUniqueness MemBind
info <- VName -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, ExplicitMemorish lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
v
                      case MemInfo SubExp NoUniqueness MemBind
info of
                        MemArray PrimType
_ptp Shape
_shp NoUniqueness
_u MemBind
mem_ixf -> Maybe MemBind -> m (Maybe MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe MemBind -> m (Maybe MemBind))
-> Maybe MemBind -> m (Maybe MemBind)
forall a b. (a -> b) -> a -> b
$ MemBind -> Maybe MemBind
forall a. a -> Maybe a
Just MemBind
mem_ixf
                        MemInfo SubExp NoUniqueness MemBind
_ -> Maybe MemBind -> m (Maybe MemBind)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe MemBind
forall a. Maybe a
Nothing
allocInExp Exp fromlore
e = Mapper fromlore tolore (AllocM fromlore tolore)
-> Exp fromlore -> AllocM fromlore tolore (Exp tolore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper fromlore tolore (AllocM fromlore tolore)
alloc Exp fromlore
e
  where alloc :: Mapper fromlore tolore (AllocM fromlore tolore)
alloc =
          Mapper Any Any (AllocM fromlore tolore)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope tolore
-> BodyT fromlore -> AllocM fromlore tolore (BodyT tolore)
mapOnBody = String
-> Scope tolore
-> BodyT fromlore
-> AllocM fromlore tolore (BodyT tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled Body in ExplicitAllocations"
                         , mapOnRetType :: RetType fromlore -> AllocM fromlore tolore (RetType tolore)
mapOnRetType = String
-> RetType fromlore -> AllocM fromlore tolore (RetType tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled RetType in ExplicitAllocations"
                         , mapOnBranchType :: BranchType fromlore -> AllocM fromlore tolore (BranchType tolore)
mapOnBranchType = String
-> BranchType fromlore
-> AllocM fromlore tolore (BranchType tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled BranchType in ExplicitAllocations"
                         , mapOnFParam :: FParam fromlore -> AllocM fromlore tolore (FParam tolore)
mapOnFParam = String -> FParam fromlore -> AllocM fromlore tolore (FParam tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled FParam in ExplicitAllocations"
                         , mapOnLParam :: LParam fromlore -> AllocM fromlore tolore (LParam tolore)
mapOnLParam = String -> LParam fromlore -> AllocM fromlore tolore (LParam tolore)
forall a. HasCallStack => String -> a
error String
"Unhandled LParam in ExplicitAllocations"
                         , mapOnOp :: Op fromlore -> AllocM fromlore tolore (Op tolore)
mapOnOp = \Op fromlore
op -> do Op fromlore -> AllocM fromlore tolore (Op tolore)
handle <- (AllocEnv fromlore tolore
 -> Op fromlore -> AllocM fromlore tolore (Op tolore))
-> AllocM
     fromlore tolore (Op fromlore -> AllocM fromlore tolore (Op tolore))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AllocEnv fromlore tolore
-> Op fromlore -> AllocM fromlore tolore (Op tolore)
forall fromlore tolore.
AllocEnv fromlore tolore
-> Op fromlore -> AllocM fromlore tolore (Op tolore)
allocInOp
                                               Op fromlore -> AllocM fromlore tolore (Op tolore)
handle Op fromlore
op
                         }

addResCtxInIfBody :: (Allocable fromlore tolore, Allocator tolore (AllocM fromlore tolore)) =>
                     [ExtType] -> Body tolore -> [Maybe Space] ->
                     [Maybe (ExtIxFun, [PrimExp VName])] ->
                     AllocM fromlore tolore (Body tolore, [BodyReturns])
addResCtxInIfBody :: [TypeBase ExtShape NoUniqueness]
-> Body tolore
-> [Maybe Space]
-> [Maybe (ExtIxFun, [PrimExp VName])]
-> AllocM fromlore tolore (Body tolore, [BodyReturns])
addResCtxInIfBody [TypeBase ExtShape NoUniqueness]
ifrets (Body BodyAttr tolore
_ Stms tolore
bnds Result
res) [Maybe Space]
spaces [Maybe (ExtIxFun, [PrimExp VName])]
substs = do
  let num_vals :: Int
num_vals = [TypeBase ExtShape NoUniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase ExtShape NoUniqueness]
ifrets
      (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_vals Result
res
  ((Result
res', [BodyReturns]
bodyrets'), Stms tolore
all_body_stms) <- AllocM fromlore tolore (Result, [BodyReturns])
-> AllocM
     fromlore
     tolore
     ((Result, [BodyReturns]), Stms (Lore (AllocM fromlore tolore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (AllocM fromlore tolore (Result, [BodyReturns])
 -> AllocM
      fromlore
      tolore
      ((Result, [BodyReturns]), Stms (Lore (AllocM fromlore tolore))))
-> AllocM fromlore tolore (Result, [BodyReturns])
-> AllocM
     fromlore
     tolore
     ((Result, [BodyReturns]), Stms (Lore (AllocM fromlore tolore)))
forall a b. (a -> b) -> a -> b
$ do
    (Stm tolore -> AllocM fromlore tolore ())
-> Stms tolore -> AllocM fromlore tolore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm tolore -> AllocM fromlore tolore ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stms tolore
bnds
    (Result
val_res', Result
ext_ses_res, Result
mem_ctx_res, [BodyReturns]
bodyrets, Int
total_existentials) <-
      ((Result, Result, Result, [BodyReturns], Int)
 -> (TypeBase ExtShape NoUniqueness, SubExp,
     Maybe (ExtIxFun, [PrimExp VName]), Maybe Space)
 -> AllocM
      fromlore tolore (Result, Result, Result, [BodyReturns], Int))
-> (Result, Result, Result, [BodyReturns], Int)
-> [(TypeBase ExtShape NoUniqueness, SubExp,
     Maybe (ExtIxFun, [PrimExp VName]), Maybe Space)]
-> AllocM
     fromlore tolore (Result, Result, Result, [BodyReturns], Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Result, Result, Result, [BodyReturns], Int)
-> (TypeBase ExtShape NoUniqueness, SubExp,
    Maybe (ExtIxFun, [PrimExp VName]), Maybe Space)
-> AllocM
     fromlore tolore (Result, Result, Result, [BodyReturns], Int)
forall tolore fromlore u.
(PrettyLore fromlore, SizeSubst (Op tolore), BinderOps tolore,
 Allocator tolore (AllocM fromlore tolore), LetAttr fromlore ~ Type,
 BodyAttr fromlore ~ (),
 BranchType fromlore ~ TypeBase ExtShape NoUniqueness,
 LParamAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 FParamAttr fromlore ~ DeclType,
 LetAttr tolore ~ MemInfo SubExp NoUniqueness MemBind,
 BodyAttr tolore ~ (), RetType fromlore ~ DeclExtType,
 ExpAttr tolore ~ (),
 FParamAttr tolore ~ MemInfo SubExp Uniqueness MemBind,
 LParamAttr fromlore ~ Type) =>
(Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExp, Maybe (ExtIxFun, [PrimExp VName]),
    Maybe Space)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
helper ([], [], [], [], Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
ctx_res) ([TypeBase ExtShape NoUniqueness]
-> Result
-> [Maybe (ExtIxFun, [PrimExp VName])]
-> [Maybe Space]
-> [(TypeBase ExtShape NoUniqueness, SubExp,
     Maybe (ExtIxFun, [PrimExp VName]), Maybe Space)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [TypeBase ExtShape NoUniqueness]
ifrets Result
val_res [Maybe (ExtIxFun, [PrimExp VName])]
substs [Maybe Space]
spaces)
    (Result, [BodyReturns])
-> AllocM fromlore tolore (Result, [BodyReturns])
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
ext_ses_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
mem_ctx_res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res',
             -- We need to adjust the ReturnsNewBlock existentials, because they
             -- should always be numbered _after_ all other existentials in the
             -- return values.
            [BodyReturns] -> [BodyReturns]
forall a. [a] -> [a]
reverse ([BodyReturns] -> [BodyReturns]) -> [BodyReturns] -> [BodyReturns]
forall a b. (a -> b) -> a -> b
$ ([BodyReturns], Int) -> [BodyReturns]
forall a b. (a, b) -> a
fst (([BodyReturns], Int) -> [BodyReturns])
-> ([BodyReturns], Int) -> [BodyReturns]
forall a b. (a -> b) -> a -> b
$ (([BodyReturns], Int) -> BodyReturns -> ([BodyReturns], Int))
-> ([BodyReturns], Int) -> [BodyReturns] -> ([BodyReturns], Int)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([BodyReturns], Int) -> BodyReturns -> ([BodyReturns], Int)
adjustNewBlockExistential ([], Int
total_existentials) [BodyReturns]
bodyrets)
  Body tolore
body' <- Stms (Lore (AllocM fromlore tolore))
-> Result
-> AllocM fromlore tolore (Body (Lore (AllocM fromlore tolore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> Result -> m (Body (Lore m))
mkBodyM Stms tolore
Stms (Lore (AllocM fromlore tolore))
all_body_stms Result
res'
  (Body tolore, [BodyReturns])
-> AllocM fromlore tolore (Body tolore, [BodyReturns])
forall (m :: * -> *) a. Monad m => a -> m a
return (Body tolore
body', [BodyReturns]
bodyrets')
    where
      helper :: (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> (TypeBase ExtShape u, SubExp, Maybe (ExtIxFun, [PrimExp VName]),
    Maybe Space)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
helper (Result
res_acc, Result
ext_acc, Result
ctx_acc, [MemInfo (Ext SubExp) u MemReturn]
br_acc, Int
k) (TypeBase ExtShape u
ifr, SubExp
r, Maybe (ExtIxFun, [PrimExp VName])
mbixfsub, Maybe Space
sp) =
        case Maybe (ExtIxFun, [PrimExp VName])
mbixfsub of
          Maybe (ExtIxFun, [PrimExp VName])
Nothing -> do
            -- does NOT generalize/antiunify; ensure direct
            SubExp
r' <- Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Maybe Space -> SubExp -> AllocM fromlore tolore SubExp
ensureDirect Maybe Space
sp SubExp
r
            Result
mem_ctx_r <- SubExp -> AllocM fromlore tolore Result
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx SubExp
r'
            let body_ret :: MemInfo (Ext SubExp) u MemReturn
body_ret = TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
forall u.
TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
inspect TypeBase ExtShape u
ifr Maybe Space
sp
            (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
r'],
                    Result
ext_acc,
                    Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_r,
                    [MemInfo (Ext SubExp) u MemReturn]
br_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
body_ret],
                    Int
k)
          Just (ExtIxFun
ixfn, [PrimExp VName]
m) -> do -- generalizes
            let i :: Int
i = [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
m
            Result
ext_ses <- (PrimExp VName -> AllocM fromlore tolore SubExp)
-> [PrimExp VName] -> AllocM fromlore tolore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> (VName
    -> AllocM fromlore tolore (Exp (Lore (AllocM fromlore tolore))))
-> PrimExp VName
-> AllocM fromlore tolore SubExp
forall (m :: * -> *) v.
MonadBinder m =>
String -> (v -> m (Exp (Lore m))) -> PrimExp v -> m SubExp
primExpToSubExp String
"ixfn_exist"
                             (ExpT tolore -> AllocM fromlore tolore (ExpT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT tolore -> AllocM fromlore tolore (ExpT tolore))
-> (VName -> ExpT tolore)
-> VName
-> AllocM fromlore tolore (ExpT tolore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp tolore -> ExpT tolore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp tolore -> ExpT tolore)
-> (VName -> BasicOp tolore) -> VName -> ExpT tolore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp tolore
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp tolore)
-> (VName -> SubExp) -> VName -> BasicOp tolore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var))
                       [PrimExp VName]
m
            Result
mem_ctx_r <- SubExp -> AllocM fromlore tolore Result
forall fromlore tolore.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
SubExp -> AllocM fromlore tolore Result
bodyReturnMemCtx SubExp
r
            let sp' :: Space
sp' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
sp
                ixfn' :: ExtIxFun
ixfn' = (PrimExp (Ext VName) -> PrimExp (Ext VName))
-> ExtIxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> PrimExp (Ext VName) -> PrimExp (Ext VName)
adjustExtPE Int
k) ExtIxFun
ixfn
                exttp :: MemInfo (Ext SubExp) u MemReturn
exttp = case TypeBase ExtShape u
ifr of
                          Array PrimType
pt ExtShape
shp' u
u ->
                            PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shp' u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
                            Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
sp' Int
0 ExtIxFun
ixfn'
                          TypeBase ExtShape u
_ -> String -> MemInfo (Ext SubExp) u MemReturn
forall a. HasCallStack => String -> a
error String
"Impossible case reached in addResCtxInIfBody"
            (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
-> AllocM
     fromlore
     tolore
     (Result, Result, Result, [MemInfo (Ext SubExp) u MemReturn], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
res_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
r],
                    Result
ext_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
ext_ses,
                    Result
ctx_acc Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
mem_ctx_r,
                    [MemInfo (Ext SubExp) u MemReturn]
br_acc [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
-> [MemInfo (Ext SubExp) u MemReturn]
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn
exttp],
                    Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)

      adjustNewBlockExistential :: ([BodyReturns], Int) -> BodyReturns -> ([BodyReturns], Int)
      adjustNewBlockExistential :: ([BodyReturns], Int) -> BodyReturns -> ([BodyReturns], Int)
adjustNewBlockExistential ([BodyReturns]
acc, Int
k) (MemArray PrimType
pt ExtShape
shp NoUniqueness
u (ReturnsNewBlock Space
space Int
_ ExtIxFun
ixfun)) =
        (PrimType -> ExtShape -> NoUniqueness -> MemReturn -> BodyReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shp NoUniqueness
u (Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
k ExtIxFun
ixfun) BodyReturns -> [BodyReturns] -> [BodyReturns]
forall a. a -> [a] -> [a]
: [BodyReturns]
acc, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      adjustNewBlockExistential ([BodyReturns]
acc, Int
k) BodyReturns
x = (BodyReturns
x BodyReturns -> [BodyReturns] -> [BodyReturns]
forall a. a -> [a] -> [a]
: [BodyReturns]
acc, Int
k)

      inspect :: TypeBase ExtShape u
-> Maybe Space -> MemInfo (Ext SubExp) u MemReturn
inspect (Array PrimType
pt ExtShape
shape u
u) Maybe Space
space =
        let space' :: Space
space' = Space -> Maybe Space -> Space
forall a. a -> Maybe a -> a
fromMaybe Space
DefaultSpace Maybe Space
space
            bodyret :: MemInfo (Ext SubExp) u MemReturn
bodyret = PrimType
-> ExtShape -> u -> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ExtShape
shape u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space' Int
0 (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
              Shape (PrimExp (Ext VName)) -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp (Ext VName)) -> ExtIxFun)
-> Shape (PrimExp (Ext VName)) -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> PrimExp (Ext VName))
-> [Ext SubExp] -> Shape (PrimExp (Ext VName))
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> PrimExp (Ext VName)
convert ([Ext SubExp] -> Shape (PrimExp (Ext VName)))
-> [Ext SubExp] -> Shape (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape
        in MemInfo (Ext SubExp) u MemReturn
bodyret
      inspect (Prim PrimType
pt) Maybe Space
_ = PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
      inspect (Mem Space
space) Maybe Space
_ = Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

      convert :: Ext SubExp -> PrimExp (Ext VName)
convert (Ext Int
i) = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
int32
      convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
v

      adjustExtV :: Int -> Ext VName -> Ext VName
      adjustExtV :: Int -> Ext VName -> Ext VName
adjustExtV Int
_ (Free VName
v) = VName -> Ext VName
forall a. a -> Ext a
Free VName
v
      adjustExtV Int
k (Ext Int
i) = Int -> Ext VName
forall a. Int -> Ext a
Ext (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)

      adjustExtPE :: Int -> PrimExp (Ext VName) -> PrimExp (Ext VName)
      adjustExtPE :: Int -> PrimExp (Ext VName) -> PrimExp (Ext VName)
adjustExtPE Int
k = (Ext VName -> Ext VName)
-> PrimExp (Ext VName) -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Ext VName -> Ext VName
adjustExtV Int
k)

mkSpaceOks :: (ExplicitMemorish tolore, LocalScope tolore m) =>
              Int -> Body tolore -> m [Maybe Space]
mkSpaceOks :: Int -> Body tolore -> m [Maybe Space]
mkSpaceOks Int
num_vals (Body BodyAttr tolore
_ Stms tolore
stms Result
res) =
  Stms tolore -> m [Maybe Space] -> m [Maybe Space]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms tolore
stms (m [Maybe Space] -> m [Maybe Space])
-> m [Maybe Space] -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$
  (SubExp -> m (Maybe Space)) -> Result -> m [Maybe Space]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m (Maybe Space)
forall lore (m :: * -> *).
(HasScope lore m, Monad m, Checkable lore, OpReturns lore,
 BranchType lore ~ BodyReturns,
 LetAttr lore ~ MemInfo SubExp NoUniqueness MemBind,
 LParamAttr lore ~ MemInfo SubExp NoUniqueness MemBind,
 RetType lore ~ FunReturns,
 FParamAttr lore ~ MemInfo SubExp Uniqueness MemBind) =>
SubExp -> m (Maybe Space)
mkSpaceOK (Result -> m [Maybe Space]) -> Result -> m [Maybe Space]
forall a b. (a -> b) -> a -> b
$ Int -> Result -> Result
forall a. Int -> [a] -> [a]
takeLast Int
num_vals Result
res
  where mkSpaceOK :: SubExp -> m (Maybe Space)
mkSpaceOK (Var VName
v) = do
          MemInfo SubExp NoUniqueness MemBind
v_info <- VName -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, ExplicitMemorish lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
v
          case MemInfo SubExp NoUniqueness MemBind
v_info of MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
_) -> do
                           MemInfo SubExp NoUniqueness MemBind
mem_info <- VName -> m (MemInfo SubExp NoUniqueness MemBind)
forall lore (m :: * -> *).
(HasScope lore m, ExplicitMemorish lore) =>
VName -> m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo VName
mem
                           case MemInfo SubExp NoUniqueness MemBind
mem_info of MemMem Space
space -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Space -> m (Maybe Space)) -> Maybe Space -> m (Maybe Space)
forall a b. (a -> b) -> a -> b
$ Space -> Maybe Space
forall a. a -> Maybe a
Just Space
space
                                            MemInfo SubExp NoUniqueness MemBind
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
                         MemInfo SubExp NoUniqueness MemBind
_ -> Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing
        mkSpaceOK SubExp
_ = Maybe Space -> m (Maybe Space)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Space
forall a. Maybe a
Nothing

allocInLoopForm :: (Allocable fromlore tolore,
                    Allocator tolore (AllocM fromlore tolore)) =>
                   LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
allocInLoopForm :: LoopForm fromlore -> AllocM fromlore tolore (LoopForm tolore)
allocInLoopForm (WhileLoop VName
v) = LoopForm tolore -> AllocM fromlore tolore (LoopForm tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopForm tolore -> AllocM fromlore tolore (LoopForm tolore))
-> LoopForm tolore -> AllocM fromlore tolore (LoopForm tolore)
forall a b. (a -> b) -> a -> b
$ VName -> LoopForm tolore
forall lore. VName -> LoopForm lore
WhileLoop VName
v
allocInLoopForm (ForLoop VName
i IntType
it SubExp
n [(LParam fromlore, VName)]
loopvars) =
  VName
-> IntType -> SubExp -> [(LParam tolore, VName)] -> LoopForm tolore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
n ([(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
 -> LoopForm tolore)
-> AllocM
     fromlore
     tolore
     [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> AllocM fromlore tolore (LoopForm tolore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
 -> AllocM
      fromlore
      tolore
      (Param (MemInfo SubExp NoUniqueness MemBind), VName))
-> [(Param Type, VName)]
-> AllocM
     fromlore
     tolore
     [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> AllocM
     fromlore
     tolore
     (Param (MemInfo SubExp NoUniqueness MemBind), VName)
allocInLoopVar [(Param Type, VName)]
[(LParam fromlore, VName)]
loopvars
  where allocInLoopVar :: (Param Type, VName)
-> AllocM
     fromlore
     tolore
     (Param (MemInfo SubExp NoUniqueness MemBind), VName)
allocInLoopVar (Param Type
p,VName
a) = do
          (VName
mem, IxFun
ixfun) <- VName -> AllocM fromlore tolore (VName, IxFun)
forall lore (m :: * -> *).
(ExplicitMemorish lore, HasScope lore m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
a
          case Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param Type
p of
            Array PrimType
bt Shape
shape NoUniqueness
u -> do
              [PrimExp VName]
dims <- (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Result -> [PrimExp VName])
-> (Type -> Result) -> Type -> [PrimExp VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> [PrimExp VName])
-> AllocM fromlore tolore Type
-> AllocM fromlore tolore [PrimExp VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> AllocM fromlore tolore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
a
              let ixfun' :: IxFun
ixfun' = IxFun -> Slice (PrimExp VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun (Slice (PrimExp VName) -> IxFun) -> Slice (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
                           [PrimExp VName] -> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [PrimExp VName]
dims [PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> DimIndex d
DimFix (PrimExp VName -> DimIndex (PrimExp VName))
-> PrimExp VName -> DimIndex (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
i PrimType
int32]
              (Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> AllocM
     fromlore
     tolore
     (Param (MemInfo SubExp NoUniqueness MemBind), VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun' }, VName
a)
            Prim PrimType
bt ->
              (Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> AllocM
     fromlore
     tolore
     (Param (MemInfo SubExp NoUniqueness MemBind), VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt }, VName
a)
            Mem Space
space ->
              (Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> AllocM
     fromlore
     tolore
     (Param (MemInfo SubExp NoUniqueness MemBind), VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
p { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space }, VName
a)

allocInBinOpLambda :: SegLevel -> SegSpace -> Lambda Kernels
                   -> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInBinOpLambda :: SegLevel
-> SegSpace
-> Lambda Kernels
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInBinOpLambda SegLevel
lvl (SegSpace VName
flat [(VName, SubExp)]
_) Lambda Kernels
lam = do
  SubExp
num_threads <- String
-> Exp (Lore (AllocM Kernels ExplicitMemory))
-> AllocM Kernels ExplicitMemory SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore (AllocM Kernels ExplicitMemory))
 -> AllocM Kernels ExplicitMemory SubExp)
-> Exp (Lore (AllocM Kernels ExplicitMemory))
-> AllocM Kernels ExplicitMemory SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp ExplicitMemory -> Exp ExplicitMemory
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp ExplicitMemory -> Exp ExplicitMemory)
-> BasicOp ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp ExplicitMemory
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32)
                 (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl))
  let ([Param Type]
acc_params, [Param Type]
arr_params) =
        Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam
      index_x :: PrimExp VName
index_x = VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
flat PrimType
int32
      index_y :: PrimExp VName
index_y = PrimExp VName
index_x PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
num_threads
  ([Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params', [Param (MemInfo SubExp NoUniqueness MemBind)]
arr_params') <-
    SubExp
-> PrimExp VName
-> PrimExp VName
-> [LParam Kernels]
-> [LParam Kernels]
-> AllocM
     Kernels
     ExplicitMemory
     ([LParam ExplicitMemory], [LParam ExplicitMemory])
allocInBinOpParams SubExp
num_threads PrimExp VName
index_x PrimExp VName
index_y [Param Type]
[LParam Kernels]
acc_params [Param Type]
[LParam Kernels]
arr_params

  (AllocEnv Kernels ExplicitMemory
 -> AllocEnv Kernels ExplicitMemory)
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\AllocEnv Kernels ExplicitMemory
env -> AllocEnv Kernels ExplicitMemory
env { envExpHints :: Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
envExpHints = Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
forall (m :: * -> *).
Allocator ExplicitMemory m =>
Exp ExplicitMemory -> m [ExpHint]
inThreadExpHints }) (AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
 -> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory))
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
    [LParam ExplicitMemory]
-> BodyT Kernels
-> [Type]
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInLambda ([Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params' [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++ [Param (MemInfo SubExp NoUniqueness MemBind)]
arr_params')
    (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam) (Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
lam)

allocInBinOpParams :: SubExp
                   -> PrimExp VName -> PrimExp VName
                   -> [LParam Kernels]
                   -> [LParam Kernels]
                   -> AllocM Kernels ExplicitMemory ([LParam ExplicitMemory], [LParam ExplicitMemory])
allocInBinOpParams :: SubExp
-> PrimExp VName
-> PrimExp VName
-> [LParam Kernels]
-> [LParam Kernels]
-> AllocM
     Kernels
     ExplicitMemory
     ([LParam ExplicitMemory], [LParam ExplicitMemory])
allocInBinOpParams SubExp
num_threads PrimExp VName
my_id PrimExp VName
other_id [LParam Kernels]
xs [LParam Kernels]
ys = [(Param (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind))]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind))]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> AllocM
     Kernels
     ExplicitMemory
     [(Param (MemInfo SubExp NoUniqueness MemBind),
       Param (MemInfo SubExp NoUniqueness MemBind))]
-> AllocM
     Kernels
     ExplicitMemory
     ([Param (MemInfo SubExp NoUniqueness MemBind)],
      [Param (MemInfo SubExp NoUniqueness MemBind)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param Type
 -> Param Type
 -> AllocM
      Kernels
      ExplicitMemory
      (Param (MemInfo SubExp NoUniqueness MemBind),
       Param (MemInfo SubExp NoUniqueness MemBind)))
-> [Param Type]
-> [Param Type]
-> AllocM
     Kernels
     ExplicitMemory
     [(Param (MemInfo SubExp NoUniqueness MemBind),
       Param (MemInfo SubExp NoUniqueness MemBind))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param Type
-> Param Type
-> AllocM
     Kernels
     ExplicitMemory
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
alloc [Param Type]
[LParam Kernels]
xs [Param Type]
[LParam Kernels]
ys
  where alloc :: Param Type
-> Param Type
-> AllocM
     Kernels
     ExplicitMemory
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
alloc Param Type
x Param Type
y =
          case Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param Type
x of
            Array PrimType
bt Shape
shape NoUniqueness
u -> do
              SubExp
twice_num_threads <-
                String
-> Exp (Lore (AllocM Kernels ExplicitMemory))
-> AllocM Kernels ExplicitMemory SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"twice_num_threads" (Exp (Lore (AllocM Kernels ExplicitMemory))
 -> AllocM Kernels ExplicitMemory SubExp)
-> Exp (Lore (AllocM Kernels ExplicitMemory))
-> AllocM Kernels ExplicitMemory SubExp
forall a b. (a -> b) -> a -> b
$
                BasicOp ExplicitMemory -> Exp ExplicitMemory
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp ExplicitMemory -> Exp ExplicitMemory)
-> BasicOp ExplicitMemory -> Exp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp ExplicitMemory
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp (IntType -> BinOp
Mul IntType
Int32) SubExp
num_threads (SubExp -> BasicOp ExplicitMemory)
-> SubExp -> BasicOp ExplicitMemory
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
2
              let t :: Type
t = Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param Type
x Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
twice_num_threads
              VName
mem <- Type -> Space -> AllocM Kernels ExplicitMemory VName
forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
t Space
DefaultSpace
              -- XXX: this iota ixfun is a bit inefficient; leading to uncoalesced access.
              let base_dims :: [PrimExp VName]
base_dims = (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)
                  ixfun_base :: IxFun
ixfun_base = [PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [PrimExp VName]
base_dims
                  ixfun_x :: IxFun
ixfun_x = IxFun -> Slice (PrimExp VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun_base (Slice (PrimExp VName) -> IxFun) -> Slice (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
                            [PrimExp VName] -> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [PrimExp VName]
base_dims [PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> DimIndex d
DimFix PrimExp VName
my_id]
                  ixfun_y :: IxFun
ixfun_y = IxFun -> Slice (PrimExp VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun_base (Slice (PrimExp VName) -> IxFun) -> Slice (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
                            [PrimExp VName] -> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [PrimExp VName]
base_dims [PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> DimIndex d
DimFix PrimExp VName
other_id]
              (Param (MemInfo SubExp NoUniqueness MemBind),
 Param (MemInfo SubExp NoUniqueness MemBind))
-> AllocM
     Kernels
     ExplicitMemory
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
x { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun_x },
                      Param Type
y { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
ixfun_y })
            Prim PrimType
bt ->
              (Param (MemInfo SubExp NoUniqueness MemBind),
 Param (MemInfo SubExp NoUniqueness MemBind))
-> AllocM
     Kernels
     ExplicitMemory
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
x { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt },
                      Param Type
y { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt })
            Mem Space
space ->
              (Param (MemInfo SubExp NoUniqueness MemBind),
 Param (MemInfo SubExp NoUniqueness MemBind))
-> AllocM
     Kernels
     ExplicitMemory
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
x { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space },
                      Param Type
y { paramAttr :: MemInfo SubExp NoUniqueness MemBind
paramAttr = Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space })

allocInLambda :: [LParam ExplicitMemory] -> Body Kernels -> [Type]
              -> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInLambda :: [LParam ExplicitMemory]
-> BodyT Kernels
-> [Type]
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
allocInLambda [LParam ExplicitMemory]
params BodyT Kernels
body [Type]
rettype = do
  BodyT ExplicitMemory
body' <- Scope ExplicitMemory
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> Scope ExplicitMemory
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
 -> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory))
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
forall a b. (a -> b) -> a -> b
$
           Stms Kernels
-> (Stms ExplicitMemory
    -> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory))
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms (BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms BodyT Kernels
body) ((Stms ExplicitMemory
  -> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory))
 -> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory))
-> (Stms ExplicitMemory
    -> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory))
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ \Stms ExplicitMemory
bnds' ->
           BodyT ExplicitMemory
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT ExplicitMemory
 -> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory))
-> BodyT ExplicitMemory
-> AllocM Kernels ExplicitMemory (BodyT ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ BodyAttr ExplicitMemory
-> Stms ExplicitMemory -> Result -> BodyT ExplicitMemory
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms ExplicitMemory
bnds' (Result -> BodyT ExplicitMemory) -> Result -> BodyT ExplicitMemory
forall a b. (a -> b) -> a -> b
$ BodyT Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT Kernels
body
  Lambda ExplicitMemory
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda ExplicitMemory
 -> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory))
-> Lambda ExplicitMemory
-> AllocM Kernels ExplicitMemory (Lambda ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ [LParam ExplicitMemory]
-> BodyT ExplicitMemory -> [Type] -> Lambda ExplicitMemory
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam ExplicitMemory]
params BodyT ExplicitMemory
body' [Type]
rettype

allocInKernelBody :: SegLevel -> KernelBody Kernels
                  -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
allocInKernelBody :: SegLevel
-> KernelBody Kernels
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
allocInKernelBody SegLevel
lvl (KernelBody () Stms Kernels
stms [KernelResult]
res) =
  (AllocEnv Kernels ExplicitMemory
 -> AllocEnv Kernels ExplicitMemory)
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AllocEnv Kernels ExplicitMemory -> AllocEnv Kernels ExplicitMemory
f (AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
 -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ Stms Kernels
-> (Stms ExplicitMemory
    -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall fromlore tolore a.
(Allocable fromlore tolore,
 Allocator tolore (AllocM fromlore tolore)) =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms Kernels
stms ((Stms ExplicitMemory
  -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
 -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
-> (Stms ExplicitMemory
    -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ \Stms ExplicitMemory
stms' -> KernelBody ExplicitMemory
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody ExplicitMemory
 -> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory))
-> KernelBody ExplicitMemory
-> AllocM Kernels ExplicitMemory (KernelBody ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ BodyAttr ExplicitMemory
-> Stms ExplicitMemory
-> [KernelResult]
-> KernelBody ExplicitMemory
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms ExplicitMemory
stms' [KernelResult]
res
  where f :: AllocEnv Kernels ExplicitMemory -> AllocEnv Kernels ExplicitMemory
f = case SegLevel
lvl of SegThread{} -> AllocEnv Kernels ExplicitMemory -> AllocEnv Kernels ExplicitMemory
forall fromlore.
(PrettyLore fromlore, LetAttr fromlore ~ Type,
 RetType fromlore ~ DeclExtType, FParamAttr fromlore ~ DeclType,
 BodyAttr fromlore ~ (),
 BranchType fromlore ~ TypeBase ExtShape NoUniqueness,
 LParamAttr fromlore ~ Type) =>
AllocEnv fromlore ExplicitMemory
-> AllocEnv fromlore ExplicitMemory
inThread
                        SegGroup{} -> AllocEnv Kernels ExplicitMemory -> AllocEnv Kernels ExplicitMemory
inGroup
        inThread :: AllocEnv fromlore ExplicitMemory
-> AllocEnv fromlore ExplicitMemory
inThread AllocEnv fromlore ExplicitMemory
env = AllocEnv fromlore ExplicitMemory
env { envExpHints :: Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint]
envExpHints = Exp ExplicitMemory -> AllocM fromlore ExplicitMemory [ExpHint]
forall (m :: * -> *).
Allocator ExplicitMemory m =>
Exp ExplicitMemory -> m [ExpHint]
inThreadExpHints }
        inGroup :: AllocEnv Kernels ExplicitMemory -> AllocEnv Kernels ExplicitMemory
inGroup AllocEnv Kernels ExplicitMemory
env = AllocEnv Kernels ExplicitMemory
env { envExpHints :: Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
envExpHints = Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
inGroupExpHints }

class SizeSubst op where
  opSizeSubst :: PatternT attr -> op -> ChunkMap

instance SizeSubst (HostOp lore op) where
  opSizeSubst :: PatternT attr -> HostOp lore op -> Map VName SubExp
opSizeSubst (Pattern [PatElemT attr]
_ [PatElemT attr
size]) (SizeOp (SplitSpace SplitOrdering
_ SubExp
_ SubExp
_ SubExp
elems_per_thread)) =
    VName -> SubExp -> Map VName SubExp
forall k a. k -> a -> Map k a
M.singleton (PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
size) SubExp
elems_per_thread
  opSizeSubst PatternT attr
_ HostOp lore op
_ = Map VName SubExp
forall a. Monoid a => a
mempty

instance SizeSubst op => SizeSubst (MemOp op) where
  opSizeSubst :: PatternT attr -> MemOp op -> Map VName SubExp
opSizeSubst PatternT attr
pat (Inner op
op) = PatternT attr -> op -> Map VName SubExp
forall op attr.
SizeSubst op =>
PatternT attr -> op -> Map VName SubExp
opSizeSubst PatternT attr
pat op
op
  opSizeSubst PatternT attr
_ MemOp op
_ = Map VName SubExp
forall a. Monoid a => a
mempty

sizeSubst :: SizeSubst (Op lore) => Stm lore -> ChunkMap
sizeSubst :: Stm lore -> Map VName SubExp
sizeSubst (Let Pattern lore
pat StmAux (ExpAttr lore)
_ (Op Op lore
op)) = Pattern lore -> Op lore -> Map VName SubExp
forall op attr.
SizeSubst op =>
PatternT attr -> op -> Map VName SubExp
opSizeSubst Pattern lore
pat Op lore
op
sizeSubst Stm lore
_ = Map VName SubExp
forall a. Monoid a => a
mempty

mkLetNamesB' :: (Op (Lore m) ~ MemOp inner,
                 MonadBinder m, ExpAttr (Lore m) ~ (),
                 Allocator (Lore m) (PatAllocM (Lore m))) =>
                ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' :: ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ExpAttr (Lore m)
attr [VName]
names Exp (Lore m)
e = do
  Scope (Lore m)
scope <- m (Scope (Lore m))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  PatternT (LetAttr (Lore m))
pat <- Scope (Lore m)
-> [VName] -> Exp (Lore m) -> m (PatternT (LetAttr (Lore m)))
forall (m :: * -> *) lore inner.
(MonadBinder m, ExpAttr lore ~ (), Op (Lore m) ~ MemOp inner,
 Allocator lore (PatAllocM lore)) =>
Scope lore -> [VName] -> Exp lore -> m (Pattern lore)
bindPatternWithAllocations Scope (Lore m)
scope [VName]
names Exp (Lore m)
e
  Stm (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Lore m) -> m (Stm (Lore m)))
-> Stm (Lore m) -> m (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ PatternT (LetAttr (Lore m))
-> StmAux (ExpAttr (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT (LetAttr (Lore m))
pat (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()
ExpAttr (Lore m)
attr) Exp (Lore m)
e

mkLetNamesB'' :: (Op (Lore m) ~ MemOp inner, ExpAttr lore ~ (),
                   HasScope (Engine.Wise lore) m, Allocator lore (PatAllocM lore),
                   MonadBinder m, Engine.CanBeWise (Op lore)) =>
                 [VName] -> Exp (Engine.Wise lore)
              -> m (Stm (Engine.Wise lore))
mkLetNamesB'' :: [VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB'' [VName]
names Exp (Wise lore)
e = do
  Scope lore
scope <- Scope (Wise lore) -> Scope lore
forall lore. Scope (Wise lore) -> Scope lore
Engine.removeScopeWisdom (Scope (Wise lore) -> Scope lore)
-> m (Scope (Wise lore)) -> m (Scope lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope (Wise lore))
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (PatternT (MemInfo SubExp NoUniqueness MemBind)
pat, [AllocStm]
prestms) <- PatAllocM lore (PatternT (MemInfo SubExp NoUniqueness MemBind))
-> Scope lore
-> m (PatternT (MemInfo SubExp NoUniqueness MemBind), [AllocStm])
forall (m :: * -> *) lore a.
MonadFreshNames m =>
PatAllocM lore a -> Scope lore -> m (a, [AllocStm])
runPatAllocM ([VName] -> Exp lore -> PatAllocM lore (Pattern lore)
forall lore (m :: * -> *).
(Allocator lore m, ExpAttr lore ~ ()) =>
[VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names (Exp lore -> PatAllocM lore (Pattern lore))
-> Exp lore -> PatAllocM lore (Pattern lore)
forall a b. (a -> b) -> a -> b
$ Exp (Wise lore) -> Exp lore
forall lore. CanBeWise (Op lore) => Exp (Wise lore) -> Exp lore
Engine.removeExpWisdom Exp (Wise lore)
e) Scope lore
scope
  (AllocStm -> m ()) -> [AllocStm] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AllocStm -> m ()
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm [AllocStm]
prestms
  let pat' :: Pattern (Wise lore)
pat' = Pattern lore -> Exp (Wise lore) -> Pattern (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern lore -> Exp (Wise lore) -> Pattern (Wise lore)
Engine.addWisdomToPattern Pattern lore
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat Exp (Wise lore)
e
      attr :: ExpAttr (Wise lore)
attr = Pattern (Wise lore)
-> ExpAttr lore -> Exp (Wise lore) -> ExpAttr (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpAttr lore -> Exp (Wise lore) -> ExpAttr (Wise lore)
Engine.mkWiseExpAttr Pattern (Wise lore)
pat' () Exp (Wise lore)
e
  Stm (Wise lore) -> m (Stm (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Wise lore) -> m (Stm (Wise lore)))
-> Stm (Wise lore) -> m (Stm (Wise lore))
forall a b. (a -> b) -> a -> b
$ Pattern (Wise lore)
-> StmAux (ExpAttr (Wise lore))
-> Exp (Wise lore)
-> Stm (Wise lore)
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern (Wise lore)
pat' ((ExpWisdom, ()) -> StmAux (ExpWisdom, ())
forall attr. attr -> StmAux attr
defAux (ExpWisdom, ())
ExpAttr (Wise lore)
attr) Exp (Wise lore)
e

instance BinderOps ExplicitMemory where
  mkExpAttrB :: Pattern ExplicitMemory
-> Exp ExplicitMemory -> m (ExpAttr ExplicitMemory)
mkExpAttrB Pattern ExplicitMemory
_ Exp ExplicitMemory
_ = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  mkBodyB :: Stms ExplicitMemory -> Result -> m (BodyT ExplicitMemory)
mkBodyB Stms ExplicitMemory
stms Result
res = BodyT ExplicitMemory -> m (BodyT ExplicitMemory)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT ExplicitMemory -> m (BodyT ExplicitMemory))
-> BodyT ExplicitMemory -> m (BodyT ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ BodyAttr ExplicitMemory
-> Stms ExplicitMemory -> Result -> BodyT ExplicitMemory
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body () Stms ExplicitMemory
stms Result
res
  mkLetNamesB :: [VName] -> Exp ExplicitMemory -> m (Stm ExplicitMemory)
mkLetNamesB = ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *) inner.
(Op (Lore m) ~ MemOp inner, MonadBinder m, ExpAttr (Lore m) ~ (),
 Allocator (Lore m) (PatAllocM (Lore m))) =>
ExpAttr (Lore m) -> [VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesB' ()

instance BinderOps (Engine.Wise ExplicitMemory) where
  mkExpAttrB :: Pattern (Wise ExplicitMemory)
-> Exp (Wise ExplicitMemory) -> m (ExpAttr (Wise ExplicitMemory))
mkExpAttrB Pattern (Wise ExplicitMemory)
pat Exp (Wise ExplicitMemory)
e = (ExpWisdom, ()) -> m (ExpWisdom, ())
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ()) -> m (ExpWisdom, ()))
-> (ExpWisdom, ()) -> m (ExpWisdom, ())
forall a b. (a -> b) -> a -> b
$ Pattern (Wise ExplicitMemory)
-> ExpAttr ExplicitMemory
-> Exp (Wise ExplicitMemory)
-> ExpAttr (Wise ExplicitMemory)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpAttr lore -> Exp (Wise lore) -> ExpAttr (Wise lore)
Engine.mkWiseExpAttr Pattern (Wise ExplicitMemory)
pat () Exp (Wise ExplicitMemory)
e
  mkBodyB :: Stms (Wise ExplicitMemory)
-> Result -> m (Body (Wise ExplicitMemory))
mkBodyB Stms (Wise ExplicitMemory)
stms Result
res = Body (Wise ExplicitMemory) -> m (Body (Wise ExplicitMemory))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise ExplicitMemory) -> m (Body (Wise ExplicitMemory)))
-> Body (Wise ExplicitMemory) -> m (Body (Wise ExplicitMemory))
forall a b. (a -> b) -> a -> b
$ BodyAttr ExplicitMemory
-> Stms (Wise ExplicitMemory)
-> Result
-> Body (Wise ExplicitMemory)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
Engine.mkWiseBody () Stms (Wise ExplicitMemory)
stms Result
res
  mkLetNamesB :: [VName]
-> Exp (Wise ExplicitMemory) -> m (Stm (Wise ExplicitMemory))
mkLetNamesB = [VName]
-> Exp (Wise ExplicitMemory) -> m (Stm (Wise ExplicitMemory))
forall (m :: * -> *) inner lore.
(Op (Lore m) ~ MemOp inner, ExpAttr lore ~ (),
 HasScope (Wise lore) m, Allocator lore (PatAllocM lore),
 MonadBinder m, CanBeWise (Op lore)) =>
[VName] -> Exp (Wise lore) -> m (Stm (Wise lore))
mkLetNamesB''

simplifiable :: (Engine.SimplifiableLore lore,
                 ExpAttr lore ~ (),
                 BodyAttr lore ~ (),
                 Op lore ~ MemOp inner,
                 Allocator lore (PatAllocM lore)) =>
                (inner -> Engine.SimpleM lore (Engine.OpWithWisdom inner, Stms (Engine.Wise lore)))
             -> SimpleOps lore
simplifiable :: (inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore)))
-> SimpleOps lore
simplifiable inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore))
simplifyInnerOp =
  (SymbolTable (Wise lore)
 -> Pattern (Wise lore)
 -> Exp (Wise lore)
 -> SimpleM lore (ExpAttr (Wise lore)))
-> (SymbolTable (Wise lore)
    -> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> (SymbolTable (Wise lore)
    -> [VName]
    -> Exp (Wise lore)
    -> SimpleM lore (Stm (Wise lore), Stms (Wise lore)))
-> Protect (Binder (Wise lore))
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
forall lore.
(SymbolTable (Wise lore)
 -> Pattern (Wise lore)
 -> Exp (Wise lore)
 -> SimpleM lore (ExpAttr (Wise lore)))
-> (SymbolTable (Wise lore)
    -> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> (SymbolTable (Wise lore)
    -> [VName]
    -> Exp (Wise lore)
    -> SimpleM lore (Stm (Wise lore), Stms (Wise lore)))
-> Protect (Binder (Wise lore))
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
SimpleOps SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpAttr (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Attributes lore, CanBeWise (Op lore),
 ExpAttr lore ~ ()) =>
p
-> PatternT (VarWisdom, LetAttr lore)
-> Exp (Wise lore)
-> m (ExpWisdom, ExpAttr lore)
mkExpAttrS' SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Attributes lore, CanBeWise (Op lore),
 BodyAttr lore ~ ()) =>
p -> Stms (Wise lore) -> Result -> m (Body (Wise lore))
mkBodyS' SymbolTable (Wise lore)
-> [VName]
-> Exp (Wise lore)
-> SimpleM lore (Stm (Wise lore), Stms (Wise lore))
forall somelore (m :: * -> *) lore lore inner.
(HasScope somelore m, MonadFreshNames m, BinderOps lore,
 Allocator lore (PatAllocM lore), CanBeWise (Op lore),
 LetAttr somelore ~ LetAttr lore,
 FParamAttr somelore ~ FParamAttr lore,
 LParamAttr somelore ~ LParamAttr lore, Op lore ~ MemOp inner,
 ExpAttr lore ~ ()) =>
SymbolTable (Wise lore)
-> [VName] -> Exp (Wise lore) -> m (Stm (Wise lore), Stms lore)
mkLetNamesS' Protect (Binder (Wise lore))
forall (m :: * -> *) d u ret inner inner.
(MonadBinder m, BranchType (Lore m) ~ MemInfo d u ret,
 Op (Lore m) ~ MemOp inner) =>
SubExp
-> PatternT (LetAttr (Lore m)) -> MemOp inner -> Maybe (m ())
protectOp SimplifyOp lore (Op lore)
MemOp inner
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
simplifyOp
  where mkExpAttrS' :: p
-> PatternT (VarWisdom, LetAttr lore)
-> Exp (Wise lore)
-> m (ExpWisdom, ExpAttr lore)
mkExpAttrS' p
_ PatternT (VarWisdom, LetAttr lore)
pat Exp (Wise lore)
e =
          (ExpWisdom, ExpAttr lore) -> m (ExpWisdom, ExpAttr lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ExpWisdom, ExpAttr lore) -> m (ExpWisdom, ExpAttr lore))
-> (ExpWisdom, ExpAttr lore) -> m (ExpWisdom, ExpAttr lore)
forall a b. (a -> b) -> a -> b
$ Pattern (Wise lore)
-> ExpAttr lore -> Exp (Wise lore) -> ExpAttr (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern (Wise lore)
-> ExpAttr lore -> Exp (Wise lore) -> ExpAttr (Wise lore)
Engine.mkWiseExpAttr PatternT (VarWisdom, LetAttr lore)
Pattern (Wise lore)
pat () Exp (Wise lore)
e

        mkBodyS' :: p -> Stms (Wise lore) -> Result -> m (Body (Wise lore))
mkBodyS' p
_ Stms (Wise lore)
bnds Result
res = Body (Wise lore) -> m (Body (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise lore) -> m (Body (Wise lore)))
-> Body (Wise lore) -> m (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ BodyAttr lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore -> Stms (Wise lore) -> Result -> Body (Wise lore)
mkWiseBody () Stms (Wise lore)
bnds Result
res

        mkLetNamesS' :: SymbolTable (Wise lore)
-> [VName] -> Exp (Wise lore) -> m (Stm (Wise lore), Stms lore)
mkLetNamesS' SymbolTable (Wise lore)
vtable [VName]
names Exp (Wise lore)
e = do
          (PatternT (LetAttr lore)
pat', Stms lore
stms) <- Binder lore (PatternT (LetAttr lore))
-> m (PatternT (LetAttr lore), Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore (PatternT (LetAttr lore))
 -> m (PatternT (LetAttr lore), Stms lore))
-> Binder lore (PatternT (LetAttr lore))
-> m (PatternT (LetAttr lore), Stms lore)
forall a b. (a -> b) -> a -> b
$ Scope lore
-> [VName] -> Exp lore -> Binder lore (PatternT (LetAttr lore))
forall (m :: * -> *) lore inner.
(MonadBinder m, ExpAttr lore ~ (), Op (Lore m) ~ MemOp inner,
 Allocator lore (PatAllocM lore)) =>
Scope lore -> [VName] -> Exp lore -> m (Pattern lore)
bindPatternWithAllocations Scope lore
env [VName]
names (Exp lore -> Binder lore (PatternT (LetAttr lore)))
-> Exp lore -> Binder lore (PatternT (LetAttr lore))
forall a b. (a -> b) -> a -> b
$
                          Exp (Wise lore) -> Exp lore
forall lore. CanBeWise (Op lore) => Exp (Wise lore) -> Exp lore
removeExpWisdom Exp (Wise lore)
e
          (Stm (Wise lore), Stms lore) -> m (Stm (Wise lore), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (PatternT (LetAttr lore)
-> StmAux (ExpAttr lore) -> Exp (Wise lore) -> Stm (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern lore
-> StmAux (ExpAttr lore) -> Exp (Wise lore) -> Stm (Wise lore)
mkWiseLetStm PatternT (LetAttr lore)
pat' (() -> StmAux ()
forall attr. attr -> StmAux attr
defAux ()) Exp (Wise lore)
e, Stms lore
stms)
          where env :: Scope lore
env = Scope (Wise lore) -> Scope lore
forall lore. Scope (Wise lore) -> Scope lore
removeScopeWisdom (Scope (Wise lore) -> Scope lore)
-> Scope (Wise lore) -> Scope lore
forall a b. (a -> b) -> a -> b
$ SymbolTable (Wise lore) -> Scope (Wise lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable (Wise lore)
vtable

        protectOp :: SubExp
-> PatternT (LetAttr (Lore m)) -> MemOp inner -> Maybe (m ())
protectOp SubExp
taken PatternT (LetAttr (Lore m))
pat (Alloc SubExp
size Space
space) = m () -> Maybe (m ())
forall a. a -> Maybe a
Just (m () -> Maybe (m ())) -> m () -> Maybe (m ())
forall a b. (a -> b) -> a -> b
$ do
          BodyT (Lore m)
tbody <- Result -> m (BodyT (Lore m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [SubExp
size]
          BodyT (Lore m)
fbody <- Result -> m (BodyT (Lore m))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
          SubExp
size' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"hoisted_alloc_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                   SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfAttr (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
taken BodyT (Lore m)
tbody BodyT (Lore m)
fbody (IfAttr (BranchType (Lore m)) -> Exp (Lore m))
-> IfAttr (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [MemInfo d u ret] -> IfSort -> IfAttr (MemInfo d u ret)
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64] IfSort
IfFallback
          PatternT (LetAttr (Lore m)) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ PatternT (LetAttr (Lore m))
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
        protectOp SubExp
_ PatternT (LetAttr (Lore m))
_ MemOp inner
_ = Maybe (m ())
forall a. Maybe a
Nothing

        simplifyOp :: MemOp inner
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
simplifyOp (Alloc SubExp
size Space
space) =
          (,) (MemOp (OpWithWisdom inner)
 -> Stms (Wise lore)
 -> (MemOp (OpWithWisdom inner), Stms (Wise lore)))
-> SimpleM lore (MemOp (OpWithWisdom inner))
-> SimpleM
     lore
     (Stms (Wise lore)
      -> (MemOp (OpWithWisdom inner), Stms (Wise lore)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Space -> MemOp (OpWithWisdom inner)
forall inner. SubExp -> Space -> MemOp inner
Alloc (SubExp -> Space -> MemOp (OpWithWisdom inner))
-> SimpleM lore SubExp
-> SimpleM lore (Space -> MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
size SimpleM lore (Space -> MemOp (OpWithWisdom inner))
-> SimpleM lore Space -> SimpleM lore (MemOp (OpWithWisdom inner))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> SimpleM lore Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space) SimpleM
  lore
  (Stms (Wise lore)
   -> (MemOp (OpWithWisdom inner), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore))
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise lore) -> SimpleM lore (Stms (Wise lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise lore)
forall a. Monoid a => a
mempty
        simplifyOp (Inner inner
k) = do (OpWithWisdom inner
k', Stms (Wise lore)
hoisted) <- inner -> SimpleM lore (OpWithWisdom inner, Stms (Wise lore))
simplifyInnerOp inner
k
                                  (MemOp (OpWithWisdom inner), Stms (Wise lore))
-> SimpleM lore (MemOp (OpWithWisdom inner), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithWisdom inner -> MemOp (OpWithWisdom inner)
forall inner. inner -> MemOp inner
Inner OpWithWisdom inner
k', Stms (Wise lore)
hoisted)

bindPatternWithAllocations :: (MonadBinder m,
                               ExpAttr lore ~ (),
                               Op (Lore m) ~ MemOp inner,
                               Allocator lore (PatAllocM lore)) =>
                              Scope lore -> [VName] -> Exp lore
                           -> m (Pattern lore)
bindPatternWithAllocations :: Scope lore -> [VName] -> Exp lore -> m (Pattern lore)
bindPatternWithAllocations Scope lore
types [VName]
names Exp lore
e = do
  (PatternT (MemInfo SubExp NoUniqueness MemBind)
pat,[AllocStm]
prebnds) <- PatAllocM lore (PatternT (MemInfo SubExp NoUniqueness MemBind))
-> Scope lore
-> m (PatternT (MemInfo SubExp NoUniqueness MemBind), [AllocStm])
forall (m :: * -> *) lore a.
MonadFreshNames m =>
PatAllocM lore a -> Scope lore -> m (a, [AllocStm])
runPatAllocM ([VName] -> Exp lore -> PatAllocM lore (Pattern lore)
forall lore (m :: * -> *).
(Allocator lore m, ExpAttr lore ~ ()) =>
[VName] -> Exp lore -> m (Pattern lore)
patternWithAllocations [VName]
names Exp lore
e) Scope lore
types
  (AllocStm -> m ()) -> [AllocStm] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AllocStm -> m ()
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ MemOp inner) =>
AllocStm -> m ()
bindAllocStm [AllocStm]
prebnds
  PatternT (MemInfo SubExp NoUniqueness MemBind)
-> m (PatternT (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return PatternT (MemInfo SubExp NoUniqueness MemBind)
pat

data ExpHint = NoHint
             | Hint IxFun Space

kernelExpHints :: Allocator ExplicitMemory m => Exp ExplicitMemory -> m [ExpHint]
kernelExpHints :: Exp ExplicitMemory -> m [ExpHint]
kernelExpHints (BasicOp (Manifest [Int]
perm VName
v)) = do
  Result
dims <- Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> Result) -> m Type -> m Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
  let perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
      dims' :: Result
dims' = [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm Result
dims
      ixfun :: IxFun
ixfun = IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute ([PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([PrimExp VName] -> IxFun) -> [PrimExp VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) Result
dims')
              [Int]
perm_inv
  [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return [IxFun -> Space -> ExpHint
Hint IxFun
ixfun Space
DefaultSpace]

kernelExpHints (Op (Inner (SegOp (SegMap lvl@SegThread{} space ts body)))) =
  (Type -> KernelResult -> m ExpHint)
-> [Type] -> [KernelResult] -> m [ExpHint]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
forall lore (m :: * -> *).
Allocator lore m =>
SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint SegLevel
lvl SegSpace
space) [Type]
ts ([KernelResult] -> m [ExpHint]) -> [KernelResult] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body

kernelExpHints (Op (Inner (SegOp (SegRed lvl@SegThread{} space reds ts body)))) =
  ((KernelResult -> ExpHint) -> [KernelResult] -> [ExpHint]
forall a b. (a -> b) -> [a] -> [b]
map (ExpHint -> KernelResult -> ExpHint
forall a b. a -> b -> a
const ExpHint
NoHint) [KernelResult]
red_res [ExpHint] -> [ExpHint] -> [ExpHint]
forall a. Semigroup a => a -> a -> a
<>) ([ExpHint] -> [ExpHint]) -> m [ExpHint] -> m [ExpHint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Type -> KernelResult -> m ExpHint)
-> [Type] -> [KernelResult] -> m [ExpHint]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
forall lore (m :: * -> *).
Allocator lore m =>
SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint SegLevel
lvl SegSpace
space) (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_reds [Type]
ts) [KernelResult]
map_res
  where num_reds :: Int
num_reds = [SegRedOp ExplicitMemory] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp ExplicitMemory]
reds
        ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_reds ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body

kernelExpHints Exp ExplicitMemory
e =
  [ExpHint] -> m [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> m [ExpHint]) -> [ExpHint] -> m [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp ExplicitMemory -> Int
forall lore.
(Annotations lore, TypedOp (Op lore)) =>
Exp lore -> Int
expExtTypeSize Exp ExplicitMemory
e) ExpHint
NoHint

mapResultHint :: Allocator lore m =>
                 SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint :: SegLevel -> SegSpace -> Type -> KernelResult -> m ExpHint
mapResultHint SegLevel
lvl SegSpace
space = Type -> KernelResult -> m ExpHint
hint
  where num_threads :: PrimExp VName
num_threads = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
                      PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

        -- Heuristic: do not rearrange for returned arrays that are
        -- sufficiently small.
        coalesceReturnOfShape :: Int32 -> Result -> Bool
coalesceReturnOfShape Int32
_ [] = Bool
False
        coalesceReturnOfShape Int32
bs [Constant (IntValue (Int32Value Int32
d))] = Int32
bs Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
d Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
4
        coalesceReturnOfShape Int32
_ Result
_ = Bool
True

        hint :: Type -> KernelResult -> m ExpHint
hint Type
t Returns{}
          | Int32 -> Result -> Bool
coalesceReturnOfShape (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t = do
              let space_dims :: Result
space_dims = SegSpace -> Result
segSpaceDims SegSpace
space
              Result
t_dims <- (SubExp -> m SubExp) -> Result -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize (Result -> m Result) -> Result -> m Result
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t
              ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun -> Space -> ExpHint
Hint (Result -> Result -> IxFun
innermost Result
space_dims Result
t_dims) Space
DefaultSpace

        hint Type
t (ConcatReturns SplitStrided{} SubExp
w SubExp
_ VName
_) = do
          Result
t_dims <- (SubExp -> m SubExp) -> Result -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
forall lore (m :: * -> *). Allocator lore m => SubExp -> m SubExp
dimAllocationSize (Result -> m Result) -> Result -> m Result
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t
          ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun -> Space -> ExpHint
Hint (Result -> Result -> IxFun
innermost [SubExp
w] Result
t_dims) Space
DefaultSpace

        hint Prim{} (ConcatReturns SplitOrdering
SplitContiguous SubExp
w SubExp
elems_per_thread VName
_) = do
          let ixfun_base :: IxFun
ixfun_base = [PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [PrimExp VName
num_threads, PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
elems_per_thread]
              ixfun_tr :: IxFun
ixfun_tr = IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
ixfun_base [Int
1,Int
0]
              ixfun :: IxFun
ixfun = IxFun -> ShapeChange (PrimExp VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun
ixfun_tr (ShapeChange (PrimExp VName) -> IxFun)
-> ShapeChange (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimChange (PrimExp VName))
-> Result -> ShapeChange (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> DimChange (PrimExp VName)
forall d. d -> DimChange d
DimNew (PrimExp VName -> DimChange (PrimExp VName))
-> (SubExp -> PrimExp VName) -> SubExp -> DimChange (PrimExp VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) [SubExp
w]
          ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun -> Space -> ExpHint
Hint IxFun
ixfun Space
DefaultSpace

        hint Type
_ KernelResult
_ = ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return ExpHint
NoHint

innermost :: [SubExp] -> [SubExp] -> IxFun
innermost :: Result -> Result -> IxFun
innermost Result
space_dims Result
t_dims =
  let r :: Int
r = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
t_dims
      dims :: Result
dims = Result
space_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
t_dims
      perm :: [Int]
perm = [Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
space_dims..Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
space_dimsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++
             [Int
0..Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
space_dimsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      perm_inv :: [Int]
perm_inv = [Int] -> [Int]
rearrangeInverse [Int]
perm
      dims_perm :: Result
dims_perm = [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm Result
dims
      ixfun_base :: IxFun
ixfun_base = [PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([PrimExp VName] -> IxFun) -> [PrimExp VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) Result
dims_perm
      ixfun_rearranged :: IxFun
ixfun_rearranged = IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
ixfun_base [Int]
perm_inv
  in IxFun
ixfun_rearranged

inGroupExpHints :: Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
inGroupExpHints :: Exp ExplicitMemory -> AllocM Kernels ExplicitMemory [ExpHint]
inGroupExpHints (Op (Inner (SegOp (SegMap _ space ts body))))
  | (KernelResult -> Bool) -> [KernelResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelResult -> Bool
private ([KernelResult] -> Bool) -> [KernelResult] -> Bool
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body = [ExpHint] -> AllocM Kernels ExplicitMemory [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> AllocM Kernels ExplicitMemory [ExpHint])
-> [ExpHint] -> AllocM Kernels ExplicitMemory [ExpHint]
forall a b. (a -> b) -> a -> b
$ do
      (Type
t, KernelResult
r) <- [Type] -> [KernelResult] -> [(Type, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
ts ([KernelResult] -> [(Type, KernelResult)])
-> [KernelResult] -> [(Type, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body
      ExpHint -> [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> [ExpHint]) -> ExpHint -> [ExpHint]
forall a b. (a -> b) -> a -> b
$
        if KernelResult -> Bool
private KernelResult
r
        then let dims :: [PrimExp VName]
dims = (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Result -> [PrimExp VName]) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
                            SegSpace -> Result
segSpaceDims SegSpace
space Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t
             in IxFun -> Space -> ExpHint
Hint ([PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [PrimExp VName]
dims) (Space -> ExpHint) -> Space -> ExpHint
forall a b. (a -> b) -> a -> b
$ Result -> PrimType -> Space
ScalarSpace (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t) (PrimType -> Space) -> PrimType -> Space
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
        else ExpHint
NoHint
  where private :: KernelResult -> Bool
private (Returns ResultManifest
ResultPrivate SubExp
_) = Bool
True
        private KernelResult
_                         = Bool
False
inGroupExpHints Exp ExplicitMemory
e = [ExpHint] -> AllocM Kernels ExplicitMemory [ExpHint]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpHint] -> AllocM Kernels ExplicitMemory [ExpHint])
-> [ExpHint] -> AllocM Kernels ExplicitMemory [ExpHint]
forall a b. (a -> b) -> a -> b
$ Int -> ExpHint -> [ExpHint]
forall a. Int -> a -> [a]
replicate (Exp ExplicitMemory -> Int
forall lore.
(Annotations lore, TypedOp (Op lore)) =>
Exp lore -> Int
expExtTypeSize Exp ExplicitMemory
e) ExpHint
NoHint

inThreadExpHints :: Allocator ExplicitMemory m => Exp ExplicitMemory -> m [ExpHint]
inThreadExpHints :: Exp ExplicitMemory -> m [ExpHint]
inThreadExpHints Exp ExplicitMemory
e =
  (TypeBase ExtShape NoUniqueness -> m ExpHint)
-> [TypeBase ExtShape NoUniqueness] -> m [ExpHint]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase ExtShape NoUniqueness -> m ExpHint
forall (m :: * -> *).
Monad m =>
TypeBase ExtShape NoUniqueness -> m ExpHint
maybePrivate ([TypeBase ExtShape NoUniqueness] -> m [ExpHint])
-> m [TypeBase ExtShape NoUniqueness] -> m [ExpHint]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp ExplicitMemory -> m [TypeBase ExtShape NoUniqueness]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [TypeBase ExtShape NoUniqueness]
expExtType Exp ExplicitMemory
e
  where maybePrivate :: TypeBase ExtShape NoUniqueness -> m ExpHint
maybePrivate TypeBase ExtShape NoUniqueness
t
          | Just (Array PrimType
pt Shape
shape NoUniqueness
_) <- TypeBase ExtShape NoUniqueness -> Maybe Type
hasStaticShape TypeBase ExtShape NoUniqueness
t,
            (SubExp -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
semiStatic (Result -> Bool) -> Result -> Bool
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape = do
              let ixfun :: IxFun
ixfun = [PrimExp VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([PrimExp VName] -> IxFun) -> [PrimExp VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Result -> [PrimExp VName]) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
                          Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
              ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpHint -> m ExpHint) -> ExpHint -> m ExpHint
forall a b. (a -> b) -> a -> b
$ IxFun -> Space -> ExpHint
Hint IxFun
ixfun (Space -> ExpHint) -> Space -> ExpHint
forall a b. (a -> b) -> a -> b
$ Result -> PrimType -> Space
ScalarSpace (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) PrimType
pt
          | Bool
otherwise =
              ExpHint -> m ExpHint
forall (m :: * -> *) a. Monad m => a -> m a
return ExpHint
NoHint

        semiStatic :: SubExp -> Bool
semiStatic Constant{} = Bool
True
        semiStatic SubExp
_ = Bool
False