{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Building blocks for defining representations where every array
-- is given information about which memory block is it based in, and
-- how array elements map to memory block offsets.
--
-- There are two primary concepts you will need to understand:
--
--  1. Memory blocks, which are Futhark values of type v'Mem'
--     (parametrized with their size).  These correspond to arbitrary
--     blocks of memory, and are created using the 'Alloc' operation.
--
--  2. Index functions, which describe a mapping from the index space
--     of an array (eg. a two-dimensional space for an array of type
--     @[[int]]@) to a one-dimensional offset into a memory block.
--     Thus, index functions describe how arbitrary-dimensional arrays
--     are mapped to the single-dimensional world of memory.
--
-- At a conceptual level, imagine that we have a two-dimensional array
-- @a@ of 32-bit integers, consisting of @n@ rows of @m@ elements
-- each.  This array could be represented in classic row-major format
-- with an index function like the following:
--
-- @
--   f(i,j) = i * m + j
-- @
--
-- When we want to know the location of element @a[2,3]@, we simply
-- call the index function as @f(2,3)@ and obtain @2*m+3@.  We could
-- also have chosen another index function, one that represents the
-- array in column-major (or "transposed") format:
--
-- @
--   f(i,j) = j * n + i
-- @
--
-- Index functions are not Futhark-level functions, but a special
-- construct that the final code generator will eventually use to
-- generate concrete access code.  By modifying the index functions we
-- can change how an array is represented in memory, which can permit
-- memory access pattern optimisations.
--
-- Every time we bind an array, whether in a @let@-binding, @loop@
-- merge parameter, or @lambda@ parameter, we have an annotation
-- specifying a memory block and an index function.  In some cases,
-- such as @let@-bindings for many expressions, we are free to specify
-- an arbitrary index function and memory block - for example, we get
-- to decide where 'Copy' stores its result - but in other cases the
-- type rules of the expression chooses for us.  For example, 'Index'
-- always produces an array in the same memory block as its input, and
-- with the same index function, except with some indices fixed.
module Futhark.IR.Mem
  ( LetDecMem,
    FParamMem,
    LParamMem,
    RetTypeMem,
    BranchTypeMem,
    MemOp (..),
    MemInfo (..),
    MemBound,
    MemBind (..),
    MemReturn (..),
    IxFun,
    ExtIxFun,
    isStaticIxFun,
    ExpReturns,
    BodyReturns,
    FunReturns,
    noUniquenessReturns,
    bodyReturnsToExpReturns,
    Mem,
    AllocOp (..),
    OpReturns (..),
    varReturns,
    expReturns,
    extReturns,
    lookupMemInfo,
    subExpMemInfo,
    lookupArraySummary,
    existentialiseIxFun,

    -- * Type checking parts
    matchBranchReturnType,
    matchPatternToExp,
    matchFunctionReturnType,
    matchLoopResultMem,
    bodyReturnsFromPattern,
    checkMemInfo,

    -- * Module re-exports
    module Futhark.IR.Prop,
    module Futhark.IR.Traversals,
    module Futhark.IR.Pretty,
    module Futhark.IR.Syntax,
    module Futhark.Analysis.PrimExp.Convert,
  )
where

import Control.Category
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Foldable (toList, traverse_)
import Data.List (elemIndex, find)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.PrimExp.Simplify
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Aliases
  ( Aliases,
    removeExpAliases,
    removePatternAliases,
    removeScopeAliases,
  )
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.Pretty
import Futhark.IR.Prop
import Futhark.IR.Prop.Aliases
import Futhark.IR.Syntax
import Futhark.IR.Traversals
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util
import Futhark.Util.Pretty (indent, ppr, text, (<+>), (</>))
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))

type LetDecMem = MemInfo SubExp NoUniqueness MemBind

type FParamMem = MemInfo SubExp Uniqueness MemBind

type LParamMem = MemInfo SubExp NoUniqueness MemBind

type RetTypeMem = FunReturns

type BranchTypeMem = BodyReturns

-- | The class of ops that have memory allocation.
class AllocOp op where
  allocOp :: SubExp -> Space -> op

type Mem rep =
  ( AllocOp (Op rep),
    FParamInfo rep ~ FParamMem,
    LParamInfo rep ~ LParamMem,
    LetDec rep ~ LetDecMem,
    RetType rep ~ RetTypeMem,
    BranchType rep ~ BranchTypeMem,
    ASTRep rep,
    OpReturns rep
  )

instance IsRetType FunReturns where
  primRetType :: PrimType -> FunReturns
primRetType = PrimType -> FunReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim
  applyRetType :: forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyRetType = [FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns

instance IsBodyType BodyReturns where
  primBodyType :: PrimType -> BodyReturns
primBodyType = PrimType -> BodyReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim

data MemOp inner
  = -- | Allocate a memory block.  This really should not be an
    -- expression, but what are you gonna do...
    Alloc SubExp Space
  | Inner inner
  deriving (MemOp inner -> MemOp inner -> Bool
(MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool) -> Eq (MemOp inner)
forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemOp inner -> MemOp inner -> Bool
$c/= :: forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
== :: MemOp inner -> MemOp inner -> Bool
$c== :: forall inner. Eq inner => MemOp inner -> MemOp inner -> Bool
Eq, Eq (MemOp inner)
Eq (MemOp inner)
-> (MemOp inner -> MemOp inner -> Ordering)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> Bool)
-> (MemOp inner -> MemOp inner -> MemOp inner)
-> (MemOp inner -> MemOp inner -> MemOp inner)
-> Ord (MemOp inner)
MemOp inner -> MemOp inner -> Bool
MemOp inner -> MemOp inner -> Ordering
MemOp inner -> MemOp inner -> MemOp inner
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {inner}. Ord inner => Eq (MemOp inner)
forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
forall inner. Ord inner => MemOp inner -> MemOp inner -> Ordering
forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
min :: MemOp inner -> MemOp inner -> MemOp inner
$cmin :: forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
max :: MemOp inner -> MemOp inner -> MemOp inner
$cmax :: forall inner.
Ord inner =>
MemOp inner -> MemOp inner -> MemOp inner
>= :: MemOp inner -> MemOp inner -> Bool
$c>= :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
> :: MemOp inner -> MemOp inner -> Bool
$c> :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
<= :: MemOp inner -> MemOp inner -> Bool
$c<= :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
< :: MemOp inner -> MemOp inner -> Bool
$c< :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Bool
compare :: MemOp inner -> MemOp inner -> Ordering
$ccompare :: forall inner. Ord inner => MemOp inner -> MemOp inner -> Ordering
Ord, Int -> MemOp inner -> ShowS
[MemOp inner] -> ShowS
MemOp inner -> String
(Int -> MemOp inner -> ShowS)
-> (MemOp inner -> String)
-> ([MemOp inner] -> ShowS)
-> Show (MemOp inner)
forall inner. Show inner => Int -> MemOp inner -> ShowS
forall inner. Show inner => [MemOp inner] -> ShowS
forall inner. Show inner => MemOp inner -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemOp inner] -> ShowS
$cshowList :: forall inner. Show inner => [MemOp inner] -> ShowS
show :: MemOp inner -> String
$cshow :: forall inner. Show inner => MemOp inner -> String
showsPrec :: Int -> MemOp inner -> ShowS
$cshowsPrec :: forall inner. Show inner => Int -> MemOp inner -> ShowS
Show)

instance AllocOp (MemOp inner) where
  allocOp :: SubExp -> Space -> MemOp inner
allocOp = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc

instance FreeIn inner => FreeIn (MemOp inner) where
  freeIn' :: MemOp inner -> FV
freeIn' (Alloc SubExp
size Space
_) = SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
size
  freeIn' (Inner inner
k) = inner -> FV
forall a. FreeIn a => a -> FV
freeIn' inner
k

instance TypedOp inner => TypedOp (MemOp inner) where
  opType :: forall t (m :: * -> *). HasScope t m => MemOp inner -> m [ExtType]
opType (Alloc SubExp
_ Space
space) = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Space -> ExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space]
  opType (Inner inner
k) = inner -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType inner
k

instance AliasedOp inner => AliasedOp (MemOp inner) where
  opAliases :: MemOp inner -> [Names]
opAliases Alloc {} = [Names
forall a. Monoid a => a
mempty]
  opAliases (Inner inner
k) = inner -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases inner
k

  consumedInOp :: MemOp inner -> Names
consumedInOp Alloc {} = Names
forall a. Monoid a => a
mempty
  consumedInOp (Inner inner
k) = inner -> Names
forall op. AliasedOp op => op -> Names
consumedInOp inner
k

instance CanBeAliased inner => CanBeAliased (MemOp inner) where
  type OpWithAliases (MemOp inner) = MemOp (OpWithAliases inner)
  removeOpAliases :: OpWithAliases (MemOp inner) -> MemOp inner
