{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Building blocks for defining representations where every array
-- is given information about which memory block is it based in, and
-- how array elements map to memory block offsets.
--
-- There are two primary concepts you will need to understand:
--
--  1. Memory blocks, which are Futhark values of type v'Mem'
--     (parametrized with their size).  These correspond to arbitrary
--     blocks of memory, and are created using the 'Alloc' operation.
--
--  2. Index functions, which describe a mapping from the index space
--     of an array (eg. a two-dimensional space for an array of type
--     @[[int]]@) to a one-dimensional offset into a memory block.
--     Thus, index functions describe how arbitrary-dimensional arrays
--     are mapped to the single-dimensional world of memory.
--
-- At a conceptual level, imagine that we have a two-dimensional array
-- @a@ of 32-bit integers, consisting of @n@ rows of @m@ elements
-- each.  This array could be represented in classic row-major format
-- with an index function like the following:
--
-- @
--   f(i,j) = i * m + j
-- @
--
-- When we want to know the location of element @a[2,3]@, we simply
-- call the index function as @f(2,3)@ and obtain @2*m+3@.  We could
-- also have chosen another index function, one that represents the
-- array in column-major (or "transposed") format:
--
-- @
--   f(i,j) = j * n + i
-- @
--
-- Index functions are not Futhark-level functions, but a special
-- construct that the final code generator will eventually use to
-- generate concrete access code.  By modifying the index functions we
-- can change how an array is represented in memory, which can permit
-- memory access pattern optimisations.
--
-- Every time we bind an array, whether in a @let@-binding, @loop@
-- merge parameter, or @lambda@ parameter, we have an annotation
-- specifying a memory block and an index function.  In some cases,
-- such as @let@-bindings for many expressions, we are free to specify
-- an arbitrary index function and memory block - for example, we get
-- to decide where 'Copy' stores its result - but in other cases the
-- type rules of the expression chooses for us.  For example, 'Index'
-- always produces an array in the same memory block as its input, and
-- with the same index function, except with some indices fixed.
module Futhark.IR.Mem
  ( LetDecMem,
    FParamMem,
    LParamMem,
    RetTypeMem,
    BranchTypeMem,
    MemOp (..),
    traverseMemOpStms,
    MemInfo (..),
    MemBound,
    MemBind (..),
    MemReturn (..),
    LMAD,
    ExtLMAD,
    isStaticLMAD,
    ExpReturns,
    BodyReturns,
    FunReturns,
    noUniquenessReturns,
    bodyReturnsToExpReturns,
    Mem,
    HasLetDecMem (..),
    OpReturns (..),
    varReturns,
    expReturns,
    extReturns,
    lookupMemInfo,
    subExpMemInfo,
    lookupArraySummary,
    lookupMemSpace,
    existentialiseLMAD,

    -- * Type checking parts
    matchBranchReturnType,
    matchPatToExp,
    matchFunctionReturnType,
    matchLoopResultMem,
    bodyReturnsFromPat,
    checkMemInfo,

    -- * Module re-exports
    module Futhark.IR.Prop,
    module Futhark.IR.Traversals,
    module Futhark.IR.Pretty,
    module Futhark.IR.Syntax,
    module Futhark.Analysis.PrimExp.Convert,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (traverse_)
import Data.Function ((&))
import Data.Kind qualified
import Data.List (elemIndex, find)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Text qualified as T
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.PrimExp.Simplify
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.IR.Aliases
  ( Aliases,
    CanBeAliased (..),
    removeExpAliases,
    removePatAliases,
    removeScopeAliases,
  )
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.Prop.Aliases
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.Pretty (docText, indent, ppTupleLines', pretty, (<+>), (</>))
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

type LetDecMem = MemInfo SubExp NoUniqueness MemBind

type FParamMem = MemInfo SubExp Uniqueness MemBind

type LParamMem = MemInfo SubExp NoUniqueness MemBind

type RetTypeMem = FunReturns

type BranchTypeMem = BodyReturns

-- | The class of pattern element decorators that contain memory
-- information.
class HasLetDecMem t where
  letDecMem :: t -> LetDecMem

instance HasLetDecMem LetDecMem where
  letDecMem :: LetDecMem -> LetDecMem
letDecMem = LetDecMem -> LetDecMem
forall a. a -> a
forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id

instance (HasLetDecMem b) => HasLetDecMem (a, b) where
  letDecMem :: (a, b) -> LetDecMem
letDecMem = b -> LetDecMem
forall t. HasLetDecMem t => t -> LetDecMem
letDecMem (b -> LetDecMem) -> ((a, b) -> b) -> (a, b) -> LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a, b) -> b
forall a b. (a, b) -> b
snd

type Mem rep inner =
  ( FParamInfo rep ~ FParamMem,
    LParamInfo rep ~ LParamMem,
    HasLetDecMem (LetDec rep),
    RetType rep ~ RetTypeMem,
    BranchType rep ~ BranchTypeMem,
    ASTRep rep,
    OpReturns (inner rep),
    RephraseOp inner,
    Op rep ~ MemOp inner rep
  )

instance IsRetType FunReturns where
  primRetType :: PrimType -> FunReturns
primRetType = PrimType -> FunReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim
  applyRetType :: forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyRetType = [FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns

instance IsBodyType BodyReturns where
  primBodyType :: PrimType -> BodyReturns
primBodyType = PrimType -> BodyReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim

data MemOp (inner :: Data.Kind.Type -> Data.Kind.Type) (rep :: Data.Kind.Type)
  = -- | Allocate a memory block.
    Alloc SubExp Space
  | Inner (inner rep)
  deriving (MemOp inner rep -> MemOp inner rep -> Bool
(MemOp inner rep -> MemOp inner rep -> Bool)
-> (MemOp inner rep -> MemOp inner rep -> Bool)
-> Eq (MemOp inner rep)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (inner :: * -> *) rep.
Eq (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
$c== :: forall (inner :: * -> *) rep.
Eq (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
== :: MemOp inner rep -> MemOp inner rep -> Bool
$c/= :: forall (inner :: * -> *) rep.
Eq (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
/= :: MemOp inner rep -> MemOp inner rep -> Bool
Eq, Eq (MemOp inner rep)
Eq (MemOp inner rep) =>
(MemOp inner rep -> MemOp inner rep -> Ordering)
-> (MemOp inner rep -> MemOp inner rep -> Bool)
-> (MemOp inner rep -> MemOp inner rep -> Bool)
-> (MemOp inner rep -> MemOp inner rep -> Bool)
-> (MemOp inner rep -> MemOp inner rep -> Bool)
-> (MemOp inner rep -> MemOp inner rep -> MemOp inner rep)
-> (MemOp inner rep -> MemOp inner rep -> MemOp inner rep)
-> Ord (MemOp inner rep)
MemOp inner rep -> MemOp inner rep -> Bool
MemOp inner rep -> MemOp inner rep -> Ordering
MemOp inner rep -> MemOp inner rep -> MemOp inner rep
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall (inner :: * -> *) rep.
Ord (inner rep) =>
Eq (MemOp inner rep)
forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Ordering
forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> MemOp inner rep
$ccompare :: forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Ordering
compare :: MemOp inner rep -> MemOp inner rep -> Ordering
$c< :: forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
< :: MemOp inner rep -> MemOp inner rep -> Bool
$c<= :: forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
<= :: MemOp inner rep -> MemOp inner rep -> Bool
$c> :: forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
> :: MemOp inner rep -> MemOp inner rep -> Bool
$c>= :: forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> Bool
>= :: MemOp inner rep -> MemOp inner rep -> Bool
$cmax :: forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> MemOp inner rep
max :: MemOp inner rep -> MemOp inner rep -> MemOp inner rep
$cmin :: forall (inner :: * -> *) rep.
Ord (inner rep) =>
MemOp inner rep -> MemOp inner rep -> MemOp inner rep
min :: MemOp inner rep -> MemOp inner rep -> MemOp inner rep
Ord, Int -> MemOp inner rep -> ShowS
[MemOp inner rep] -> ShowS
MemOp inner rep -> String
(Int -> MemOp inner rep -> ShowS)
-> (MemOp inner rep -> String)
-> ([MemOp inner rep] -> ShowS)
-> Show (MemOp inner rep)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (inner :: * -> *) rep.
Show (inner rep) =>
Int -> MemOp inner rep -> ShowS
forall (inner :: * -> *) rep.
Show (inner rep) =>
[MemOp inner rep] -> ShowS
forall (inner :: * -> *) rep.
Show (inner rep) =>
MemOp inner rep -> String
$cshowsPrec :: forall (inner :: * -> *) rep.
Show (inner rep) =>
Int -> MemOp inner rep -> ShowS
showsPrec :: Int -> MemOp inner rep -> ShowS
$cshow :: forall (inner :: * -> *) rep.
Show (inner rep) =>
MemOp inner rep -> String
show :: MemOp inner rep -> String
$cshowList :: forall (inner :: * -> *) rep.
Show (inner rep) =>
[MemOp inner rep] -> ShowS
showList :: [MemOp inner rep] -> ShowS
Show)

-- | A helper for defining 'TraverseOpStms'.
traverseMemOpStms ::
  (Monad m) =>
  OpStmsTraverser m (inner rep) rep ->
  OpStmsTraverser m (MemOp inner rep) rep
traverseMemOpStms :: forall (m :: * -> *) (inner :: * -> *) rep.
Monad m =>
OpStmsTraverser m (inner rep) rep
-> OpStmsTraverser m (MemOp inner rep) rep
traverseMemOpStms OpStmsTraverser m (inner rep) rep
_ Scope rep -> Stms rep -> m (Stms rep)
_ op :: MemOp inner rep
op@Alloc {} = MemOp inner rep -> m (MemOp inner rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp inner rep
op
traverseMemOpStms OpStmsTraverser m (inner rep) rep
onInner Scope rep -> Stms rep -> m (Stms rep)
f (Inner inner rep
inner) = inner rep -> MemOp inner rep
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (inner rep -> MemOp inner rep)
-> m (inner rep) -> m (MemOp inner rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpStmsTraverser m (inner rep) rep
onInner Scope rep -> Stms rep -> m (Stms rep)
f inner rep
inner

instance (RephraseOp inner) => RephraseOp (MemOp inner) where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> MemOp inner from -> m (MemOp inner to)
rephraseInOp Rephraser m from to
_ (Alloc SubExp
e Space
space) = MemOp inner to -> m (MemOp inner to)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Space -> MemOp inner to
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
e Space
space)
  rephraseInOp Rephraser m from to
r (Inner inner from
x) = inner to -> MemOp inner to
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (inner to -> MemOp inner to) -> m (inner to) -> m (MemOp inner to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rephraser m from to -> inner from -> m (inner to)
forall (op :: * -> *) (m :: * -> *) from to.
(RephraseOp op, Monad m) =>
Rephraser m from to -> op from -> m (op to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> inner from -> m (inner to)
rephraseInOp Rephraser m from to
r inner from
x

instance (FreeIn (inner rep)) => FreeIn (MemOp inner rep) where
  freeIn' :: MemOp inner rep -> FV
freeIn' (Alloc SubExp
size Space
_) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
size
  freeIn' (Inner inner rep
k) = inner rep -> FV
forall a. FreeIn a => a -> FV
freeIn' inner rep
k

instance (TypedOp (inner rep)) => TypedOp (MemOp inner rep) where
  opType :: forall t (m :: * -> *).
HasScope t m =>
MemOp inner rep -> m [ExtType]
opType (Alloc SubExp
_ Space
space) = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Space -> ExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space]
  opType (Inner inner rep
k) = inner rep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
forall t (m :: * -> *). HasScope t m => inner rep -> m [ExtType]
opType inner rep
k

instance (AliasedOp (inner rep)) => AliasedOp (MemOp inner rep) where
  opAliases :: MemOp inner rep -> [Names]
opAliases Alloc {} = [Names
forall a. Monoid a => a
mempty]
  opAliases (Inner inner rep
k) = inner rep -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases inner rep
k

  consumedInOp :: MemOp inner rep -> Names
consumedInOp Alloc {} = Names
forall a. Monoid a => a
mempty
  consumedInOp (Inner inner rep
k) = inner rep -> Names
forall op. AliasedOp op => op -> Names
consumedInOp inner rep
k

instance (CanBeAliased inner) => CanBeAliased (MemOp inner) where
  addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> MemOp inner rep -> MemOp inner (Aliases rep)
addOpAliases AliasTable
_ (Alloc SubExp
se Space
space) = SubExp -> Space -> MemOp inner (Aliases rep)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
se Space
space
  addOpAliases AliasTable
aliases (Inner inner rep
k) = inner (Aliases rep) -> MemOp inner (Aliases rep)
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (inner (Aliases rep) -> MemOp inner (Aliases rep))
-> inner (Aliases rep) -> MemOp inner (Aliases rep)
forall a b. (a -> b) -> a -> b
$ AliasTable -> inner rep -> inner (Aliases rep)
forall rep.
AliasableRep rep =>
AliasTable -> inner rep -> inner (Aliases rep)
forall (op :: * -> *) rep.
(CanBeAliased op, AliasableRep rep) =>
AliasTable -> op rep -> op (Aliases rep)
addOpAliases AliasTable
aliases inner rep
k

instance (Rename (inner rep)) => Rename (MemOp inner rep) where
  rename :: MemOp inner rep -> RenameM (MemOp inner rep)
rename (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner rep
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc (SubExp -> Space -> MemOp inner rep)
-> RenameM SubExp -> RenameM (Space -> MemOp inner rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
size RenameM (Space -> MemOp inner rep)
-> RenameM Space -> RenameM (MemOp inner rep)
forall a b. RenameM (a -> b) -> RenameM a -> RenameM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> RenameM Space
forall a. a -> RenameM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space
  rename (Inner inner rep
k) = inner rep -> MemOp inner rep
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (inner rep -> MemOp inner rep)
-> RenameM (inner rep) -> RenameM (MemOp inner rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> inner rep -> RenameM (inner rep)
forall a. Rename a => a -> RenameM a
rename inner rep
k

instance (Substitute (inner rep)) => Substitute (MemOp inner rep) where
  substituteNames :: Map VName VName -> MemOp inner rep -> MemOp inner rep
substituteNames Map VName VName
subst (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner rep
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
size) Space
space
  substituteNames Map VName VName
subst (Inner inner rep
k) = inner rep -> MemOp inner rep
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (inner rep -> MemOp inner rep) -> inner rep -> MemOp inner rep
forall a b. (a -> b) -> a -> b
$ Map VName VName -> inner rep -> inner rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst inner rep
k

instance (PP.Pretty (inner rep)) => PP.Pretty (MemOp inner rep) where
  pretty :: forall ann. MemOp inner rep -> Doc ann
pretty (Alloc SubExp
e Space
DefaultSpace) = Doc ann
"alloc" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.apply [SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
PP.pretty SubExp
e]
  pretty (Alloc SubExp
e Space
s) = Doc ann
"alloc" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
PP.apply [SubExp -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. SubExp -> Doc ann
PP.pretty SubExp
e, Space -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Space -> Doc ann
PP.pretty Space
s]
  pretty (Inner inner rep
k) = inner rep -> Doc ann
forall ann. inner rep -> Doc ann
forall a ann. Pretty a => a -> Doc ann
PP.pretty inner rep
k

instance (OpMetrics (inner rep)) => OpMetrics (MemOp inner rep) where
  opMetrics :: MemOp inner rep -> MetricsM ()
opMetrics Alloc {} = Text -> MetricsM ()
seen Text
"Alloc"
  opMetrics (Inner inner rep
k) = inner rep -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics inner rep
k

instance (IsOp (inner rep)) => IsOp (MemOp inner rep) where
  safeOp :: MemOp inner rep -> Bool
safeOp (Alloc (Constant (IntValue (Int64Value Int64
k))) Space
_) = Int64
k Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0
  safeOp Alloc {} = Bool
False
  safeOp (Inner inner rep
k) = inner rep -> Bool
forall op. IsOp op => op -> Bool
safeOp inner rep
k
  cheapOp :: MemOp inner rep -> Bool
cheapOp (Inner inner rep
k) = inner rep -> Bool
forall op. IsOp op => op -> Bool
cheapOp inner rep
k
  cheapOp Alloc {} = Bool
True
  opDependencies :: MemOp inner rep -> [Names]
opDependencies op :: MemOp inner rep
op@(Alloc {}) = [MemOp inner rep -> Names
forall a. FreeIn a => a -> Names
freeIn MemOp inner rep
op]
  opDependencies (Inner inner rep
op) = inner rep -> [Names]
forall op. IsOp op => op -> [Names]
opDependencies inner rep
op

instance (CanBeWise inner) => CanBeWise (MemOp inner) where
  addOpWisdom :: forall rep.
Informing rep =>
MemOp inner rep -> MemOp inner (Wise rep)
addOpWisdom (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner (Wise rep)
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
size Space
space
  addOpWisdom (Inner inner rep
k) = inner (Wise rep) -> MemOp inner (Wise rep)
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (inner (Wise rep) -> MemOp inner (Wise rep))
-> inner (Wise rep) -> MemOp inner (Wise rep)
forall a b. (a -> b) -> a -> b
$ inner rep -> inner (Wise rep)
forall rep. Informing rep => inner rep -> inner (Wise rep)
forall (op :: * -> *) rep.
(CanBeWise op, Informing rep) =>
op rep -> op (Wise rep)
addOpWisdom inner rep
k

instance (ST.IndexOp (inner rep)) => ST.IndexOp (MemOp inner rep) where
  indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int
-> MemOp inner rep
-> [TPrimExp Int64 VName]
-> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (Inner inner rep
op) [TPrimExp Int64 VName]
is = SymbolTable rep
-> Int -> inner rep -> [TPrimExp Int64 VName] -> Maybe Indexed
forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> inner rep -> [TPrimExp Int64 VName] -> Maybe Indexed
forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k inner rep
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
_ Int
_ MemOp inner rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

-- | The LMAD representation used for memory annotations.
type LMAD = LMAD.LMAD (TPrimExp Int64 VName)

-- | An index function that may contain existential variables.
type ExtLMAD = LMAD.LMAD (TPrimExp Int64 (Ext VName))

-- | A summary of the memory information for every let-bound
-- identifier, function parameter, and return value.  Parameterisered
-- over uniqueness, dimension, and auxiliary array information.
data MemInfo d u ret
  = -- | A primitive value.
    MemPrim PrimType
  | -- | A memory block.
    MemMem Space
  | -- | The array is stored in the named memory block, and with the
    -- given index function.  The index function maps indices in the
    -- array to /element/ offset, /not/ byte offsets!  To translate to
    -- byte offsets, multiply the offset with the size of the array
    -- element type.
    MemArray PrimType (ShapeBase d) u ret
  | -- | An accumulator, which is not stored anywhere.
    MemAcc VName Shape [Type] u
  deriving (MemInfo d u ret -> MemInfo d u ret -> Bool
(MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> Eq (MemInfo d u ret)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall d u ret.
(Eq d, Eq ret, Eq u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
$c== :: forall d u ret.
(Eq d, Eq ret, Eq u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
== :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c/= :: forall d u ret.
(Eq d, Eq ret, Eq u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
/= :: MemInfo d u ret -> MemInfo d u ret -> Bool
Eq, Int -> MemInfo d u ret -> ShowS
[MemInfo d u ret] -> ShowS
MemInfo d u ret -> String
(Int -> MemInfo d u ret -> ShowS)
-> (MemInfo d u ret -> String)
-> ([MemInfo d u ret] -> ShowS)
-> Show (MemInfo d u ret)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall d u ret.
(Show d, Show ret, Show u) =>
Int -> MemInfo d u ret -> ShowS
forall d u ret.
(Show d, Show ret, Show u) =>
[MemInfo d u ret] -> ShowS
forall d u ret.
(Show d, Show ret, Show u) =>
MemInfo d u ret -> String
$cshowsPrec :: forall d u ret.
(Show d, Show ret, Show u) =>
Int -> MemInfo d u ret -> ShowS
showsPrec :: Int -> MemInfo d u ret -> ShowS
$cshow :: forall d u ret.
(Show d, Show ret, Show u) =>
MemInfo d u ret -> String
show :: MemInfo d u ret -> String
$cshowList :: forall d u ret.
(Show d, Show ret, Show u) =>
[MemInfo d u ret] -> ShowS
showList :: [MemInfo d u ret] -> ShowS
Show, Eq (MemInfo d u ret)
Eq (MemInfo d u ret) =>
(MemInfo d u ret -> MemInfo d u ret -> Ordering)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret)
-> (MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret)
-> Ord (MemInfo d u ret)
MemInfo d u ret -> MemInfo d u ret -> Bool
MemInfo d u ret -> MemInfo d u ret -> Ordering
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall d u ret. (Ord d, Ord ret, Ord u) => Eq (MemInfo d u ret)
forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$ccompare :: forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
compare :: MemInfo d u ret -> MemInfo d u ret -> Ordering
$c< :: forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
< :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c<= :: forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
<= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c> :: forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
> :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c>= :: forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
>= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$cmax :: forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
max :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$cmin :: forall d u ret.
(Ord d, Ord ret, Ord u) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
min :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
Ord) --- XXX Ord?

type MemBound u = MemInfo SubExp u MemBind

instance (FixExt ret) => DeclExtTyped (MemInfo ExtSize Uniqueness ret) where
  declExtTypeOf :: MemInfo (Ext SubExp) Uniqueness ret -> DeclExtType
declExtTypeOf (MemPrim PrimType
pt) = PrimType -> DeclExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  declExtTypeOf (MemMem Space
space) = Space -> DeclExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space
  declExtTypeOf (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u ret
_) = PrimType -> ShapeBase (Ext SubExp) -> Uniqueness -> DeclExtType
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u
  declExtTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = VName -> Shape -> [Type] -> Uniqueness -> DeclExtType
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u

instance (FixExt ret) => ExtTyped (MemInfo ExtSize Uniqueness ret) where
  extTypeOf :: MemInfo (Ext SubExp) Uniqueness ret -> ExtType
extTypeOf = DeclExtType -> ExtType
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (DeclExtType -> ExtType)
-> (MemInfo (Ext SubExp) Uniqueness ret -> DeclExtType)
-> MemInfo (Ext SubExp) Uniqueness ret
-> ExtType
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. MemInfo (Ext SubExp) Uniqueness ret -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf

instance (FixExt ret) => ExtTyped (MemInfo ExtSize NoUniqueness ret) where
  extTypeOf :: MemInfo (Ext SubExp) NoUniqueness ret -> ExtType
extTypeOf (MemPrim PrimType
pt) = PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  extTypeOf (MemMem Space
space) = Space -> ExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space
  extTypeOf (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u ret
_) = PrimType -> ShapeBase (Ext SubExp) -> NoUniqueness -> ExtType
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u
  extTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> ExtType
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

instance (FixExt ret) => FixExt (MemInfo ExtSize u ret) where
  fixExt :: Int
-> SubExp
-> MemInfo (Ext SubExp) u ret
-> MemInfo (Ext SubExp) u ret
fixExt Int
_ SubExp
_ (MemPrim PrimType
pt) = PrimType -> MemInfo (Ext SubExp) u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
  fixExt Int
_ SubExp
_ (MemMem Space
space) = Space -> MemInfo (Ext SubExp) u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  fixExt Int
_ SubExp
_ (MemAcc VName
acc Shape
ispace [Type]
ts u
u) = VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u ret
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u
  fixExt Int
i SubExp
se (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape u
u ret
ret) =
    PrimType
-> ShapeBase (Ext SubExp) -> u -> ret -> MemInfo (Ext SubExp) u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Int -> SubExp -> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ShapeBase (Ext SubExp)
shape) u
u (Int -> SubExp -> ret -> ret
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ret
ret)

  mapExt :: (Int -> Int)
-> MemInfo (Ext SubExp) u ret -> MemInfo (Ext SubExp) u ret
mapExt Int -> Int
_ (MemPrim PrimType
pt) = PrimType -> MemInfo (Ext SubExp) u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
  mapExt Int -> Int
_ (MemMem Space
space) = Space -> MemInfo (Ext SubExp) u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  mapExt Int -> Int
_ (MemAcc VName
acc Shape
ispace [Type]
ts u
u) = VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u ret
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u
  mapExt Int -> Int
f (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape u
u ret
ret) =
    PrimType
-> ShapeBase (Ext SubExp) -> u -> ret -> MemInfo (Ext SubExp) u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ((Int -> Int) -> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall t. FixExt t => (Int -> Int) -> t -> t
mapExt Int -> Int
f ShapeBase (Ext SubExp)
shape) u
u ((Int -> Int) -> ret -> ret
forall t. FixExt t => (Int -> Int) -> t -> t
mapExt Int -> Int
f ret
ret)

instance Typed (MemInfo SubExp Uniqueness ret) where
  typeOf :: MemInfo SubExp Uniqueness ret -> Type
typeOf = TypeBase Shape Uniqueness -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (TypeBase Shape Uniqueness -> Type)
-> (MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness)
-> MemInfo SubExp Uniqueness ret
-> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf

instance Typed (MemInfo SubExp NoUniqueness ret) where
  typeOf :: MemInfo SubExp NoUniqueness ret -> Type
typeOf (MemPrim PrimType
pt) = PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  typeOf (MemMem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
  typeOf (MemArray PrimType
bt Shape
shape NoUniqueness
u ret
_) = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape NoUniqueness
u
  typeOf (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

instance DeclTyped (MemInfo SubExp Uniqueness ret) where
  declTypeOf :: MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness
declTypeOf (MemPrim PrimType
bt) = PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
  declTypeOf (MemMem Space
space) = Space -> TypeBase Shape Uniqueness
forall shape u. Space -> TypeBase shape u
Mem Space
space
  declTypeOf (MemArray PrimType
bt Shape
shape Uniqueness
u ret
_) = PrimType -> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape Uniqueness
u
  declTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = VName -> Shape -> [Type] -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u

instance (FreeIn d, FreeIn ret) => FreeIn (MemInfo d u ret) where
  freeIn' :: MemInfo d u ret -> FV
freeIn' (MemArray PrimType
_ ShapeBase d
shape u
_ ret
ret) = ShapeBase d -> FV
forall a. FreeIn a => a -> FV
freeIn' ShapeBase d
shape FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ret -> FV
forall a. FreeIn a => a -> FV
freeIn' ret
ret
  freeIn' (MemMem Space
s) = Space -> FV
forall a. FreeIn a => a -> FV
freeIn' Space
s
  freeIn' MemPrim {} = FV
forall a. Monoid a => a
mempty
  freeIn' (MemAcc VName
acc Shape
ispace [Type]
ts u
_) = (VName, Shape, [Type]) -> FV
forall a. FreeIn a => a -> FV
freeIn' (VName
acc, Shape
ispace, [Type]
ts)

instance (Substitute d, Substitute ret) => Substitute (MemInfo d u ret) where
  substituteNames :: Map VName VName -> MemInfo d u ret -> MemInfo d u ret
substituteNames Map VName VName
subst (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray
      PrimType
bt
      (Map VName VName -> ShapeBase d -> ShapeBase d
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase d
shape)
      u
u
      (Map VName VName -> ret -> ret
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ret
ret)
  substituteNames Map VName VName
substs (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    VName -> Shape -> [Type] -> u -> MemInfo d u ret
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
acc)
      (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ispace)
      (Map VName VName -> [Type] -> [Type]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Type]
ts)
      u
u
  substituteNames Map VName VName
_ (MemMem Space
space) =
    Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  substituteNames Map VName VName
_ (MemPrim PrimType
bt) =
    PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt

instance (Substitute d, Substitute ret) => Rename (MemInfo d u ret) where
  rename :: MemInfo d u ret -> RenameM (MemInfo d u ret)
rename = MemInfo d u ret -> RenameM (MemInfo d u ret)
forall a. Substitute a => a -> RenameM a
substituteRename

simplifyLMAD ::
  (Engine.SimplifiableRep rep) =>
  LMAD ->
  Engine.SimpleM rep LMAD
simplifyLMAD :: forall rep. SimplifiableRep rep => LMAD -> SimpleM rep LMAD
simplifyLMAD = (TPrimExp Int64 VName -> SimpleM rep (TPrimExp Int64 VName))
-> LMAD -> SimpleM rep LMAD
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse ((TPrimExp Int64 VName -> SimpleM rep (TPrimExp Int64 VName))
 -> LMAD -> SimpleM rep LMAD)
-> (TPrimExp Int64 VName -> SimpleM rep (TPrimExp Int64 VName))
-> LMAD
-> SimpleM rep LMAD
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> TPrimExp Int64 VName)
-> SimpleM rep (PrimExp VName)
-> SimpleM rep (TPrimExp Int64 VName)
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (SimpleM rep (PrimExp VName) -> SimpleM rep (TPrimExp Int64 VName))
-> (TPrimExp Int64 VName -> SimpleM rep (PrimExp VName))
-> TPrimExp Int64 VName
-> SimpleM rep (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimExp VName -> SimpleM rep (PrimExp VName)
forall rep.
SimplifiableRep rep =>
PrimExp VName -> SimpleM rep (PrimExp VName)
simplifyPrimExp (PrimExp VName -> SimpleM rep (PrimExp VName))
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> SimpleM rep (PrimExp VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

simplifyExtLMAD ::
  (Engine.SimplifiableRep rep) =>
  ExtLMAD ->
  Engine.SimpleM rep ExtLMAD
simplifyExtLMAD :: forall rep. SimplifiableRep rep => ExtLMAD -> SimpleM rep ExtLMAD
simplifyExtLMAD = (TPrimExp Int64 (Ext VName)
 -> SimpleM rep (TPrimExp Int64 (Ext VName)))
-> ExtLMAD -> SimpleM rep ExtLMAD
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse ((TPrimExp Int64 (Ext VName)
  -> SimpleM rep (TPrimExp Int64 (Ext VName)))
 -> ExtLMAD -> SimpleM rep ExtLMAD)
-> (TPrimExp Int64 (Ext VName)
    -> SimpleM rep (TPrimExp Int64 (Ext VName)))
-> ExtLMAD
-> SimpleM rep ExtLMAD
forall a b. (a -> b) -> a -> b
$ (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> SimpleM rep (PrimExp (Ext VName))
-> SimpleM rep (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (SimpleM rep (PrimExp (Ext VName))
 -> SimpleM rep (TPrimExp Int64 (Ext VName)))
-> (TPrimExp Int64 (Ext VName)
    -> SimpleM rep (PrimExp (Ext VName)))
-> TPrimExp Int64 (Ext VName)
-> SimpleM rep (TPrimExp Int64 (Ext VName))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall rep.
SimplifiableRep rep =>
PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
simplifyExtPrimExp (PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName)))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> SimpleM rep (PrimExp (Ext VName))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

isStaticLMAD :: ExtLMAD -> Maybe LMAD
isStaticLMAD :: ExtLMAD -> Maybe LMAD
isStaticLMAD = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtLMAD -> Maybe LMAD
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse ((TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
 -> ExtLMAD -> Maybe LMAD)
-> (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtLMAD
-> Maybe LMAD
forall a b. (a -> b) -> a -> b
$ (Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TPrimExp Int64 a -> f (TPrimExp Int64 b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst
  where
    inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
    inst (Free a
x) = a -> Maybe a
forall a. a -> Maybe a
Just a
x

instance
  (Engine.Simplifiable d, Engine.Simplifiable ret) =>
  Engine.Simplifiable (MemInfo d u ret)
  where
  simplify :: forall rep.
SimplifiableRep rep =>
MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
simplify (MemPrim PrimType
bt) =
    MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d u ret -> SimpleM rep (MemInfo d u ret))
-> MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
  simplify (MemMem Space
space) =
    MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d u ret -> SimpleM rep (MemInfo d u ret))
-> MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  simplify (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt (ShapeBase d -> u -> ret -> MemInfo d u ret)
-> SimpleM rep (ShapeBase d)
-> SimpleM rep (u -> ret -> MemInfo d u ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase d -> SimpleM rep (ShapeBase d)
forall rep.
SimplifiableRep rep =>
ShapeBase d -> SimpleM rep (ShapeBase d)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase d
shape SimpleM rep (u -> ret -> MemInfo d u ret)
-> SimpleM rep u -> SimpleM rep (ret -> MemInfo d u ret)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u SimpleM rep (ret -> MemInfo d u ret)
-> SimpleM rep ret -> SimpleM rep (MemInfo d u ret)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ret -> SimpleM rep ret
forall rep. SimplifiableRep rep => ret -> SimpleM rep ret
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ret
ret
  simplify (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    VName -> Shape -> [Type] -> u -> MemInfo d u ret
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc (VName -> Shape -> [Type] -> u -> MemInfo d u ret)
-> SimpleM rep VName
-> SimpleM rep (Shape -> [Type] -> u -> MemInfo d u ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
acc SimpleM rep (Shape -> [Type] -> u -> MemInfo d u ret)
-> SimpleM rep Shape
-> SimpleM rep ([Type] -> u -> MemInfo d u ret)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape -> SimpleM rep Shape
forall rep. SimplifiableRep rep => Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
ispace SimpleM rep ([Type] -> u -> MemInfo d u ret)
-> SimpleM rep [Type] -> SimpleM rep (u -> MemInfo d u ret)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> SimpleM rep [Type]
forall rep. SimplifiableRep rep => [Type] -> SimpleM rep [Type]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [Type]
ts SimpleM rep (u -> MemInfo d u ret)
-> SimpleM rep u -> SimpleM rep (MemInfo d u ret)
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall a. a -> SimpleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

instance
  ( PP.Pretty (ShapeBase d),
    PP.Pretty (TypeBase (ShapeBase d) u),
    PP.Pretty d,
    PP.Pretty u,
    PP.Pretty ret
  ) =>
  PP.Pretty (MemInfo d u ret)
  where
  pretty :: forall ann. MemInfo d u ret -> Doc ann
pretty (MemPrim PrimType
bt) = PrimType -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. PrimType -> Doc ann
PP.pretty PrimType
bt
  pretty (MemMem Space
DefaultSpace) = Doc ann
"mem"
  pretty (MemMem Space
s) = Doc ann
"mem" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Space -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Space -> Doc ann
PP.pretty Space
s
  pretty (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    TypeBase (ShapeBase d) u -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. TypeBase (ShapeBase d) u -> Doc ann
PP.pretty (PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt ShapeBase d
shape u
u) Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"@" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> ret -> Doc ann
forall ann. ret -> Doc ann
forall a ann. Pretty a => a -> Doc ann
PP.pretty ret
ret
  pretty (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    u -> Doc ann
forall ann. u -> Doc ann
forall a ann. Pretty a => a -> Doc ann
PP.pretty u
u Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Type -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Type -> Doc ann
PP.pretty (VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness :: Type)

-- | Memory information for an array bound somewhere in the program.
data MemBind
  = -- | Located in this memory block with this index
    -- function.
    ArrayIn VName LMAD
  deriving (Int -> MemBind -> ShowS
[MemBind] -> ShowS
MemBind -> String
(Int -> MemBind -> ShowS)
-> (MemBind -> String) -> ([MemBind] -> ShowS) -> Show MemBind
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemBind -> ShowS
showsPrec :: Int -> MemBind -> ShowS
$cshow :: MemBind -> String
show :: MemBind -> String
$cshowList :: [MemBind] -> ShowS
showList :: [MemBind] -> ShowS
Show)

instance Eq MemBind where
  MemBind
_ == :: MemBind -> MemBind -> Bool
== MemBind
_ = Bool
True

instance Ord MemBind where
  MemBind
_ compare :: MemBind -> MemBind -> Ordering
`compare` MemBind
_ = Ordering
EQ

instance Rename MemBind where
  rename :: MemBind -> RenameM MemBind
rename = MemBind -> RenameM MemBind
forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute MemBind where
  substituteNames :: Map VName VName -> MemBind -> MemBind
substituteNames Map VName VName
substs (ArrayIn VName
ident LMAD
lmad) =
    VName -> LMAD -> MemBind
ArrayIn (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (Map VName VName -> LMAD -> LMAD
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs LMAD
lmad)

instance PP.Pretty MemBind where
  pretty :: forall ann. MemBind -> Doc ann
pretty (ArrayIn VName
mem LMAD
lmad) =
    VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
PP.pretty VName
mem Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"->" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
PP.</> LMAD -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. LMAD -> Doc ann
PP.pretty LMAD
lmad

instance FreeIn MemBind where
  freeIn' :: MemBind -> FV
freeIn' (ArrayIn VName
mem LMAD
lmad) = VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
mem FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> LMAD -> FV
forall a. FreeIn a => a -> FV
freeIn' LMAD
lmad

-- | A description of the memory properties of an array being returned
-- by an operation.
data MemReturn
  = -- | The array is located in a memory block that is
    -- already in scope.
    ReturnsInBlock VName ExtLMAD
  | -- | The operation returns a new (existential) memory
    -- block.
    ReturnsNewBlock Space Int ExtLMAD
  deriving (Int -> MemReturn -> ShowS
[MemReturn] -> ShowS
MemReturn -> String
(Int -> MemReturn -> ShowS)
-> (MemReturn -> String)
-> ([MemReturn] -> ShowS)
-> Show MemReturn
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemReturn -> ShowS
showsPrec :: Int -> MemReturn -> ShowS
$cshow :: MemReturn -> String
show :: MemReturn -> String
$cshowList :: [MemReturn] -> ShowS
showList :: [MemReturn] -> ShowS
Show)

instance Eq MemReturn where
  MemReturn
_ == :: MemReturn -> MemReturn -> Bool
== MemReturn
_ = Bool
True

instance Ord MemReturn where
  MemReturn
_ compare :: MemReturn -> MemReturn -> Ordering
`compare` MemReturn
_ = Ordering
EQ

instance Rename MemReturn where
  rename :: MemReturn -> RenameM MemReturn
rename = MemReturn -> RenameM MemReturn
forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute MemReturn where
  substituteNames :: Map VName VName -> MemReturn -> MemReturn
substituteNames Map VName VName
substs (ReturnsInBlock VName
ident ExtLMAD
lmad) =
    VName -> ExtLMAD -> MemReturn
ReturnsInBlock (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (Map VName VName -> ExtLMAD -> ExtLMAD
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtLMAD
lmad)
  substituteNames Map VName VName
substs (ReturnsNewBlock Space
space Int
i ExtLMAD
lmad) =
    Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
i (Map VName VName -> ExtLMAD -> ExtLMAD
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtLMAD
lmad)

instance FixExt MemReturn where
  fixExt :: Int -> SubExp -> MemReturn -> MemReturn
fixExt Int
i (Var VName
v) (ReturnsNewBlock Space
_ Int
j ExtLMAD
lmad)
    | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i =
        VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
v (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
          Int -> PrimExp VName -> ExtLMAD -> ExtLMAD
fixExtLMAD
            Int
i
            (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 (VName -> SubExp
Var VName
v))
            ExtLMAD
lmad
  fixExt Int
i SubExp
se (ReturnsNewBlock Space
space Int
j ExtLMAD
lmad) =
    Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock
      Space
space
      Int
j'
      (Int -> PrimExp VName -> ExtLMAD -> ExtLMAD
fixExtLMAD Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtLMAD
lmad)
    where
      j' :: Int
j'
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        | Bool
otherwise = Int
j
  fixExt Int
i SubExp
se (ReturnsInBlock VName
mem ExtLMAD
lmad) =
    VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem (Int -> PrimExp VName -> ExtLMAD -> ExtLMAD
fixExtLMAD Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtLMAD
lmad)

  mapExt :: (Int -> Int) -> MemReturn -> MemReturn
mapExt Int -> Int
f (ReturnsNewBlock Space
space Int
i ExtLMAD
lmad) =
    Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space (Int -> Int
f Int
i) ExtLMAD
lmad
  mapExt Int -> Int
f (ReturnsInBlock VName
mem ExtLMAD
lmad) =
    VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem ((TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtLMAD -> ExtLMAD
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Ext VName -> Ext VName)
-> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> TPrimExp Int64 a -> TPrimExp Int64 b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> Ext VName
f') ExtLMAD
lmad)
    where
      f' :: Ext VName -> Ext VName
f' (Ext Int
i) = Int -> Ext VName
forall a. Int -> Ext a
Ext (Int -> Ext VName) -> Int -> Ext VName
forall a b. (a -> b) -> a -> b
$ Int -> Int
f Int
i
      f' Ext VName
v = Ext VName
v

fixExtLMAD :: Int -> PrimExp VName -> ExtLMAD -> ExtLMAD
fixExtLMAD :: Int -> PrimExp VName -> ExtLMAD -> ExtLMAD
fixExtLMAD Int
i PrimExp VName
e = (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtLMAD -> ExtLMAD
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
 -> ExtLMAD -> ExtLMAD)
-> (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtLMAD
-> ExtLMAD
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> TPrimExp Int64 (Ext VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Ext VName -> PrimType -> PrimExp (Ext VName))
-> PrimExp (Ext VName) -> PrimExp (Ext VName)
forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp Ext VName -> PrimType -> PrimExp (Ext VName)
update (PrimExp (Ext VName) -> PrimExp (Ext VName))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> PrimExp (Ext VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
  where
    update :: Ext VName -> PrimType -> PrimExp (Ext VName)
update (Ext Int
j) PrimType
t
      | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext (Int -> Ext VName) -> Int -> Ext VName
forall a b. (a -> b) -> a -> b
$ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) PrimType
t
      | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall a b. (a -> b) -> PrimExp a -> PrimExp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free PrimExp VName
e
      | Bool
otherwise = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
j) PrimType
t
    update (Free VName
x) PrimType
t = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
x) PrimType
t

leafExp :: Int -> TPrimExp Int64 (Ext a)
leafExp :: forall a. Int -> TPrimExp Int64 (Ext a)
leafExp Int
i = PrimExp (Ext a) -> TPrimExp Int64 (Ext a)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext a) -> TPrimExp Int64 (Ext a))
-> PrimExp (Ext a) -> TPrimExp Int64 (Ext a)
forall a b. (a -> b) -> a -> b
$ Ext a -> PrimType -> PrimExp (Ext a)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext a
forall a. Int -> Ext a
Ext Int
i) PrimType
int64

existentialiseLMAD :: [VName] -> LMAD -> ExtLMAD
existentialiseLMAD :: [VName] -> LMAD -> ExtLMAD
existentialiseLMAD [VName]
ctx = Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' (ExtLMAD -> ExtLMAD) -> (LMAD -> ExtLMAD) -> LMAD -> ExtLMAD
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> LMAD -> ExtLMAD
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> TPrimExp Int64 a -> TPrimExp Int64 b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free)
  where
    ctx' :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' = (Int -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Int -> TPrimExp Int64 (Ext VName)
forall a. Int -> TPrimExp Int64 (Ext a)
leafExp (Map (Ext VName) Int
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ [(Ext VName, Int)] -> Map (Ext VName) Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, Int)] -> Map (Ext VName) Int)
-> [(Ext VName, Int)] -> Map (Ext VName) Int
forall a b. (a -> b) -> a -> b
$ [Ext VName] -> [Int] -> [(Ext VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((VName -> Ext VName) -> [VName] -> [Ext VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Ext VName
forall a. a -> Ext a
Free [VName]
ctx) [Int
0 ..]

instance PP.Pretty MemReturn where
  pretty :: forall ann. MemReturn -> Doc ann
pretty (ReturnsInBlock VName
v ExtLMAD
lmad) =
    VName -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. VName -> Doc ann
pretty VName
v Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"->" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
PP.</> ExtLMAD -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ExtLMAD -> Doc ann
PP.pretty ExtLMAD
lmad
  pretty (ReturnsNewBlock Space
space Int
i ExtLMAD
lmad) =
    Doc ann
"?" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Int -> Doc ann
forall ann. Int -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Int
i Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Space -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Space -> Doc ann
PP.pretty Space
space Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
"->" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
PP.</> ExtLMAD -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. ExtLMAD -> Doc ann
PP.pretty ExtLMAD
lmad

instance FreeIn MemReturn where
  freeIn' :: MemReturn -> FV
freeIn' (ReturnsInBlock VName
v ExtLMAD
lmad) = VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ExtLMAD -> FV
forall a. FreeIn a => a -> FV
freeIn' ExtLMAD
lmad
  freeIn' (ReturnsNewBlock Space
space Int
_ ExtLMAD
lmad) = Space -> FV
forall a. FreeIn a => a -> FV
freeIn' Space
space FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ExtLMAD -> FV
forall a. FreeIn a => a -> FV
freeIn' ExtLMAD
lmad

instance Engine.Simplifiable MemReturn where
  simplify :: forall rep.
SimplifiableRep rep =>
MemReturn -> SimpleM rep MemReturn
simplify (ReturnsNewBlock Space
space Int
i ExtLMAD
lmad) =
    Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtLMAD -> MemReturn)
-> SimpleM rep ExtLMAD -> SimpleM rep MemReturn
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtLMAD -> SimpleM rep ExtLMAD
forall rep. SimplifiableRep rep => ExtLMAD -> SimpleM rep ExtLMAD
simplifyExtLMAD ExtLMAD
lmad
  simplify (ReturnsInBlock VName
v ExtLMAD
lmad) =
    VName -> ExtLMAD -> MemReturn
ReturnsInBlock (VName -> ExtLMAD -> MemReturn)
-> SimpleM rep VName -> SimpleM rep (ExtLMAD -> MemReturn)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
v SimpleM rep (ExtLMAD -> MemReturn)
-> SimpleM rep ExtLMAD -> SimpleM rep MemReturn
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ExtLMAD -> SimpleM rep ExtLMAD
forall rep. SimplifiableRep rep => ExtLMAD -> SimpleM rep ExtLMAD
simplifyExtLMAD ExtLMAD
lmad

instance Engine.Simplifiable MemBind where
  simplify :: forall rep. SimplifiableRep rep => MemBind -> SimpleM rep MemBind
simplify (ArrayIn VName
mem LMAD
lmad) =
    VName -> LMAD -> MemBind
ArrayIn (VName -> LMAD -> MemBind)
-> SimpleM rep VName -> SimpleM rep (LMAD -> MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall rep. SimplifiableRep rep => VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
mem SimpleM rep (LMAD -> MemBind)
-> SimpleM rep LMAD -> SimpleM rep MemBind
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LMAD -> SimpleM rep LMAD
forall rep. SimplifiableRep rep => LMAD -> SimpleM rep LMAD
simplifyLMAD LMAD
lmad

instance Engine.Simplifiable [FunReturns] where
  simplify :: forall rep.
SimplifiableRep rep =>
[FunReturns] -> SimpleM rep [FunReturns]
simplify = (FunReturns -> SimpleM rep FunReturns)
-> [FunReturns] -> SimpleM rep [FunReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FunReturns -> SimpleM rep FunReturns
forall rep.
SimplifiableRep rep =>
FunReturns -> SimpleM rep FunReturns
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify

-- | The memory return of an expression.  An array is annotated with
-- @Maybe MemReturn@, which can be interpreted as the expression
-- either dictating exactly where the array is located when it is
-- returned (if 'Just'), or able to put it whereever the binding
-- prefers (if 'Nothing').
--
-- This is necessary to capture the difference between an expression
-- that is just an array-typed variable, in which the array being
-- "returned" is located where it already is, and a @copy@ expression,
-- whose entire purpose is to store an existing array in some
-- arbitrary location.  This is a consequence of the design decision
-- never to have implicit memory copies.
type ExpReturns = MemInfo ExtSize NoUniqueness (Maybe MemReturn)

-- | The return of a body, which must always indicate where
-- returned arrays are located.
type BodyReturns = MemInfo ExtSize NoUniqueness MemReturn

-- | The memory return of a function, which must always indicate where
-- returned arrays are located.
type FunReturns = MemInfo ExtSize Uniqueness MemReturn

maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns :: forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns (MemArray PrimType
bt ShapeBase d
shape u
u r
ret) =
  PrimType -> ShapeBase d -> u -> Maybe r -> MemInfo d u (Maybe r)
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape u
u (Maybe r -> MemInfo d u (Maybe r))
-> Maybe r -> MemInfo d u (Maybe r)
forall a b. (a -> b) -> a -> b
$ r -> Maybe r
forall a. a -> Maybe a
Just r
ret
maybeReturns (MemPrim PrimType
bt) =
  PrimType -> MemInfo d u (Maybe r)
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
maybeReturns (MemMem Space
space) =
  Space -> MemInfo d u (Maybe r)
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
maybeReturns (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
  VName -> Shape -> [Type] -> u -> MemInfo d u (Maybe r)
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns :: forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemArray PrimType
bt ShapeBase d
shape u
_ r
r) =
  PrimType
-> ShapeBase d -> NoUniqueness -> r -> MemInfo d NoUniqueness r
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape NoUniqueness
NoUniqueness r
r
noUniquenessReturns (MemPrim PrimType
bt) =
  PrimType -> MemInfo d NoUniqueness r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
noUniquenessReturns (MemMem Space
space) =
  Space -> MemInfo d NoUniqueness r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
noUniquenessReturns (MemAcc VName
acc Shape
ispace [Type]
ts u
_) =
  VName
-> Shape -> [Type] -> NoUniqueness -> MemInfo d NoUniqueness r
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns = MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn) -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn) -> ExpReturns)
-> (FunReturns
    -> MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn))
-> FunReturns
-> ExpReturns
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. FunReturns -> MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn)
forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns

bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns = ExpReturns -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (ExpReturns -> ExpReturns)
-> (BodyReturns -> ExpReturns) -> BodyReturns -> ExpReturns
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BodyReturns -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns

varInfoToExpReturns :: MemInfo SubExp NoUniqueness MemBind -> ExpReturns
varInfoToExpReturns :: LetDecMem -> ExpReturns
varInfoToExpReturns (MemArray PrimType
et Shape
shape NoUniqueness
u (ArrayIn VName
mem LMAD
lmad)) =
  PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ((SubExp -> Ext SubExp) -> Shape -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
forall a. a -> Ext a
Free Shape
shape) NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
    MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$
      VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
        [VName] -> LMAD -> ExtLMAD
existentialiseLMAD [] LMAD
lmad
varInfoToExpReturns (MemPrim PrimType
pt) = PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
varInfoToExpReturns (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
varInfoToExpReturns (MemMem Space
space) = Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space

matchRetTypeToResult ::
  (Mem rep inner, TC.Checkable rep) =>
  [FunReturns] ->
  Result ->
  TC.TypeM rep ()
matchRetTypeToResult :: forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype Result
result = do
  Scope (Aliases rep)
scope <- TypeM rep (Scope (Aliases rep))
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  [LetDecMem]
result_ts <- ReaderT (Scope rep) (TypeM rep) [LetDecMem]
-> Scope rep -> TypeM rep [LetDecMem]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((SubExpRes -> ReaderT (Scope rep) (TypeM rep) LetDecMem)
-> Result -> ReaderT (Scope rep) (TypeM rep) [LetDecMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> ReaderT (Scope rep) (TypeM rep) LetDecMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LetDecMem
subExpMemInfo (SubExp -> ReaderT (Scope rep) (TypeM rep) LetDecMem)
-> (SubExpRes -> SubExp)
-> SubExpRes
-> ReaderT (Scope rep) (TypeM rep) LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SubExpRes -> SubExp
resSubExp) Result
result) (Scope rep -> TypeM rep [LetDecMem])
-> Scope rep -> TypeM rep [LetDecMem]
forall a b. (a -> b) -> a -> b
$ Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases Scope (Aliases rep)
scope
  [FunReturns] -> [SubExp] -> [LetDecMem] -> TypeM rep ()
forall u rep.
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> [SubExp] -> [LetDecMem] -> TypeM rep ()
matchReturnType [FunReturns]
rettype ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
result) [LetDecMem]
result_ts

matchFunctionReturnType ::
  (Mem rep inner, TC.Checkable rep) =>
  [FunReturns] ->
  Result ->
  TC.TypeM rep ()
matchFunctionReturnType :: forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchFunctionReturnType [FunReturns]
rettype Result
result = do
  [FunReturns] -> Result -> TypeM rep ()
forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype Result
result
  (SubExpRes -> TypeM rep ()) -> Result -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> TypeM rep ()
forall {rep} {inner :: * -> *}.
(RetType rep ~ FunReturns, FParamInfo rep ~ FParamMem,
 LParamInfo rep ~ LetDecMem, BranchType rep ~ BodyReturns,
 OpC rep ~ MemOp inner, HasLetDecMem (LetDec rep), ASTRep rep,
 OpReturns (inner rep), RephraseOp inner) =>
SubExp -> TypeM rep ()
checkResultSubExp (SubExp -> TypeM rep ())
-> (SubExpRes -> SubExp) -> SubExpRes -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SubExpRes -> SubExp
resSubExp) Result
result
  where
    checkResultSubExp :: SubExp -> TypeM rep ()
checkResultSubExp Constant {} =
      () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    checkResultSubExp (Var VName
v) = do
      LetDecMem
dec <- VName -> TypeM rep LetDecMem
forall rep (inner :: * -> *).
Mem rep inner =>
VName -> TypeM rep LetDecMem
varMemInfo VName
v
      case LetDecMem
dec of
        MemPrim PrimType
_ -> () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        MemMem {} -> () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        MemAcc {} -> () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
_ LMAD
lmad)
          | LMAD -> Bool
forall num. (Eq num, IntegralExp num) => LMAD num -> Bool
LMAD.isDirect LMAD
lmad ->
              () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          | Bool
otherwise ->
              ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Text -> ErrorCase rep) -> Text -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> TypeM rep ()) -> Text -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
                Text
"Array "
                  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
v
                  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" returned by function, but has nontrivial index function:\n"
                  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> LMAD -> Text
forall a. Pretty a => a -> Text
prettyText LMAD
lmad

matchLoopResultMem ::
  (Mem rep inner, TC.Checkable rep) =>
  [FParam (Aliases rep)] ->
  Result ->
  TC.TypeM rep ()
matchLoopResultMem :: forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem [FParam (Aliases rep)]
params = [FunReturns] -> Result -> TypeM rep ()
forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype
  where
    param_names :: [VName]
param_names = (Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [FParam (Aliases rep)]
[Param FParamMem]
params

    -- Invent a ReturnType so we can pretend that the loop body is
    -- actually returning from a function.
    rettype :: [FunReturns]
rettype = (Param FParamMem -> FunReturns)
-> [Param FParamMem] -> [FunReturns]
forall a b. (a -> b) -> [a] -> [b]
map (FParamMem -> FunReturns
toRet (FParamMem -> FunReturns)
-> (Param FParamMem -> FParamMem) -> Param FParamMem -> FunReturns
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [FParam (Aliases rep)]
[Param FParamMem]
params

    toExtV :: VName -> Ext VName
toExtV VName
v
      | Just Int
i <- VName
v VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
param_names = Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
      | Bool
otherwise = VName -> Ext VName
forall a. a -> Ext a
Free VName
v

    toExtSE :: SubExp -> Ext SubExp
toExtSE (Var VName
v) = VName -> SubExp
Var (VName -> SubExp) -> Ext VName -> Ext SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Ext VName
toExtV VName
v
    toExtSE (Constant PrimValue
v) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free (SubExp -> Ext SubExp) -> SubExp -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v

    toRet :: FParamMem -> FunReturns
toRet (MemPrim PrimType
t) =
      PrimType -> FunReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    toRet (MemMem Space
space) =
      Space -> FunReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    toRet (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) =
      VName -> Shape -> [Type] -> Uniqueness -> FunReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u
    toRet (MemArray PrimType
pt Shape
shape Uniqueness
u (ArrayIn VName
mem LMAD
lmad))
      | Just Int
i <- VName
mem VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
param_names,
        Param Attrs
_ VName
_ (MemMem Space
space) : [Param FParamMem]
_ <- Int -> [Param FParamMem] -> [Param FParamMem]
forall a. Int -> [a] -> [a]
drop Int
i [FParam (Aliases rep)]
[Param FParamMem]
params =
          PrimType
-> ShapeBase (Ext SubExp) -> Uniqueness -> MemReturn -> FunReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u (MemReturn -> FunReturns) -> MemReturn -> FunReturns
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
i ExtLMAD
lmad'
      | Bool
otherwise =
          PrimType
-> ShapeBase (Ext SubExp) -> Uniqueness -> MemReturn -> FunReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u (MemReturn -> FunReturns) -> MemReturn -> FunReturns
forall a b. (a -> b) -> a -> b
$ VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem ExtLMAD
lmad'
      where
        shape' :: ShapeBase (Ext SubExp)
shape' = (SubExp -> Ext SubExp) -> Shape -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
toExtSE Shape
shape
        lmad' :: ExtLMAD
lmad' = [VName] -> LMAD -> ExtLMAD
existentialiseLMAD [VName]
param_names LMAD
lmad

matchBranchReturnType ::
  (Mem rep inner, TC.Checkable rep) =>
  [BodyReturns] ->
  Body (Aliases rep) ->
  TC.TypeM rep ()
matchBranchReturnType :: forall rep (inner :: * -> *).
(Mem rep inner, Checkable rep) =>
[BodyReturns] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType [BodyReturns]
rettype (Body BodyDec (Aliases rep)
_ Stms (Aliases rep)
stms Result
res) = do
  Scope (Aliases rep)
scope <- TypeM rep (Scope (Aliases rep))
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  [LetDecMem]
ts <- ReaderT (Scope rep) (TypeM rep) [LetDecMem]
-> Scope rep -> TypeM rep [LetDecMem]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((SubExpRes -> ReaderT (Scope rep) (TypeM rep) LetDecMem)
-> Result -> ReaderT (Scope rep) (TypeM rep) [LetDecMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> ReaderT (Scope rep) (TypeM rep) LetDecMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LetDecMem
subExpMemInfo (SubExp -> ReaderT (Scope rep) (TypeM rep) LetDecMem)
-> (SubExpRes -> SubExp)
-> SubExpRes
-> ReaderT (Scope rep) (TypeM rep) LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SubExpRes -> SubExp
resSubExp) Result
res) (Scope rep -> TypeM rep [LetDecMem])
-> Scope rep -> TypeM rep [LetDecMem]
forall a b. (a -> b) -> a -> b
$ Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases (Scope (Aliases rep)
scope Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Aliases rep)
stms)
  [BodyReturns] -> [SubExp] -> [LetDecMem] -> TypeM rep ()
forall u rep.
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> [SubExp] -> [LetDecMem] -> TypeM rep ()
matchReturnType [BodyReturns]
rettype ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res) [LetDecMem]
ts

-- | Helper function for index function unification.
--
-- The first return value maps a VName (wrapped in 'Free') to its Int
-- (wrapped in 'Ext').  In case of duplicates, it is mapped to the
-- *first* Int that occurs.
--
-- The second return value maps each Int (wrapped in an 'Ext') to a
-- 'LeafExp' 'Ext' with the Int at which its associated VName first
-- occurs.
getExtMaps ::
  [(VName, Int)] ->
  ( M.Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    M.Map (Ext VName) (TPrimExp Int64 (Ext VName))
  )
getExtMaps :: [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps [(VName, Int)]
ctx_lst_ids =
  ( (Int -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Int -> TPrimExp Int64 (Ext VName)
forall a. Int -> TPrimExp Int64 (Ext a)
leafExp (Map (Ext VName) Int
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ (VName -> Ext VName) -> Map VName Int -> Map (Ext VName) Int
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys VName -> Ext VName
forall a. a -> Ext a
Free (Map VName Int -> Map (Ext VName) Int)
-> Map VName Int -> Map (Ext VName) Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [(VName, Int)] -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith ((Int -> Int) -> Int -> Int -> Int
forall a b. a -> b -> a
const Int -> Int
forall a. a -> a
forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id) [(VName, Int)]
ctx_lst_ids,
    [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$
      ((VName, Int) -> Maybe (Ext VName, TPrimExp Int64 (Ext VName)))
-> [(VName, Int)] -> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
        ( (VName -> Maybe (TPrimExp Int64 (Ext VName)))
-> (Ext VName, VName)
-> Maybe (Ext VName, TPrimExp Int64 (Ext VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> (Ext VName, a) -> f (Ext VName, b)
traverse
            ( (Int -> TPrimExp Int64 (Ext VName))
-> Maybe Int -> Maybe (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
int64)
                (Maybe Int -> Maybe (TPrimExp Int64 (Ext VName)))
-> (VName -> Maybe Int)
-> VName
-> Maybe (TPrimExp Int64 (Ext VName))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> [(VName, Int)] -> Maybe Int
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(VName, Int)]
ctx_lst_ids)
            )
            ((Ext VName, VName)
 -> Maybe (Ext VName, TPrimExp Int64 (Ext VName)))
-> ((VName, Int) -> (Ext VName, VName))
-> (VName, Int)
-> Maybe (Ext VName, TPrimExp Int64 (Ext VName))
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> Ext VName -> (Ext VName, VName))
-> (VName, Ext VName) -> (Ext VName, VName)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Ext VName -> VName -> (Ext VName, VName))
-> VName -> Ext VName -> (Ext VName, VName)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
            ((VName, Ext VName) -> (Ext VName, VName))
-> ((VName, Int) -> (VName, Ext VName))
-> (VName, Int)
-> (Ext VName, VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int -> Ext VName) -> (VName, Int) -> (VName, Ext VName)
forall a b. (a -> b) -> (VName, a) -> (VName, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext
        )
        [(VName, Int)]
ctx_lst_ids
  )

matchReturnType ::
  (PP.Pretty u) =>
  [MemInfo ExtSize u MemReturn] ->
  [SubExp] ->
  [MemInfo SubExp NoUniqueness MemBind] ->
  TC.TypeM rep ()
matchReturnType :: forall u rep.
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> [SubExp] -> [LetDecMem] -> TypeM rep ()
matchReturnType [MemInfo (Ext SubExp) u MemReturn]
rettype [SubExp]
res [LetDecMem]
ts = do
  let existentialiseLMAD0 :: LMAD -> ExtLMAD
      existentialiseLMAD0 :: LMAD -> ExtLMAD
existentialiseLMAD0 = (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> LMAD -> ExtLMAD
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
 -> LMAD -> ExtLMAD)
-> (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> LMAD
-> ExtLMAD
forall a b. (a -> b) -> a -> b
$ (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> TPrimExp Int64 a -> TPrimExp Int64 b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free

      fetchCtx :: Int -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
fetchCtx Int
i = case Int -> [(SubExp, LetDecMem)] -> Maybe (SubExp, LetDecMem)
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([(SubExp, LetDecMem)] -> Maybe (SubExp, LetDecMem))
-> [(SubExp, LetDecMem)] -> Maybe (SubExp, LetDecMem)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [LetDecMem] -> [(SubExp, LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
res [LetDecMem]
ts of
        Maybe (SubExp, LetDecMem)
Nothing ->
          Text -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) (SubExp, LetDecMem))
-> Text -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
forall a b. (a -> b) -> a -> b
$ Text
"Cannot find variable #" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" in results: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Text
forall a. Pretty a => a -> Text
prettyText [SubExp]
res
        Just (SubExp
se, LetDecMem
t) -> (SubExp, LetDecMem) -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
forall a. a -> ExceptT Text (TypeM rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
se, LetDecMem
t)

      checkReturn :: MemInfo (Ext SubExp) u MemReturn
-> LetDecMem -> ExceptT Text (TypeM rep) ()
checkReturn (MemPrim PrimType
x) (MemPrim PrimType
y)
        | PrimType
x PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y = () -> ExceptT Text (TypeM rep) ()
forall a. a -> ExceptT Text (TypeM rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      checkReturn (MemMem Space
x) (MemMem Space
y)
        | Space
x Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y = () -> ExceptT Text (TypeM rep) ()
forall a. a -> ExceptT Text (TypeM rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      checkReturn (MemAcc VName
xacc Shape
xispace [Type]
xts u
_) (MemAcc VName
yacc Shape
yispace [Type]
yts NoUniqueness
_)
        | (VName
xacc, Shape
xispace, [Type]
xts) (VName, Shape, [Type]) -> (VName, Shape, [Type]) -> Bool
forall a. Eq a => a -> a -> Bool
== (VName
yacc, Shape
yispace, [Type]
yts) =
            () -> ExceptT Text (TypeM rep) ()
forall a. a -> ExceptT Text (TypeM rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      checkReturn
        (MemArray PrimType
x_pt ShapeBase (Ext SubExp)
x_shape u
_ MemReturn
x_ret)
        (MemArray PrimType
y_pt Shape
y_shape NoUniqueness
_ MemBind
y_ret)
          | PrimType
x_pt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y_pt,
            ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
x_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
y_shape = do
              (Ext SubExp -> SubExp -> ExceptT Text (TypeM rep) ())
-> [Ext SubExp] -> [SubExp] -> ExceptT Text (TypeM rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Ext SubExp -> SubExp -> ExceptT Text (TypeM rep) ()
checkDim (ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
x_shape) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
y_shape)
              MemReturn -> MemBind -> ExceptT Text (TypeM rep) ()
checkMemReturn MemReturn
x_ret MemBind
y_ret
      checkReturn MemInfo (Ext SubExp) u MemReturn
x LetDecMem
y =
        Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> Text -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.unwords [Text
"Expected", MemInfo (Ext SubExp) u MemReturn -> Text
forall a. Pretty a => a -> Text
prettyText MemInfo (Ext SubExp) u MemReturn
x, Text
"but got", LetDecMem -> Text
forall a. Pretty a => a -> Text
prettyText LetDecMem
y]

      checkDim :: Ext SubExp -> SubExp -> ExceptT Text (TypeM rep) ()
checkDim (Free SubExp
x) SubExp
y
        | SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y = () -> ExceptT Text (TypeM rep) ()
forall a. a -> ExceptT Text (TypeM rep) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        | Bool
otherwise =
            Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> Text -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.unwords [Text
"Expected dim", SubExp -> Text
forall a. Pretty a => a -> Text
prettyText SubExp
x, Text
"but got", SubExp -> Text
forall a. Pretty a => a -> Text
prettyText SubExp
y]
      checkDim (Ext Int
i) SubExp
y = do
        (SubExp
x, LetDecMem
_) <- Int -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
fetchCtx Int
i
        Bool -> ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y) (ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ())
-> ([Text] -> ExceptT Text (TypeM rep) ())
-> [Text]
-> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> ([Text] -> Text) -> [Text] -> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords ([Text] -> ExceptT Text (TypeM rep) ())
-> [Text] -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
          [Text
"Expected ext dim", Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i, Text
"=>", SubExp -> Text
forall a. Pretty a => a -> Text
prettyText SubExp
x, Text
"but got", SubExp -> Text
forall a. Pretty a => a -> Text
prettyText SubExp
y]

      checkMemReturn :: MemReturn -> MemBind -> ExceptT Text (TypeM rep) ()
checkMemReturn (ReturnsInBlock VName
x_mem ExtLMAD
x_lmad) (ArrayIn VName
y_mem LMAD
y_lmad)
        | VName
x_mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y_mem =
            Bool -> ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ExtLMAD -> ExtLMAD -> Bool
forall num. LMAD num -> LMAD num -> Bool
LMAD.closeEnough ExtLMAD
x_lmad (ExtLMAD -> Bool) -> ExtLMAD -> Bool
forall a b. (a -> b) -> a -> b
$ LMAD -> ExtLMAD
existentialiseLMAD0 LMAD
y_lmad) (ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ())
-> ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
              Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> ([Text] -> Text) -> [Text] -> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords ([Text] -> ExceptT Text (TypeM rep) ())
-> [Text] -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
                [ Text
"Index function unification failed (ReturnsInBlock)",
                  Text
"\nlmad of body result: ",
                  LMAD -> Text
forall a. Pretty a => a -> Text
prettyText LMAD
y_lmad,
                  Text
"\nlmad of return type: ",
                  ExtLMAD -> Text
forall a. Pretty a => a -> Text
prettyText ExtLMAD
x_lmad
                ]
      checkMemReturn
        (ReturnsNewBlock Space
x_space Int
x_ext ExtLMAD
x_lmad)
        (ArrayIn VName
y_mem LMAD
y_lmad) = do
          (SubExp
x_mem, LetDecMem
x_mem_type) <- Int -> ExceptT Text (TypeM rep) (SubExp, LetDecMem)
fetchCtx Int
x_ext
          Bool -> ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ExtLMAD -> ExtLMAD -> Bool
forall num. LMAD num -> LMAD num -> Bool
LMAD.closeEnough ExtLMAD
x_lmad (ExtLMAD -> Bool) -> ExtLMAD -> Bool
forall a b. (a -> b) -> a -> b
$ LMAD -> ExtLMAD
existentialiseLMAD0 LMAD
y_lmad) (ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ())
-> ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
            Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> (Doc Any -> Text) -> Doc Any -> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> ExceptT Text (TypeM rep) ())
-> Doc Any -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
              Doc Any
"Index function unification failed (ReturnsNewBlock)"
                Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"Lmad of body result:"
                Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (LMAD -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. LMAD -> Doc ann
pretty LMAD
y_lmad)
                Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"Lmad of return type:"
                Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (ExtLMAD -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. ExtLMAD -> Doc ann
pretty ExtLMAD
x_lmad)
          case LetDecMem
x_mem_type of
            MemMem Space
y_space ->
              Bool -> ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Space
x_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y_space) (ExceptT Text (TypeM rep) () -> ExceptT Text (TypeM rep) ())
-> ([Text] -> ExceptT Text (TypeM rep) ())
-> [Text]
-> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> ([Text] -> Text) -> [Text] -> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords ([Text] -> ExceptT Text (TypeM rep) ())
-> [Text] -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
                [ Text
"Expected memory",
                  VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
y_mem,
                  Text
"in space",
                  Space -> Text
forall a. Pretty a => a -> Text
prettyText Space
x_space,
                  Text
"but actually in space",
                  Space -> Text
forall a. Pretty a => a -> Text
prettyText Space
y_space
                ]
            LetDecMem
t ->
              Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> ([Text] -> Text) -> [Text] -> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Text] -> Text
T.unwords ([Text] -> ExceptT Text (TypeM rep) ())
-> [Text] -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
                [Text
"Expected memory", Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
x_ext, Text
"=>", SubExp -> Text
forall a. Pretty a => a -> Text
prettyText SubExp
x_mem, Text
"but but has type", LetDecMem -> Text
forall a. Pretty a => a -> Text
prettyText LetDecMem
t]
      checkMemReturn MemReturn
x MemBind
y =
        Text -> ExceptT Text (TypeM rep) ()
forall a. Text -> ExceptT Text (TypeM rep) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Text -> ExceptT Text (TypeM rep) ())
-> (Doc Any -> Text) -> Doc Any -> ExceptT Text (TypeM rep) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> ExceptT Text (TypeM rep) ())
-> Doc Any -> ExceptT Text (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
          Doc Any
"Expected array in"
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (MemReturn -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. MemReturn -> Doc ann
pretty MemReturn
x)
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"but array returned in"
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (MemBind -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. MemBind -> Doc ann
pretty MemBind
y)

      bad :: Text -> TypeM rep ()
bad Text
s =
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Doc Any -> ErrorCase rep) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          Doc Any
"Return type"
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 ([Doc Any] -> Doc Any
forall a. [Doc a] -> Doc a
ppTupleLines' ([Doc Any] -> Doc Any) -> [Doc Any] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (MemInfo (Ext SubExp) u MemReturn -> Doc Any)
-> [MemInfo (Ext SubExp) u MemReturn] -> [Doc Any]
forall a b. (a -> b) -> [a] -> [b]
map MemInfo (Ext SubExp) u MemReturn -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. MemInfo (Ext SubExp) u MemReturn -> Doc ann
pretty [MemInfo (Ext SubExp) u MemReturn]
rettype)
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"cannot match returns of results"
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 ([Doc Any] -> Doc Any
forall a. [Doc a] -> Doc a
ppTupleLines' ([Doc Any] -> Doc Any) -> [Doc Any] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (LetDecMem -> Doc Any) -> [LetDecMem] -> [Doc Any]
forall a b. (a -> b) -> [a] -> [b]
map LetDecMem -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. LetDecMem -> Doc ann
pretty [LetDecMem]
ts)
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Text -> Doc Any
forall ann. Text -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Text
s

  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([MemInfo (Ext SubExp) u MemReturn] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MemInfo (Ext SubExp) u MemReturn]
rettype Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LetDecMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LetDecMem]
ts) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Doc Any -> ErrorCase rep) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      Doc Any
"Return type"
        Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 ([Doc Any] -> Doc Any
forall a. [Doc a] -> Doc a
ppTupleLines' ([Doc Any] -> Doc Any) -> [Doc Any] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (MemInfo (Ext SubExp) u MemReturn -> Doc Any)
-> [MemInfo (Ext SubExp) u MemReturn] -> [Doc Any]
forall a b. (a -> b) -> [a] -> [b]
map MemInfo (Ext SubExp) u MemReturn -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. MemInfo (Ext SubExp) u MemReturn -> Doc ann
pretty [MemInfo (Ext SubExp) u MemReturn]
rettype)
        Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"does not have same number of elements as results"
        Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 ([Doc Any] -> Doc Any
forall a. [Doc a] -> Doc a
ppTupleLines' ([Doc Any] -> Doc Any) -> [Doc Any] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (LetDecMem -> Doc Any) -> [LetDecMem] -> [Doc Any]
forall a b. (a -> b) -> [a] -> [b]
map LetDecMem -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. LetDecMem -> Doc ann
pretty [LetDecMem]
ts)

  (Text -> TypeM rep ())
-> (() -> TypeM rep ()) -> Either Text () -> TypeM rep ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Text -> TypeM rep ()
bad () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Text () -> TypeM rep ())
-> TypeM rep (Either Text ()) -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExceptT Text (TypeM rep) () -> TypeM rep (Either Text ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ((MemInfo (Ext SubExp) u MemReturn
 -> LetDecMem -> ExceptT Text (TypeM rep) ())
-> [MemInfo (Ext SubExp) u MemReturn]
-> [LetDecMem]
-> ExceptT Text (TypeM rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ MemInfo (Ext SubExp) u MemReturn
-> LetDecMem -> ExceptT Text (TypeM rep) ()
checkReturn [MemInfo (Ext SubExp) u MemReturn]
rettype [LetDecMem]
ts)

matchPatToExp ::
  (Mem rep inner, LetDec rep ~ LetDecMem, TC.Checkable rep) =>
  Pat (LetDec (Aliases rep)) ->
  Exp (Aliases rep) ->
  TC.TypeM rep ()
matchPatToExp :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, Checkable rep) =>
Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep ()
matchPatToExp Pat (LetDec (Aliases rep))
pat Exp (Aliases rep)
e = do
  Scope rep
scope <- (Scope (Aliases rep) -> Scope rep) -> TypeM rep (Scope rep)
forall a. (Scope (Aliases rep) -> a) -> TypeM rep a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases
  [ExpReturns]
rt <- TypeM rep [ExpReturns]
-> ([ExpReturns] -> TypeM rep [ExpReturns])
-> Maybe [ExpReturns]
-> TypeM rep [ExpReturns]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe TypeM rep [ExpReturns]
illformed [ExpReturns] -> TypeM rep [ExpReturns]
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [ExpReturns] -> TypeM rep [ExpReturns])
-> Maybe [ExpReturns] -> TypeM rep [ExpReturns]
forall a b. (a -> b) -> a -> b
$ Reader (Scope rep) (Maybe [ExpReturns])
-> Scope rep -> Maybe [ExpReturns]
forall r a. Reader r a -> r -> a
runReader (Exp rep -> Reader (Scope rep) (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns (Exp rep -> Reader (Scope rep) (Maybe [ExpReturns]))
-> Exp rep -> Reader (Scope rep) (Maybe [ExpReturns])
forall a b. (a -> b) -> a -> b
$ Exp (Aliases rep) -> Exp rep
forall rep. RephraseOp (OpC rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases rep)
e) Scope rep
scope

  let ([VName]
ctx_ids, [BodyReturns]
val_ts) = [(VName, BodyReturns)] -> ([VName], [BodyReturns])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, BodyReturns)] -> ([VName], [BodyReturns]))
-> [(VName, BodyReturns)] -> ([VName], [BodyReturns])
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [(VName, BodyReturns)]
bodyReturnsFromPat (Pat LetDecMem -> [(VName, BodyReturns)])
-> Pat LetDecMem -> [(VName, BodyReturns)]
forall a b. (a -> b) -> a -> b
$ Pat (AliasDec, LetDecMem) -> Pat LetDecMem
forall a. Pat (AliasDec, a) -> Pat a
removePatAliases Pat (AliasDec, LetDecMem)
Pat (LetDec (Aliases rep))
pat
      (Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids, Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) = [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps ([(VName, Int)]
 -> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
     Map (Ext VName) (TPrimExp Int64 (Ext VName))))
-> [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
forall a b. (a -> b) -> a -> b
$ [VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ctx_ids [Int
0 .. Int
1]
      ok :: Bool
ok =
        [BodyReturns] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BodyReturns]
val_ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [ExpReturns] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExpReturns]
rt
          Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((BodyReturns -> ExpReturns -> Bool)
-> [BodyReturns] -> [ExpReturns] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> BodyReturns
-> ExpReturns
-> Bool
forall {d} {u} {u}.
Eq d =>
Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) [BodyReturns]
val_ts [ExpReturns]
rt)

  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
