{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Building blocks for defining representations where every array
-- is given information about which memory block is it based in, and
-- how array elements map to memory block offsets.
--
-- There are two primary concepts you will need to understand:
--
--  1. Memory blocks, which are Futhark values of type v'Mem'
--     (parametrized with their size).  These correspond to arbitrary
--     blocks of memory, and are created using the 'Alloc' operation.
--
--  2. Index functions, which describe a mapping from the index space
--     of an array (eg. a two-dimensional space for an array of type
--     @[[int]]@) to a one-dimensional offset into a memory block.
--     Thus, index functions describe how arbitrary-dimensional arrays
--     are mapped to the single-dimensional world of memory.
--
-- At a conceptual level, imagine that we have a two-dimensional array
-- @a@ of 32-bit integers, consisting of @n@ rows of @m@ elements
-- each.  This array could be represented in classic row-major format
-- with an index function like the following:
--
-- @
--   f(i,j) = i * m + j
-- @
--
-- When we want to know the location of element @a[2,3]@, we simply
-- call the index function as @f(2,3)@ and obtain @2*m+3@.  We could
-- also have chosen another index function, one that represents the
-- array in column-major (or "transposed") format:
--
-- @
--   f(i,j) = j * n + i
-- @
--
-- Index functions are not Futhark-level functions, but a special
-- construct that the final code generator will eventually use to
-- generate concrete access code.  By modifying the index functions we
-- can change how an array is represented in memory, which can permit
-- memory access pattern optimisations.
--
-- Every time we bind an array, whether in a @let@-binding, @loop@
-- merge parameter, or @lambda@ parameter, we have an annotation
-- specifying a memory block and an index function.  In some cases,
-- such as @let@-bindings for many expressions, we are free to specify
-- an arbitrary index function and memory block - for example, we get
-- to decide where 'Copy' stores its result - but in other cases the
-- type rules of the expression chooses for us.  For example, 'Index'
-- always produces an array in the same memory block as its input, and
-- with the same index function, except with some indices fixed.
module Futhark.IR.Mem
  ( LetDecMem,
    FParamMem,
    LParamMem,
    RetTypeMem,
    BranchTypeMem,
    MemOp (..),
    traverseMemOpStms,
    MemInfo (..),
    MemBound,
    MemBind (..),
    MemReturn (..),
    IxFun,
    ExtIxFun,
    isStaticIxFun,
    ExpReturns,
    BodyReturns,
    FunReturns,
    noUniquenessReturns,
    bodyReturnsToExpReturns,
    Mem,
    HasLetDecMem (..),
    OpReturns (..),
    varReturns,
    expReturns,
    extReturns,
    lookupMemInfo,
    subExpMemInfo,
    lookupArraySummary,
    lookupMemSpace,
    existentialiseIxFun,

    -- * Type checking parts
    matchBranchReturnType,
    matchPatToExp,
    matchFunctionReturnType,
    matchLoopResultMem,
    bodyReturnsFromPat,
    checkMemInfo,

    -- * Module re-exports
    module Futhark.IR.Prop,
    module Futhark.IR.Traversals,
    module Futhark.IR.Pretty,
    module Futhark.IR.Syntax,
    module Futhark.Analysis.PrimExp.Convert,
  )
where

import Control.Category
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (traverse_)
import Data.Function ((&))
import Data.List (elemIndex, find)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Text qualified as T
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.PrimExp.Simplify
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.IR.Aliases
  ( Aliases,
    removeExpAliases,
    removePatAliases,
    removeScopeAliases,
  )
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.Prop.Aliases
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.Pretty (docText, indent, ppTuple', pretty, (<+>), (</>))
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

type LetDecMem = MemInfo SubExp NoUniqueness MemBind

type FParamMem = MemInfo SubExp Uniqueness MemBind

type LParamMem = MemInfo SubExp NoUniqueness MemBind

type RetTypeMem = FunReturns

type BranchTypeMem = BodyReturns

-- | The class of pattern element decorators that contain memory
-- information.
class HasLetDecMem t where
  letDecMem :: t -> LetDecMem

instance HasLetDecMem LetDecMem where
  letDecMem :: LetDecMem -> LetDecMem
letDecMem = forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id

instance HasLetDecMem b => HasLetDecMem (a, b) where
  letDecMem :: (a, b) -> LetDecMem
letDecMem = forall t. HasLetDecMem t => t -> LetDecMem
letDecMem forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd

type Mem rep inner =
  ( FParamInfo rep ~ FParamMem,
    LParamInfo rep ~ LParamMem,
    HasLetDecMem (LetDec rep),
    RetType rep ~ RetTypeMem,
    BranchType rep ~ BranchTypeMem,
    ASTRep rep,
    OpReturns inner,
    Op rep ~ MemOp inner
  )

instance IsRetType FunReturns where
  primRetType :: PrimType -> FunReturns
primRetType = forall d u ret. PrimType -> MemInfo d u ret
MemPrim
  applyRetType :: forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyRetType = forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns

instance IsBodyType BodyReturns where
  primBodyType :: PrimType -> BodyReturns
primBodyType = forall d u ret. PrimType -> MemInfo d u ret
MemPrim

data MemOp inner
  = -- | Allocate a memory block.
    Alloc SubExp Space
  | Inner inner
  deriving (MemOp inner -> MemOp inner -> Bool
forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemOp inner -> MemOp inner -> Bool
$c/= :: forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
== :: MemOp inner -> MemOp inner -> Bool
$c== :: forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
Eq, MemOp inner -> MemOp inner -> Bool
MemOp inner -> MemOp inner -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {inner}. Ord inner => Eq (MemOp inner)
forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
forall inner. Ord inner => MemOp inner -> MemOp inner -> Ordering
forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
min :: MemOp inner -> MemOp inner -> MemOp inner
$cmin :: forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
max :: MemOp inner -> MemOp inner -> MemOp inner
$cmax :: forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
>= :: MemOp inner -> MemOp inner -> Bool
$c>= :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
> :: MemOp inner -> MemOp inner -> Bool
$c> :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
<= :: MemOp inner -> MemOp inner -> Bool
$c<= :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
< :: MemOp inner -> MemOp inner -> Bool
$c< :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
compare :: MemOp inner -> MemOp inner -> Ordering
$ccompare :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Ordering
Ord, Int -> MemOp inner -> ShowS
forall inner. Show inner => Int -> MemOp inner -> ShowS
forall inner. Show inner => [MemOp inner] -> ShowS
forall inner. Show inner => MemOp inner -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemOp inner] -> ShowS
$cshowList :: forall inner. Show inner => [MemOp inner] -> ShowS
show :: MemOp inner -> String
$cshow :: forall inner. Show inner => MemOp inner -> String
showsPrec :: Int -> MemOp inner -> ShowS
$cshowsPrec :: forall inner. Show inner => Int -> MemOp inner -> ShowS
Show)

-- | A helper for defining 'TraverseOpStms'.
traverseMemOpStms ::
  Monad m =>
  OpStmsTraverser m inner rep ->
  OpStmsTraverser m (MemOp inner) rep
traverseMemOpStms :: forall {k} (m :: * -> *) inner (rep :: k).
Monad m =>
OpStmsTraverser m inner rep -> OpStmsTraverser m (MemOp inner) rep
traverseMemOpStms OpStmsTraverser m inner rep
_ Scope rep -> Stms rep -> m (Stms rep)
_ op :: MemOp inner
op@Alloc {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp inner
op
traverseMemOpStms OpStmsTraverser m inner rep
onInner Scope rep -> Stms rep -> m (Stms rep)
f (Inner inner
inner) = forall inner. inner -> MemOp inner
Inner forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpStmsTraverser m inner rep
onInner Scope rep -> Stms rep -> m (Stms rep)
f inner
inner

instance FreeIn inner => FreeIn (MemOp inner) where
  freeIn' :: MemOp inner -> FV
freeIn' (Alloc SubExp
size Space
_) = forall a. FreeIn a => a -> FV
freeIn' SubExp
size
  freeIn' (Inner inner
k) = forall a. FreeIn a => a -> FV
freeIn' inner
k

instance TypedOp inner => TypedOp (MemOp inner) where
  opType :: forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
MemOp inner -> m [ExtType]
opType (Alloc SubExp
_ Space
space) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall shape u. Space -> TypeBase shape u
Mem Space
space]
  opType (Inner inner
k) = forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType inner
k

instance AliasedOp inner => AliasedOp (MemOp inner) where
  opAliases :: MemOp inner -> [Names]
opAliases Alloc {} = [forall a. Monoid a => a
mempty]
  opAliases (Inner inner
k) = forall op. AliasedOp op => op -> [Names]
opAliases inner
k

  consumedInOp :: MemOp inner -> Names
consumedInOp Alloc {} = forall a. Monoid a => a
mempty
  consumedInOp (Inner inner
k) = forall op. AliasedOp op => op -> Names
consumedInOp inner
k

instance CanBeAliased inner => CanBeAliased (MemOp inner) where
  type OpWithAliases (MemOp inner) = MemOp (OpWithAliases inner)
  removeOpAliases :: OpWithAliases (MemOp inner) -> MemOp inner
removeOpAliases (Alloc SubExp
se Space
space) = forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
se Space
space
  removeOpAliases (Inner OpWithAliases inner
k) = forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases inner
k

  addOpAliases :: AliasTable -> MemOp inner -> OpWithAliases (MemOp inner)
addOpAliases AliasTable
_ (Alloc SubExp
se Space
space) = forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
se Space
space
  addOpAliases AliasTable
aliases (Inner inner
k) = forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases inner
k

instance Rename inner => Rename (MemOp inner) where
  rename :: MemOp inner -> RenameM (MemOp inner)
rename (Alloc SubExp
size Space
space) = forall inner. SubExp -> Space -> MemOp inner
Alloc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename SubExp
size forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space
  rename (Inner inner
k) = forall inner. inner -> MemOp inner
Inner forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Rename a => a -> RenameM a
rename inner
k

instance Substitute inner => Substitute (MemOp inner) where
  substituteNames :: Map VName VName -> MemOp inner -> MemOp inner
substituteNames Map VName VName
subst (Alloc SubExp
size Space
space) = forall inner. SubExp -> Space -> MemOp inner
Alloc (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
size) Space
space
  substituteNames Map VName VName
subst (Inner inner
k) = forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst inner
k

instance PP.Pretty inner => PP.Pretty (MemOp inner) where
  pretty :: forall ann. MemOp inner -> Doc ann
pretty (Alloc SubExp
e Space
DefaultSpace) = Doc ann
"alloc" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
PP.pretty SubExp
e]
  pretty (Alloc SubExp
e Space
s) = Doc ann
"alloc" forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
PP.pretty SubExp
e, forall a ann. Pretty a => a -> Doc ann
PP.pretty Space
s]
  pretty (Inner inner
k) = forall a ann. Pretty a => a -> Doc ann
PP.pretty inner
k

instance OpMetrics inner => OpMetrics (MemOp inner) where
  opMetrics :: MemOp inner -> MetricsM ()
opMetrics Alloc {} = Text -> MetricsM ()
seen Text
"Alloc"
  opMetrics (Inner inner
k) = forall op. OpMetrics op => op -> MetricsM ()
opMetrics inner
k

instance IsOp inner => IsOp (MemOp inner) where
  safeOp :: MemOp inner -> Bool
safeOp (Alloc (Constant (IntValue (Int64Value Int64
k))) Space
_) = Int64
k forall a. Ord a => a -> a -> Bool
>= Int64
0
  safeOp Alloc {} = Bool
False
  safeOp (Inner inner
k) = forall op. IsOp op => op -> Bool
safeOp inner
k
  cheapOp :: MemOp inner -> Bool
cheapOp (Inner inner
k) = forall op. IsOp op => op -> Bool
cheapOp inner
k
  cheapOp Alloc {} = Bool
True

instance CanBeWise inner => CanBeWise (MemOp inner) where
  type OpWithWisdom (MemOp inner) = MemOp (OpWithWisdom inner)
  removeOpWisdom :: OpWithWisdom (MemOp inner) -> MemOp inner
removeOpWisdom (Alloc SubExp
size Space
space) = forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
  removeOpWisdom (Inner OpWithWisdom inner
k) = forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom inner
k
  addOpWisdom :: MemOp inner -> OpWithWisdom (MemOp inner)
addOpWisdom (Alloc SubExp
size Space
space) = forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
  addOpWisdom (Inner inner
k) = forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall op. CanBeWise op => op -> OpWithWisdom op
addOpWisdom inner
k

instance ST.IndexOp inner => ST.IndexOp (MemOp inner) where
  indexOp :: forall {k} (rep :: k).
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> MemOp inner -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (Inner inner
op) [TPrimExp Int64 VName]
is = forall op {k} (rep :: k).
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k inner
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
_ Int
_ MemOp inner
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing

-- | The index function representation used for memory annotations.
type IxFun = IxFun.IxFun (TPrimExp Int64 VName)

-- | An index function that may contain existential variables.
type ExtIxFun = IxFun.IxFun (TPrimExp Int64 (Ext VName))

-- | A summary of the memory information for every let-bound
-- identifier, function parameter, and return value.  Parameterisered
-- over uniqueness, dimension, and auxiliary array information.
data MemInfo d u ret
  = -- | A primitive value.
    MemPrim PrimType
  | -- | A memory block.
    MemMem Space
  | -- | The array is stored in the named memory block, and with the
    -- given index function.  The index function maps indices in the
    -- array to /element/ offset, /not/ byte offsets!  To translate to
    -- byte offsets, multiply the offset with the size of the array
    -- element type.
    MemArray PrimType (ShapeBase d) u ret
  | -- | An accumulator, which is not stored anywhere.
    MemAcc VName Shape [Type] u
  deriving (MemInfo d u ret -> MemInfo d u ret -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
/= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c/= :: forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
== :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c== :: forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
Eq, Int -> MemInfo d u ret -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall d u ret.
(Show d, Show u, Show ret) =>
Int -> MemInfo d u ret -> ShowS
forall d u ret.
(Show d, Show u, Show ret) =>
[MemInfo d u ret] -> ShowS
forall d u ret.
(Show d, Show u, Show ret) =>
MemInfo d u ret -> String
showList :: [MemInfo d u ret] -> ShowS
$cshowList :: forall d u ret.
(Show d, Show u, Show ret) =>
[MemInfo d u ret] -> ShowS
show :: MemInfo d u ret -> String
$cshow :: forall d u ret.
(Show d, Show u, Show ret) =>
MemInfo d u ret -> String
showsPrec :: Int -> MemInfo d u ret -> ShowS
$cshowsPrec :: forall d u ret.
(Show d, Show u, Show ret) =>
Int -> MemInfo d u ret -> ShowS
Show, MemInfo d u ret -> MemInfo d u ret -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {d} {u} {ret}.
(Ord d, Ord u, Ord ret) =>
Eq (MemInfo d u ret)
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
min :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$cmin :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
max :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$cmax :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
>= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c>= :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
> :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c> :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
<= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c<= :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
< :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c< :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
compare :: MemInfo d u ret -> MemInfo d u ret -> Ordering
$ccompare :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
Ord) --- XXX Ord?

type MemBound u = MemInfo SubExp u MemBind

instance FixExt ret => DeclExtTyped (MemInfo ExtSize Uniqueness ret) where
  declExtTypeOf :: MemInfo (Ext SubExp) Uniqueness ret -> DeclExtType
declExtTypeOf (MemPrim PrimType
pt) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  declExtTypeOf (MemMem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space
  declExtTypeOf (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u ret
_) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u
  declExtTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u

instance FixExt ret => ExtTyped (MemInfo ExtSize NoUniqueness ret) where
  extTypeOf :: MemInfo (Ext SubExp) NoUniqueness ret -> ExtType
extTypeOf (MemPrim PrimType
pt) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  extTypeOf (MemMem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space
  extTypeOf (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u ret
_) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u
  extTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

instance FixExt ret => FixExt (MemInfo ExtSize u ret) where
  fixExt :: Int
-> SubExp
-> MemInfo (Ext SubExp) u ret
-> MemInfo (Ext SubExp) u ret
fixExt Int
_ SubExp
_ (MemPrim PrimType
pt) = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
  fixExt Int
_ SubExp
_ (MemMem Space
space) = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  fixExt Int
i SubExp
se (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape u
u ret
ret) =
    forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ShapeBase (Ext SubExp)
shape) u
u (forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ret
ret)
  fixExt Int
_ SubExp
_ (MemAcc VName
acc Shape
ispace [Type]
ts u
u) = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

instance Typed (MemInfo SubExp Uniqueness ret) where
  typeOf :: MemInfo SubExp Uniqueness ret -> Type
typeOf = forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall t. DeclTyped t => t -> DeclType
declTypeOf

instance Typed (MemInfo SubExp NoUniqueness ret) where
  typeOf :: MemInfo SubExp NoUniqueness ret -> Type
typeOf (MemPrim PrimType
pt) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  typeOf (MemMem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space
  typeOf (MemArray PrimType
bt Shape
shape NoUniqueness
u ret
_) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape NoUniqueness
u
  typeOf (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

instance DeclTyped (MemInfo SubExp Uniqueness ret) where
  declTypeOf :: MemInfo SubExp Uniqueness ret -> DeclType
declTypeOf (MemPrim PrimType
bt) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
  declTypeOf (MemMem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space
  declTypeOf (MemArray PrimType
bt Shape
shape Uniqueness
u ret
_) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape Uniqueness
u
  declTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u

instance (FreeIn d, FreeIn ret) => FreeIn (MemInfo d u ret) where
  freeIn' :: MemInfo d u ret -> FV
freeIn' (MemArray PrimType
_ ShapeBase d
shape u
_ ret
ret) = forall a. FreeIn a => a -> FV
freeIn' ShapeBase d
shape forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' ret
ret
  freeIn' (MemMem Space
s) = forall a. FreeIn a => a -> FV
freeIn' Space
s
  freeIn' MemPrim {} = forall a. Monoid a => a
mempty
  freeIn' (MemAcc VName
acc Shape
ispace [Type]
ts u
_) = forall a. FreeIn a => a -> FV
freeIn' (VName
acc, Shape
ispace, [Type]
ts)

instance (Substitute d, Substitute ret) => Substitute (MemInfo d u ret) where
  substituteNames :: Map VName VName -> MemInfo d u ret -> MemInfo d u ret
substituteNames Map VName VName
subst (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray
      PrimType
bt
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase d
shape)
      u
u
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ret
ret)
  substituteNames Map VName VName
substs (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
acc)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ispace)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Type]
ts)
      u
u
  substituteNames Map VName VName
_ (MemMem Space
space) =
    forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  substituteNames Map VName VName
_ (MemPrim PrimType
bt) =
    forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt

instance (Substitute d, Substitute ret) => Rename (MemInfo d u ret) where
  rename :: MemInfo d u ret -> RenameM (MemInfo d u ret)
rename = forall a. Substitute a => a -> RenameM a
substituteRename

simplifyIxFun ::
  Engine.SimplifiableRep rep =>
  IxFun ->
  Engine.SimpleM rep IxFun
simplifyIxFun :: forall {k} (rep :: k).
SimplifiableRep rep =>
IxFun -> SimpleM rep IxFun
simplifyIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
SimplifiableRep rep =>
PrimExp VName -> SimpleM rep (PrimExp VName)
simplifyPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

simplifyExtIxFun ::
  Engine.SimplifiableRep rep =>
  ExtIxFun ->
  Engine.SimpleM rep ExtIxFun
simplifyExtIxFun :: forall {k} (rep :: k).
SimplifiableRep rep =>
ExtIxFun -> SimpleM rep ExtIxFun
simplifyExtIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
SimplifiableRep rep =>
PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
simplifyExtPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {a}. Ext a -> Maybe a
inst
  where
    inst :: Ext a -> Maybe a
inst Ext {} = forall a. Maybe a
Nothing
    inst (Free a
x) = forall a. a -> Maybe a
Just a
x

instance
  (Engine.Simplifiable d, Engine.Simplifiable ret) =>
  Engine.Simplifiable (MemInfo d u ret)
  where
  simplify :: forall {k} (rep :: k).
SimplifiableRep rep =>
MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
simplify (MemPrim PrimType
bt) =
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
  simplify (MemMem Space
space) =
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  simplify (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase d
shape forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ret
ret
  simplify (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
acc forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
ispace forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [Type]
ts forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

instance
  ( PP.Pretty (ShapeBase d),
    PP.Pretty (TypeBase (ShapeBase d) u),
    PP.Pretty d,
    PP.Pretty u,
    PP.Pretty ret
  ) =>
  PP.Pretty (MemInfo d u ret)
  where
  pretty :: forall ann. MemInfo d u ret -> Doc ann
pretty (MemPrim PrimType
bt) = forall a ann. Pretty a => a -> Doc ann
PP.pretty PrimType
bt
  pretty (MemMem Space
DefaultSpace) = Doc ann
"mem"
  pretty (MemMem Space
s) = Doc ann
"mem" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
PP.pretty Space
s
  pretty (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    forall a ann. Pretty a => a -> Doc ann
PP.pretty (forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt ShapeBase d
shape u
u) forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"@" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
PP.pretty ret
ret
  pretty (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    forall a ann. Pretty a => a -> Doc ann
PP.pretty u
u forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
PP.pretty (forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness :: Type)

-- | Memory information for an array bound somewhere in the program.
data MemBind
  = -- | Located in this memory block with this index
    -- function.
    ArrayIn VName IxFun
  deriving (Int -> MemBind -> ShowS
[MemBind] -> ShowS
MemBind -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemBind] -> ShowS
$cshowList :: [MemBind] -> ShowS
show :: MemBind -> String
$cshow :: MemBind -> String
showsPrec :: Int -> MemBind -> ShowS
$cshowsPrec :: Int -> MemBind -> ShowS
Show)

instance Eq MemBind where
  MemBind
_ == :: MemBind -> MemBind -> Bool
== MemBind
_ = Bool
True

instance Ord MemBind where
  MemBind
_ compare :: MemBind -> MemBind -> Ordering
`compare` MemBind
_ = Ordering
EQ

instance Rename MemBind where
  rename :: MemBind -> RenameM MemBind
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute MemBind where
  substituteNames :: Map VName VName -> MemBind -> MemBind
substituteNames Map VName VName
substs (ArrayIn VName
ident IxFun
ixfun) =
    VName -> IxFun -> MemBind
ArrayIn (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs IxFun
ixfun)

instance PP.Pretty MemBind where
  pretty :: forall ann. MemBind -> Doc ann
pretty (ArrayIn VName
mem IxFun
ixfun) =
    forall a ann. Pretty a => a -> Doc ann
PP.pretty VName
mem forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"->" forall ann. Doc ann -> Doc ann -> Doc ann
PP.</> forall a ann. Pretty a => a -> Doc ann
PP.pretty IxFun
ixfun

instance FreeIn MemBind where
  freeIn' :: MemBind -> FV
freeIn' (ArrayIn VName
mem IxFun
ixfun) = forall a. FreeIn a => a -> FV
freeIn' VName
mem forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' IxFun
ixfun

-- | A description of the memory properties of an array being returned
-- by an operation.
data MemReturn
  = -- | The array is located in a memory block that is
    -- already in scope.
    ReturnsInBlock VName ExtIxFun
  | -- | The operation returns a new (existential) memory
    -- block.
    ReturnsNewBlock Space Int ExtIxFun
  deriving (Int -> MemReturn -> ShowS
[MemReturn] -> ShowS
MemReturn -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemReturn] -> ShowS
$cshowList :: [MemReturn] -> ShowS
show :: MemReturn -> String
$cshow :: MemReturn -> String
showsPrec :: Int -> MemReturn -> ShowS
$cshowsPrec :: Int -> MemReturn -> ShowS
Show)

instance Eq MemReturn where
  MemReturn
_ == :: MemReturn -> MemReturn -> Bool
== MemReturn
_ = Bool
True

instance Ord MemReturn where
  MemReturn
_ compare :: MemReturn -> MemReturn -> Ordering
`compare` MemReturn
_ = Ordering
EQ

instance Rename MemReturn where
  rename :: MemReturn -> RenameM MemReturn
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute MemReturn where
  substituteNames :: Map VName VName -> MemReturn -> MemReturn
substituteNames Map VName VName
substs (ReturnsInBlock VName
ident ExtIxFun
ixfun) =
    VName -> ExtIxFun -> MemReturn
ReturnsInBlock (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtIxFun
ixfun)
  substituteNames Map VName VName
substs (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtIxFun
ixfun)

instance FixExt MemReturn where
  fixExt :: Int -> SubExp -> MemReturn -> MemReturn
fixExt Int
i (Var VName
v) (ReturnsNewBlock Space
_ Int
j ExtIxFun
ixfun)
    | Int
j forall a. Eq a => a -> a -> Bool
== Int
i =
        VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
v forall a b. (a -> b) -> a -> b
$
          Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun
            Int
i
            (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 (VName -> SubExp
Var VName
v))
            ExtIxFun
ixfun
  fixExt Int
i SubExp
se (ReturnsNewBlock Space
space Int
j ExtIxFun
ixfun) =
    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock
      Space
space
      Int
j'
      (Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtIxFun
ixfun)
    where
      j' :: Int
j'
        | Int
i forall a. Ord a => a -> a -> Bool
< Int
j = Int
j forall a. Num a => a -> a -> a
- Int
1
        | Bool
otherwise = Int
j
  fixExt Int
i SubExp
se (ReturnsInBlock VName
mem ExtIxFun
ixfun) =
    VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtIxFun
ixfun)

fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i PrimExp VName
e = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp Ext VName -> PrimType -> PrimExp (Ext VName)
update forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
  where
    update :: Ext VName -> PrimType -> PrimExp (Ext VName)
update (Ext Int
j) PrimType
t
      | Int
j forall a. Ord a => a -> a -> Bool
> Int
i = forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext forall a b. (a -> b) -> a -> b
$ Int
j forall a. Num a => a -> a -> a
- Int
1) PrimType
t
      | Int
j forall a. Eq a => a -> a -> Bool
== Int
i = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free PrimExp VName
e
      | Bool
otherwise = forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
j) PrimType
t
    update (Free VName
x) PrimType
t = forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. a -> Ext a
Free VName
x) PrimType
t

leafExp :: Int -> TPrimExp Int64 (Ext a)
leafExp :: forall a. Int -> TPrimExp Int64 (Ext a)
leafExp Int
i = forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
i) PrimType
int64

existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [VName]
ctx = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free)
  where
    ctx' :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a. Int -> TPrimExp Int64 (Ext a)
leafExp forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Ext a
Free [VName]
ctx) [Int
0 ..]

instance PP.Pretty MemReturn where
  pretty :: forall ann. MemReturn -> Doc ann
pretty (ReturnsInBlock VName
v ExtIxFun
ixfun) =
    forall ann. Doc ann -> Doc ann
PP.parens forall a b. (a -> b) -> a -> b
$ forall a ann. Pretty a => a -> Doc ann
pretty VName
v forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"->" forall ann. Doc ann -> Doc ann -> Doc ann
PP.</> forall a ann. Pretty a => a -> Doc ann
PP.pretty ExtIxFun
ixfun
  pretty (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
    Doc ann
"?" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Int
i forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
PP.pretty Space
space forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"->" forall ann. Doc ann -> Doc ann -> Doc ann
PP.</> forall a ann. Pretty a => a -> Doc ann
PP.pretty ExtIxFun
ixfun

instance FreeIn MemReturn where
  freeIn' :: MemReturn -> FV
freeIn' (ReturnsInBlock VName
v ExtIxFun
ixfun) = forall a. FreeIn a => a -> FV
freeIn' VName
v forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' ExtIxFun
ixfun
  freeIn' (ReturnsNewBlock Space
space Int
_ ExtIxFun
ixfun) = forall a. FreeIn a => a -> FV
freeIn' Space
space forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' ExtIxFun
ixfun

instance Engine.Simplifiable MemReturn where
  simplify :: forall {k} (rep :: k).
SimplifiableRep rep =>
MemReturn -> SimpleM rep MemReturn
simplify (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
SimplifiableRep rep =>
ExtIxFun -> SimpleM rep ExtIxFun
simplifyExtIxFun ExtIxFun
ixfun
  simplify (ReturnsInBlock VName
v ExtIxFun
ixfun) =
    VName -> ExtIxFun -> MemReturn
ReturnsInBlock forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
v forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} (rep :: k).
SimplifiableRep rep =>
ExtIxFun -> SimpleM rep ExtIxFun
simplifyExtIxFun ExtIxFun
ixfun

instance Engine.Simplifiable MemBind where
  simplify :: forall {k} (rep :: k).
SimplifiableRep rep =>
MemBind -> SimpleM rep MemBind
simplify (ArrayIn VName
mem IxFun
ixfun) =
    VName -> IxFun -> MemBind
ArrayIn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
mem forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} (rep :: k).
SimplifiableRep rep =>
IxFun -> SimpleM rep IxFun
simplifyIxFun IxFun
ixfun

instance Engine.Simplifiable [FunReturns] where
  simplify :: forall {k} (rep :: k).
SimplifiableRep rep =>
[FunReturns] -> SimpleM rep [FunReturns]
simplify = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall e {k} (rep :: k).
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify

-- | The memory return of an expression.  An array is annotated with
-- @Maybe MemReturn@, which can be interpreted as the expression
-- either dictating exactly where the array is located when it is
-- returned (if 'Just'), or able to put it whereever the binding
-- prefers (if 'Nothing').
--
-- This is necessary to capture the difference between an expression
-- that is just an array-typed variable, in which the array being
-- "returned" is located where it already is, and a @copy@ expression,
-- whose entire purpose is to store an existing array in some
-- arbitrary location.  This is a consequence of the design decision
-- never to have implicit memory copies.
type ExpReturns = MemInfo ExtSize NoUniqueness (Maybe MemReturn)

-- | The return of a body, which must always indicate where
-- returned arrays are located.
type BodyReturns = MemInfo ExtSize NoUniqueness MemReturn

-- | The memory return of a function, which must always indicate where
-- returned arrays are located.
type FunReturns = MemInfo ExtSize Uniqueness MemReturn

maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns :: forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns (MemArray PrimType
bt ShapeBase d
shape u
u r
ret) =
  forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape u
u forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just r
ret
maybeReturns (MemPrim PrimType
bt) =
  forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
maybeReturns (MemMem Space
space) =
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
maybeReturns (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
  forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns :: forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemArray PrimType
bt ShapeBase d
shape u
_ r
r) =
  forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape NoUniqueness
NoUniqueness r
r
noUniquenessReturns (MemPrim PrimType
bt) =
  forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
noUniquenessReturns (MemMem Space
space) =
  forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
noUniquenessReturns (MemAcc VName
acc Shape
ispace [Type]
ts u
_) =
  forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns = forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns

bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns = forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns

varInfoToExpReturns :: MemInfo SubExp NoUniqueness MemBind -> ExpReturns
varInfoToExpReturns :: LetDecMem -> ExpReturns
varInfoToExpReturns (MemArray PrimType
et Shape
shape NoUniqueness
u (ArrayIn VName
mem IxFun
ixfun)) =
  forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free Shape
shape) NoUniqueness
u forall a b. (a -> b) -> a -> b
$
    forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
      VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem forall a b. (a -> b) -> a -> b
$
        [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
varInfoToExpReturns (MemPrim PrimType
pt) = forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
varInfoToExpReturns (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
varInfoToExpReturns (MemMem Space
space) = forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

matchRetTypeToResult ::
  (Mem rep inner, TC.Checkable rep) =>
  [FunReturns] ->
  Result ->
  TC.TypeM rep ()
matchRetTypeToResult :: forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype Result
result = do
  Scope (Aliases rep)
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  [LetDecMem]
result_ts <- forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
SubExp -> m LetDecMem
subExpMemInfo forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SubExpRes -> SubExp
resSubExp) Result
result) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scope (Aliases rep) -> Scope rep
removeScopeAliases Scope (Aliases rep)
scope
  forall {k} u (rep :: k).
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> [SubExp] -> [LetDecMem] -> TypeM rep ()
matchReturnType [FunReturns]
rettype (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
result) [LetDecMem]
result_ts

matchFunctionReturnType ::
  (Mem rep inner, TC.Checkable rep) =>
  [FunReturns] ->
  Result ->
  TC.TypeM rep ()
matchFunctionReturnType :: forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchFunctionReturnType [FunReturns]
rettype Result
result = do
  forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype Result
result
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} {rep :: k} {inner}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LetDecMem,
 RetType rep ~ FunReturns, BranchType rep ~ BodyReturns,
 Op rep ~ MemOp inner, HasLetDecMem (LetDec rep), ASTRep rep,
 OpReturns inner) =>
SubExp -> TypeM rep ()
checkResultSubExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SubExpRes -> SubExp
resSubExp) Result
result
  where
    checkResultSubExp :: SubExp -> TypeM rep ()
checkResultSubExp Constant {} =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    checkResultSubExp (Var VName
v) = do
      LetDecMem
dec <- forall {k} (rep :: k) inner.
Mem rep inner =>
VName -> TypeM rep LetDecMem
varMemInfo VName
v
      case LetDecMem
dec of
        MemPrim PrimType
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        MemMem {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        MemAcc {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
_ IxFun
ixfun)
          | forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun
ixfun ->
              forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          | Bool
otherwise ->
              forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
                Text
"Array "
                  forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
v
                  forall a. Semigroup a => a -> a -> a
<> Text
" returned by function, but has nontrivial index function "
                  forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText IxFun
ixfun

matchLoopResultMem ::
  (Mem rep inner, TC.Checkable rep) =>
  [FParam (Aliases rep)] ->
  Result ->
  TC.TypeM rep ()
matchLoopResultMem :: forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem [FParam (Aliases rep)]
params = forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype
  where
    param_names :: [VName]
param_names = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [FParam (Aliases rep)]
params

    -- Invent a ReturnType so we can pretend that the loop body is
    -- actually returning from a function.
    rettype :: [FunReturns]
rettype = forall a b. (a -> b) -> [a] -> [b]
map (FParamMem -> FunReturns
toRet forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall dec. Param dec -> dec
paramDec) [FParam (Aliases rep)]
params

    toExtV :: VName -> Ext VName
toExtV VName
v
      | Just Int
i <- VName
v forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
param_names = forall a. Int -> Ext a
Ext Int
i
      | Bool
otherwise = forall a. a -> Ext a
Free VName
v

    toExtSE :: SubExp -> Ext SubExp
toExtSE (Var VName
v) = VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Ext VName
toExtV VName
v
    toExtSE (Constant PrimValue
v) = forall a. a -> Ext a
Free forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v

    toRet :: FParamMem -> FunReturns
toRet (MemPrim PrimType
t) =
      forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    toRet (MemMem Space
space) =
      forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    toRet (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) =
      forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u
    toRet (MemArray PrimType
pt Shape
shape Uniqueness
u (ArrayIn VName
mem IxFun
ixfun))
      | Just Int
i <- VName
mem forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
param_names,
        Param Attrs
_ VName
_ (MemMem Space
space) : [Param FParamMem]
_ <- forall a. Int -> [a] -> [a]
drop Int
i [FParam (Aliases rep)]
params =
          forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun'
      | Bool
otherwise =
          forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem ExtIxFun
ixfun'
      where
        shape' :: ShapeBase (Ext SubExp)
shape' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
toExtSE Shape
shape
        ixfun' :: ExtIxFun
ixfun' = [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [VName]
param_names IxFun
ixfun

matchBranchReturnType ::
  (Mem rep inner, TC.Checkable rep) =>
  [BodyReturns] ->
  Body (Aliases rep) ->
  TC.TypeM rep ()
matchBranchReturnType :: forall {k} (rep :: k) inner.
(Mem rep inner, Checkable rep) =>
[BodyReturns] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType [BodyReturns]
rettype (Body BodyDec (Aliases rep)
_ Stms (Aliases rep)
stms Result
res) = do
  Scope (Aliases rep)
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  [LetDecMem]
ts <- forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
SubExp -> m LetDecMem
subExpMemInfo forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SubExpRes -> SubExp
resSubExp) Result
res) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scope (Aliases rep) -> Scope rep
removeScopeAliases (Scope (Aliases rep)
scope forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms (Aliases rep)
stms)
  forall {k} u (rep :: k).
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> [SubExp] -> [LetDecMem] -> TypeM rep ()
matchReturnType [BodyReturns]
rettype (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res) [LetDecMem]
ts

-- | Helper function for index function unification.
--
-- The first return value maps a VName (wrapped in 'Free') to its Int
-- (wrapped in 'Ext').  In case of duplicates, it is mapped to the
-- *first* Int that occurs.
--
-- The second return value maps each Int (wrapped in an 'Ext') to a
-- 'LeafExp' 'Ext' with the Int at which its associated VName first
-- occurs.
getExtMaps ::
  [(VName, Int)] ->
  ( M.Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    M.Map (Ext VName) (TPrimExp Int64 (Ext VName))
  )
getExtMaps :: [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps [(VName, Int)]
ctx_lst_ids =
  ( forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a. Int -> TPrimExp Int64 (Ext a)
leafExp forall a b. (a -> b) -> a -> b
$ forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys forall a. a -> Ext a
Free forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith (forall a b. a -> b -> a
const forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id) [(VName, Int)]
ctx_lst_ids,
    forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
        ( forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
            ( forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
i) PrimType
int64)
                forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(VName, Int)]
ctx_lst_ids)
            )
            forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
            forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Int -> Ext a
Ext
        )
        [(VName, Int)]
ctx_lst_ids
  )

matchReturnType ::
  PP.Pretty u =>
  [MemInfo ExtSize u MemReturn] ->
  [SubExp] ->
  [MemInfo SubExp NoUniqueness MemBind] ->
  TC.TypeM rep ()
matchReturnType :: forall {k} u (rep :: k).
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> [SubExp] -> [LetDecMem] -> TypeM rep ()
matchReturnType [MemInfo (Ext SubExp) u MemReturn]
rettype [SubExp]
res [LetDecMem]
ts = do
  let existentialiseIxFun0 :: IxFun -> ExtIxFun
      existentialiseIxFun0 :: IxFun -> ExtIxFun
existentialiseIxFun0 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free

      fetchCtx :: Int -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
fetchCtx Int
i = case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
res [LetDecMem]
ts of
        Maybe (SubExp, LetDecMem)
Nothing ->
          forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ Text
"Cannot find variable #" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Int
i forall a. Semigroup a => a -> a -> a
<> Text
" in results: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText [SubExp]
res
        Just (SubExp
se, LetDecMem
t) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se, LetDecMem
t)

      checkReturn :: MemInfo (Ext SubExp) u MemReturn
-> LetDecMem -> ExceptT Text (TypeM rep) ()
checkReturn (MemPrim PrimType
x) (MemPrim PrimType
y)
        | PrimType
x forall a. Eq a => a -> a -> Bool
== PrimType
y = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      checkReturn (MemMem Space
x) (MemMem Space
y)
        | Space
x forall a. Eq a => a -> a -> Bool
== Space
y = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      checkReturn (MemAcc VName
xacc Shape
xispace [Type]
xts u
_) (MemAcc VName
yacc Shape
yispace [Type]
yts NoUniqueness
_)
        | (VName
xacc, Shape
xispace, [Type]
xts) forall a. Eq a => a -> a -> Bool
== (VName
yacc, Shape
yispace, [Type]
yts) =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      checkReturn
        (MemArray PrimType
x_pt ShapeBase (Ext SubExp)
x_shape u
_ MemReturn
x_ret)
        (MemArray PrimType
y_pt Shape
y_shape NoUniqueness
_ MemBind
y_ret)
          | PrimType
x_pt forall a. Eq a => a -> a -> Bool
== PrimType
y_pt,
            forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
x_shape forall a. Eq a => a -> a -> Bool
== forall a. ArrayShape a => a -> Int
shapeRank Shape
y_shape = do
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Ext SubExp -> SubExp -> ExceptT Text (TypeM rep) ()
checkDim (forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
x_shape) (forall d. ShapeBase d -> [d]
shapeDims Shape
y_shape)
              MemReturn -> MemBind -> ExceptT Text (TypeM rep) ()
checkMemReturn MemReturn
x_ret MemBind
y_ret
      checkReturn MemInfo (Ext SubExp) u MemReturn
x LetDecMem
y =
        forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.unwords [Text
"Expected", forall a. Pretty a => a -> Text
prettyText MemInfo (Ext SubExp) u MemReturn
x, Text
"but got", forall a. Pretty a => a -> Text
prettyText LetDecMem
y]

      checkDim :: Ext SubExp -> SubExp -> ExceptT Text (TypeM rep) ()
checkDim (Free SubExp
x) SubExp
y
        | SubExp
x forall a. Eq a => a -> a -> Bool
== SubExp
y = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        | Bool
otherwise =
            forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.unwords [Text
"Expected dim", forall a. Pretty a => a -> Text
prettyText SubExp
x, Text
"but got", forall a. Pretty a => a -> Text
prettyText SubExp
y]
      checkDim (Ext Int
i) SubExp
y = do
        (SubExp
x, LetDecMem
_) <- Int -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
fetchCtx Int
i
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (SubExp
x forall a. Eq a => a -> a -> Bool
== SubExp
y) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords forall a b. (a -> b) -> a -> b
$
          [Text
"Expected ext dim", forall a. Pretty a => a -> Text
prettyText Int
i, Text
"=>", forall a. Pretty a => a -> Text
prettyText SubExp
x, Text
"but got", forall a. Pretty a => a -> Text
prettyText SubExp
y]

      checkMemReturn :: MemReturn -> MemBind -> ExceptT Text (TypeM rep) ()
checkMemReturn (ReturnsInBlock VName
x_mem ExtIxFun
x_ixfun) (ArrayIn VName
y_mem IxFun
y_ixfun)
        | VName
x_mem forall a. Eq a => a -> a -> Bool
== VName
y_mem =
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun forall a b. (a -> b) -> a -> b
$ IxFun -> ExtIxFun
existentialiseIxFun0 IxFun
y_ixfun) forall a b. (a -> b) -> a -> b
$
              forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords forall a b. (a -> b) -> a -> b
$
                [ Text
"Index function unification failed (ReturnsInBlock)",
                  Text
"\nixfun of body result: ",
                  forall a. Pretty a => a -> Text
prettyText IxFun
y_ixfun,
                  Text
"\nixfun of return type: ",
                  forall a. Pretty a => a -> Text
prettyText ExtIxFun
x_ixfun
                ]
      checkMemReturn
        (ReturnsNewBlock Space
x_space Int
x_ext ExtIxFun
x_ixfun)
        (ArrayIn VName
y_mem IxFun
y_ixfun) = do
          (SubExp
x_mem, LetDecMem
x_mem_type) <- Int -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
fetchCtx Int
x_ext
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun forall a b. (a -> b) -> a -> b
$ IxFun -> ExtIxFun
existentialiseIxFun0 IxFun
y_ixfun) forall a b. (a -> b) -> a -> b
$
            forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
              Doc Any
"Index function unification failed (ReturnsNewBlock)"
                forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"Ixfun of body result:"
                forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty IxFun
y_ixfun)
                forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"Ixfun of return type:"
                forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty ExtIxFun
x_ixfun)
          case LetDecMem
x_mem_type of
            MemMem Space
y_space ->
              forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Space
x_space forall a. Eq a => a -> a -> Bool
== Space
y_space) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords forall a b. (a -> b) -> a -> b
$
                [ Text
"Expected memory",
                  forall a. Pretty a => a -> Text
prettyText VName
y_mem,
                  Text
"in space",
                  forall a. Pretty a => a -> Text
prettyText Space
x_space,
                  Text
"but actually in space",
                  forall a. Pretty a => a -> Text
prettyText Space
y_space
                ]
            LetDecMem
t ->
              forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords forall a b. (a -> b) -> a -> b
$
                [Text
"Expected memory", forall a. Pretty a => a -> Text
prettyText Int
x_ext, Text
"=>", forall a. Pretty a => a -> Text
prettyText SubExp
x_mem, Text
"but but has type", forall a. Pretty a => a -> Text
prettyText LetDecMem
t]
      checkMemReturn MemReturn
x MemBind
y =
        forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
          Doc Any
"Expected array in"
            forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty MemReturn
x)
            forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"but array returned in"
            forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty MemBind
y)

      bad :: Text -> TypeM rep ()
bad Text
s =
        forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
          Doc Any
"Return type"
            forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a. [Doc a] -> Doc a
ppTuple' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [MemInfo (Ext SubExp) u MemReturn]
rettype)
            forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"cannot match returns of results"
            forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a. [Doc a] -> Doc a
ppTuple' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [LetDecMem]
ts)
            forall ann. Doc ann -> Doc ann -> Doc ann
</> forall a ann. Pretty a => a -> Doc ann
pretty Text
s

  forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Text -> TypeM rep ()
bad forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ MemInfo (Ext SubExp) u MemReturn
-> LetDecMem -> ExceptT Text (TypeM rep) ()
checkReturn [MemInfo (Ext SubExp) u MemReturn]
rettype [LetDecMem]
ts)

matchPatToExp ::
  (Mem rep inner, LetDec rep ~ LetDecMem, TC.Checkable rep) =>
  Pat (LetDec (Aliases rep)) ->
  Exp (Aliases rep) ->
  TC.TypeM rep ()
matchPatToExp :: forall {k} (rep :: k) inner.
(Mem rep inner, LetDec rep ~ LetDecMem, Checkable rep) =>
Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep ()
matchPatToExp Pat (LetDec (Aliases rep))
pat Exp (Aliases rep)
e = do
  Scope rep
scope <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope forall {k} (rep :: k). Scope (Aliases rep) -> Scope rep
removeScopeAliases
  [ExpReturns]
rt <- forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases rep)
e) Scope rep
scope

  let ([VName]
ctx_ids, [BodyReturns]
val_ts) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [(VName, BodyReturns)]
bodyReturnsFromPat forall a b. (a -> b) -> a -> b
$ forall a. Pat (AliasDec, a) -> Pat a
removePatAliases Pat (LetDec (Aliases rep))
pat
      (Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids, Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) = [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ctx_ids [Int
0 .. Int
1]
      ok :: Bool
ok =
        forall (t :: * -> *) a. Foldable t => t a -> Int
length [BodyReturns]
val_ts forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExpReturns]
rt
          Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall {d} {u} {u}.
Eq d =>
Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) [BodyReturns]
val_ts [ExpReturns]
rt)

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
ok forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
    Doc Any
"Expression type:"
      forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a. [Doc a] -> Doc a
ppTuple' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [ExpReturns]
rt)
      forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"cannot match pattern type:"
      forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a. [Doc a] -> Doc a
ppTuple' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [BodyReturns]
val_ts)
  where
    matches :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemPrim PrimType
x) (MemPrim PrimType
y) = PrimType
x forall a. Eq a => a -> a -> Bool
== PrimType
y
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemMem Space
x_space) (MemMem Space
y_space) =
      Space
x_space forall a. Eq a => a -> a -> Bool
== Space
y_space
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemAcc VName
x_accs Shape
x_ispace [Type]
x_ts u
_) (MemAcc VName
y_accs Shape
y_ispace [Type]
y_ts u
_) =
      (VName
x_accs, Shape
x_ispace, [Type]
x_ts) forall a. Eq a => a -> a -> Bool
== (VName
y_accs, Shape
y_ispace, [Type]
y_ts)
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts (MemArray PrimType
x_pt ShapeBase d
x_shape u
_ MemReturn
x_ret) (MemArray PrimType
y_pt ShapeBase d
y_shape u
_ Maybe MemReturn
y_ret) =
      PrimType
x_pt forall a. Eq a => a -> a -> Bool
== PrimType
y_pt
        Bool -> Bool -> Bool
&& ShapeBase d
x_shape forall a. Eq a => a -> a -> Bool
== ShapeBase d
y_shape
        Bool -> Bool -> Bool
&& case (MemReturn
x_ret, Maybe MemReturn
y_ret) of
          (ReturnsInBlock VName
_ ExtIxFun
x_ixfun, Just (ReturnsInBlock VName
_ ExtIxFun
y_ixfun)) ->
            let x_ixfun' :: ExtIxFun
x_ixfun' = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
                y_ixfun' :: ExtIxFun
y_ixfun' = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
             in forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
          ( ReturnsInBlock VName
_ ExtIxFun
x_ixfun,
            Just (ReturnsNewBlock Space
_ Int
_ ExtIxFun
y_ixfun)
            ) ->
              let x_ixfun' :: ExtIxFun
x_ixfun' = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
                  y_ixfun' :: ExtIxFun
y_ixfun' = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
               in forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
          ( ReturnsNewBlock Space
_ Int
x_i ExtIxFun
x_ixfun,
            Just (ReturnsNewBlock Space
_ Int
y_i ExtIxFun
y_ixfun)
            ) ->
              let x_ixfun' :: ExtIxFun
x_ixfun' = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
                  y_ixfun' :: ExtIxFun
y_ixfun' = forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
               in Int
x_i forall a. Eq a => a -> a -> Bool
== Int
y_i Bool -> Bool -> Bool
&& forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
          (MemReturn
_, Maybe MemReturn
Nothing) -> Bool
True
          (MemReturn, Maybe MemReturn)
_ -> Bool
False
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ MemInfo d u MemReturn
_ MemInfo d u (Maybe MemReturn)
_ = Bool
False

varMemInfo ::
  Mem rep inner =>
  VName ->
  TC.TypeM rep (MemInfo SubExp NoUniqueness MemBind)
varMemInfo :: forall {k} (rep :: k) inner.
Mem rep inner =>
VName -> TypeM rep LetDecMem
varMemInfo VName
name = do
  NameInfo (Aliases rep)
dec <- forall {k} (rep :: k). VName -> TypeM rep (NameInfo (Aliases rep))
TC.lookupVar VName
name

  case NameInfo (Aliases rep)
dec of
    LetName (AliasDec
_, LetDec rep
summary) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall t. HasLetDecMem t => t -> LetDecMem
letDecMem LetDec rep
summary
    FParamName FParamInfo (Aliases rep)
summary -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo (Aliases rep)
summary
    LParamName LParamInfo (Aliases rep)
summary -> forall (f :: * -> *) a. Applicative f => a -> f a
pure LParamInfo (Aliases rep)
summary
    IndexName IntType
it -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

nameInfoToMemInfo :: Mem rep inner => NameInfo rep -> MemBound NoUniqueness
nameInfoToMemInfo :: forall {k} (rep :: k) inner.
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo NameInfo rep
info =
  case NameInfo rep
info of
    FParamName FParamInfo rep
summary -> forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
summary
    LParamName LParamInfo rep
summary -> LParamInfo rep
summary
    LetName LetDec rep
summary -> forall t. HasLetDecMem t => t -> LetDecMem
letDecMem LetDec rep
summary
    IndexName IntType
it -> forall d u ret. PrimType -> MemInfo d u ret
MemPrim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

lookupMemInfo ::
  (HasScope rep m, Mem rep inner) =>
  VName ->
  m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo :: forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k) inner.
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (NameInfo rep)
lookupInfo

subExpMemInfo ::
  (HasScope rep m, Mem rep inner) =>
  SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo :: forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
SubExp -> m LetDecMem
subExpMemInfo (Var VName
v) = forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
subExpMemInfo (Constant PrimValue
v) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

lookupArraySummary ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  VName ->
  m (VName, IxFun.IxFun (TPrimExp Int64 VName))
lookupArraySummary :: forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
name = do
  LetDecMem
summary <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
name
  case LetDecMem
summary of
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, IxFun
ixfun)
    LetDecMem
_ ->
      forall a. HasCallStack => String -> a
error forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$
        Text
"Expected "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
name
          forall a. Semigroup a => a -> a -> a
<> Text
" to be array but bound to:\n"
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText LetDecMem
summary

lookupMemSpace ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  VName ->
  m Space
lookupMemSpace :: forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
name = do
  LetDecMem
summary <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
name
  case LetDecMem
summary of
    MemMem Space
space ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space
    LetDecMem
_ ->
      forall a. HasCallStack => String -> a
error forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$
        Text
"Expected "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
name
          forall a. Semigroup a => a -> a -> a
<> Text
" to be memory but bound to:\n"
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText LetDecMem
summary

checkMemInfo ::
  TC.Checkable rep =>
  VName ->
  MemInfo SubExp u MemBind ->
  TC.TypeM rep ()
checkMemInfo :: forall {k} (rep :: k) u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo VName
_ (MemPrim PrimType
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkMemInfo VName
_ (MemMem (ScalarSpace [SubExp]
d PrimType
_)) = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
d
checkMemInfo VName
_ (MemMem Space
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkMemInfo VName
_ (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
  forall {k} (rep :: k) u.
Checkable rep =>
TypeBase Shape u -> TypeM rep ()
TC.checkType forall a b. (a -> b) -> a -> b
$ forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
checkMemInfo VName
name (MemArray PrimType
_ Shape
shape u
_ (ArrayIn VName
v IxFun
ixfun)) = do
  Type
t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
  case Type
t of
    Mem {} ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Type
_ ->
      forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Variable "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
v
            forall a. Semigroup a => a -> a -> a
<> Text
" used as memory block, but is of type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
t
            forall a. Semigroup a => a -> a -> a
<> Text
"."

  forall {k} (rep :: k) a. Text -> TypeM rep a -> TypeM rep a
TC.context (Text
"in index function " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText IxFun
ixfun) forall a b. (a -> b) -> a -> b
$ do
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (forall {k} (rep :: k).
Checkable rep =>
PrimType -> PrimExp VName -> TypeM rep ()
TC.requirePrimExp PrimType
int64 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) IxFun
ixfun
    let ixfun_rank :: Int
ixfun_rank = forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun
ixfun
        ident_rank :: Int
ident_rank = forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
ixfun_rank forall a. Eq a => a -> a -> Bool
== Int
ident_rank) forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Arity of index function ("
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Int
ixfun_rank
            forall a. Semigroup a => a -> a -> a
<> Text
") does not match rank of array "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
name
            forall a. Semigroup a => a -> a -> a
<> Text
" ("
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Int
ident_rank
            forall a. Semigroup a => a -> a -> a
<> Text
")"

bodyReturnsFromPat ::
  Pat (MemBound NoUniqueness) -> [(VName, BodyReturns)]
bodyReturnsFromPat :: Pat LetDecMem -> [(VName, BodyReturns)]
bodyReturnsFromPat Pat LetDecMem
pat =
  forall a b. (a -> b) -> [a] -> [b]
map PatElem LetDecMem -> (VName, BodyReturns)
asReturns forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat
  where
    ctx :: [PatElem LetDecMem]
ctx = forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat

    ext :: SubExp -> Ext SubExp
ext (Var VName
v)
      | Just (Int
i, PatElem LetDecMem
_) <- forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall dec. PatElem dec -> VName
patElemName forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElem LetDecMem]
ctx =
          forall a. Int -> Ext a
Ext Int
i
    ext SubExp
se = forall a. a -> Ext a
Free SubExp
se

    asReturns :: PatElem LetDecMem -> (VName, BodyReturns)
asReturns PatElem LetDecMem
pe =
      ( forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe,
        case forall dec. PatElem dec -> dec
patElemDec PatElem LetDecMem
pe of
          MemPrim PrimType
pt -> forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
          MemMem Space
space -> forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
          MemArray PrimType
pt Shape
shape NoUniqueness
u (ArrayIn VName
mem IxFun
ixfun) ->
            forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
ext forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape) NoUniqueness
u forall a b. (a -> b) -> a -> b
$
              case forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
mem) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall dec. PatElem dec -> VName
patElemName forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElem LetDecMem]
ctx of
                Just (Int
i, PatElem VName
_ (MemMem Space
space)) ->
                  Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i forall a b. (a -> b) -> a -> b
$
                    [VName] -> IxFun -> ExtIxFun
existentialiseIxFun (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem LetDecMem]
ctx) IxFun
ixfun
                Maybe (Int, PatElem LetDecMem)
_ -> VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
          MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      )

extReturns :: [ExtType] -> [ExpReturns]
extReturns :: [ExtType] -> [ExpReturns]
extReturns [ExtType]
ets =
  forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *}. MonadState Int f => ExtType -> f ExpReturns
addDec [ExtType]
ets) Int
0
  where
    addDec :: ExtType -> f ExpReturns
addDec (Prim PrimType
bt) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
    addDec (Mem Space
space) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    addDec t :: ExtType
t@(Array PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u)
      | ExtType -> Bool
existential ExtType
t = do
          Int
i <- forall s (m :: * -> *). MonadState s m => m s
get forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Num a => a -> a -> a
+ Int
1)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$
              forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
                Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i forall a b. (a -> b) -> a -> b
$
                  forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
                    forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert forall a b. (a -> b) -> a -> b
$
                      forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
      | Bool
otherwise =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u forall a. Maybe a
Nothing
    addDec (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = forall a. a -> TPrimExp Int64 a
le64 (forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = forall a. a -> Ext a
Free forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

arrayVarReturns ::
  (HasScope rep m, Monad m, Mem rep inner) =>
  VName ->
  m (PrimType, Shape, VName, IxFun)
arrayVarReturns :: forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v = do
  LetDecMem
summary <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
  case LetDecMem
summary of
    MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType
et, forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape, VName
mem, IxFun
ixfun)
    LetDecMem
_ ->
      forall a. HasCallStack => String -> a
error forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ Text
"arrayVarReturns: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText VName
v forall a. Semigroup a => a -> a -> a
<> Text
" is not an array."

varReturns ::
  (HasScope rep m, Monad m, Mem rep inner) =>
  VName ->
  m ExpReturns
varReturns :: forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v = do
  LetDecMem
summary <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
  case LetDecMem
summary of
    MemPrim PrimType
bt ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
    MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free Shape
shape) NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$
          forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
            VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem forall a b. (a -> b) -> a -> b
$
              [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
    MemMem Space
space ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

subExpReturns :: (HasScope rep m, Monad m, Mem rep inner) => SubExp -> m ExpReturns
subExpReturns :: forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
SubExp -> m ExpReturns
subExpReturns (Var VName
v) =
  forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
subExpReturns (Constant PrimValue
v) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

-- | The return information of an expression.  This can be seen as the
-- "return type with memory annotations" of the expression.
expReturns ::
  (LocalScope rep m, Mem rep inner) =>
  Exp rep ->
  m [ExpReturns]
expReturns :: forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns (BasicOp (SubExp SubExp
se)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
SubExp -> m ExpReturns
subExpReturns SubExp
se
expReturns (BasicOp (Opaque OpaqueOp
_ (Var VName
v))) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp (Reshape ReshapeKind
k Shape
newshape VName
v)) = do
  (PrimType
et, Shape
_, VName
mem, IxFun
ixfun) <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    [ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free Shape
newshape) NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$
        forall a. a -> Maybe a
Just forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] forall a b. (a -> b) -> a -> b
$
          IxFun -> [TPrimExp Int64 VName] -> IxFun
reshaper IxFun
ixfun forall a b. (a -> b) -> a -> b
$
            forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
newshape)
    ]
  where
    reshaper :: IxFun -> [TPrimExp Int64 VName] -> IxFun
reshaper = case ReshapeKind
k of
      ReshapeKind
ReshapeArbitrary -> forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.reshape
      ReshapeKind
ReshapeCoerce -> forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.coerce
expReturns (BasicOp (Rearrange [Int]
perm VName
v)) = do
  (PrimType
et, Shape [SubExp]
dims, VName
mem, IxFun
ixfun) <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  let ixfun' :: IxFun
ixfun' = forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
ixfun [Int]
perm
      dims' :: [SubExp]
dims' = forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    [ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Ext a
Free [SubExp]
dims') NoUniqueness
NoUniqueness forall a b. (a -> b) -> a -> b
$
        forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
          VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem forall a b. (a -> b) -> a -> b
$
            [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun'
    ]
expReturns (BasicOp (Index VName
v Slice SubExp
slice)) = do
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LetDecMem -> ExpReturns
varInfoToExpReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) inner.
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> Slice SubExp -> m LetDecMem
sliceInfo VName
v Slice SubExp
slice
expReturns (BasicOp (Update Safety
_ VName
v Slice SubExp
_ SubExp
_)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp (FlatIndex VName
v FlatSlice SubExp
slice)) = do
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LetDecMem -> ExpReturns
varInfoToExpReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) inner.
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> FlatSlice SubExp -> m LetDecMem
flatSliceInfo VName
v FlatSlice SubExp
slice
expReturns (BasicOp (FlatUpdate VName
v FlatSlice SubExp
_ VName
_)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp BasicOp
op) =
  [ExtType] -> [ExpReturns]
extReturns forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
BasicOp -> m [Type]
basicOpType BasicOp
op
expReturns e :: Exp rep
e@(DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
_ Body rep
_) = do
  [ExtType]
t <- forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> Param FParamMem -> m ExpReturns
typeWithDec [ExtType]
t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge
  where
    typeWithDec :: ExtType -> Param FParamMem -> m ExpReturns
typeWithDec ExtType
t Param FParamMem
p =
      case (ExtType
t, forall dec. Param dec -> dec
paramDec Param FParamMem
p) of
        ( Array PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u,
          MemArray PrimType
_ Shape
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun)
          )
            | Just (Int
i, Param FParamMem
mem_p) <- VName -> Maybe (Int, Param FParamMem)
isMergeVar VName
mem,
              Mem Space
space <- forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
mem_p ->
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun'
            | Bool
otherwise ->
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem ExtIxFun
ixfun'
            where
              ixfun' :: ExtIxFun
ixfun' = [VName] -> IxFun -> ExtIxFun
existentialiseIxFun (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param FParamMem]
mergevars) IxFun
ixfun
        (Array {}, FParamMem
_) ->
          forall a. HasCallStack => String -> a
error String
"expReturns: Array return type but not array merge variable."
        (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u, FParamMem
_) ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (Prim PrimType
pt, FParamMem
_) ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (Mem Space
space, FParamMem
_) ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    isMergeVar :: VName -> Maybe (Int, Param FParamMem)
isMergeVar VName
v = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall dec. Param dec -> VName
paramName forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [Param FParamMem]
mergevars
    mergevars :: [Param FParamMem]
mergevars = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge
expReturns (Apply Name
_ [(SubExp, Diet)]
_ [RetType rep]
ret (Safety, SrcLoc, [SrcLoc])
_) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> ExpReturns
funReturnsToExpReturns [RetType rep]
ret
expReturns (Match [SubExp]
_ [Case (Body rep)]
_ Body rep
_ (MatchDec [BranchType rep]
ret MatchSort
_)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map BodyReturns -> ExpReturns
bodyReturnsToExpReturns [BranchType rep]
ret
expReturns (Op Op rep
op) =
  forall op {k} (rep :: k) inner (m :: * -> *).
(OpReturns op, Mem rep inner, Monad m, HasScope rep m) =>
op -> m [ExpReturns]
opReturns Op rep
op
expReturns (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) =
  forall a. Semigroup a => a -> a -> a
(<>)
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} {rep :: k} {inner} {m :: * -> *} {t :: * -> *} {a} {c}.
(FParamInfo rep ~ FParamMem, LParamInfo rep ~ LetDecMem,
 RetType rep ~ FunReturns, BranchType rep ~ BodyReturns,
 Op rep ~ MemOp inner, HasScope rep m, Monad m, Traversable t,
 HasLetDecMem (LetDec rep), ASTRep rep, OpReturns inner) =>
(a, t VName, c) -> m (t ExpReturns)
inputReturns [WithAccInput rep]
inputs)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
    -- XXX: this is a bit dubious because it enforces extra copies.  I
    -- think WithAcc should perhaps have a return annotation like If.
    forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> [ExpReturns]
extReturns forall a b. (a -> b) -> a -> b
$ forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop Int
num_accs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)
  where
    inputReturns :: (a, t VName, c) -> m (t ExpReturns)
inputReturns (a
_, t VName
arrs, c
_) = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns t VName
arrs
    num_accs :: Int
num_accs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs

sliceInfo ::
  (Monad m, HasScope rep m, Mem rep inner) =>
  VName ->
  Slice SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo :: forall {k} (m :: * -> *) (rep :: k) inner.
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> Slice SubExp -> m LetDecMem
sliceInfo VName
v Slice SubExp
slice = do
  (PrimType
et, Shape
_, VName
mem, IxFun
ixfun) <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  case forall d. Slice d -> [d]
sliceDims Slice SubExp
slice of
    [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
et
    [SubExp]
dims ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> IxFun -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$
          forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice)

flatSliceInfo ::
  (Monad m, HasScope rep m, Mem rep inner) =>
  VName ->
  FlatSlice SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
flatSliceInfo :: forall {k} (m :: * -> *) (rep :: k) inner.
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> FlatSlice SubExp -> m LetDecMem
flatSliceInfo VName
v slice :: FlatSlice SubExp
slice@(FlatSlice SubExp
offset [FlatDimIndex SubExp]
idxs) = do
  (PrimType
et, Shape
_, VName
mem, IxFun
ixfun) <- forall {k} (rep :: k) (m :: * -> *) inner.
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
idxs
    forall a b. a -> (a -> b) -> b
& forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset)
    forall a b. a -> (a -> b) -> b
& forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun
ixfun
    forall a b. a -> (a -> b) -> b
& VName -> IxFun -> MemBind
ArrayIn VName
mem
    forall a b. a -> (a -> b) -> b
& forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (forall d. [d] -> ShapeBase d
Shape (forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
slice)) NoUniqueness
NoUniqueness
    forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure

class IsOp op => OpReturns op where
  opReturns :: (Mem rep inner, Monad m, HasScope rep m) => op -> m [ExpReturns]
  opReturns op
op = [ExtType] -> [ExpReturns]
extReturns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op {k} (t :: k) (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType op
op

instance OpReturns inner => OpReturns (MemOp inner) where
  opReturns :: forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
MemOp inner -> m [ExpReturns]
opReturns (Alloc SubExp
_ Space
space) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner inner
op) = forall op {k} (rep :: k) inner (m :: * -> *).
(OpReturns op, Mem rep inner, Monad m, HasScope rep m) =>
op -> m [ExpReturns]
opReturns inner
op

instance OpReturns () where
  opReturns :: forall {k} (rep :: k) inner (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
() -> m [ExpReturns]
opReturns () = forall (f :: * -> *) a. Applicative f => a -> f a
pure []

applyFunReturns ::
  Typed dec =>
  [FunReturns] ->
  [Param dec] ->
  [(SubExp, Type)] ->
  Maybe [FunReturns]
applyFunReturns :: forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns [FunReturns]
rets [Param dec]
params [(SubExp, Type)]
args
  | Just [DeclExtType]
_ <- forall rt dec.
(IsRetType rt, Typed dec) =>
[rt] -> [Param dec] -> [(SubExp, Type)] -> Maybe [rt]
applyRetType [DeclExtType]
rettype [Param dec]
params [(SubExp, Type)]
args =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {u}.
MemInfo (Ext SubExp) u MemReturn
-> MemInfo (Ext SubExp) u MemReturn
correctDims [FunReturns]
rets
  | Bool
otherwise =
      forall a. Maybe a
Nothing
  where
    rettype :: [DeclExtType]
rettype = forall a b. (a -> b) -> [a] -> [b]
map forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf [FunReturns]
rets
    parammap :: M.Map VName (SubExp, Type)
    parammap :: Map VName (SubExp, Type)
parammap =
      forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
        forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param dec]
params) [(SubExp, Type)]
args

    substSubExp :: SubExp -> SubExp
substSubExp (Var VName
v)
      | Just (SubExp
se, Type
_) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (SubExp, Type)
parammap = SubExp
se
    substSubExp SubExp
se = SubExp
se

    correctDims :: MemInfo (Ext SubExp) u MemReturn
-> MemInfo (Ext SubExp) u MemReturn
correctDims (MemPrim PrimType
t) =
      forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    correctDims (MemMem Space
space) =
      forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    correctDims (MemArray PrimType
et ShapeBase (Ext SubExp)
shape u
u MemReturn
memsummary) =
      forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
correctShape ShapeBase (Ext SubExp)
shape) u
u forall a b. (a -> b) -> a -> b
$
        MemReturn -> MemReturn
correctSummary MemReturn
memsummary
    correctDims (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
      forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

    correctShape :: ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
correctShape = forall d. [d] -> ShapeBase d
Shape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> Ext SubExp
correctDim forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall d. ShapeBase d -> [d]
shapeDims
    correctDim :: Ext SubExp -> Ext SubExp
correctDim (Ext Int
i) = forall a. Int -> Ext a
Ext Int
i
    correctDim (Free SubExp
se) = forall a. a -> Ext a
Free forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp
substSubExp SubExp
se

    correctSummary :: MemReturn -> MemReturn
correctSummary (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
      Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun
    correctSummary (ReturnsInBlock VName
mem ExtIxFun
ixfun) =
      -- FIXME: we should also do a replacement in ixfun here.
      VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem' ExtIxFun
ixfun
      where
        mem' :: VName
mem' = case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem Map VName (SubExp, Type)
parammap of
          Just (Var VName
v, Type
_) -> VName
v
          Maybe (SubExp, Type)
_ -> VName
mem