{-# LANGUAGE TypeFamilies #-}

-- | This module provides facilities for obtaining the types of
-- various Futhark constructs.  Typically, you will need to execute
-- these in a context where type information is available as a
-- 'Scope'; usually by using a monad that is an instance of
-- 'HasScope'.  The information is returned as a list of 'ExtType'
-- values - one for each of the values the Futhark construct returns.
-- Some constructs (such as subexpressions) can produce only a single
-- value, and their typing functions hence do not return a list.
--
-- Some representations may have more specialised facilities enabling
-- even more information - for example,
-- "Futhark.IR.Mem" exposes functionality for
-- also obtaining information about the storage location of results.
module Futhark.IR.Prop.TypeOf
  ( expExtType,
    subExpType,
    subExpResType,
    basicOpType,
    mapType,

    -- * Return type
    module Futhark.IR.RetType,

    -- * Type environment
    module Futhark.IR.Prop.Scope,

    -- * Extensibility
    TypedOp (..),
  )
where

import Data.List.NonEmpty (NonEmpty (..))
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Scope
import Futhark.IR.Prop.Types
import Futhark.IR.RetType
import Futhark.IR.Syntax

-- | The type of a subexpression.
subExpType :: (HasScope t m) => SubExp -> m Type
subExpType :: forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType (Constant PrimValue
val) = Type -> m Type
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
val
subExpType (Var VName
name) = VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
name

-- | Type type of a 'SubExpRes' - not that this might refer to names
-- bound in the body containing the result.
subExpResType :: (HasScope t m) => SubExpRes -> m Type
subExpResType :: forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType = SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType (SubExp -> m Type) -> (SubExpRes -> SubExp) -> SubExpRes -> m Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp

-- | @mapType f arrts@ wraps each element in the return type of @f@ in
-- an array with size equal to the outermost dimension of the first
-- element of @arrts@.
mapType :: SubExp -> Lambda rep -> [Type]
mapType :: forall rep. SubExp -> Lambda rep -> [Type]
mapType SubExp
outersize Lambda rep
f =
  [ Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
outersize]) NoUniqueness
NoUniqueness
    | Type
t <- Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
f
  ]

-- | The type of a primitive operation.
basicOpType :: (HasScope rep m) => BasicOp -> m [Type]
basicOpType :: forall rep (m :: * -> *). HasScope rep m => BasicOp -> m [Type]
basicOpType (SubExp SubExp
se) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
basicOpType (Opaque OpaqueOp
_ SubExp
se) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
basicOpType (ArrayVal [PrimValue]
vs PrimType
t) =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase Shape Any -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (PrimType -> TypeBase Shape Any
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) NoUniqueness
NoUniqueness]
  where
    n :: SubExp
n = IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs
basicOpType (ArrayLit [SubExp]
es Type
rt) =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
rt ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) NoUniqueness
NoUniqueness]
  where
    n :: SubExp
n = IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es
basicOpType (BinOp BinOp
bop SubExp
_ SubExp
_) =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ BinOp -> PrimType
binOpType BinOp
bop]
basicOpType (UnOp UnOp
_ SubExp
x) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
x
basicOpType CmpOp {} =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool]
basicOpType (ConvOp ConvOp
conv SubExp
_) =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ (PrimType, PrimType) -> PrimType
forall a b. (a, b) -> b
snd ((PrimType, PrimType) -> PrimType)
-> (PrimType, PrimType) -> PrimType
forall a b. (a -> b) -> a -> b
$ ConvOp -> (PrimType, PrimType)
convOpType ConvOp
conv]
basicOpType (Index VName
ident Slice SubExp
slice) =
  Type -> [Type]
result (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ident
  where
    result :: Type -> [Type]
result Type
t = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Type -> Shape -> Type
`arrayOfShape` Shape
shape]
    shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
basicOpType (Update Safety
_ VName
src Slice SubExp
_ SubExp
_) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
src
basicOpType (FlatIndex VName
ident FlatSlice SubExp
slice) =
  Type -> [Type]
result (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ident
  where
    result :: Type -> [Type]
result Type
t = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Type -> Shape -> Type
`arrayOfShape` Shape
shape]
    shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ FlatSlice SubExp -> [SubExp]
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
slice
basicOpType (FlatUpdate VName
src FlatSlice SubExp
_ VName
_) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
src
basicOpType (Iota SubExp
n SubExp
_ SubExp
_ IntType
et) =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase Shape Any -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (PrimType -> TypeBase Shape Any
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
et)) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) NoUniqueness
NoUniqueness]
basicOpType (Replicate (Shape []) SubExp
e) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e
basicOpType (Replicate Shape
shape SubExp
e) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> (Type -> Type) -> Type -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Shape -> Type) -> Shape -> Type -> Type
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Shape -> Type
arrayOfShape Shape
shape (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e
basicOpType (Scratch PrimType
t [SubExp]
shape) =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase Shape Any -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (PrimType -> TypeBase Shape Any
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
shape) NoUniqueness
NoUniqueness]
basicOpType (Reshape ReshapeKind
_ (Shape []) VName
e) =
  Type -> [Type]
forall {shape} {u} {shape} {u}.
TypeBase shape u -> [TypeBase shape u]
result (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
e
  where
    result :: TypeBase shape u -> [TypeBase shape u]
result TypeBase shape u
t = [PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase shape u) -> PrimType -> TypeBase shape u
forall a b. (a -> b) -> a -> b
$ TypeBase shape u -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase shape u
t]
basicOpType (Reshape ReshapeKind
_ Shape
shape VName
e) =
  Type -> [Type]