removeOpAliases (Alloc SubExp
se Space
space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
se Space
space
  removeOpAliases (Inner OpWithAliases inner
k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> inner -> MemOp inner
forall a b. (a -> b) -> a -> b
$ OpWithAliases inner -> inner
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases inner
k

  addOpAliases :: AliasTable -> MemOp inner -> OpWithAliases (MemOp inner)
addOpAliases AliasTable
_ (Alloc SubExp
se Space
space) = SubExp -> Space -> MemOp (OpWithAliases inner)
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
se Space
space
  addOpAliases AliasTable
aliases (Inner inner
k) = OpWithAliases inner -> MemOp (OpWithAliases inner)
forall inner. inner -> MemOp inner
Inner (OpWithAliases inner -> MemOp (OpWithAliases inner))
-> OpWithAliases inner -> MemOp (OpWithAliases inner)
forall a b. (a -> b) -> a -> b
$ AliasTable -> inner -> OpWithAliases inner
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
aliases inner
k

instance Rename inner => Rename (MemOp inner) where
  rename :: MemOp inner -> RenameM (MemOp inner)
rename (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc (SubExp -> Space -> MemOp inner)
-> RenameM SubExp -> RenameM (Space -> MemOp inner)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename SubExp
size RenameM (Space -> MemOp inner)
-> RenameM Space -> RenameM (MemOp inner)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> RenameM Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
space
  rename (Inner inner
k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> RenameM inner -> RenameM (MemOp inner)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> inner -> RenameM inner
forall a. Rename a => a -> RenameM a
rename inner
k

instance Substitute inner => Substitute (MemOp inner) where
  substituteNames :: Map VName VName -> MemOp inner -> MemOp inner
substituteNames Map VName VName
subst (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc (Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst SubExp
size) Space
space
  substituteNames Map VName VName
subst (Inner inner
k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> inner -> MemOp inner
forall a b. (a -> b) -> a -> b
$ Map VName VName -> inner -> inner
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst inner
k

instance PP.Pretty inner => PP.Pretty (MemOp inner) where
  ppr :: MemOp inner -> Doc
ppr (Alloc SubExp
e Space
DefaultSpace) = String -> Doc
PP.text String
"alloc" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [SubExp -> Doc
forall a. Pretty a => a -> Doc
PP.ppr SubExp
e]
  ppr (Alloc SubExp
e Space
s) = String -> Doc
PP.text String
"alloc" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> [Doc] -> Doc
PP.apply [SubExp -> Doc
forall a. Pretty a => a -> Doc
PP.ppr SubExp
e, Space -> Doc
forall a. Pretty a => a -> Doc
PP.ppr Space
s]
  ppr (Inner inner
k) = inner -> Doc
forall a. Pretty a => a -> Doc
PP.ppr inner
k

instance OpMetrics inner => OpMetrics (MemOp inner) where
  opMetrics :: MemOp inner -> MetricsM ()
opMetrics Alloc {} = Text -> MetricsM ()
seen Text
"Alloc"
  opMetrics (Inner inner
k) = inner -> MetricsM ()
forall op. OpMetrics op => op -> MetricsM ()
opMetrics inner
k

instance IsOp inner => IsOp (MemOp inner) where
  safeOp :: MemOp inner -> Bool
safeOp (Alloc (Constant (IntValue (Int64Value Int64
k))) Space
_) = Int64
k Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0
  safeOp Alloc {} = Bool
False
  safeOp (Inner inner
k) = inner -> Bool
forall op. IsOp op => op -> Bool
safeOp inner
k
  cheapOp :: MemOp inner -> Bool
cheapOp (Inner inner
k) = inner -> Bool
forall op. IsOp op => op -> Bool
cheapOp inner
k
  cheapOp Alloc {} = Bool
True

instance CanBeWise inner => CanBeWise (MemOp inner) where
  type OpWithWisdom (MemOp inner) = MemOp (OpWithWisdom inner)
  removeOpWisdom :: OpWithWisdom (MemOp inner) -> MemOp inner
removeOpWisdom (Alloc SubExp
size Space
space) = SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size Space
space
  removeOpWisdom (Inner OpWithWisdom inner
k) = inner -> MemOp inner
forall inner. inner -> MemOp inner
Inner (inner -> MemOp inner) -> inner -> MemOp inner
forall a b. (a -> b) -> a -> b
$ OpWithWisdom inner -> inner
forall op. CanBeWise op => OpWithWisdom op -> op
removeOpWisdom OpWithWisdom inner
k

instance ST.IndexOp inner => ST.IndexOp (MemOp inner) where
  indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> MemOp inner -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k (Inner inner
op) [TPrimExp Int64 VName]
is = SymbolTable rep
-> Int -> inner -> [TPrimExp Int64 VName] -> Maybe Indexed
forall op rep.
(IndexOp op, ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed
ST.indexOp SymbolTable rep
vtable Int
k inner
op [TPrimExp Int64 VName]
is
  indexOp SymbolTable rep
_ Int
_ MemOp inner
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

-- | The index function representation used for memory annotations.
type IxFun = IxFun.IxFun (TPrimExp Int64 VName)

-- | An index function that may contain existential variables.
type ExtIxFun = IxFun.IxFun (TPrimExp Int64 (Ext VName))

-- | A summary of the memory information for every let-bound
-- identifier, function parameter, and return value.  Parameterisered
-- over uniqueness, dimension, and auxiliary array information.
data MemInfo d u ret
  = -- | A primitive value.
    MemPrim PrimType
  | -- | A memory block.
    MemMem Space
  | -- | The array is stored in the named memory block, and with the
    -- given index function.  The index function maps indices in the
    -- array to /element/ offset, /not/ byte offsets!  To translate to
    -- byte offsets, multiply the offset with the size of the array
    -- element type.
    MemArray PrimType (ShapeBase d) u ret
  | -- | An accumulator, which is not stored anywhere.
    MemAcc VName Shape [Type] u
  deriving (MemInfo d u ret -> MemInfo d u ret -> Bool
(MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> Eq (MemInfo d u ret)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
/= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c/= :: forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
== :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c== :: forall d u ret.
(Eq d, Eq u, Eq ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
Eq, Int -> MemInfo d u ret -> ShowS
[MemInfo d u ret] -> ShowS
MemInfo d u ret -> String
(Int -> MemInfo d u ret -> ShowS)
-> (MemInfo d u ret -> String)
-> ([MemInfo d u ret] -> ShowS)
-> Show (MemInfo d u ret)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall d u ret.
(Show d, Show u, Show ret) =>
Int -> MemInfo d u ret -> ShowS
forall d u ret.
(Show d, Show u, Show ret) =>
[MemInfo d u ret] -> ShowS
forall d u ret.
(Show d, Show u, Show ret) =>
MemInfo d u ret -> String
showList :: [MemInfo d u ret] -> ShowS
$cshowList :: forall d u ret.
(Show d, Show u, Show ret) =>
[MemInfo d u ret] -> ShowS
show :: MemInfo d u ret -> String
$cshow :: forall d u ret.
(Show d, Show u, Show ret) =>
MemInfo d u ret -> String
showsPrec :: Int -> MemInfo d u ret -> ShowS
$cshowsPrec :: forall d u ret.
(Show d, Show u, Show ret) =>
Int -> MemInfo d u ret -> ShowS
Show, Eq (MemInfo d u ret)
Eq (MemInfo d u ret)
-> (MemInfo d u ret -> MemInfo d u ret -> Ordering)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> Bool)
-> (MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret)
-> (MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret)
-> Ord (MemInfo d u ret)
MemInfo d u ret -> MemInfo d u ret -> Bool
MemInfo d u ret -> MemInfo d u ret -> Ordering
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {d} {u} {ret}.
(Ord d, Ord u, Ord ret) =>
Eq (MemInfo d u ret)
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
min :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$cmin :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
max :: MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
$cmax :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> MemInfo d u ret
>= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c>= :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
> :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c> :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
<= :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c<= :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
< :: MemInfo d u ret -> MemInfo d u ret -> Bool
$c< :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Bool
compare :: MemInfo d u ret -> MemInfo d u ret -> Ordering
$ccompare :: forall d u ret.
(Ord d, Ord u, Ord ret) =>
MemInfo d u ret -> MemInfo d u ret -> Ordering
Ord) --- XXX Ord?

type MemBound u = MemInfo SubExp u MemBind

instance FixExt ret => DeclExtTyped (MemInfo ExtSize Uniqueness ret) where
  declExtTypeOf :: MemInfo (Ext SubExp) Uniqueness ret -> DeclExtType
declExtTypeOf (MemPrim PrimType
pt) = PrimType -> DeclExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  declExtTypeOf (MemMem Space
space) = Space -> DeclExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space
  declExtTypeOf (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u ret
_) = PrimType -> ShapeBase (Ext SubExp) -> Uniqueness -> DeclExtType
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase (Ext SubExp)
shape Uniqueness
u
  declExtTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = VName -> Shape -> [Type] -> Uniqueness -> DeclExtType
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u

instance FixExt ret => ExtTyped (MemInfo ExtSize NoUniqueness ret) where
  extTypeOf :: MemInfo (Ext SubExp) NoUniqueness ret -> ExtType
extTypeOf (MemPrim PrimType
pt) = PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  extTypeOf (MemMem Space
space) = Space -> ExtType
forall shape u. Space -> TypeBase shape u
Mem Space
space
  extTypeOf (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u ret
_) = PrimType -> ShapeBase (Ext SubExp) -> NoUniqueness -> ExtType
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u
  extTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> ExtType
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

instance FixExt ret => FixExt (MemInfo ExtSize u ret) where
  fixExt :: Int
-> SubExp
-> MemInfo (Ext SubExp) u ret
-> MemInfo (Ext SubExp) u ret
fixExt Int
_ SubExp
_ (MemPrim PrimType
pt) = PrimType -> MemInfo (Ext SubExp) u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
  fixExt Int
_ SubExp
_ (MemMem Space
space) = Space -> MemInfo (Ext SubExp) u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  fixExt Int
i SubExp
se (MemArray PrimType
pt ShapeBase (Ext SubExp)
shape u
u ret
ret) =
    PrimType
-> ShapeBase (Ext SubExp) -> u -> ret -> MemInfo (Ext SubExp) u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt (Int -> SubExp -> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ShapeBase (Ext SubExp)
shape) u
u (Int -> SubExp -> ret -> ret
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se ret
ret)
  fixExt Int
_ SubExp
_ (MemAcc VName
acc Shape
ispace [Type]
ts u
u) = VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u ret
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

instance Typed (MemInfo SubExp Uniqueness ret) where
  typeOf :: MemInfo SubExp Uniqueness ret -> Type
typeOf = TypeBase Shape Uniqueness -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (TypeBase Shape Uniqueness -> Type)
-> (MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness)
-> MemInfo SubExp Uniqueness ret
-> Type
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf

instance Typed (MemInfo SubExp NoUniqueness ret) where
  typeOf :: MemInfo SubExp NoUniqueness ret -> Type
typeOf (MemPrim PrimType
pt) = PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
  typeOf (MemMem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
  typeOf (MemArray PrimType
bt Shape
shape NoUniqueness
u ret
_) = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape NoUniqueness
u
  typeOf (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) = VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

instance DeclTyped (MemInfo SubExp Uniqueness ret) where
  declTypeOf :: MemInfo SubExp Uniqueness ret -> TypeBase Shape Uniqueness
declTypeOf (MemPrim PrimType
bt) = PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
  declTypeOf (MemMem Space
space) = Space -> TypeBase Shape Uniqueness
forall shape u. Space -> TypeBase shape u
Mem Space
space
  declTypeOf (MemArray PrimType
bt Shape
shape Uniqueness
u ret
_) = PrimType -> Shape -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt Shape
shape Uniqueness
u
  declTypeOf (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) = VName -> Shape -> [Type] -> Uniqueness -> TypeBase Shape Uniqueness
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u

instance (FreeIn d, FreeIn ret) => FreeIn (MemInfo d u ret) where
  freeIn' :: MemInfo d u ret -> FV
freeIn' (MemArray PrimType
_ ShapeBase d
shape u
_ ret
ret) = ShapeBase d -> FV
forall a. FreeIn a => a -> FV
freeIn' ShapeBase d
shape FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ret -> FV
forall a. FreeIn a => a -> FV
freeIn' ret
ret
  freeIn' (MemMem Space
s) = Space -> FV
forall a. FreeIn a => a -> FV
freeIn' Space
s
  freeIn' MemPrim {} = FV
forall a. Monoid a => a
mempty
  freeIn' (MemAcc VName
acc Shape
ispace [Type]
ts u
_) = (VName, Shape, [Type]) -> FV
forall a. FreeIn a => a -> FV
freeIn' (VName
acc, Shape
ispace, [Type]
ts)

instance (Substitute d, Substitute ret) => Substitute (MemInfo d u ret) where
  substituteNames :: Map VName VName -> MemInfo d u ret -> MemInfo d u ret
substituteNames Map VName VName
subst (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray
      PrimType
bt
      (Map VName VName -> ShapeBase d -> ShapeBase d
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ShapeBase d
shape)
      u
u
      (Map VName VName -> ret -> ret
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ret
ret)
  substituteNames Map VName VName
substs (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    VName -> Shape -> [Type] -> u -> MemInfo d u ret
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc
      (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
acc)
      (Map VName VName -> Shape -> Shape
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ispace)
      (Map VName VName -> [Type] -> [Type]
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Type]
ts)
      u
u
  substituteNames Map VName VName
_ (MemMem Space
space) =
    Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  substituteNames Map VName VName
_ (MemPrim PrimType
bt) =
    PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt

instance (Substitute d, Substitute ret) => Rename (MemInfo d u ret) where
  rename :: MemInfo d u ret -> RenameM (MemInfo d u ret)
rename = MemInfo d u ret -> RenameM (MemInfo d u ret)
forall a. Substitute a => a -> RenameM a
substituteRename

simplifyIxFun ::
  Engine.SimplifiableRep rep =>
  IxFun ->
  Engine.SimpleM rep IxFun
simplifyIxFun :: forall rep. SimplifiableRep rep => IxFun -> SimpleM rep IxFun
simplifyIxFun = (TPrimExp Int64 VName -> SimpleM rep (TPrimExp Int64 VName))
-> IxFun -> SimpleM rep IxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 VName -> SimpleM rep (TPrimExp Int64 VName))
 -> IxFun -> SimpleM rep IxFun)
-> (TPrimExp Int64 VName -> SimpleM rep (TPrimExp Int64 VName))
-> IxFun
-> SimpleM rep IxFun
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> TPrimExp Int64 VName)
-> SimpleM rep (PrimExp VName)
-> SimpleM rep (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (SimpleM rep (PrimExp VName) -> SimpleM rep (TPrimExp Int64 VName))
-> (TPrimExp Int64 VName -> SimpleM rep (PrimExp VName))
-> TPrimExp Int64 VName
-> SimpleM rep (TPrimExp Int64 VName)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimExp VName -> SimpleM rep (PrimExp VName)
forall rep.
SimplifiableRep rep =>
PrimExp VName -> SimpleM rep (PrimExp VName)
simplifyPrimExp (PrimExp VName -> SimpleM rep (PrimExp VName))
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> SimpleM rep (PrimExp VName)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped

simplifyExtIxFun ::
  Engine.SimplifiableRep rep =>
  ExtIxFun ->
  Engine.SimpleM rep ExtIxFun
simplifyExtIxFun :: forall rep. SimplifiableRep rep => ExtIxFun -> SimpleM rep ExtIxFun
simplifyExtIxFun = (TPrimExp Int64 (Ext VName)
 -> SimpleM rep (TPrimExp Int64 (Ext VName)))
-> ExtIxFun -> SimpleM rep ExtIxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName)
  -> SimpleM rep (TPrimExp Int64 (Ext VName)))
 -> ExtIxFun -> SimpleM rep ExtIxFun)
-> (TPrimExp Int64 (Ext VName)
    -> SimpleM rep (TPrimExp Int64 (Ext VName)))
-> ExtIxFun
-> SimpleM rep ExtIxFun
forall a b. (a -> b) -> a -> b
$ (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> SimpleM rep (PrimExp (Ext VName))
-> SimpleM rep (TPrimExp Int64 (Ext VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (SimpleM rep (PrimExp (Ext VName))
 -> SimpleM rep (TPrimExp Int64 (Ext VName)))
-> (TPrimExp Int64 (Ext VName)
    -> SimpleM rep (PrimExp (Ext VName)))
-> TPrimExp Int64 (Ext VName)
-> SimpleM rep (TPrimExp Int64 (Ext VName))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall rep.
SimplifiableRep rep =>
PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
simplifyExtPrimExp (PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName)))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> SimpleM rep (PrimExp (Ext VName))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall t v. TPrimExp t v -> PrimExp v
untyped

isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun :: ExtIxFun -> Maybe IxFun
isStaticIxFun = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe IxFun
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
 -> ExtIxFun -> Maybe IxFun)
-> (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun
-> Maybe IxFun
forall a b. (a -> b) -> a -> b
$ (Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst
  where
    inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
    inst (Free a
x) = a -> Maybe a
forall a. a -> Maybe a
Just a
x

instance
  (Engine.Simplifiable d, Engine.Simplifiable ret) =>
  Engine.Simplifiable (MemInfo d u ret)
  where
  simplify :: forall rep.
SimplifiableRep rep =>
MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
simplify (MemPrim PrimType
bt) =
    MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo d u ret -> SimpleM rep (MemInfo d u ret))
-> MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
  simplify (MemMem Space
space) =
    MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemInfo d u ret -> SimpleM rep (MemInfo d u ret))
-> MemInfo d u ret -> SimpleM rep (MemInfo d u ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
  simplify (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt (ShapeBase d -> u -> ret -> MemInfo d u ret)
-> SimpleM rep (ShapeBase d)
-> SimpleM rep (u -> ret -> MemInfo d u ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase d -> SimpleM rep (ShapeBase d)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ShapeBase d
shape SimpleM rep (u -> ret -> MemInfo d u ret)
-> SimpleM rep u -> SimpleM rep (ret -> MemInfo d u ret)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u SimpleM rep (ret -> MemInfo d u ret)
-> SimpleM rep ret -> SimpleM rep (MemInfo d u ret)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ret -> SimpleM rep ret
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify ret
ret
  simplify (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    VName -> Shape -> [Type] -> u -> MemInfo d u ret
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc (VName -> Shape -> [Type] -> u -> MemInfo d u ret)
-> SimpleM rep VName
-> SimpleM rep (Shape -> [Type] -> u -> MemInfo d u ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
acc SimpleM rep (Shape -> [Type] -> u -> MemInfo d u ret)
-> SimpleM rep Shape
-> SimpleM rep ([Type] -> u -> MemInfo d u ret)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify Shape
ispace SimpleM rep ([Type] -> u -> MemInfo d u ret)
-> SimpleM rep [Type] -> SimpleM rep (u -> MemInfo d u ret)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> SimpleM rep [Type]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [Type]
ts SimpleM rep (u -> MemInfo d u ret)
-> SimpleM rep u -> SimpleM rep (MemInfo d u ret)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

instance
  ( PP.Pretty (ShapeBase d),
    PP.Pretty (TypeBase (ShapeBase d) u),
    PP.Pretty d,
    PP.Pretty u,
    PP.Pretty ret
  ) =>
  PP.Pretty (MemInfo d u ret)
  where
  ppr :: MemInfo d u ret -> Doc
ppr (MemPrim PrimType
bt) = PrimType -> Doc
forall a. Pretty a => a -> Doc
PP.ppr PrimType
bt
  ppr (MemMem Space
DefaultSpace) = String -> Doc
PP.text String
"mem"
  ppr (MemMem Space
s) = String -> Doc
PP.text String
"mem" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Space -> Doc
forall a. Pretty a => a -> Doc
PP.ppr Space
s
  ppr (MemArray PrimType
bt ShapeBase d
shape u
u ret
ret) =
    TypeBase (ShapeBase d) u -> Doc
forall a. Pretty a => a -> Doc
PP.ppr (PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt ShapeBase d
shape u
u) Doc -> Doc -> Doc
<+> String -> Doc
PP.text String
"@" Doc -> Doc -> Doc
<+> ret -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ret
ret
  ppr (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
    u -> Doc
forall a. Pretty a => a -> Doc
PP.ppr u
u Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Type -> Doc
forall a. Pretty a => a -> Doc
PP.ppr (VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness :: Type)

-- | Memory information for an array bound somewhere in the program.
data MemBind
  = -- | Located in this memory block with this index
    -- function.
    ArrayIn VName IxFun
  deriving (Int -> MemBind -> ShowS
[MemBind] -> ShowS
MemBind -> String
(Int -> MemBind -> ShowS)
-> (MemBind -> String) -> ([MemBind] -> ShowS) -> Show MemBind
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemBind] -> ShowS
$cshowList :: [MemBind] -> ShowS
show :: MemBind -> String
$cshow :: MemBind -> String
showsPrec :: Int -> MemBind -> ShowS
$cshowsPrec :: Int -> MemBind -> ShowS
Show)

instance Eq MemBind where
  MemBind
_ == :: MemBind -> MemBind -> Bool
== MemBind
_ = Bool
True

instance Ord MemBind where
  MemBind
_ compare :: MemBind -> MemBind -> Ordering
`compare` MemBind
_ = Ordering
EQ

instance Rename MemBind where
  rename :: MemBind -> RenameM MemBind
rename = MemBind -> RenameM MemBind
forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute MemBind where
  substituteNames :: Map VName VName -> MemBind -> MemBind
substituteNames Map VName VName
substs (ArrayIn VName
ident IxFun
ixfun) =
    VName -> IxFun -> MemBind
ArrayIn (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (Map VName VName -> IxFun -> IxFun
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs IxFun
ixfun)

instance PP.Pretty MemBind where
  ppr :: MemBind -> Doc
ppr (ArrayIn VName
mem IxFun
ixfun) =
    VName -> Doc
forall a. Pretty a => a -> Doc
PP.ppr VName
mem Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
PP.</> IxFun -> Doc
forall a. Pretty a => a -> Doc
PP.ppr IxFun
ixfun

instance FreeIn MemBind where
  freeIn' :: MemBind -> FV
freeIn' (ArrayIn VName
mem IxFun
ixfun) = VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
mem FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> IxFun -> FV
forall a. FreeIn a => a -> FV
freeIn' IxFun
ixfun

-- | A description of the memory properties of an array being returned
-- by an operation.
data MemReturn
  = -- | The array is located in a memory block that is
    -- already in scope.
    ReturnsInBlock VName ExtIxFun
  | -- | The operation returns a new (existential) memory
    -- block.
    ReturnsNewBlock Space Int ExtIxFun
  deriving (Int -> MemReturn -> ShowS
[MemReturn] -> ShowS
MemReturn -> String
(Int -> MemReturn -> ShowS)
-> (MemReturn -> String)
-> ([MemReturn] -> ShowS)
-> Show MemReturn
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemReturn] -> ShowS
$cshowList :: [MemReturn] -> ShowS
show :: MemReturn -> String
$cshow :: MemReturn -> String
showsPrec :: Int -> MemReturn -> ShowS
$cshowsPrec :: Int -> MemReturn -> ShowS
Show)

instance Eq MemReturn where
  MemReturn
_ == :: MemReturn -> MemReturn -> Bool
== MemReturn
_ = Bool
True

instance Ord MemReturn where
  MemReturn
_ compare :: MemReturn -> MemReturn -> Ordering
`compare` MemReturn
_ = Ordering
EQ

instance Rename MemReturn where
  rename :: MemReturn -> RenameM MemReturn
rename = MemReturn -> RenameM MemReturn
forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute MemReturn where
  substituteNames :: Map VName VName -> MemReturn -> MemReturn
substituteNames Map VName VName
substs (ReturnsInBlock VName
ident ExtIxFun
ixfun) =
    VName -> ExtIxFun -> MemReturn
ReturnsInBlock (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (Map VName VName -> ExtIxFun -> ExtIxFun
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtIxFun
ixfun)
  substituteNames Map VName VName
substs (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (Map VName VName -> ExtIxFun -> ExtIxFun
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ExtIxFun
ixfun)

instance FixExt MemReturn where
  fixExt :: Int -> SubExp -> MemReturn -> MemReturn
fixExt Int
i (Var VName
v) (ReturnsNewBlock Space
_ Int
j ExtIxFun
ixfun)
    | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i =
      VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
v (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
        Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun
          Int
i
          (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 (VName -> SubExp
Var VName
v))
          ExtIxFun
ixfun
  fixExt Int
i SubExp
se (ReturnsNewBlock Space
space Int
j ExtIxFun
ixfun) =
    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock
      Space
space
      Int
j'
      (Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtIxFun
ixfun)
    where
      j' :: Int
j'
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1
        | Bool
otherwise = Int
j
  fixExt Int
i SubExp
se (ReturnsInBlock VName
mem ExtIxFun
ixfun) =
    VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64 SubExp
se) ExtIxFun
ixfun)

fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun :: Int -> PrimExp VName -> ExtIxFun -> ExtIxFun
fixExtIxFun Int
i PrimExp VName
e = (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
 -> ExtIxFun -> ExtIxFun)
-> (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 (Ext VName))
-> ExtIxFun
-> ExtIxFun
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> TPrimExp Int64 (Ext VName)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Ext VName -> PrimType -> PrimExp (Ext VName))
-> PrimExp (Ext VName) -> PrimExp (Ext VName)
forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp Ext VName -> PrimType -> PrimExp (Ext VName)
update (PrimExp (Ext VName) -> PrimExp (Ext VName))
-> (TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName))
-> TPrimExp Int64 (Ext VName)
-> PrimExp (Ext VName)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> PrimExp (Ext VName)
forall t v. TPrimExp t v -> PrimExp v
untyped
  where
    update :: Ext VName -> PrimType -> PrimExp (Ext VName)
update (Ext Int
j) PrimType
t
      | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext (Int -> Ext VName) -> Int -> Ext VName
forall a b. (a -> b) -> a -> b
$ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) PrimType
t
      | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = (VName -> Ext VName) -> PrimExp VName -> PrimExp (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free PrimExp VName
e
      | Bool
otherwise = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
j) PrimType
t
    update (Free VName
x) PrimType
t = Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
x) PrimType
t

leafExp :: Int -> TPrimExp Int64 (Ext a)
leafExp :: forall a. Int -> TPrimExp Int64 (Ext a)
leafExp Int
i = PrimExp (Ext a) -> TPrimExp Int64 (Ext a)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext a) -> TPrimExp Int64 (Ext a))
-> PrimExp (Ext a) -> TPrimExp Int64 (Ext a)
forall a b. (a -> b) -> a -> b
$ Ext a -> PrimType -> PrimExp (Ext a)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext a
forall a. Int -> Ext a
Ext Int
i) PrimType
int64

existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun :: [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [VName]
ctx = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' (ExtIxFun -> ExtIxFun) -> (IxFun -> ExtIxFun) -> IxFun -> ExtIxFun
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free)
  where
    ctx' :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx' = (Int -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Int -> TPrimExp Int64 (Ext VName)
forall a. Int -> TPrimExp Int64 (Ext a)
leafExp (Map (Ext VName) Int
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ [(Ext VName, Int)] -> Map (Ext VName) Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, Int)] -> Map (Ext VName) Int)
-> [(Ext VName, Int)] -> Map (Ext VName) Int
forall a b. (a -> b) -> a -> b
$ [Ext VName] -> [Int] -> [(Ext VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((VName -> Ext VName) -> [VName] -> [Ext VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Ext VName
forall a. a -> Ext a
Free [VName]
ctx) [Int
0 ..]

instance PP.Pretty MemReturn where
  ppr :: MemReturn -> Doc
ppr (ReturnsInBlock VName
v ExtIxFun
ixfun) =
    Doc -> Doc
PP.parens (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
PP.</> ExtIxFun -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ExtIxFun
ixfun
  ppr (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
    Doc
"?" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Int -> Doc
forall a. Pretty a => a -> Doc
ppr Int
i Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Space -> Doc
forall a. Pretty a => a -> Doc
PP.ppr Space
space Doc -> Doc -> Doc
<+> Doc
"->" Doc -> Doc -> Doc
PP.</> ExtIxFun -> Doc
forall a. Pretty a => a -> Doc
PP.ppr ExtIxFun
ixfun

instance FreeIn MemReturn where
  freeIn' :: MemReturn -> FV
freeIn' (ReturnsInBlock VName
v ExtIxFun
ixfun) = VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
v FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ExtIxFun -> FV
forall a. FreeIn a => a -> FV
freeIn' ExtIxFun
ixfun
  freeIn' (ReturnsNewBlock Space
space Int
_ ExtIxFun
ixfun) = Space -> FV
forall a. FreeIn a => a -> FV
freeIn' Space
space FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ExtIxFun -> FV
forall a. FreeIn a => a -> FV
freeIn' ExtIxFun
ixfun

instance Engine.Simplifiable MemReturn where
  simplify :: forall rep.
SimplifiableRep rep =>
MemReturn -> SimpleM rep MemReturn
simplify (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
    Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtIxFun -> MemReturn)
-> SimpleM rep ExtIxFun -> SimpleM rep MemReturn
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtIxFun -> SimpleM rep ExtIxFun
forall rep. SimplifiableRep rep => ExtIxFun -> SimpleM rep ExtIxFun
simplifyExtIxFun ExtIxFun
ixfun
  simplify (ReturnsInBlock VName
v ExtIxFun
ixfun) =
    VName -> ExtIxFun -> MemReturn
ReturnsInBlock (VName -> ExtIxFun -> MemReturn)
-> SimpleM rep VName -> SimpleM rep (ExtIxFun -> MemReturn)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
v SimpleM rep (ExtIxFun -> MemReturn)
-> SimpleM rep ExtIxFun -> SimpleM rep MemReturn
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ExtIxFun -> SimpleM rep ExtIxFun
forall rep. SimplifiableRep rep => ExtIxFun -> SimpleM rep ExtIxFun
simplifyExtIxFun ExtIxFun
ixfun

instance Engine.Simplifiable MemBind where
  simplify :: forall rep. SimplifiableRep rep => MemBind -> SimpleM rep MemBind
simplify (ArrayIn VName
mem IxFun
ixfun) =
    VName -> IxFun -> MemBind
ArrayIn (VName -> IxFun -> MemBind)
-> SimpleM rep VName -> SimpleM rep (IxFun -> MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify VName
mem SimpleM rep (IxFun -> MemBind)
-> SimpleM rep IxFun -> SimpleM rep MemBind
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IxFun -> SimpleM rep IxFun
forall rep. SimplifiableRep rep => IxFun -> SimpleM rep IxFun
simplifyIxFun IxFun
ixfun

instance Engine.Simplifiable [FunReturns] where
  simplify :: forall rep.
SimplifiableRep rep =>
[FunReturns] -> SimpleM rep [FunReturns]
simplify = (FunReturns -> SimpleM rep FunReturns)
-> [FunReturns] -> SimpleM rep [FunReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunReturns -> SimpleM rep FunReturns
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify

-- | The memory return of an expression.  An array is annotated with
-- @Maybe MemReturn@, which can be interpreted as the expression
-- either dictating exactly where the array is located when it is
-- returned (if 'Just'), or able to put it whereever the binding
-- prefers (if 'Nothing').
--
-- This is necessary to capture the difference between an expression
-- that is just an array-typed variable, in which the array being
-- "returned" is located where it already is, and a @copy@ expression,
-- whose entire purpose is to store an existing array in some
-- arbitrary location.  This is a consequence of the design decision
-- never to have implicit memory copies.
type ExpReturns = MemInfo ExtSize NoUniqueness (Maybe MemReturn)

-- | The return of a body, which must always indicate where
-- returned arrays are located.
type BodyReturns = MemInfo ExtSize NoUniqueness MemReturn

-- | The memory return of a function, which must always indicate where
-- returned arrays are located.
type FunReturns = MemInfo ExtSize Uniqueness MemReturn

maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns :: forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns (MemArray PrimType
bt ShapeBase d
shape u
u r
ret) =
  PrimType -> ShapeBase d -> u -> Maybe r -> MemInfo d u (Maybe r)
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape u
u (Maybe r -> MemInfo d u (Maybe r))
-> Maybe r -> MemInfo d u (Maybe r)
forall a b. (a -> b) -> a -> b
$ r -> Maybe r
forall a. a -> Maybe a
Just r
ret
maybeReturns (MemPrim PrimType
bt) =
  PrimType -> MemInfo d u (Maybe r)
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
maybeReturns (MemMem Space
space) =
  Space -> MemInfo d u (Maybe r)
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
maybeReturns (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
  VName -> Shape -> [Type] -> u -> MemInfo d u (Maybe r)
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns :: forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemArray PrimType
bt ShapeBase d
shape u
_ r
r) =
  PrimType
-> ShapeBase d -> NoUniqueness -> r -> MemInfo d NoUniqueness r
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase d
shape NoUniqueness
NoUniqueness r
r
noUniquenessReturns (MemPrim PrimType
bt) =
  PrimType -> MemInfo d NoUniqueness r
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
noUniquenessReturns (MemMem Space
space) =
  Space -> MemInfo d NoUniqueness r
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
noUniquenessReturns (MemAcc VName
acc Shape
ispace [Type]
ts u
_) =
  VName
-> Shape -> [Type] -> NoUniqueness -> MemInfo d NoUniqueness r
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns :: FunReturns -> ExpReturns
funReturnsToExpReturns = MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn) -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn) -> ExpReturns)
-> (FunReturns
    -> MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn))
-> FunReturns
-> ExpReturns
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. FunReturns -> MemInfo (Ext SubExp) Uniqueness (Maybe MemReturn)
forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns

bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns :: BodyReturns -> ExpReturns
bodyReturnsToExpReturns = ExpReturns -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (ExpReturns -> ExpReturns)
-> (BodyReturns -> ExpReturns) -> BodyReturns -> ExpReturns
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. BodyReturns -> ExpReturns
forall d u r. MemInfo d u r -> MemInfo d u (Maybe r)
maybeReturns

matchRetTypeToResult ::
  (Mem rep, TC.Checkable rep) =>
  [FunReturns] ->
  Result ->
  TC.TypeM rep ()
matchRetTypeToResult :: forall rep.
(Mem rep, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype Result
result = do
  Scope (Aliases rep)
scope <- TypeM rep (Scope (Aliases rep))
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  [LParamMem]
result_ts <- ReaderT (Scope rep) (TypeM rep) [LParamMem]
-> Scope rep -> TypeM rep [LParamMem]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((SubExp -> ReaderT (Scope rep) (TypeM rep) LParamMem)
-> Result -> ReaderT (Scope rep) (TypeM rep) [LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ReaderT (Scope rep) (TypeM rep) LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
SubExp -> m LParamMem
subExpMemInfo Result
result) (Scope rep -> TypeM rep [LParamMem])
-> Scope rep -> TypeM rep [LParamMem]
forall a b. (a -> b) -> a -> b
$ Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases Scope (Aliases rep)
scope
  [FunReturns] -> Result -> [LParamMem] -> TypeM rep ()
forall u rep.
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> Result -> [LParamMem] -> TypeM rep ()
matchReturnType [FunReturns]
rettype Result
result [LParamMem]
result_ts

matchFunctionReturnType ::
  (Mem rep, TC.Checkable rep) =>
  [FunReturns] ->
  Result ->
  TC.TypeM rep ()
matchFunctionReturnType :: forall rep.
(Mem rep, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchFunctionReturnType [FunReturns]
rettype Result
result = do
  [FunReturns] -> Result -> TypeM rep ()
forall rep.
(Mem rep, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype Result
result
  (SubExp -> TypeM rep ()) -> Result -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> TypeM rep ()
forall {rep}.
(AllocOp (Op rep), ASTRep rep, OpReturns rep,
 LetDec rep ~ LParamMem, LParamInfo rep ~ LParamMem,
 RetType rep ~ FunReturns, FParamInfo rep ~ FParamMem,
 BranchType rep ~ BodyReturns) =>
SubExp -> TypeM rep ()
checkResultSubExp Result
result
  where
    checkResultSubExp :: SubExp -> TypeM rep ()
checkResultSubExp Constant {} =
      () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    checkResultSubExp (Var VName
v) = do
      LParamMem
dec <- VName -> TypeM rep LParamMem
forall rep. Mem rep => VName -> TypeM rep LParamMem
varMemInfo VName
v
      case LParamMem
dec of
        MemPrim PrimType
_ -> () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        MemMem {} -> () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        MemAcc {} -> () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
_ IxFun
ixfun)
          | IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun
ixfun ->
            () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          | Bool
otherwise ->
            ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
              String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
                String
"Array " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" returned by function, but has nontrivial index function "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ IxFun -> String
forall a. Pretty a => a -> String
pretty IxFun
ixfun

matchLoopResultMem ::
  (Mem rep, TC.Checkable rep) =>
  [FParam (Aliases rep)] ->
  [FParam (Aliases rep)] ->
  [SubExp] ->
  TC.TypeM rep ()
matchLoopResultMem :: forall rep.
(Mem rep, Checkable rep) =>
[FParam (Aliases rep)]
-> [FParam (Aliases rep)] -> Result -> TypeM rep ()
matchLoopResultMem [FParam (Aliases rep)]
ctx [FParam (Aliases rep)]
val = [FunReturns] -> Result -> TypeM rep ()
forall rep.
(Mem rep, Checkable rep) =>
[FunReturns] -> Result -> TypeM rep ()
matchRetTypeToResult [FunReturns]
rettype
  where
    ctx_names :: [VName]
ctx_names = (Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [FParam (Aliases rep)]
[Param FParamMem]
ctx

    -- Invent a ReturnType so we can pretend that the loop body is
    -- actually returning from a function.
    rettype :: [FunReturns]
rettype = (Param FParamMem -> FunReturns)
-> [Param FParamMem] -> [FunReturns]
forall a b. (a -> b) -> [a] -> [b]
map (FParamMem -> FunReturns
toRet (FParamMem -> FunReturns)
-> (Param FParamMem -> FParamMem) -> Param FParamMem -> FunReturns
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [FParam (Aliases rep)]
[Param FParamMem]
val

    toExtV :: VName -> Ext VName
toExtV VName
v
      | Just Int
i <- VName
v VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
ctx_names = Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i
      | Bool
otherwise = VName -> Ext VName
forall a. a -> Ext a
Free VName
v

    toExtSE :: SubExp -> Ext SubExp
toExtSE (Var VName
v) = VName -> SubExp
Var (VName -> SubExp) -> Ext VName -> Ext SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Ext VName
toExtV VName
v
    toExtSE (Constant PrimValue
v) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free (SubExp -> Ext SubExp) -> SubExp -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v

    toRet :: FParamMem -> FunReturns
toRet (MemPrim PrimType
t) =
      PrimType -> FunReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    toRet (MemMem Space
space) =
      Space -> FunReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    toRet (MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u) =
      VName -> Shape -> [Type] -> Uniqueness -> FunReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts Uniqueness
u
    toRet (MemArray PrimType
pt Shape
shape Uniqueness
u (ArrayIn VName
mem IxFun
ixfun))
      | Just Int
i <- VName
mem VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
ctx_names,
        Param VName
_ (MemMem Space
space) : [Param FParamMem]
_ <- Int -> [Param FParamMem] -> [Param FParamMem]
forall a. Int -> [a] -> [a]
drop Int
i [FParam (Aliases rep)]
[Param FParamMem]
ctx =
        PrimType
-> ShapeBase (Ext SubExp) -> Uniqueness -> MemReturn -> FunReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u (MemReturn -> FunReturns) -> MemReturn -> FunReturns
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun'
      | Bool
otherwise =
        PrimType
-> ShapeBase (Ext SubExp) -> Uniqueness -> MemReturn -> FunReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape' Uniqueness
u (MemReturn -> FunReturns) -> MemReturn -> FunReturns
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem ExtIxFun
ixfun'
      where
        shape' :: ShapeBase (Ext SubExp)
shape' = (SubExp -> Ext SubExp) -> Shape -> ShapeBase (Ext SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
toExtSE Shape
shape
        ixfun' :: ExtIxFun
ixfun' = [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [VName]
ctx_names IxFun
ixfun

matchBranchReturnType ::
  (Mem rep, TC.Checkable rep) =>
  [BodyReturns] ->
  Body (Aliases rep) ->
  TC.TypeM rep ()
matchBranchReturnType :: forall rep.
(Mem rep, Checkable rep) =>
[BodyReturns] -> Body (Aliases rep) -> TypeM rep ()
matchBranchReturnType [BodyReturns]
rettype (Body BodyDec (Aliases rep)
_ Stms (Aliases rep)
stms Result
res) = do
  Scope (Aliases rep)
scope <- TypeM rep (Scope (Aliases rep))
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  [LParamMem]
ts <- ReaderT (Scope rep) (TypeM rep) [LParamMem]
-> Scope rep -> TypeM rep [LParamMem]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((SubExp -> ReaderT (Scope rep) (TypeM rep) LParamMem)
-> Result -> ReaderT (Scope rep) (TypeM rep) [LParamMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ReaderT (Scope rep) (TypeM rep) LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
SubExp -> m LParamMem
subExpMemInfo Result
res) (Scope rep -> TypeM rep [LParamMem])
-> Scope rep -> TypeM rep [LParamMem]
forall a b. (a -> b) -> a -> b
$ Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases (Scope (Aliases rep)
scope Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Aliases rep)
stms)
  [BodyReturns] -> Result -> [LParamMem] -> TypeM rep ()
forall u rep.
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> Result -> [LParamMem] -> TypeM rep ()
matchReturnType [BodyReturns]
rettype Result
res [LParamMem]
ts

-- | Helper function for index function unification.
--
-- The first return value maps a VName (wrapped in 'Free') to its Int
-- (wrapped in 'Ext').  In case of duplicates, it is mapped to the
-- *first* Int that occurs.
--
-- The second return value maps each Int (wrapped in an 'Ext') to a
-- 'LeafExp' 'Ext' with the Int at which its associated VName first
-- occurs.
getExtMaps ::
  [(VName, Int)] ->
  ( M.Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    M.Map (Ext VName) (TPrimExp Int64 (Ext VName))
  )
getExtMaps :: [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps [(VName, Int)]
ctx_lst_ids =
  ( (Int -> TPrimExp Int64 (Ext VName))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Int -> TPrimExp Int64 (Ext VName)
forall a. Int -> TPrimExp Int64 (Ext a)
leafExp (Map (Ext VName) Int
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> Map (Ext VName) Int
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ (VName -> Ext VName) -> Map VName Int -> Map (Ext VName) Int
forall k2 k1 a. Ord k2 => (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeys VName -> Ext VName
forall a. a -> Ext a
Free (Map VName Int -> Map (Ext VName) Int)
-> Map VName Int -> Map (Ext VName) Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [(VName, Int)] -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith ((Int -> Int) -> Int -> Int -> Int
forall a b. a -> b -> a
const Int -> Int
forall {k} (cat :: k -> k -> *) (a :: k). Category cat => cat a a
id) [(VName, Int)]
ctx_lst_ids,
    [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Ext VName, TPrimExp Int64 (Ext VName))]
 -> Map (Ext VName) (TPrimExp Int64 (Ext VName)))
-> [(Ext VName, TPrimExp Int64 (Ext VName))]
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$
      ((VName, Int) -> Maybe (Ext VName, TPrimExp Int64 (Ext VName)))
-> [(VName, Int)] -> [(Ext VName, TPrimExp Int64 (Ext VName))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
        ( (VName -> Maybe (TPrimExp Int64 (Ext VName)))
-> (Ext VName, VName)
-> Maybe (Ext VName, TPrimExp Int64 (Ext VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
            ( (Int -> TPrimExp Int64 (Ext VName))
-> Maybe Int -> Maybe (TPrimExp Int64 (Ext VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Int
i -> PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName))
-> PrimExp (Ext VName) -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
int64)
                (Maybe Int -> Maybe (TPrimExp Int64 (Ext VName)))
-> (VName -> Maybe Int)
-> VName
-> Maybe (TPrimExp Int64 (Ext VName))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> [(VName, Int)] -> Maybe Int
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(VName, Int)]
ctx_lst_ids)
            )
            ((Ext VName, VName)
 -> Maybe (Ext VName, TPrimExp Int64 (Ext VName)))
-> ((VName, Int) -> (Ext VName, VName))
-> (VName, Int)
-> Maybe (Ext VName, TPrimExp Int64 (Ext VName))
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (VName -> Ext VName -> (Ext VName, VName))
-> (VName, Ext VName) -> (Ext VName, VName)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Ext VName -> VName -> (Ext VName, VName))
-> VName -> Ext VName -> (Ext VName, VName)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
            ((VName, Ext VName) -> (Ext VName, VName))
-> ((VName, Int) -> (VName, Ext VName))
-> (VName, Int)
-> (Ext VName, VName)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int -> Ext VName) -> (VName, Int) -> (VName, Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> Ext VName
forall a. Int -> Ext a
Ext
        )
        [(VName, Int)]
ctx_lst_ids
  )

matchReturnType ::
  PP.Pretty u =>
  [MemInfo ExtSize u MemReturn] ->
  [SubExp] ->
  [MemInfo SubExp NoUniqueness MemBind] ->
  TC.TypeM rep ()
matchReturnType :: forall u rep.
Pretty u =>
[MemInfo (Ext SubExp) u MemReturn]
-> Result -> [LParamMem] -> TypeM rep ()
matchReturnType [MemInfo (Ext SubExp) u MemReturn]
rettype Result
res [LParamMem]
ts = do
  let ([LParamMem]
ctx_ts, [LParamMem]
val_ts) = Int -> [LParamMem] -> ([LParamMem], [LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([MemInfo (Ext SubExp) u MemReturn] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MemInfo (Ext SubExp) u MemReturn]
rettype) [LParamMem]
ts
      (Result
ctx_res, Result
_val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([MemInfo (Ext SubExp) u MemReturn] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [MemInfo (Ext SubExp) u MemReturn]
rettype) Result
res

      existentialiseIxFun0 :: IxFun -> ExtIxFun
      existentialiseIxFun0 :: IxFun -> ExtIxFun
existentialiseIxFun0 = (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
 -> IxFun -> ExtIxFun)
-> (TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun
-> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free

      fetchCtx :: Int -> ExceptT String (TypeM rep) (SubExp, LParamMem)
fetchCtx Int
i = case Int -> [(SubExp, LParamMem)] -> Maybe (SubExp, LParamMem)
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([(SubExp, LParamMem)] -> Maybe (SubExp, LParamMem))
-> [(SubExp, LParamMem)] -> Maybe (SubExp, LParamMem)
forall a b. (a -> b) -> a -> b
$ Result -> [LParamMem] -> [(SubExp, LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
ctx_res [LParamMem]
ctx_ts of
        Maybe (SubExp, LParamMem)
Nothing ->
          String -> ExceptT String (TypeM rep) (SubExp, LParamMem)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) (SubExp, LParamMem))
-> String -> ExceptT String (TypeM rep) (SubExp, LParamMem)
forall a b. (a -> b) -> a -> b
$
            String
"Cannot find context variable "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" in context results: "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ Result -> String
forall a. Pretty a => a -> String
pretty Result
ctx_res
        Just (SubExp
se, LParamMem
t) -> (SubExp, LParamMem)
-> ExceptT String (TypeM rep) (SubExp, LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
se, LParamMem
t)

      checkReturn :: MemInfo (Ext SubExp) u MemReturn
-> LParamMem -> ExceptT String (TypeM rep) ()
checkReturn (MemPrim PrimType
x) (MemPrim PrimType
y)
        | PrimType
x PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y = () -> ExceptT String (TypeM rep) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      checkReturn (MemMem Space
x) (MemMem Space
y)
        | Space
x Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y = () -> ExceptT String (TypeM rep) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      checkReturn (MemAcc VName
xacc Shape
xispace [Type]
xts u
_) (MemAcc VName
yacc Shape
yispace [Type]
yts NoUniqueness
_)
        | (VName
xacc, Shape
xispace, [Type]
xts) (VName, Shape, [Type]) -> (VName, Shape, [Type]) -> Bool
forall a. Eq a => a -> a -> Bool
== (VName
yacc, Shape
yispace, [Type]
yts) =
          () -> ExceptT String (TypeM rep) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      checkReturn
        (MemArray PrimType
x_pt ShapeBase (Ext SubExp)
x_shape u
_ MemReturn
x_ret)
        (MemArray PrimType
y_pt Shape
y_shape NoUniqueness
_ MemBind
y_ret)
          | PrimType
x_pt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y_pt,
            ShapeBase (Ext SubExp) -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase (Ext SubExp)
x_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
y_shape = do
            (Ext SubExp -> SubExp -> ExceptT String (TypeM rep) ())
-> [Ext SubExp] -> Result -> ExceptT String (TypeM rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Ext SubExp -> SubExp -> ExceptT String (TypeM rep) ()
checkDim (ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
x_shape) (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
y_shape)
            MemReturn -> MemBind -> ExceptT String (TypeM rep) ()
checkMemReturn MemReturn
x_ret MemBind
y_ret
      checkReturn MemInfo (Ext SubExp) u MemReturn
x LParamMem
y =
        String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords [String
"Expected", MemInfo (Ext SubExp) u MemReturn -> String
forall a. Pretty a => a -> String
pretty MemInfo (Ext SubExp) u MemReturn
x, String
"but got", LParamMem -> String
forall a. Pretty a => a -> String
pretty LParamMem
y]

      checkDim :: Ext SubExp -> SubExp -> ExceptT String (TypeM rep) ()
checkDim (Free SubExp
x) SubExp
y
        | SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y = () -> ExceptT String (TypeM rep) ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Bool
otherwise =
          String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
            [String] -> String
unwords
              [ String
"Expected dim",
                SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
x,
                String
"but got",
                SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
y
              ]
      checkDim (Ext Int
i) SubExp
y = do
        (SubExp
x, LParamMem
_) <- Int -> ExceptT String (TypeM rep) (SubExp, LParamMem)
fetchCtx Int
i
        Bool
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (SubExp
x SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y) (ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ())
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
          String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
            [String] -> String
unwords
              [ String
"Expected ext dim",
                Int -> String
forall a. Pretty a => a -> String
pretty Int
i,
                String
"=>",
                SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
x,
                String
"but got",
                SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
y
              ]

      extsInMemInfo :: MemInfo ExtSize u MemReturn -> S.Set Int
      extsInMemInfo :: forall u. MemInfo (Ext SubExp) u MemReturn -> Set Int
extsInMemInfo (MemArray PrimType
_ ShapeBase (Ext SubExp)
shp u
_ MemReturn
ret) =
        ShapeBase (Ext SubExp) -> Set Int
extInShape ShapeBase (Ext SubExp)
shp Set Int -> Set Int -> Set Int
forall a. Semigroup a => a -> a -> a
<> MemReturn -> Set Int
extInMemReturn MemReturn
ret
      extsInMemInfo MemInfo (Ext SubExp) u MemReturn
_ = Set Int
forall a. Set a
S.empty

      checkMemReturn :: MemReturn -> MemBind -> ExceptT String (TypeM rep) ()
checkMemReturn (ReturnsInBlock VName
x_mem ExtIxFun
x_ixfun) (ArrayIn VName
y_mem IxFun
y_ixfun)
        | VName
x_mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y_mem =
          Bool
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun (ExtIxFun -> Bool) -> ExtIxFun -> Bool
forall a b. (a -> b) -> a -> b
$ IxFun -> ExtIxFun
existentialiseIxFun0 IxFun
y_ixfun) (ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ())
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
            String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
              [String] -> String
unwords
                [ String
"Index function unification failed (ReturnsInBlock)",
                  String
"\nixfun of body result: ",
                  IxFun -> String
forall a. Pretty a => a -> String
pretty IxFun
y_ixfun,
                  String
"\nixfun of return type: ",
                  ExtIxFun -> String
forall a. Pretty a => a -> String
pretty ExtIxFun
x_ixfun,
                  String
"\nand context elements: ",
                  Result -> String
forall a. Pretty a => a -> String
pretty Result
ctx_res
                ]
      checkMemReturn
        (ReturnsNewBlock Space
x_space Int
x_ext ExtIxFun
x_ixfun)
        (ArrayIn VName
y_mem IxFun
y_ixfun) = do
          (SubExp
x_mem, LParamMem
x_mem_type) <- Int -> ExceptT String (TypeM rep) (SubExp, LParamMem)
fetchCtx Int
x_ext
          Bool
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun (ExtIxFun -> Bool) -> ExtIxFun -> Bool
forall a b. (a -> b) -> a -> b
$ IxFun -> ExtIxFun
existentialiseIxFun0 IxFun
y_ixfun) (ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ())
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
            String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
              Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$
                Doc
"Index function unification failed (ReturnsNewBlock)"
                  Doc -> Doc -> Doc
</> Doc
"Ixfun of body result:"
                  Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (IxFun -> Doc
forall a. Pretty a => a -> Doc
ppr IxFun
y_ixfun)
                  Doc -> Doc -> Doc
</> Doc
"Ixfun of return type:"
                  Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (ExtIxFun -> Doc
forall a. Pretty a => a -> Doc
ppr ExtIxFun
x_ixfun)
                  Doc -> Doc -> Doc
</> Doc
"Context elements: "
                  Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 (Result -> Doc
forall a. Pretty a => a -> Doc
ppr Result
ctx_res)
          case LParamMem
x_mem_type of
            MemMem Space
y_space ->
              Bool
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Space
x_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y_space) (ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ())
-> ExceptT String (TypeM rep) () -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
                String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
                  [String] -> String
unwords
                    [ String
"Expected memory",
                      VName -> String
forall a. Pretty a => a -> String
pretty VName
y_mem,
                      String
"in space",
                      Space -> String
forall a. Pretty a => a -> String
pretty Space
x_space,
                      String
"but actually in space",
                      Space -> String
forall a. Pretty a => a -> String
pretty Space
y_space
                    ]
            LParamMem
t ->
              String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
                [String] -> String
unwords
                  [ String
"Expected memory",
                    Int -> String
forall a. Pretty a => a -> String
pretty Int
x_ext,
                    String
"=>",
                    SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
x_mem,
                    String
"but but has type",
                    LParamMem -> String
forall a. Pretty a => a -> String
pretty LParamMem
t
                  ]
      checkMemReturn MemReturn
x MemBind
y =
        String -> ExceptT String (TypeM rep) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> ExceptT String (TypeM rep) ())
-> String -> ExceptT String (TypeM rep) ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords
            [ String
"Expected array in",
              MemReturn -> String
forall a. Pretty a => a -> String
pretty MemReturn
x,
              String
"but array returned in",
              MemBind -> String
forall a. Pretty a => a -> String
pretty MemBind
y
            ]

      bad :: String -> TC.TypeM rep a
      bad :: forall rep a. String -> TypeM rep a
bad String
s =
        ErrorCase rep -> TypeM rep a
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep a) -> ErrorCase rep -> TypeM rep a
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            Doc -> String
forall a. Pretty a => a -> String
pretty (Doc -> String) -> Doc -> String
forall a b. (a -> b) -> a -> b
$
              Doc
"Return type"
                Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 ([MemInfo (Ext SubExp) u MemReturn] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [MemInfo (Ext SubExp) u MemReturn]
rettype)
                Doc -> Doc -> Doc
</> Doc
"cannot match returns of results"
                Doc -> Doc -> Doc
</> Int -> Doc -> Doc
indent Int
2 ([LParamMem] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [LParamMem]
ts)
                Doc -> Doc -> Doc
</> String -> Doc
text String
s

  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Set Int -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Set Int] -> Set Int
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set Int] -> Set Int) -> [Set Int] -> Set Int
forall a b. (a -> b) -> a -> b
$ (MemInfo (Ext SubExp) u MemReturn -> Set Int)
-> [MemInfo (Ext SubExp) u MemReturn] -> [Set Int]
forall a b. (a -> b) -> [a] -> [b]
map MemInfo (Ext SubExp) u MemReturn -> Set Int
forall u. MemInfo (Ext SubExp) u MemReturn -> Set Int
extsInMemInfo [MemInfo (Ext SubExp) u MemReturn]
rettype) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
ctx_res) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
        String
"Too many context parameters for the number of "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"existentials in the return type! type:\n  "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ [MemInfo (Ext SubExp) u MemReturn] -> String
forall a. Pretty a => [a] -> String
prettyTuple [MemInfo (Ext SubExp) u MemReturn]
rettype
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\ncannot match context parameters:\n  "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ Result -> String
forall a. Pretty a => [a] -> String
prettyTuple Result
ctx_res

  (String -> TypeM rep ())
-> (() -> TypeM rep ()) -> Either String () -> TypeM rep ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> TypeM rep ()
forall rep a. String -> TypeM rep a
bad () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String () -> TypeM rep ())
-> TypeM rep (Either String ()) -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ExceptT String (TypeM rep) () -> TypeM rep (Either String ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ((MemInfo (Ext SubExp) u MemReturn
 -> LParamMem -> ExceptT String (TypeM rep) ())
-> [MemInfo (Ext SubExp) u MemReturn]
-> [LParamMem]
-> ExceptT String (TypeM rep) ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ MemInfo (Ext SubExp) u MemReturn
-> LParamMem -> ExceptT String (TypeM rep) ()
checkReturn [MemInfo (Ext SubExp) u MemReturn]
rettype [LParamMem]
val_ts)

matchPatternToExp ::
  (Mem rep, TC.Checkable rep) =>
  Pattern (Aliases rep) ->
  Exp (Aliases rep) ->
  TC.TypeM rep ()
matchPatternToExp :: forall rep.
(Mem rep, Checkable rep) =>
Pattern (Aliases rep) -> Exp (Aliases rep) -> TypeM rep ()
matchPatternToExp Pattern (Aliases rep)
pat Exp (Aliases rep)
e = do
  Scope rep
scope <- (Scope (Aliases rep) -> Scope rep) -> TypeM rep (Scope rep)
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases
  [ExpReturns]
rt <- ReaderT (Scope rep) (TypeM rep) [ExpReturns]
-> Scope rep -> TypeM rep [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Exp rep -> ReaderT (Scope rep) (TypeM rep) [ExpReturns]
forall (m :: * -> *) rep.
(Monad m, LocalScope rep m, Mem rep) =>
Exp rep -> m [ExpReturns]
expReturns (Exp rep -> ReaderT (Scope rep) (TypeM rep) [ExpReturns])
-> Exp rep -> ReaderT (Scope rep) (TypeM rep) [ExpReturns]
forall a b. (a -> b) -> a -> b
$ Exp (Aliases rep) -> Exp rep
forall rep. CanBeAliased (Op rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases rep)
e) Scope rep
scope

  let ([(VName, BodyReturns)]
ctxs, [(VName, BodyReturns)]
vals) = PatternT LParamMem
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern (PatternT LParamMem
 -> ([(VName, BodyReturns)], [(VName, BodyReturns)]))
-> PatternT LParamMem
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
forall a b. (a -> b) -> a -> b
$ PatternT (AliasDec, LParamMem) -> PatternT LParamMem
forall a. PatternT (AliasDec, a) -> PatternT a
removePatternAliases PatternT (AliasDec, LParamMem)
Pattern (Aliases rep)
pat
      ([VName]
ctx_ids, [BodyReturns]
_ctx_ts) = [(VName, BodyReturns)] -> ([VName], [BodyReturns])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, BodyReturns)]
ctxs
      ([VName]
_val_ids, [BodyReturns]
val_ts) = [(VName, BodyReturns)] -> ([VName], [BodyReturns])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, BodyReturns)]
vals
      (Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids, Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) =
        [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
getExtMaps ([(VName, Int)]
 -> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
     Map (Ext VName) (TPrimExp Int64 (Ext VName))))
-> [(VName, Int)]
-> (Map (Ext VName) (TPrimExp Int64 (Ext VName)),
    Map (Ext VName) (TPrimExp Int64 (Ext VName)))
forall a b. (a -> b) -> a -> b
$ [VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ctx_ids [Int
0 .. [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ctx_ids Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

  let rt_exts :: Set Int
rt_exts = (ExpReturns -> Set Int) -> [ExpReturns] -> Set Int
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ExpReturns -> Set Int
extInExpReturns [ExpReturns]
rt

  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    ( [BodyReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BodyReturns]
val_ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [ExpReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExpReturns]
rt
        Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((BodyReturns -> ExpReturns -> Bool)
-> [BodyReturns] -> [ExpReturns] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> BodyReturns
-> ExpReturns
-> Bool
forall {d} {u} {u}.
Eq d =>
Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_ids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts) [BodyReturns]
val_ts [ExpReturns]
rt)
        Bool -> Bool -> Bool
&& Map (Ext VName) (TPrimExp Int64 (Ext VName)) -> Set (Ext VName)
forall k a. Map k a -> Set k
M.keysSet Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctx_map_exts Set (Ext VName) -> Set (Ext VName) -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.isSubsetOf` (Int -> Ext VName) -> Set Int -> Set (Ext VName)
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Int -> Ext VName
forall a. Int -> Ext a
Ext Set Int
rt_exts
    )
    (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
        String
"Expression type:\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [ExpReturns] -> String
forall a. Pretty a => [a] -> String
prettyTuple [ExpReturns]
rt
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\ncannot match pattern type:\n  "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ [BodyReturns] -> String
forall a. Pretty a => [a] -> String
prettyTuple [BodyReturns]
val_ts
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nwith context elements: "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ [VName] -> String
forall a. Pretty a => a -> String
pretty [VName]
ctx_ids
  where
    matches :: Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> MemInfo d u MemReturn
-> MemInfo d u (Maybe MemReturn)
-> Bool
matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemPrim PrimType
x) (MemPrim PrimType
y) = PrimType
x PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemMem Space
x_space) (MemMem Space
y_space) =
      Space
x_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
y_space
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ (MemAcc VName
x_accs Shape
x_ispace [Type]
x_ts u
_) (MemAcc VName
y_accs Shape
y_ispace [Type]
y_ts u
_) =
      (VName
x_accs, Shape
x_ispace, [Type]
x_ts) (VName, Shape, [Type]) -> (VName, Shape, [Type]) -> Bool
forall a. Eq a => a -> a -> Bool
== (VName
y_accs, Shape
y_ispace, [Type]
y_ts)
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts (MemArray PrimType
x_pt ShapeBase d
x_shape u
_ MemReturn
x_ret) (MemArray PrimType
y_pt ShapeBase d
y_shape u
_ Maybe MemReturn
y_ret) =
      PrimType
x_pt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
y_pt Bool -> Bool -> Bool
&& ShapeBase d
x_shape ShapeBase d -> ShapeBase d -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase d
y_shape
        Bool -> Bool -> Bool
&& case (MemReturn
x_ret, Maybe MemReturn
y_ret) of
          (ReturnsInBlock VName
_ ExtIxFun
x_ixfun, Just (ReturnsInBlock VName
_ ExtIxFun
y_ixfun)) ->
            let x_ixfun' :: ExtIxFun
x_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
                y_ixfun' :: ExtIxFun
y_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
             in ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
          ( ReturnsInBlock VName
_ ExtIxFun
x_ixfun,
            Just (ReturnsNewBlock Space
_ Int
_ ExtIxFun
y_ixfun)
            ) ->
              let x_ixfun' :: ExtIxFun
x_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
                  y_ixfun' :: ExtIxFun
y_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
               in ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
          ( ReturnsNewBlock Space
_ Int
x_i ExtIxFun
x_ixfun,
            Just (ReturnsNewBlock Space
_ Int
y_i ExtIxFun
y_ixfun)
            ) ->
              let x_ixfun' :: ExtIxFun
x_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxids ExtIxFun
x_ixfun
                  y_ixfun' :: ExtIxFun
y_ixfun' = Map (Ext VName) (TPrimExp Int64 (Ext VName))
-> ExtIxFun -> ExtIxFun
forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun Map (Ext VName) (TPrimExp Int64 (Ext VName))
ctxexts ExtIxFun
y_ixfun
               in Int
x_i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y_i Bool -> Bool -> Bool
&& ExtIxFun -> ExtIxFun -> Bool
forall num. IxFun num -> IxFun num -> Bool
IxFun.closeEnough ExtIxFun
x_ixfun' ExtIxFun
y_ixfun'
          (MemReturn
_, Maybe MemReturn
Nothing) -> Bool
True
          (MemReturn, Maybe MemReturn)
_ -> Bool
False
    matches Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ Map (Ext VName) (TPrimExp Int64 (Ext VName))
_ MemInfo d u MemReturn
_ MemInfo d u (Maybe MemReturn)
_ = Bool
False

    extInExpReturns :: ExpReturns -> S.Set Int
    extInExpReturns :: ExpReturns -> Set Int
extInExpReturns (MemArray PrimType
_ ShapeBase (Ext SubExp)
shape NoUniqueness
_ Maybe MemReturn
mem_return) =
      ShapeBase (Ext SubExp) -> Set Int
extInShape ShapeBase (Ext SubExp)
shape Set Int -> Set Int -> Set Int
forall a. Semigroup a => a -> a -> a
<> Set Int -> (MemReturn -> Set Int) -> Maybe MemReturn -> Set Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Set Int
forall a. Set a
S.empty MemReturn -> Set Int
extInMemReturn Maybe MemReturn
mem_return
    extInExpReturns ExpReturns
_ = Set Int
forall a. Monoid a => a
mempty

extInShape :: ShapeBase (Ext SubExp) -> S.Set Int
extInShape :: ShapeBase (Ext SubExp) -> Set Int
extInShape ShapeBase (Ext SubExp)
shape = [Int] -> Set Int
forall a. Ord a => [a] -> Set a
S.fromList ([Int] -> Set Int) -> [Int] -> Set Int
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> Maybe Int) -> [Ext SubExp] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ext SubExp -> Maybe Int
forall a. Ext a -> Maybe Int
isExt ([Ext SubExp] -> [Int]) -> [Ext SubExp] -> [Int]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape

extInMemReturn :: MemReturn -> S.Set Int
extInMemReturn :: MemReturn -> Set Int
extInMemReturn (ReturnsInBlock VName
_ ExtIxFun
extixfn) = ExtIxFun -> Set Int
extInIxFn ExtIxFun
extixfn
extInMemReturn (ReturnsNewBlock Space
_ Int
i ExtIxFun
extixfn) =
  Int -> Set Int
forall a. a -> Set a
S.singleton Int
i Set Int -> Set Int -> Set Int
forall a. Semigroup a => a -> a -> a
<> ExtIxFun -> Set Int
extInIxFn ExtIxFun
extixfn

extInIxFn :: ExtIxFun -> S.Set Int
extInIxFn :: ExtIxFun -> Set Int
extInIxFn ExtIxFun
ixfun = [Int] -> Set Int
forall a. Ord a => [a] -> Set a
S.fromList ([Int] -> Set Int) -> [Int] -> Set Int
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 (Ext VName) -> [Int]) -> ExtIxFun -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Ext VName -> Maybe Int) -> [Ext VName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ext VName -> Maybe Int
forall a. Ext a -> Maybe Int
isExt ([Ext VName] -> [Int])
-> (TPrimExp Int64 (Ext VName) -> [Ext VName])
-> TPrimExp Int64 (Ext VName)
-> [Int]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 (Ext VName) -> [Ext VName]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList) ExtIxFun
ixfun

varMemInfo ::
  Mem rep =>
  VName ->
  TC.TypeM rep (MemInfo SubExp NoUniqueness MemBind)
varMemInfo :: forall rep. Mem rep => VName -> TypeM rep LParamMem
varMemInfo VName
name = do
  NameInfo (Aliases rep)
dec <- VName -> TypeM rep (NameInfo (Aliases rep))
forall rep. VName -> TypeM rep (NameInfo (Aliases rep))
TC.lookupVar VName
name

  case NameInfo (Aliases rep)
dec of
    LetName (AliasDec
_, LParamMem
summary) -> LParamMem -> TypeM rep LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return LParamMem
summary
    FParamName FParamInfo (Aliases rep)
summary -> LParamMem -> TypeM rep LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> TypeM rep LParamMem)
-> LParamMem -> TypeM rep LParamMem
forall a b. (a -> b) -> a -> b
$ FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo (Aliases rep)
FParamMem
summary
    LParamName LParamInfo (Aliases rep)
summary -> LParamMem -> TypeM rep LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return LParamInfo (Aliases rep)
LParamMem
summary
    IndexName IntType
it -> LParamMem -> TypeM rep LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> TypeM rep LParamMem)
-> LParamMem -> TypeM rep LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

nameInfoToMemInfo :: Mem rep => NameInfo rep -> MemBound NoUniqueness
nameInfoToMemInfo :: forall rep. Mem rep => NameInfo rep -> LParamMem
nameInfoToMemInfo NameInfo rep
info =
  case NameInfo rep
info of
    FParamName FParamInfo rep
summary -> FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
FParamMem
summary
    LParamName LParamInfo rep
summary -> LParamInfo rep
LParamMem
summary
    LetName LetDec rep
summary -> LetDec rep
LParamMem
summary
    IndexName IntType
it -> PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

lookupMemInfo ::
  (HasScope rep m, Mem rep) =>
  VName ->
  m (MemInfo SubExp NoUniqueness MemBind)
lookupMemInfo :: forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo = (NameInfo rep -> LParamMem) -> m (NameInfo rep) -> m LParamMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NameInfo rep -> LParamMem
forall rep. Mem rep => NameInfo rep -> LParamMem
nameInfoToMemInfo (m (NameInfo rep) -> m LParamMem)
-> (VName -> m (NameInfo rep)) -> VName -> m LParamMem
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. VName -> m (NameInfo rep)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (NameInfo rep)
lookupInfo

subExpMemInfo ::
  (HasScope rep m, Monad m, Mem rep) =>
  SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo :: forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
SubExp -> m LParamMem
subExpMemInfo (Var VName
v) = VName -> m LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
v
subExpMemInfo (Constant PrimValue
v) = LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

lookupArraySummary ::
  (Mem rep, HasScope rep m, Monad m) =>
  VName ->
  m (VName, IxFun.IxFun (TPrimExp Int64 VName))
lookupArraySummary :: forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
name = do
  LParamMem
summary <- VName -> m LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
name
  case LParamMem
summary of
    MemArray PrimType
_ Shape
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
      (VName, IxFun) -> m (VName, IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, IxFun
ixfun)
    LParamMem
_ ->
      String -> m (VName, IxFun)
forall a. HasCallStack => String -> a
error (String -> m (VName, IxFun)) -> String -> m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ String
"Variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" does not look like an array."

checkMemInfo ::
  TC.Checkable rep =>
  VName ->
  MemInfo SubExp u MemBind ->
  TC.TypeM rep ()
checkMemInfo :: forall rep u.
Checkable rep =>
VName -> MemInfo SubExp u MemBind -> TypeM rep ()
checkMemInfo VName
_ (MemPrim PrimType
_) = () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkMemInfo VName
_ (MemMem (ScalarSpace Result
d PrimType
_)) = (SubExp -> TypeM rep ()) -> Result -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Result
d
checkMemInfo VName
_ (MemMem Space
_) = () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkMemInfo VName
_ (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
  TypeBase Shape u -> TypeM rep ()
forall rep u. Checkable rep => TypeBase Shape u -> TypeM rep ()
TC.checkType (TypeBase Shape u -> TypeM rep ())
-> TypeBase Shape u -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> u -> TypeBase Shape u
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
checkMemInfo VName
name (MemArray PrimType
_ Shape
shape u
_ (ArrayIn VName
v IxFun
ixfun)) = do
  Type
t <- VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  case Type
t of
    Mem {} ->
      () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Type
_ ->
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          String
"Variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" used as memory block, but is of type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Type -> String
forall a. Pretty a => a -> String
pretty Type
t
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."

  String -> TypeM rep () -> TypeM rep ()
forall rep a. String -> TypeM rep a -> TypeM rep a
TC.context (String
"in index function " String -> ShowS
forall a. [a] -> [a] -> [a]
++ IxFun -> String
forall a. Pretty a => a -> String
pretty IxFun
ixfun) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ do
    (TPrimExp Int64 VName -> TypeM rep ()) -> IxFun -> TypeM rep ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (PrimType -> PrimExp VName -> TypeM rep ()
forall rep.
Checkable rep =>
PrimType -> PrimExp VName -> TypeM rep ()
TC.requirePrimExp PrimType
int64 (PrimExp VName -> TypeM rep ())
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> TypeM rep ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) IxFun
ixfun
    let ixfun_rank :: Int
ixfun_rank = IxFun -> Int
forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun
ixfun
        ident_rank :: Int
ident_rank = Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
ixfun_rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
ident_rank) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          String
"Arity of index function (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
ixfun_rank
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") does not match rank of array "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" ("
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
ident_rank
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

bodyReturnsFromPattern ::
  PatternT (MemBound NoUniqueness) ->
  ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern :: PatternT LParamMem
-> ([(VName, BodyReturns)], [(VName, BodyReturns)])
bodyReturnsFromPattern PatternT LParamMem
pat =
  ( (PatElemT LParamMem -> (VName, BodyReturns))
-> [PatElemT LParamMem] -> [(VName, BodyReturns)]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LParamMem -> (VName, BodyReturns)
asReturns ([PatElemT LParamMem] -> [(VName, BodyReturns)])
-> [PatElemT LParamMem] -> [(VName, BodyReturns)]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT LParamMem
pat,
    (PatElemT LParamMem -> (VName, BodyReturns))
-> [PatElemT LParamMem] -> [(VName, BodyReturns)]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LParamMem -> (VName, BodyReturns)
asReturns ([PatElemT LParamMem] -> [(VName, BodyReturns)])
-> [PatElemT LParamMem] -> [(VName, BodyReturns)]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT LParamMem
pat
  )
  where
    ctx :: [PatElemT LParamMem]
ctx = PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT LParamMem
pat

    ext :: SubExp -> Ext SubExp
ext (Var VName
v)
      | Just (Int
i, PatElemT LParamMem
_) <- ((Int, PatElemT LParamMem) -> Bool)
-> [(Int, PatElemT LParamMem)] -> Maybe (Int, PatElemT LParamMem)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Int, PatElemT LParamMem) -> VName)
-> (Int, PatElemT LParamMem)
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT LParamMem -> VName)
-> ((Int, PatElemT LParamMem) -> PatElemT LParamMem)
-> (Int, PatElemT LParamMem)
-> VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, PatElemT LParamMem) -> PatElemT LParamMem
forall a b. (a, b) -> b
snd) ([(Int, PatElemT LParamMem)] -> Maybe (Int, PatElemT LParamMem))
-> [(Int, PatElemT LParamMem)] -> Maybe (Int, PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT LParamMem] -> [(Int, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElemT LParamMem]
ctx =
        Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
i
    ext SubExp
se = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
se

    asReturns :: PatElemT LParamMem -> (VName, BodyReturns)
asReturns PatElemT LParamMem
pe =
      ( PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe,
        case PatElemT LParamMem -> LParamMem
forall dec. PatElemT dec -> dec
patElemDec PatElemT LParamMem
pe of
          MemPrim PrimType
pt -> PrimType -> BodyReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
          MemMem Space
space -> Space -> BodyReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
          MemArray PrimType
pt Shape
shape NoUniqueness
u (ArrayIn VName
mem IxFun
ixfun) ->
            PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> MemReturn
-> BodyReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (SubExp -> Ext SubExp) -> Result -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
ext (Result -> [Ext SubExp]) -> Result -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) NoUniqueness
u (MemReturn -> BodyReturns) -> MemReturn -> BodyReturns
forall a b. (a -> b) -> a -> b
$
              case ((Int, PatElemT LParamMem) -> Bool)
-> [(Int, PatElemT LParamMem)] -> Maybe (Int, PatElemT LParamMem)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
mem) (VName -> Bool)
-> ((Int, PatElemT LParamMem) -> VName)
-> (Int, PatElemT LParamMem)
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT LParamMem -> VName)
-> ((Int, PatElemT LParamMem) -> PatElemT LParamMem)
-> (Int, PatElemT LParamMem)
-> VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, PatElemT LParamMem) -> PatElemT LParamMem
forall a b. (a, b) -> b
snd) ([(Int, PatElemT LParamMem)] -> Maybe (Int, PatElemT LParamMem))
-> [(Int, PatElemT LParamMem)] -> Maybe (Int, PatElemT LParamMem)
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT LParamMem] -> [(Int, PatElemT LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [PatElemT LParamMem]
ctx of
                Just (Int
i, PatElem VName
_ (MemMem Space
space)) ->
                  Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
                    [VName] -> IxFun -> ExtIxFun
existentialiseIxFun ((PatElemT LParamMem -> VName) -> [PatElemT LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName [PatElemT LParamMem]
ctx) IxFun
ixfun
                Maybe (Int, PatElemT LParamMem)
_ -> VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
          MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> VName -> Shape -> [Type] -> NoUniqueness -> BodyReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
      )

extReturns :: [ExtType] -> [ExpReturns]
extReturns :: [ExtType] -> [ExpReturns]
extReturns [ExtType]
ets =
  State Int [ExpReturns] -> Int -> [ExpReturns]
forall s a. State s a -> s -> a
evalState ((ExtType -> StateT Int Identity ExpReturns)
-> [ExtType] -> State Int [ExpReturns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ExtType -> StateT Int Identity ExpReturns
forall {m :: * -> *}. MonadState Int m => ExtType -> m ExpReturns
addDec [ExtType]
ets) Int
0
  where
    addDec :: ExtType -> m ExpReturns
addDec (Prim PrimType
bt) =
      ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
    addDec (Mem Space
space) =
      ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    addDec t :: ExtType
t@(Array PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u)
      | ExtType -> Bool
existential ExtType
t = do
        Int
i <- m Int
forall s (m :: * -> *). MonadState s m => m s
get m Int -> m () -> m Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* (Int -> Int) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$
          PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
            MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$
              Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
DefaultSpace Int
i (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
                Shape (TPrimExp Int64 (Ext VName)) -> ExtIxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 (Ext VName)) -> ExtIxFun)
-> Shape (TPrimExp Int64 (Ext VName)) -> ExtIxFun
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> TPrimExp Int64 (Ext VName))
-> [Ext SubExp] -> Shape (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> TPrimExp Int64 (Ext VName)
convert ([Ext SubExp] -> Shape (TPrimExp Int64 (Ext VName)))
-> [Ext SubExp] -> Shape (TPrimExp Int64 (Ext VName))
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
      | Bool
otherwise =
        ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase (Ext SubExp)
shape NoUniqueness
u Maybe MemReturn
forall a. Maybe a
Nothing
    addDec (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u) =
      ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
    convert :: Ext SubExp -> TPrimExp Int64 (Ext VName)
convert (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
    convert (Free SubExp
v) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
v

arrayVarReturns ::
  (HasScope rep m, Monad m, Mem rep) =>
  VName ->
  m (PrimType, Shape, VName, IxFun)
arrayVarReturns :: forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v = do
  LParamMem
summary <- VName -> m LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
summary of
    MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
      (PrimType, Shape, VName, IxFun)
-> m (PrimType, Shape, VName, IxFun)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType
et, Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape, VName
mem, IxFun
ixfun)
    LParamMem
_ ->
      String -> m (PrimType, Shape, VName, IxFun)
forall a. HasCallStack => String -> a
error (String -> m (PrimType, Shape, VName, IxFun))
-> String -> m (PrimType, Shape, VName, IxFun)
forall a b. (a -> b) -> a -> b
$ String
"arrayVarReturns: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is not an array."

varReturns ::
  (HasScope rep m, Monad m, Mem rep) =>
  VName ->
  m ExpReturns
varReturns :: forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m ExpReturns
varReturns VName
v = do
  LParamMem
summary <- VName -> m LParamMem
forall rep (m :: * -> *).
(HasScope rep m, Mem rep) =>
VName -> m LParamMem
lookupMemInfo VName
v
  case LParamMem
summary of
    MemPrim PrimType
bt ->
      ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt
    MemArray PrimType
et Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
      ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$
        PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ((SubExp -> Ext SubExp) -> Shape -> ShapeBase (Ext SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
forall a. a -> Ext a
Free Shape
shape) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
          MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
    MemMem Space
space ->
      ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
      ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u

subExpReturns :: (HasScope rep m, Monad m, Mem rep) => SubExp -> m ExpReturns
subExpReturns :: forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
SubExp -> m ExpReturns
subExpReturns (Var VName
v) =
  VName -> m ExpReturns
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m ExpReturns
varReturns VName
v
subExpReturns (Constant PrimValue
v) =
  ExpReturns -> m ExpReturns
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> ExpReturns) -> PrimType -> ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

-- | The return information of an expression.  This can be seen as the
-- "return type with memory annotations" of the expression.
expReturns ::
  ( Monad m,
    LocalScope rep m,
    Mem rep
  ) =>
  Exp rep ->
  m [ExpReturns]
expReturns :: forall (m :: * -> *) rep.
(Monad m, LocalScope rep m, Mem rep) =>
Exp rep -> m [ExpReturns]
expReturns (BasicOp (SubExp SubExp
se)) =
  ExpReturns -> [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns]) -> m ExpReturns -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> m ExpReturns
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
SubExp -> m ExpReturns
subExpReturns SubExp
se
expReturns (BasicOp (Opaque (Var VName
v))) =
  ExpReturns -> [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns]) -> m ExpReturns -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp (Reshape ShapeChange SubExp
newshape VName
v)) = do
  (PrimType
et, Shape
_, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (DimChange SubExp -> Ext SubExp)
-> ShapeChange SubExp -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Ext SubExp
forall a. a -> Ext a
Free (SubExp -> Ext SubExp)
-> (DimChange SubExp -> SubExp) -> DimChange SubExp -> Ext SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DimChange SubExp -> SubExp
forall d. DimChange d -> d
newDim) ShapeChange SubExp
newshape) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
        MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$
          VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$
            [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] (IxFun -> ExtIxFun) -> IxFun -> ExtIxFun
forall a b. (a -> b) -> a -> b
$
              IxFun -> ShapeChange (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
IxFun.reshape IxFun
ixfun (ShapeChange (TPrimExp Int64 VName) -> IxFun)
-> ShapeChange (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimChange SubExp -> DimChange (TPrimExp Int64 VName))
-> ShapeChange SubExp -> ShapeChange (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimChange SubExp -> DimChange (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) ShapeChange SubExp
newshape
    ]
expReturns (BasicOp (Rearrange [Int]
perm VName
v)) = do
  (PrimType
et, Shape Result
dims, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  let ixfun' :: IxFun
ixfun' = IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
ixfun [Int]
perm
      dims' :: Result
dims' = [Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm Result
dims
  [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (SubExp -> Ext SubExp) -> Result -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
forall a. a -> Ext a
Free Result
dims') NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
        MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun'
    ]
expReturns (BasicOp (Rotate Result
offsets VName
v)) = do
  (PrimType
et, Shape Result
dims, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  let offsets' :: [TPrimExp Int64 VName]
offsets' = (SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 Result
offsets
      ixfun' :: IxFun
ixfun' = IxFun -> [TPrimExp Int64 VName] -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> IxFun num
IxFun.rotate IxFun
ixfun [TPrimExp Int64 VName]
offsets'
  [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (SubExp -> Ext SubExp) -> Result -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
forall a. a -> Ext a
Free Result
dims) NoUniqueness
NoUniqueness (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
        MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun'
    ]
expReturns (BasicOp (Index VName
v Slice SubExp
slice)) = do
  LParamMem
info <- VName -> Slice SubExp -> m LParamMem
forall (m :: * -> *) rep.
(Monad m, HasScope rep m, Mem rep) =>
VName -> Slice SubExp -> m LParamMem
sliceInfo VName
v Slice SubExp
slice
  case LParamMem
info of
    MemArray PrimType
et Shape
shape NoUniqueness
u (ArrayIn VName
mem IxFun
ixfun) ->
      [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return
        [ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et ((SubExp -> Ext SubExp) -> Shape -> ShapeBase (Ext SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Ext SubExp
forall a. a -> Ext a
Free Shape
shape) NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
            MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> MemReturn) -> ExtIxFun -> MemReturn
forall a b. (a -> b) -> a -> b
$ [VName] -> IxFun -> ExtIxFun
existentialiseIxFun [] IxFun
ixfun
        ]
    MemPrim PrimType
pt -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt]
    MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u]
    MemMem Space
space -> [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return [Space -> ExpReturns
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
expReturns (BasicOp (Update VName
v Slice SubExp
_ SubExp
_)) =
  ExpReturns -> [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpReturns -> [ExpReturns]) -> m ExpReturns -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m ExpReturns
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m ExpReturns
varReturns VName
v
expReturns (BasicOp BasicOp
op) =
  [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns])
-> ([Type] -> [ExtType]) -> [Type] -> [ExpReturns]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes ([Type] -> [ExpReturns]) -> m [Type] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BasicOp -> m [Type]
forall rep (m :: * -> *). HasScope rep m => BasicOp -> m [Type]
primOpType BasicOp
op
expReturns e :: ExpT rep
e@(DoLoop [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
val LoopForm rep
_ BodyT rep
_) = do
  [ExtType]
t <- ExpT rep -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType ExpT rep
e
  (ExtType -> Param FParamMem -> m ExpReturns)
-> [ExtType] -> [Param FParamMem] -> m [ExpReturns]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> Param FParamMem -> m ExpReturns
typeWithDec [ExtType]
t ([Param FParamMem] -> m [ExpReturns])
-> [Param FParamMem] -> m [ExpReturns]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
val
  where
    typeWithDec :: ExtType -> Param FParamMem -> m ExpReturns
typeWithDec ExtType
t Param FParamMem
p =
      case (ExtType
t, Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
p) of
        ( Array PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u,
          MemArray PrimType
_ Shape
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun)
          )
            | Just (Int
i, Param FParamMem
mem_p) <- VName -> Maybe (Int, Param FParamMem)
isMergeVar VName
mem,
              Mem Space
space <- Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
mem_p ->
              ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$ MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun'
            | Bool
otherwise ->
              ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return
                ( PrimType
-> ShapeBase (Ext SubExp)
-> NoUniqueness
-> Maybe MemReturn
-> ExpReturns
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase (Ext SubExp)
shape NoUniqueness
u (Maybe MemReturn -> ExpReturns) -> Maybe MemReturn -> ExpReturns
forall a b. (a -> b) -> a -> b
$
                    MemReturn -> Maybe MemReturn
forall a. a -> Maybe a
Just (MemReturn -> Maybe MemReturn) -> MemReturn -> Maybe MemReturn
forall a b. (a -> b) -> a -> b
$ VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem ExtIxFun
ixfun'
                )
            where
              ixfun' :: ExtIxFun
ixfun' = [VName] -> IxFun -> ExtIxFun
existentialiseIxFun ((Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [Param FParamMem]
mergevars) IxFun
ixfun
        (Array {}, FParamMem
_) ->
          String -> m ExpReturns
forall a. HasCallStack => String -> a
error String
"expReturns: Array return type but not array merge variable."
        (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u, FParamMem
_) ->
          ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> ExpReturns
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u
        (Prim PrimType
pt, FParamMem
_) ->
          ExpReturns -> m ExpReturns
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpReturns -> m ExpReturns) -> ExpReturns -> m ExpReturns
forall a b. (a -> b) -> a -> b
$ PrimType -> ExpReturns
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
pt
        (Mem {}, FParamMem
_) ->
          String -> m ExpReturns
forall a. HasCallStack => String -> a
error String
"expReturns: loop returns memory block explicitly."
    isMergeVar :: VName -> Maybe (Int, Param FParamMem)
isMergeVar VName
v = ((Int, Param FParamMem) -> Bool)
-> [(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Int, Param FParamMem) -> VName)
-> (Int, Param FParamMem)
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Int, Param FParamMem) -> Param FParamMem)
-> (Int, Param FParamMem)
-> VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, Param FParamMem) -> Param FParamMem
forall a b. (a, b) -> b
snd) ([(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem))
-> [(Int, Param FParamMem)] -> Maybe (Int, Param FParamMem)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Param FParamMem] -> [(Int, Param FParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [Param FParamMem]
mergevars
    mergevars :: [Param FParamMem]
mergevars = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst ([(Param FParamMem, SubExp)] -> [Param FParamMem])
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
val
expReturns (Apply Name
_ [(SubExp, Diet)]
_ [RetType rep]
ret (Safety, SrcLoc, [SrcLoc])
_) =
  [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpReturns] -> m [ExpReturns]) -> [ExpReturns] -> m [ExpReturns]
forall a b. (a -> b) -> a -> b
$ (FunReturns -> ExpReturns) -> [FunReturns] -> [ExpReturns]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> ExpReturns
funReturnsToExpReturns [RetType rep]
[FunReturns]
ret
expReturns (If SubExp
_ BodyT rep
_ BodyT rep
_ (IfDec [BranchType rep]
ret IfSort
_)) =
  [ExpReturns] -> m [ExpReturns]
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExpReturns] -> m [ExpReturns]) -> [ExpReturns] -> m [ExpReturns]
forall a b. (a -> b) -> a -> b
$ (BodyReturns -> ExpReturns) -> [BodyReturns] -> [ExpReturns]
forall a b. (a -> b) -> [a] -> [b]
map BodyReturns -> ExpReturns
bodyReturnsToExpReturns [BranchType rep]
[BodyReturns]
ret
expReturns (Op Op rep
op) =
  Op rep -> m [ExpReturns]
forall rep (m :: * -> *).
(OpReturns rep, Monad m, HasScope rep m) =>
Op rep -> m [ExpReturns]
opReturns Op rep
op
expReturns (WithAcc [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs Lambda rep
lam) =
  [ExpReturns] -> [ExpReturns] -> [ExpReturns]
forall a. Semigroup a => a -> a -> a
(<>)
    ([ExpReturns] -> [ExpReturns] -> [ExpReturns])
-> m [ExpReturns] -> m ([ExpReturns] -> [ExpReturns])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([[ExpReturns]] -> [ExpReturns]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[ExpReturns]] -> [ExpReturns])
-> m [[ExpReturns]] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Shape, [VName], Maybe (Lambda rep, Result)) -> m [ExpReturns])
-> [(Shape, [VName], Maybe (Lambda rep, Result))]
-> m [[ExpReturns]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Shape, [VName], Maybe (Lambda rep, Result)) -> m [ExpReturns]
forall {rep} {m :: * -> *} {t :: * -> *} {a} {c}.
(HasScope rep m, Monad m, Traversable t, AllocOp (Op rep),
 ASTRep rep, OpReturns rep, LetDec rep ~ LParamMem,
 LParamInfo rep ~ LParamMem, RetType rep ~ FunReturns,
 FParamInfo rep ~ FParamMem, BranchType rep ~ BodyReturns) =>
(a, t VName, c) -> m (t ExpReturns)
inputReturns [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs)
    m ([ExpReturns] -> [ExpReturns])
-> m [ExpReturns] -> m [ExpReturns]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
    -- XXX: this is a bit dubious because it enforces extra copies.  I
    -- think WithAcc should perhaps have a return annotation like If.
    [ExpReturns] -> m [ExpReturns]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> [ExtType] -> [ExpReturns]
forall a b. (a -> b) -> a -> b
$ [Type] -> [ExtType]
forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes ([Type] -> [ExtType]) -> [Type] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam)
  where
    inputReturns :: (a, t VName, c) -> m (t ExpReturns)
inputReturns (a
_, t VName
arrs, c
_) = (VName -> m ExpReturns) -> t VName -> m (t ExpReturns)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m ExpReturns
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m ExpReturns
varReturns t VName
arrs
    num_accs :: Int
num_accs = [(Shape, [VName], Maybe (Lambda rep, Result))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs

sliceInfo ::
  (Monad m, HasScope rep m, Mem rep) =>
  VName ->
  Slice SubExp ->
  m (MemInfo SubExp NoUniqueness MemBind)
sliceInfo :: forall (m :: * -> *) rep.
(Monad m, HasScope rep m, Mem rep) =>
VName -> Slice SubExp -> m LParamMem
sliceInfo VName
v Slice SubExp
slice = do
  (PrimType
et, Shape
_, VName
mem, IxFun
ixfun) <- VName -> m (PrimType, Shape, VName, IxFun)
forall rep (m :: * -> *).
(HasScope rep m, Monad m, Mem rep) =>
VName -> m (PrimType, Shape, VName, IxFun)
arrayVarReturns VName
v
  case Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice of
    [] -> LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$ PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
et
    Result
dims ->
      LParamMem -> m LParamMem
forall (m :: * -> *) a. Monad m => a -> m a
return (LParamMem -> m LParamMem) -> LParamMem -> m LParamMem
forall a b. (a -> b) -> a -> b
$
        PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (Result -> Shape
forall d. [d] -> ShapeBase d
Shape Result
dims) NoUniqueness
NoUniqueness (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$
          VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
            IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice
              IxFun
ixfun
              ((DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64)) Slice SubExp
slice)

class TypedOp (Op rep) => OpReturns rep where
  opReturns ::
    (Monad m, HasScope rep m) =>
    Op rep ->
    m [ExpReturns]
  opReturns Op rep
op = [ExtType] -> [ExpReturns]
extReturns ([ExtType] -> [ExpReturns]) -> m [ExtType] -> m [ExpReturns]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Op rep -> m [ExtType]
forall op t (m :: * -> *).
(TypedOp op, HasScope t m) =>
op -> m [ExtType]
opType Op rep
op

applyFunReturns ::
  Typed dec =>
  [FunReturns] ->
  [Param dec] ->
  [(SubExp, Type)] ->
  Maybe [FunReturns]
applyFunReturns :: forall dec.
Typed dec =>
[FunReturns]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns]
applyFunReturns [FunReturns]
rets [Param dec]
params [(SubExp, Type)]
args
  | Just [DeclExtType]
_ <- [DeclExtType]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [DeclExtType]
forall rt dec.
(IsRetType rt, Typed dec) =>
[rt] -> [Param dec] -> [(SubExp, Type)] -> Maybe [rt]
applyRetType [DeclExtType]
rettype [Param dec]
params [(SubExp, Type)]
args =
    [FunReturns] -> Maybe [FunReturns]
forall a. a -> Maybe a
Just ([FunReturns] -> Maybe [FunReturns])
-> [FunReturns] -> Maybe [FunReturns]
forall a b. (a -> b) -> a -> b
$ (FunReturns -> FunReturns) -> [FunReturns] -> [FunReturns]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> FunReturns
forall {u}.
MemInfo (Ext SubExp) u MemReturn
-> MemInfo (Ext SubExp) u MemReturn
correctDims [FunReturns]
rets
  | Bool
otherwise =
    Maybe [FunReturns]
forall a. Maybe a
Nothing
  where
    rettype :: [DeclExtType]
rettype = (FunReturns -> DeclExtType) -> [FunReturns] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map FunReturns -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf [FunReturns]
rets
    parammap :: M.Map VName (SubExp, Type)
    parammap :: Map VName (SubExp, Type)
parammap =
      [(VName, (SubExp, Type))] -> Map VName (SubExp, Type)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (SubExp, Type))] -> Map VName (SubExp, Type))
-> [(VName, (SubExp, Type))] -> Map VName (SubExp, Type)
forall a b. (a -> b) -> a -> b
$
        [VName] -> [(SubExp, Type)] -> [(VName, (SubExp, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params) [(SubExp, Type)]
args

    substSubExp :: SubExp -> SubExp
substSubExp (Var VName
v)
      | Just (SubExp
se, Type
_) <- VName -> Map VName (SubExp, Type) -> Maybe (SubExp, Type)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (SubExp, Type)
parammap = SubExp
se
    substSubExp SubExp
se = SubExp
se

    correctDims :: MemInfo (Ext SubExp) u MemReturn
-> MemInfo (Ext SubExp) u MemReturn
correctDims (MemPrim PrimType
t) =
      PrimType -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
t
    correctDims (MemMem Space
space) =
      Space -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space
    correctDims (MemArray PrimType
et ShapeBase (Ext SubExp)
shape u
u MemReturn
memsummary) =
      PrimType
-> ShapeBase (Ext SubExp)
-> u
-> MemReturn
-> MemInfo (Ext SubExp) u MemReturn
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
et (ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
correctShape ShapeBase (Ext SubExp)
shape) u
u (MemReturn -> MemInfo (Ext SubExp) u MemReturn)
-> MemReturn -> MemInfo (Ext SubExp) u MemReturn
forall a b. (a -> b) -> a -> b
$
        MemReturn -> MemReturn
correctSummary MemReturn
memsummary
    correctDims (MemAcc VName
acc Shape
ispace [Type]
ts u
u) =
      VName -> Shape -> [Type] -> u -> MemInfo (Ext SubExp) u MemReturn
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts u
u

    correctShape :: ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
correctShape = [Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> ShapeBase (Ext SubExp)
-> ShapeBase (Ext SubExp)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Ext SubExp -> Ext SubExp) -> [Ext SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> Ext SubExp
correctDim ([Ext SubExp] -> [Ext SubExp])
-> (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> ShapeBase (Ext SubExp)
-> [Ext SubExp]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims
    correctDim :: Ext SubExp -> Ext SubExp
correctDim (Ext Int
i) = Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
i
    correctDim (Free SubExp
se) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free (SubExp -> Ext SubExp) -> SubExp -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp
substSubExp SubExp
se

    correctSummary :: MemReturn -> MemReturn
correctSummary (ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun) =
      Space -> Int -> ExtIxFun -> MemReturn
ReturnsNewBlock Space
space Int
i ExtIxFun
ixfun
    correctSummary (ReturnsInBlock VName
mem ExtIxFun
ixfun) =
      -- FIXME: we should also do a replacement in ixfun here.
      VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem' ExtIxFun
ixfun
      where
        mem' :: VName
mem' = case VName -> Map VName (SubExp, Type) -> Maybe (SubExp, Type)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem Map VName (SubExp, Type)
parammap of
          Just (Var VName
v, Type
_) -> VName
v
          Maybe (SubExp, Type)
_ -> VName
mem