ok (TypeM rep () -> TypeM rep ())
-> (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ())
-> (Doc Any -> ErrorCase rep) -> Doc Any -> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> TypeM rep ()) -> Doc Any -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    Doc Any
"Expression type:"
      Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 ([Doc Any] -> Doc Any
forall a. [Doc a] -> Doc a
ppTupleLines' ([Doc Any] -> Doc Any) -> [Doc Any] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (ExpReturns -> Doc Any) -> [ExpReturns] -> [Doc Any]
forall a b. (a -> b) -> [a] -> [b]
map ExpReturns -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. ExpReturns -> Doc ann
pretty [ExpReturns]
rt)
      Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"cannot match pattern type:"
      Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 ([Doc Any] -> Doc Any
forall a. [Doc a] -> Doc a
ppTupleLines' ([Doc Any] -> Doc Any) -> [Doc Any] -> Doc Any
forall a b. (a -> b) -> a -> b
$ (BodyReturns -> Doc Any) -> [BodyReturns] -> [Doc Any]
forall a b. (a -> b) -> [a] -> [b]
map BodyReturns -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. BodyReturns -> Doc ann
pretty [BodyReturns]
val_ts)
  where
    illformed :: TypeM rep [ExpReturns]
illformed =
      ErrorCase rep -> TypeM rep [ExpReturns]
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep [ExpReturns])
-> ErrorCase rep -> TypeM rep [ExpReturns]
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep)
-> (Doc Any -> Text) -> Doc Any -> ErrorCase rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Doc Any -> Text
forall a. Doc a -> Text
docText (Doc Any -> ErrorCase rep) -> Doc Any -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          Doc Any
"Expression"
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Int -> Doc Any -> Doc Any
forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (Exp (Aliases rep) -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. Exp (Aliases rep) -> Doc ann
pretty Exp (Aliases rep)
e)
            Doc Any -> Doc Any -> Doc Any
forall ann. Doc ann -> Doc ann -> Doc ann
</> Doc Any
"cannot be assigned an index function."

    matches :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemPrim PrimType
x) (MemPrim PrimType
y) = PrimType
x PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemMem Space
x_space) (MemMem Space
y_space) =
      Space
x_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y_space
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemAcc VName
x_accs Shape
x_ispace [Type]
x_ts u
_) (MemAcc VName
y_accs Shape
y_ispace [Type]
y_ts u
_) =
      (VName
x_accs, Shape
x_ispace, [Type]
x_ts) (VName, Shape, [Type]) -> (VName, Shape, [Type]) -> Bool
forall a. Eq a => a -> a -> Bool
== (VName
y_accs, Shape
y_ispace, [Type]
y_ts)
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts (MemArray PrimType
x_pt ShapeBase d
x_shape u
_ MemReturn
x_ret) (MemArray PrimType
y_pt ShapeBase d
y_shape u
_ Maybe MemReturn
y_ret) =
      PrimType
x_pt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y_pt
        Bool -> Bool -> Bool
&& ShapeBase d
x_shape ShapeBase d -> ShapeBase d -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase d
y_shape
        Bool -> Bool -> Bool
&& case (MemReturn
x_ret, Maybe MemReturn
y_ret) of
          (ReturnsInBlock VName
_ ExtLMAD
x_lmad, Just (ReturnsInBlock VName
_ ExtLMAD
y_lmad)) ->
            let x_lmad' :: ExtLMAD
x_lmad' = Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtLMAD
x_lmad
                y_lmad' :: ExtLMAD
y_lmad' = Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtLMAD
y_lmad
             in ExtLMAD -> ExtLMAD -> Bool
forall num. LMAD num -> LMAD num -> Bool
LMAD.closeEnough ExtLMAD
x_lmad' ExtLMAD
y_lmad'
          ( ReturnsInBlock VName
_ ExtLMAD
x_lmad,
            Just (ReturnsNewBlock Space
_ Int
_ ExtLMAD
y_lmad)
            ) ->
              let x_lmad' :: ExtLMAD
x_lmad' = Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtLMAD
x_lmad
                  y_lmad' :: ExtLMAD
y_lmad' = Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtLMAD
y_lmad
               in ExtLMAD -> ExtLMAD -> Bool
forall num. LMAD num -> LMAD num -> Bool
LMAD.closeEnough ExtLMAD
x_lmad' ExtLMAD
y_lmad'
          ( ReturnsNewBlock Space
_ Int
x_i ExtLMAD
x_lmad,
            Just (ReturnsNewBlock Space
_ Int
y_i ExtLMAD
y_lmad)
            ) ->
              let x_lmad' :: ExtLMAD
x_lmad' = Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtLMAD
x_lmad
                  y_lmad' :: ExtLMAD
y_lmad' = Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> ExtLMAD -> ExtLMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtLMAD
y_lmad
               in Int
x_i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y_i Bool -> Bool -> Bool
&& ExtLMAD -> ExtLMAD -> Bool
forall num. LMAD num -> LMAD num -> Bool
LMAD.closeEnough ExtLMAD
x_lmad' ExtLMAD
y_lmad'
          (MemReturn
_, Maybe MemReturn
Nothing) -> Bool
True
          (MemReturn, Maybe MemReturn)
_ -> Bool
False
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ MemInfo d u MemReturn
_ MemInfo d u (Maybe MemReturn)
_ = Bool
False

varMemInfo ::
  (Mem rep inner) =>
  VName ->
  TC.TypeM rep (MemInfo SubExp NoUniqueness MemBind)
varMemInfo :: forall rep (inner :: * -> *).
Mem rep inner =>
VName -> TypeM rep LetDecMem
varMemInfo VName
name = do
  NameInfo (Aliases rep)
dec <- VName -> TypeM rep (NameInfo (Aliases rep))
forall rep. VName -> TypeM rep (NameInfo (Aliases rep))
TC.lookupVar VName
name

  case NameInfo (Aliases rep)
dec of
    LetName (AliasDec
_, LetDec rep
summary) -> LetDecMem -> TypeM rep LetDecMem
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LetDecMem -> TypeM rep LetDecMem)
-> LetDecMem -> TypeM rep LetDecMem
forall a b. (a -> b) -> a -> b
$ LetDec rep -> LetDecMem
forall t. HasLetDecMem t => t -> LetDecMem
letDecMem LetDec rep
summary
    FParamName FParamInfo (Aliases rep)
summary -> LetDecMem -> TypeM rep LetDecMem
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LetDecMem -> TypeM rep LetDecMem)
-> LetDecMem -> TypeM rep LetDecMem
forall a b. (a -> b) -> a -> b
$ FParamMem -> LetDecMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo (Aliases rep)
FParamMem
summary
    LParamName LParamInfo (Aliases rep)
summary -> LetDecMem -> TypeM rep LetDecMem
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LParamInfo (Aliases rep)
LetDecMem
summary
    IndexName IntType
it -> LetDecMem -> TypeM rep LetDecMem
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LetDecMem -> TypeM rep LetDecMem)
-> LetDecMem -> TypeM rep LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LetDecMem) -> PrimType -> LetDecMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

nameInfoToMemInfo :: (Mem rep inner) => NameInfo rep -> MemBound NoUniqueness
nameInfoToMemInfo :: forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo NameInfo rep
info =
  case NameInfo rep
info of
    FParamName FParamInfo rep
summary -> FParamMem -> LetDecMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
FParamMem
summary
    LParamName LParamInfo rep
summary -> LParamInfo rep
LetDecMem
summary
    LetName LetDec rep
summary -> LetDec rep -> LetDecMem
forall t. HasLetDecMem t => t -> LetDecMem
letDecMem LetDec rep
summary
    IndexName IntType
it -> PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LetDecMem) -> PrimType -> LetDecMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

lookupMemInfo ::
  (HasScope rep m, Mem rep inner) =>
  VName ->
  m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo :: forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo = (NameInfo rep -> LetDecMem) -> m (NameInfo rep) -> m LetDecMem
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NameInfo rep -> LetDecMem
forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LetDecMem
nameInfoToMemInfo (m (NameInfo rep) -> m LetDecMem)
-> (VName -> m (NameInfo rep)) -> VName -> m LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> m (NameInfo rep)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (NameInfo rep)
lookupInfo

subExpMemInfo ::
  (HasScope rep m, Mem rep inner) =>
  SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo :: forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m LetDecMem
subExpMemInfo (Var VName
v) = VName -> m LetDecMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
subExpMemInfo (Constant PrimValue
v) = LetDecMem -> m LetDecMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LetDecMem) -> PrimType -> LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

lookupArraySummary ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  VName ->
  m (VName, LMAD.LMAD (TPrimExp Int64 VName))
lookupArraySummary :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m (VName, LMAD)
lookupArraySummary VName
name = do
  LetDecMem
summary <- VName -> m LetDecMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
name
  case LetDecMem
summary of
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem LMAD
lmad) ->
      (VName, LMAD) -> m (VName, LMAD)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, LMAD
lmad)
    LetDecMem
_ ->
      String -> m (VName, LMAD)
forall a. HasCallStack => String -> a
error (String -> m (VName, LMAD))
-> (Text -> String) -> Text -> m (VName, LMAD)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> String
T.unpack (Text -> m (VName, LMAD)) -> Text -> m (VName, LMAD)
forall a b. (a -> b) -> a -> b
$
        Text
"Expected "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
name
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" to be array but bound to:\n"
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> LetDecMem -> Text
forall a. Pretty a => a -> Text
prettyText LetDecMem
summary

lookupMemSpace ::
  (Mem rep inner, HasScope rep m, Monad m) =>
  VName ->
  m Space
lookupMemSpace :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
name = do
  LetDecMem
summary <- VName -> m LetDecMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
name
  case LetDecMem
summary of
    MemMem Space
space ->
      Space -> m Space
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space
    LetDecMem
_ ->
      String -> m Space
forall a. HasCallStack => String -> a
error (String -> m Space) -> (Text -> String) -> Text -> m Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> String
T.unpack (Text -> m Space) -> Text -> m Space
forall a b. (a -> b) -> a -> b
$
        Text
"Expected "
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
name
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" to be memory but bound to:\n"
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> LetDecMem -> Text
forall a. Pretty a => a -> Text
prettyText LetDecMem
summary

checkMemInfo ::
  (TC.Checkable rep) =>
  VName ->
  MemInfo SubExp u MemBind ->
  TC.TypeM rep ()
checkMemInfo :: forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo VName
_ (MemPrim PrimType
_) = () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkMemInfo VName
_ (MemMem (ScalarSpace [SubExp]
d PrimType
_)) = (SubExp -> TypeM rep ()) -> [SubExp] -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) [SubExp]
d
checkMemInfo VName
_ (MemMem Space
_) = () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
checkMemInfo VName
_ (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
  TypeBase Shape u -> TypeM rep ()
forall rep u. Checkable rep => TypeBase Shape u -> TypeM rep ()
TC.checkType (TypeBase Shape u -> TypeM rep ())
-> TypeBase Shape u -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> u -> TypeBase Shape u
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
checkMemInfo VName
name (MemArray PrimType
_ Shape
shape u
_ (ArrayIn VName
v LMAD
lmad)) = do
  Type
t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Mem {} ->
      () -> TypeM rep ()
forall a. a -> TypeM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Type
_ ->
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          Text
"Variable "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
v
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" used as memory block, but is of type "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
t
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"."

  Text -> TypeM rep () -> TypeM rep ()
forall rep a. Text -> TypeM rep a -> TypeM rep a
TC.context (Text
"in index function " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> LMAD -> Text
forall a. Pretty a => a -> Text
prettyText LMAD
lmad) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    (TPrimExp Int64 VName -> TypeM rep ()) -> LMAD -> TypeM rep ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (PrimType -> PrimExp VName -> TypeM rep ()
forall rep.
Checkable rep =>
PrimType -> PrimExp VName -> TypeM rep ()
TC.requirePrimExp PrimType
int64 (PrimExp VName -> TypeM rep ())
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> TypeM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) LMAD
lmad
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (LMAD -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD
lmad [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> Bool
forall a. Eq a => a -> a -> Bool
== (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        Text -> ErrorCase rep
forall rep. Text -> ErrorCase rep
TC.TypeError (Text -> ErrorCase rep) -> Text -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          Text
"Shape of index function ("
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [TPrimExp Int64 VName] -> Text
forall a. Pretty a => a -> Text
prettyText (LMAD -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD
lmad)
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
") does not match shape of array "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
name
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" ("
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Shape -> Text
forall a. Pretty a => a -> Text
prettyText Shape
shape
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"

bodyReturnsFromPat ::
  Pat (MemBound NoUniqueness) -> [(VName, BodyReturns)]
bodyReturnsFromPat :: Pat LetDecMem -> [(VName, BodyReturns)]
bodyReturnsFromPat Pat LetDecMem
pat =
  (PatElem LetDecMem -> (VName, BodyReturns))
-> [PatElem LetDecMem] -> [(VName, BodyReturns)]
forall a b. (a -> b) -> [a] -> [b]
map PatElem LetDecMem -> (VName, BodyReturns)
asReturns ([PatElem LetDecMem] -> [(VName, BodyReturns)])
-> [PatElem LetDecMem] -> [(VName, BodyReturns)]
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat
  where
    ctx :: [PatElem LetDecMem]
ctx = Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat

    ext :: SubExp -> Ext SubExp
ext (Var VName
v)
      | Just (Int
i, PatElem LetDecMem
_) <- ((Int, PatElem LetDecMem) -> Bool)
-> [(Int, PatElem LetDecMem)] -> Maybe (Int, PatElem LetDecMem)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Int, PatElem LetDecMem) -> VName)
-> (Int, PatElem LetDecMem)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem LetDecMem -> VName)
-> ((Int, PatElem LetDecMem) -> PatElem LetDecMem)
-> (Int, PatElem LetDecMem)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, PatElem LetDecMem) -> PatElem LetDecMem
forall a b. (a, b) -> b
snd) ([(Int, PatElem LetDecMem)] -> Maybe (Int, PatElem LetDecMem))
-> [(Int, PatElem LetDecMem)] -> Maybe (Int, PatElem LetDecMem)
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElem LetDecMem] -> [(Int, PatElem LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElem LetDecMem]
ctx =
          Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
i
    ext SubExp
se = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
se

    asReturns :: PatElem LetDecMem -> (VName, BodyReturns)
asReturns PatElem LetDecMem
pe =
      ( PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe,
        case PatElem LetDecMem -> LetDecMem
forall dec. PatElem dec -> dec
patElemDec PatElem LetDecMem
pe of
          MemPrim PrimType
pt -> PrimType -> BodyReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
          MemMem Space
space -> Space -> BodyReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
          MemArray PrimType
pt Shape
shape NoUniqueness
u (ArrayIn VName
mem LMAD
lmad) ->
            PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> MemReturn
-> BodyReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (SubExp -> Ext SubExp) -> [SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
ext ([SubExp] -> [Ext SubExp]) -> [SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) NoUniqueness
u (MemReturn -> BodyReturns) -> MemReturn -> BodyReturns
forall a b. (a -> b) -> a -> b
$
              case ((Int, PatElem LetDecMem) -> Bool)
-> [(Int, PatElem LetDecMem)] -> Maybe (Int, PatElem LetDecMem)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
mem) (VName -> Bool)
-> ((Int, PatElem LetDecMem) -> VName)
-> (Int, PatElem LetDecMem)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem LetDecMem -> VName)
-> ((Int, PatElem LetDecMem) -> PatElem LetDecMem)
-> (Int, PatElem LetDecMem)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, PatElem LetDecMem) -> PatElem LetDecMem
forall a b. (a, b) -> b
snd) ([(Int, PatElem LetDecMem)] -> Maybe (Int, PatElem LetDecMem))
-> [(Int, PatElem LetDecMem)] -> Maybe (Int, PatElem LetDecMem)
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElem LetDecMem] -> [(Int, PatElem LetDecMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElem LetDecMem]
ctx of
                Just (Int
i, PatElem VName
_ (MemMem Space
space)) ->
                  Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
                    [VName] -> LMAD -> ExtLMAD
existentialiseLMAD ((PatElem LetDecMem -> VName) -> [PatElem LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName [PatElem LetDecMem]
ctx) LMAD
lmad
                Maybe (Int, PatElem LetDecMem)
_ -> VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> LMAD -> ExtLMAD
existentialiseLMAD [] LMAD
lmad
          MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> VName -> Shape -> [Type] -> NoUniqueness -> BodyReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      )

extReturns :: [ExtType] -> [ExpReturns]
extReturns :: [ExtType] -> [ExpReturns]
extReturns [ExtType]
ets =
  State Int [ExpReturns] -> Int -> [ExpReturns]
forall s a. State s a -> s -> a
evalState ((ExtType -> StateT Int Identity ExpReturns)
-> [ExtType] -> State Int [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ExtType -> StateT Int Identity ExpReturns
forall {f :: * -> *}. MonadState Int f => ExtType -> f ExpReturns
addDec [ExtType]
ets) Int
0
  where
    addDec :: ExtType -> f ExpReturns
addDec (Prim PrimType
bt) =
      ExpReturns -> f ExpReturns
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> f ExpReturns) -> ExpReturns -> f ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
    addDec (Mem Space
space) =
      ExpReturns -> f ExpReturns
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> f ExpReturns) -> ExpReturns -> f ExpReturns
forall a b. (a -> b) -> a -> b
$ Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    addDec t :: ExtType
t@(Array PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u)
      | ExtType -> Bool
existential ExtType
t = do
          Int
i <- f Int
forall s (m :: * -> *). MonadState s m => m s
get f Int -> f () -> f Int
forall a b. f a -> f b -> f a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          ExpReturns -> f ExpReturns
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> f ExpReturns)
-> (MemReturn -> ExpReturns) -> MemReturn -> f ExpReturns
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns)
-> (MemReturn -> Maybe MemReturn) -> MemReturn -> ExpReturns
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> f ExpReturns) -> MemReturn -> f ExpReturns
forall a b. (a -> b) -> a -> b
$
            Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
              TPrimExp Int64 (Ext VName)
-> [TPrimExp Int64 (Ext VName)] -> ExtLMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 (Ext VName)
0 ((Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> [TPrimExp Int64 (Ext VName)])
-> [Ext SubExp] -> [TPrimExp Int64 (Ext VName)]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape)
      | Bool
otherwise =
          ExpReturns -> f ExpReturns
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> f ExpReturns) -> ExpReturns -> f ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u Maybe MemReturn
forall a. Maybe a
Nothing
    addDec (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) =
      ExpReturns -> f ExpReturns
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> f ExpReturns) -> ExpReturns -> f ExpReturns
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

arrayVarReturns ::
  (HasScope rep m, Monad m, Mem rep inner) =>
  VName ->
  m (PrimType, Shape, VName, LMAD)
arrayVarReturns :: forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, LMAD)
arrayVarReturns VName
v = do
  LetDecMem
summary <- VName -> m LetDecMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
  case LetDecMem
summary of
    MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem LMAD
lmad) ->
      (PrimType, Shape, VName, LMAD) -> m (PrimType, Shape, VName, LMAD)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType
et, [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape, VName
mem, LMAD
lmad)
    LetDecMem
_ ->
      String -> m (PrimType, Shape, VName, LMAD)
forall a. HasCallStack => String -> a
error (String -> m (PrimType, Shape, VName, LMAD))
-> (Text -> String) -> Text -> m (PrimType, Shape, VName, LMAD)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Text -> String
T.unpack (Text -> m (PrimType, Shape, VName, LMAD))
-> Text -> m (PrimType, Shape, VName, LMAD)
forall a b. (a -> b) -> a -> b
$ Text
"arrayVarReturns: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
v Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" is not an array."

varReturns ::
  (HasScope rep m, Monad m, Mem rep inner) =>
  VName ->
  m ExpReturns
varReturns :: forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v = do
  LetDecMem
summary <- VName -> m LetDecMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LetDecMem
lookupMemInfo VName
v
  case LetDecMem
summary of
    MemPrim PrimType
bt ->
      ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
    MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem LMAD
lmad) ->
      ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$
        PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ((SubExp -> Ext SubExp) -> Shape -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
forall a. a -> Ext a
Free Shape
shape) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
          MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$
            VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
              [VName] -> LMAD -> ExtLMAD
existentialiseLMAD [] LMAD
lmad
    MemMem Space
space ->
      ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
      ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

subExpReturns :: (HasScope rep m, Monad m, Mem rep inner) => SubExp -> m ExpReturns
subExpReturns :: forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
SubExp -> m ExpReturns
subExpReturns (Var VName
v) =
  VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
subExpReturns (Constant PrimValue
v) =
  ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> ExpReturns) -> PrimType -> ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

-- | The return information of an expression.  This can be seen as the
-- "return type with memory annotations" of the expression.
--
-- This can produce Nothing, which signifies that the result is an
-- array layout that is not expressible as an index function.
expReturns ::
  (LocalScope rep m, Mem rep inner) =>
  Exp rep ->
  m (Maybe [ExpReturns])
expReturns :: forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns (BasicOp (SubExp SubExp
se)) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> (ExpReturns -> [ExpReturns]) -> ExpReturns -> Maybe [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ExpReturns -> [ExpReturns]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> Maybe [ExpReturns])
-> m ExpReturns -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
SubExp -> m ExpReturns
subExpReturns SubExp
se
expReturns (BasicOp (Opaque OpaqueOp
_ (Var VName
v))) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> (ExpReturns -> [ExpReturns]) -> ExpReturns -> Maybe [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ExpReturns -> [ExpReturns]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> Maybe [ExpReturns])
-> m ExpReturns -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp (Reshape ReshapeKind
k Shape
newshape VName
v)) = do
  (PrimType
et, Shape
_, VName
mem, LMAD
lmad) <- VName -> m (PrimType, Shape, VName, LMAD)
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, LMAD)
arrayVarReturns VName
v
  case ReshapeKind -> LMAD -> [TPrimExp Int64 VName] -> Maybe LMAD
forall {num}.
(Eq num, IntegralExp num) =>
ReshapeKind -> LMAD num -> Shape num -> Maybe (LMAD num)
reshaper ReshapeKind
k LMAD
lmad ([TPrimExp Int64 VName] -> Maybe LMAD)
-> [TPrimExp Int64 VName] -> Maybe LMAD
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
newshape of
    Just LMAD
lmad' ->
      Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [ExpReturns] -> m (Maybe [ExpReturns]))
-> ([ExpReturns] -> Maybe [ExpReturns])
-> [ExpReturns]
-> m (Maybe [ExpReturns])
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> m (Maybe [ExpReturns]))
-> [ExpReturns] -> m (Maybe [ExpReturns])
forall a b. (a -> b) -> a -> b
$
        [ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ((SubExp -> Ext SubExp) -> Shape -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> ShapeBase a -> ShapeBase b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
forall a. a -> Ext a
Free Shape
newshape) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns)
-> (MemReturn -> Maybe MemReturn) -> MemReturn -> ExpReturns
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> ExpReturns) -> MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
            VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem ([VName] -> LMAD -> ExtLMAD
existentialiseLMAD [] LMAD
lmad')
        ]
    Maybe LMAD
Nothing -> Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe [ExpReturns]
forall a. Maybe a
Nothing
  where
    reshaper :: ReshapeKind -> LMAD num -> Shape num -> Maybe (LMAD num)
reshaper ReshapeKind
ReshapeArbitrary LMAD num
lmad =
      LMAD num -> Shape num -> Maybe (LMAD num)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
LMAD.reshape LMAD num
lmad
    reshaper ReshapeKind
ReshapeCoerce LMAD num
lmad =
      LMAD num -> Maybe (LMAD num)
forall a. a -> Maybe a
Just (LMAD num -> Maybe (LMAD num))
-> (Shape num -> LMAD num) -> Shape num -> Maybe (LMAD num)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> Shape num -> LMAD num
forall num. LMAD num -> Shape num -> LMAD num
LMAD.coerce LMAD num
lmad
expReturns (BasicOp (Rearrange [Int]
perm VName
v)) = do
  (PrimType
et, Shape [SubExp]
dims, VName
mem, LMAD
lmad) <- VName -> m (PrimType, Shape, VName, LMAD)
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, LMAD)
arrayVarReturns VName
v
  let lmad' :: LMAD
lmad' = LMAD -> [Int] -> LMAD
forall num. LMAD num -> [Int] -> LMAD num
LMAD.permute LMAD
lmad [Int]
perm
      dims' :: [SubExp]
dims' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [SubExp]
dims
  Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [ExpReturns] -> m (Maybe [ExpReturns]))
-> Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a b. (a -> b) -> a -> b
$
    [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just
      [ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (SubExp -> Ext SubExp) -> [SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
forall a. a -> Ext a
Free [SubExp]
dims') NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
          MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$
            VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
              [VName] -> LMAD -> ExtLMAD
existentialiseLMAD [] LMAD
lmad'
      ]
expReturns (BasicOp (Index VName
v Slice SubExp
slice)) = do
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> (LetDecMem -> [ExpReturns]) -> LetDecMem -> Maybe [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ExpReturns -> [ExpReturns]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns])
-> (LetDecMem -> ExpReturns) -> LetDecMem -> [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LetDecMem -> ExpReturns
varInfoToExpReturns (LetDecMem -> Maybe [ExpReturns])
-> m LetDecMem -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Slice SubExp -> m LetDecMem
forall (m :: * -> *) rep (inner :: * -> *).
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> Slice SubExp -> m LetDecMem
sliceInfo VName
v Slice SubExp
slice
expReturns (BasicOp (Update Safety
_ VName
v Slice SubExp
_ SubExp
_)) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> (ExpReturns -> [ExpReturns]) -> ExpReturns -> Maybe [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ExpReturns -> [ExpReturns]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> Maybe [ExpReturns])
-> m ExpReturns -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp (FlatIndex VName
v FlatSlice SubExp
slice)) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> (LetDecMem -> [ExpReturns]) -> LetDecMem -> Maybe [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ExpReturns -> [ExpReturns]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns])
-> (LetDecMem -> ExpReturns) -> LetDecMem -> [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LetDecMem -> ExpReturns
varInfoToExpReturns (LetDecMem -> Maybe [ExpReturns])
-> m LetDecMem -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> FlatSlice SubExp -> m LetDecMem
forall (m :: * -> *) rep (inner :: * -> *).
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> FlatSlice SubExp -> m LetDecMem
flatSliceInfo VName
v FlatSlice SubExp
slice
expReturns (BasicOp (FlatUpdate VName
v FlatSlice SubExp
_ VName
_)) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> (ExpReturns -> [ExpReturns]) -> ExpReturns -> Maybe [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ExpReturns -> [ExpReturns]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> Maybe [ExpReturns])
-> m ExpReturns -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp BasicOp
op) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> ([Type] -> [ExpReturns]) -> [Type] -> Maybe [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns])
-> ([Type] -> [ExtType]) -> [Type] -> [ExpReturns]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes ([Type] -> Maybe [ExpReturns])
-> m [Type] -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BasicOp -> m [Type]
forall rep (m :: * -> *). HasScope rep m => BasicOp -> m [Type]
basicOpType BasicOp
op
expReturns e :: Exp rep
e@(Loop [(FParam rep, SubExp)]
merge LoopForm
_ Body rep
_) = do
  [ExtType]
t <- Exp rep -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> m [ExpReturns] -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ExtType -> Param FParamMem -> m ExpReturns)
-> [ExtType] -> [Param FParamMem] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> Param FParamMem -> m ExpReturns
typeWithDec [ExtType]
t (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge)
  where
    typeWithDec :: ExtType -> Param FParamMem -> m ExpReturns
typeWithDec ExtType
t Param FParamMem
p =
      case (ExtType
t, Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
p) of
        ( Array PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u,
          MemArray PrimType
_ Shape
_ Uniqueness
_ (ArrayIn VName
mem LMAD
lmad)
          )
            | Just (Int
i, Param FParamMem
mem_p) <- VName -> Maybe (Int, Param FParamMem)
isLoopVar VName
mem,
              Mem Space
space <- Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
mem_p ->
                ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$ MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
i ExtLMAD
lmad'
            | Bool
otherwise ->
                ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$ MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem ExtLMAD
lmad'
            where
              lmad' :: ExtLMAD
lmad' = [VName] -> LMAD -> ExtLMAD
existentialiseLMAD ((Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [Param FParamMem]
mergevars) LMAD
lmad
        (Array {}, FParamMem
_) ->
          String -> m ExpReturns
forall a. HasCallStack => String -> a
error String
"expReturns: Array return type but not array merge variable."
        (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u, FParamMem
_) ->
          ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (Prim PrimType
pt, FParamMem
_) ->
          ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (Mem Space
space, FParamMem
_) ->
          ExpReturns -> m ExpReturns
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    isLoopVar :: VName -> Maybe (Int, Param FParamMem)
isLoopVar VName
v = ((Int, Param FParamMem) -> Bool)
-> [(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Int, Param FParamMem) -> VName)
-> (Int, Param FParamMem)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Int, Param FParamMem) -> Param FParamMem)
-> (Int, Param FParamMem)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, Param FParamMem) -> Param FParamMem
forall a b. (a, b) -> b
snd) ([(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem))
-> [(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Param FParamMem] -> [(Int, Param FParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [Param FParamMem]
mergevars
    mergevars :: [Param FParamMem]
mergevars = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge
expReturns (Apply Name
_ [(SubExp, Diet)]
_ [(RetType rep, RetAls)]
ret (Safety, SrcLoc, [SrcLoc])
_) =
  Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [ExpReturns] -> m (Maybe [ExpReturns]))
-> Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a b. (a -> b) -> a -> b
$ [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> [ExpReturns] -> Maybe [ExpReturns]
forall a b. (a -> b) -> a -> b
$ ((FunReturns, RetAls) -> ExpReturns)
-> [(FunReturns, RetAls)] -> [ExpReturns]
forall a b. (a -> b) -> [a] -> [b]
map (FunReturns -> ExpReturns
funReturnsToExpReturns (FunReturns -> ExpReturns)
-> ((FunReturns, RetAls) -> FunReturns)
-> (FunReturns, RetAls)
-> ExpReturns
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (FunReturns, RetAls) -> FunReturns
forall a b. (a, b) -> a
fst) [(RetType rep, RetAls)]
[(FunReturns, RetAls)]
ret
expReturns (Match [SubExp]
_ [Case (Body rep)]
_ Body rep
_ (MatchDec [BranchType rep]
ret MatchSort
_)) =
  Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [ExpReturns] -> m (Maybe [ExpReturns]))
-> Maybe [ExpReturns] -> m (Maybe [ExpReturns])
forall a b. (a -> b) -> a -> b
$ [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> [ExpReturns] -> Maybe [ExpReturns]
forall a b. (a -> b) -> a -> b
$ (BodyReturns -> ExpReturns) -> [BodyReturns] -> [ExpReturns]
forall a b. (a -> b) -> [a] -> [b]
map BodyReturns -> ExpReturns
bodyReturnsToExpReturns [BranchType rep]
[BodyReturns]
ret
expReturns (Op Op rep
op) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just ([ExpReturns] -> Maybe [ExpReturns])
-> m [ExpReturns] -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemOp inner rep -> m [ExpReturns]
forall op rep (inner :: * -> *) (m :: * -> *).
(OpReturns op, Mem rep inner, Monad m, HasScope rep m) =>
op -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
MemOp inner rep -> m [ExpReturns]
opReturns Op rep
MemOp inner rep
op
expReturns (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) =
  [ExpReturns] -> Maybe [ExpReturns]
forall a. a -> Maybe a
Just
    ([ExpReturns] -> Maybe [ExpReturns])
-> m [ExpReturns] -> m (Maybe [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( [ExpReturns] -> [ExpReturns] -> [ExpReturns]
forall a. Semigroup a => a -> a -> a
(<>)
            ([ExpReturns] -> [ExpReturns] -> [ExpReturns])
-> m [ExpReturns] -> m ([ExpReturns] -> [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput rep -> m [ExpReturns])
-> [WithAccInput rep] -> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput rep -> m [ExpReturns]
forall {rep} {inner :: * -> *} {m :: * -> *} {t :: * -> *} {a} {c}.
(RetType rep ~ FunReturns, FParamInfo rep ~ FParamMem,
 LParamInfo rep ~ LetDecMem, BranchType rep ~ BodyReturns,
 OpC rep ~ MemOp inner, HasScope rep m, Monad m, Traversable t,
 HasLetDecMem (LetDec rep), ASTRep rep, OpReturns (inner rep),
 RephraseOp inner) =>
(a, t VName, c) -> m (t ExpReturns)
inputReturns [WithAccInput rep]
inputs)
            m ([ExpReturns] -> [ExpReturns])
-> m [ExpReturns] -> m [ExpReturns]
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            -- XXX: this is a bit dubious because it enforces extra copies.  I
            -- think WithAcc should perhaps have a return annotation like If.
            [ExpReturns] -> m [ExpReturns]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> [ExtType] -> [ExpReturns]
forall a b. (a -> b) -> a -> b
$ [Type] -> [ExtType]
forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes ([Type] -> [ExtType]) -> [Type] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)
        )
  where
    inputReturns :: (a, t VName, c) -> m (t ExpReturns)
inputReturns (a
_, t VName
arrs, c
_) = (VName -> m ExpReturns) -> t VName -> m (t ExpReturns)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> t a -> m (t b)
mapM VName -> m ExpReturns
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m ExpReturns
varReturns t VName
arrs
    num_accs :: Int
num_accs = [WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs

sliceInfo ::
  (Monad m, HasScope rep m, Mem rep inner) =>
  VName ->
  Slice SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo :: forall (m :: * -> *) rep (inner :: * -> *).
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> Slice SubExp -> m LetDecMem
sliceInfo VName
v Slice SubExp
slice = do
  (PrimType
et, Shape
_, VName
mem, LMAD
lmad) <- VName -> m (PrimType, Shape, VName, LMAD)
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, LMAD)
arrayVarReturns VName
v
  case Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice of
    [] -> LetDecMem -> m LetDecMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LetDecMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
et
    [SubExp]
dims ->
      LetDecMem -> m LetDecMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LetDecMem -> m LetDecMem) -> LetDecMem -> m LetDecMem
forall a b. (a -> b) -> a -> b
$
        PrimType -> Shape -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> (LMAD -> MemBind) -> LMAD -> LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> LMAD -> MemBind
ArrayIn VName
mem (LMAD -> LetDecMem) -> LMAD -> LetDecMem
forall a b. (a -> b) -> a -> b
$
          LMAD -> Slice (TPrimExp Int64 VName) -> LMAD
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD
lmad ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice)

flatSliceInfo ::
  (Monad m, HasScope rep m, Mem rep inner) =>
  VName ->
  FlatSlice SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
flatSliceInfo :: forall (m :: * -> *) rep (inner :: * -> *).
(Monad m, HasScope rep m, Mem rep inner) =>
VName -> FlatSlice SubExp -> m LetDecMem
flatSliceInfo VName
v slice :: FlatSlice SubExp
slice@(FlatSlice SubExp
offset [FlatDimIndex SubExp]
idxs) = do
  (PrimType
et, Shape
_, VName
mem, LMAD
lmad) <- VName -> m (PrimType, Shape, VName, LMAD)
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Monad m, Mem rep inner) =>
VName -> m (PrimType, Shape, VName, LMAD)
arrayVarReturns VName
v
  (FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName))
-> [FlatDimIndex SubExp] -> [FlatDimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> FlatDimIndex a -> FlatDimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
idxs
    [FlatDimIndex (TPrimExp Int64 VName)]
-> ([FlatDimIndex (TPrimExp Int64 VName)]
    -> FlatSlice (TPrimExp Int64 VName))
-> FlatSlice (TPrimExp Int64 VName)
forall a b. a -> (a -> b) -> b
& TPrimExp Int64 VName
-> [FlatDimIndex (TPrimExp Int64 VName)]
-> FlatSlice (TPrimExp Int64 VName)
forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset)
    FlatSlice (TPrimExp Int64 VName)
-> (FlatSlice (TPrimExp Int64 VName) -> LMAD) -> LMAD
forall a b. a -> (a -> b) -> b
& LMAD -> FlatSlice (TPrimExp Int64 VName) -> LMAD
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
LMAD.flatSlice LMAD
lmad
    LMAD -> (LMAD -> LetDecMem) -> LetDecMem
forall a b. a -> (a -> b) -> b
& PrimType -> Shape -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape (FlatSlice SubExp -> [SubExp]
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
slice)) NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> (LMAD -> MemBind) -> LMAD -> LetDecMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> LMAD -> MemBind
ArrayIn VName
mem
    LetDecMem -> (LetDecMem -> m LetDecMem) -> m LetDecMem
forall a b. a -> (a -> b) -> b
& LetDecMem -> m LetDecMem
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

class (IsOp op) => OpReturns op where
  opReturns :: (Mem rep inner, Monad m, HasScope rep m) => op -> m [ExpReturns]
  opReturns op
op = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> op -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
forall t (m :: * -> *). HasScope t m => op -> m [ExtType]
opType op
op

instance (OpReturns (inner rep)) => OpReturns (MemOp inner rep) where
  opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
MemOp inner rep -> m [ExpReturns]
opReturns (Alloc SubExp
_ Space
space) = [ExpReturns] -> m [ExpReturns]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
  opReturns (Inner inner rep
op) = inner rep -> m [ExpReturns]
forall op rep (inner :: * -> *) (m :: * -> *).
(OpReturns op, Mem rep inner, Monad m, HasScope rep m) =>
op -> m [ExpReturns]
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
inner rep -> m [ExpReturns]
opReturns inner rep
op

instance OpReturns (NoOp rep) where
  opReturns :: forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, Monad m, HasScope rep m) =>
NoOp rep -> m [ExpReturns]
opReturns NoOp rep
NoOp = [ExpReturns] -> m [ExpReturns]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

applyFunReturns ::
  (Typed dec) =>
  [FunReturns] ->
  [Param dec] ->
  [(SubExp, Type)] ->
  Maybe [FunReturns]
applyFunReturns :: forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns [FunReturns]
rets [Param dec]
params [(SubExp, Type)]
args
  | Just [DeclExtType]
_ <- [DeclExtType]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [DeclExtType]
forall dec.
Typed dec =>
[DeclExtType]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [DeclExtType]
forall rt dec.
(IsRetType rt, Typed dec) =>
[rt] -> [Param dec] -> [(SubExp, Type)] -> Maybe [rt]
applyRetType [DeclExtType]
rettype [Param dec]
params [(SubExp, Type)]
args =
      [FunReturns] -> Maybe [FunReturns]
forall a. a -> Maybe a
Just ([FunReturns] -> Maybe [FunReturns])
-> [FunReturns] -> Maybe [FunReturns]
forall a b. (a -> b) -> a -> b
$ (FunReturns -> FunReturns) -> [FunReturns] -> [FunReturns]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> FunReturns
forall {u}.
MemInfo (Ext SubExp) u MemReturn
-> MemInfo (Ext SubExp) u MemReturn
correctDims [FunReturns]
rets
  | Bool
otherwise =
      Maybe [FunReturns]
forall a. Maybe a
Nothing
  where
    rettype :: [DeclExtType]
rettype = (FunReturns -> DeclExtType) -> [FunReturns] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf [FunReturns]
rets
    parammap :: M.Map VName (SubExp, Type)
    parammap :: Map VName (SubExp, Type)
parammap =
      [(VName, (SubExp, Type))] -> Map VName (SubExp, Type)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (SubExp, Type))] -> Map VName (SubExp, Type))
-> [(VName, (SubExp, Type))] -> Map VName (SubExp, Type)
forall a b. (a -> b) -> a -> b
$
        [VName] -> [(SubExp, Type)] -> [(VName, (SubExp, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params) [(SubExp, Type)]
args

    substSubExp :: SubExp -> SubExp
substSubExp (Var VName
v)
      | Just (SubExp
se, Type
_) <- VName -> Map VName (SubExp, Type) -> Maybe (SubExp, Type)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (SubExp, Type)
parammap = SubExp
se
    substSubExp SubExp
se = SubExp
se

    correctDims :: MemInfo (Ext SubExp) u MemReturn
-> MemInfo (Ext SubExp) u MemReturn
correctDims (MemPrim PrimType
t) =
      PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    correctDims (MemMem Space
space) =
      Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    correctDims (MemArray PrimType
et ShapeBase (Ext SubExp)
shape u
u MemReturn
memsummary) =
      PrimType
-> ShapeBase (Ext SubExp)
-> u
-> MemReturn
-> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
correctShape ShapeBase (Ext SubExp)
shape) u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
        MemReturn -> MemReturn
correctSummary MemReturn
memsummary
    correctDims (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
      VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

    correctShape :: ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
correctShape = [Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> ShapeBase (Ext SubExp)
-> ShapeBase (Ext SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Ext SubExp -> Ext SubExp) -> [Ext SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> Ext SubExp
correctDim ([Ext SubExp] -> [Ext SubExp])
-> (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> ShapeBase (Ext SubExp)
-> [Ext SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims
    correctDim :: Ext SubExp -> Ext SubExp
correctDim (Ext Int
i) = Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
i
    correctDim (Free SubExp
se) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free (SubExp -> Ext SubExp) -> SubExp -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp
substSubExp SubExp
se

    correctSummary :: MemReturn -> MemReturn
correctSummary (ReturnsNewBlock Space
space Int
i ExtLMAD
lmad) =
      Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
i ExtLMAD
lmad
    correctSummary (ReturnsInBlock VName
mem ExtLMAD
lmad) =
      -- FIXME: we should also do a replacement in lmad here.
      VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem' ExtLMAD
lmad
      where
        mem' :: VName
mem' = case VName -> Map VName (SubExp, Type) -> Maybe (SubExp, Type)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem Map VName (SubExp, Type)
parammap of
          Just (Var VName
v, Type
_) -> VName
v
          Maybe (SubExp, Type)
_ -> VName
mem