result (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
e
  where
    result :: Type -> [Type]
result Type
t = [Type
t Type -> Shape -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Shape
shape]
basicOpType (Rearrange [Int]
perm VName
e) =
  Type -> [Type]
result (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
e
  where
    result :: Type -> [Type]
result Type
t = [[Int] -> Type -> Type
rearrangeType [Int]
perm Type
t]
basicOpType (Concat Int
i (VName
x :| [VName]
_) SubExp
ressize) =
  Type -> [Type]
result (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
  where
    result :: Type -> [Type]
result Type
xt = [Int -> Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize Int
i Type
xt SubExp
ressize]
basicOpType (Manifest [Int]
_ VName
v) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
basicOpType Assert {} =
  [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit]
basicOpType (UpdateAcc Safety
_ VName
v [SubExp]
_ [SubExp]
_) =
  Type -> [Type]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type -> [Type]) -> m Type -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v

-- | The type of an expression.
expExtType ::
  (HasScope rep m, TypedOp (OpC rep)) =>
  Exp rep ->
  m [ExtType]
expExtType :: forall rep (m :: * -> *).
(HasScope rep m, TypedOp (OpC rep)) =>
Exp rep -> m [ExtType]
expExtType (Apply Name
_ [(SubExp, Diet)]
_ [(RetType rep, RetAls)]
rt (Safety, SrcLoc, [SrcLoc])
_) = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType]) -> [ExtType] -> m [ExtType]
forall a b. (a -> b) -> a -> b
$ ((RetType rep, RetAls) -> ExtType)
-> [(RetType rep, RetAls)] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness -> ExtType
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (TypeBase ExtShape Uniqueness -> ExtType)
-> ((RetType rep, RetAls) -> TypeBase ExtShape Uniqueness)
-> (RetType rep, RetAls)
-> ExtType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetType rep -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf (RetType rep -> TypeBase ExtShape Uniqueness)
-> ((RetType rep, RetAls) -> RetType rep)
-> (RetType rep, RetAls)
-> TypeBase ExtShape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RetType rep, RetAls) -> RetType rep
forall a b. (a, b) -> a
fst) [(RetType rep, RetAls)]
rt
expExtType (Match [SubExp]
_ [Case (Body rep)]
_ Body rep
_ MatchDec (BranchType rep)
rt) = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType]) -> [ExtType] -> m [ExtType]
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> ExtType) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf ([BranchType rep] -> [ExtType]) -> [BranchType rep] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ MatchDec (BranchType rep) -> [BranchType rep]
forall rt. MatchDec rt -> [rt]
matchReturns MatchDec (BranchType rep)
rt
expExtType (Loop [(FParam rep, SubExp)]
merge LoopForm
_ Body rep
_) =
  [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType]) -> [ExtType] -> m [ExtType]
forall a b. (a -> b) -> a -> b
$ [FParam rep] -> [ExtType]
forall dec. Typed dec => [Param dec] -> [ExtType]
loopExtType ([FParam rep] -> [ExtType]) -> [FParam rep] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExp) -> FParam rep)
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge
expExtType (BasicOp BasicOp
op) = [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType]) -> m [Type] -> m [ExtType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BasicOp -> m [Type]
forall rep (m :: * -> *). HasScope rep m => BasicOp -> m [Type]
basicOpType BasicOp
op
expExtType (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) =
  ([Type] -> [ExtType]) -> m [Type] -> m [ExtType]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes (m [Type] -> m [ExtType]) -> m [Type] -> m [ExtType]
forall a b. (a -> b) -> a -> b
$
    [Type] -> [Type] -> [Type]
forall a. Semigroup a => a -> a -> a
(<>)
      ([Type] -> [Type] -> [Type]) -> m [Type] -> m ([Type] -> [Type])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> m [[Type]] -> m [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput rep -> m [Type]) -> [WithAccInput rep] -> m [[Type]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse WithAccInput rep -> m [Type]
forall {t :: * -> *} {f :: * -> *} {rep} {a} {c}.
(Traversable t, HasScope rep f) =>
(a, t VName, c) -> f (t Type)
inputType [WithAccInput rep]
inputs)
      m ([Type] -> [Type]) -> m [Type] -> m [Type]
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
num_accs (Lambda rep -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam))
  where
    inputType :: (a, t VName, c) -> f (t Type)
inputType (a
_, t VName
arrs, c
_) = (VName -> f Type) -> t VName -> f (t Type)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse VName -> f Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType t VName
arrs
    num_accs :: Int
num_accs = [WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
expExtType (Op Op rep
op) = Op rep -> m [ExtType]
forall rep (m :: * -> *).
HasScope rep m =>
OpC rep rep -> m [ExtType]
forall (op :: * -> *) rep (m :: * -> *).
(TypedOp op, HasScope rep m) =>
op rep -> m [ExtType]
opType Op rep
op

-- | Given the parameters of a loop, produce the return type.
loopExtType :: (Typed dec) => [Param dec] -> [ExtType]
loopExtType :: forall dec. Typed dec => [Param dec] -> [ExtType]
loopExtType [Param dec]
params =
  [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes [VName]
inaccessible ([ExtType] -> [ExtType]) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType]) -> [Type] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Type) -> [Param dec] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> Type
forall t. Typed t => t -> Type
typeOf [Param dec]
params
  where
    inaccessible :: [VName]
inaccessible = (Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params

-- | Any operation must define an instance of this class, which
-- describes the type of the operation (at the value level).
class TypedOp op where
  opType :: (HasScope rep m) => op rep -> m [ExtType]

instance TypedOp NoOp where
  opType :: forall rep (m :: * -> *). HasScope rep m => NoOp rep -> m [ExtType]
opType NoOp rep
NoOp = [ExtType] -> m [ExtType]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []