-- | Functions for inspecting and constructing various types.
module Futhark.IR.Prop.Types
  ( rankShaped,
    arrayRank,
    arrayShape,
    setArrayShape,
    isEmptyArray,
    existential,
    uniqueness,
    unique,
    staticShapes,
    staticShapes1,
    primType,
    isAcc,
    arrayOf,
    arrayOfRow,
    arrayOfShape,
    setOuterSize,
    setDimSize,
    setOuterDim,
    setOuterDims,
    setDim,
    setArrayDims,
    peelArray,
    stripArray,
    arrayDims,
    arrayExtDims,
    shapeSize,
    arraySize,
    arraysSize,
    elemType,
    rowType,
    transposeType,
    rearrangeType,
    mapOnExtType,
    mapOnType,
    diet,
    subtypeOf,
    subtypesOf,
    toDecl,
    fromDecl,
    isExt,
    isFree,
    extractShapeContext,
    shapeContext,
    hasStaticShape,
    generaliseExtTypes,
    existentialiseExtTypes,
    shapeExtMapping,

    -- * Abbreviations
    int8,
    int16,
    int32,
    int64,
    float32,
    float64,

    -- * The Typed typeclass
    Typed (..),
    DeclTyped (..),
    ExtTyped (..),
    DeclExtTyped (..),
    SetType (..),
    FixExt (..),
  )
where

import Control.Monad
import Control.Monad.State
import Data.List (elemIndex, foldl')
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Rearrange
import Futhark.IR.Syntax.Core

-- | Remove shape information from a type.
rankShaped :: ArrayShape shape => TypeBase shape u -> TypeBase Rank u
rankShaped :: forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase Rank u
rankShaped (Array PrimType
et shape
sz u
u) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Int -> Rank
Rank forall a b. (a -> b) -> a -> b
$ forall a. ArrayShape a => a -> Int
shapeRank shape
sz) u
u
rankShaped (Prim PrimType
pt) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
rankShaped (Acc VName
acc Shape
ispace [Type]
ts u
u) = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
rankShaped (Mem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | Return the dimensionality of a type.  For non-arrays, this is
-- zero.  For a one-dimensional array it is one, for a two-dimensional
-- it is two, and so forth.
arrayRank :: ArrayShape shape => TypeBase shape u -> Int
arrayRank :: forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank = forall a. ArrayShape a => a -> Int
shapeRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the shape of a type - for non-arrays, this is the
-- 'mempty'.
arrayShape :: ArrayShape shape => TypeBase shape u -> shape
arrayShape :: forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Array PrimType
_ shape
ds u
_) = shape
ds
arrayShape TypeBase shape u
_ = forall a. Monoid a => a
mempty

-- | Modify the shape of an array - for non-arrays, this does nothing.
modifyArrayShape ::
  ArrayShape newshape =>
  (oldshape -> newshape) ->
  TypeBase oldshape u ->
  TypeBase newshape u
modifyArrayShape :: forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape oldshape -> newshape
f (Array PrimType
t oldshape
ds u
u)
  | forall a. ArrayShape a => a -> Int
shapeRank newshape
ds' forall a. Eq a => a -> a -> Bool
== Int
0 = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  | Bool
otherwise = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t (oldshape -> newshape
f oldshape
ds) u
u
  where
    ds' :: newshape
ds' = oldshape -> newshape
f oldshape
ds
modifyArrayShape oldshape -> newshape
_ (Prim PrimType
t) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
modifyArrayShape oldshape -> newshape
_ (Acc VName
acc Shape
ispace [Type]
ts u
u) = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
modifyArrayShape oldshape -> newshape
_ (Mem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | Set the shape of an array.  If the given type is not an
-- array, return the type unchanged.
setArrayShape ::
  ArrayShape newshape =>
  TypeBase oldshape u ->
  newshape ->
  TypeBase newshape u
setArrayShape :: forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
setArrayShape TypeBase oldshape u
t newshape
ds = forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape (forall a b. a -> b -> a
const newshape
ds) TypeBase oldshape u
t

-- | If the array is statically an empty array (meaning any dimension
-- is a static zero), return the element type and the shape.
isEmptyArray :: Type -> Maybe (PrimType, Shape)
isEmptyArray :: Type -> Maybe (PrimType, Shape)
isEmptyArray (Array PrimType
pt (Shape [SubExp]
ds) NoUniqueness
_)
  | IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [SubExp]
ds = forall a. a -> Maybe a
Just (PrimType
pt, forall d. [d] -> ShapeBase d
Shape [SubExp]
ds)
isEmptyArray Type
_ = forall a. Maybe a
Nothing

-- | True if the given type has a dimension that is existentially sized.
existential :: ExtType -> Bool
existential :: ExtType -> Bool
existential = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall {a}. Ext a -> Bool
ext forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape
  where
    ext :: Ext a -> Bool
ext (Ext Int
_) = Bool
True
    ext (Free a
_) = Bool
False

-- | Return the uniqueness of a type.
uniqueness :: TypeBase shape Uniqueness -> Uniqueness
uniqueness :: forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness (Array PrimType
_ shape
_ Uniqueness
u) = Uniqueness
u
uniqueness (Acc VName
_ Shape
_ [Type]
_ Uniqueness
u) = Uniqueness
u
uniqueness TypeBase shape Uniqueness
_ = Uniqueness
Nonunique

-- | @unique t@ is 'True' if the type of the argument is unique.
unique :: TypeBase shape Uniqueness -> Bool
unique :: forall shape. TypeBase shape Uniqueness -> Bool
unique = (forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness

-- | Convert types with non-existential shapes to types with
-- existential shapes.  Only the representation is changed, so all
-- the shapes will be 'Free'.
staticShapes :: [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes :: forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes = forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase ExtShape u
staticShapes1

-- | As 'staticShapes', but on a single type.
staticShapes1 :: TypeBase Shape u -> TypeBase ExtShape u
staticShapes1 :: forall u. TypeBase Shape u -> TypeBase ExtShape u
staticShapes1 (Prim PrimType
t) =
  forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
staticShapes1 (Acc VName
acc Shape
ispace [Type]
ts u
u) =
  forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
staticShapes1 (Array PrimType
bt (Shape [SubExp]
shape) u
u) =
  forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Ext a
Free [SubExp]
shape) u
u
staticShapes1 (Mem Space
space) =
  forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | @arrayOf t s u@ constructs an array type.  The convenience
-- compared to using the 'Array' constructor directly is that @t@ can
-- itself be an array.  If @t@ is an @n@-dimensional array, and @s@ is
-- a list of length @n@, the resulting type is of an @n+m@ dimensions.
-- The uniqueness of the new array will be @u@, no matter the
-- uniqueness of @t@.  If the shape @s@ has rank 0, then the @t@ will
-- be returned, although if it is an array, with the uniqueness
-- changed to @u@.
arrayOf ::
  ArrayShape shape =>
  TypeBase shape u_unused ->
  shape ->
  u ->
  TypeBase shape u
arrayOf :: forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Array PrimType
et shape
size1 u_unused
_) shape
size2 u
u =
  forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (shape
size2 forall a. Semigroup a => a -> a -> a
<> shape
size1) u
u
arrayOf (Prim PrimType
t) shape
shape u
u
  | Int
0 <- forall a. ArrayShape a => a -> Int
shapeRank shape
shape = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  | Bool
otherwise = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t shape
shape u
u
arrayOf (Acc VName
acc Shape
ispace [Type]
ts u_unused
_) shape
_shape u
u =
  forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
arrayOf Mem {} shape
_ u
_ =
  forall a. HasCallStack => [Char] -> a
error [Char]
"arrayOf Mem"

-- | Construct an array whose rows are the given type, and the outer
-- size is the given dimension.  This is just a convenient wrapper
-- around 'arrayOf'.
arrayOfRow ::
  ArrayShape (ShapeBase d) =>
  TypeBase (ShapeBase d) NoUniqueness ->
  d ->
  TypeBase (ShapeBase d) NoUniqueness
arrayOfRow :: forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow TypeBase (ShapeBase d) NoUniqueness
t d
size = forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase (ShapeBase d) NoUniqueness
t (forall d. [d] -> ShapeBase d
Shape [d
size]) NoUniqueness
NoUniqueness

-- | Construct an array whose rows are the given type, and the outer
-- size is the given t'Shape'.  This is just a convenient wrapper
-- around 'arrayOf'.
arrayOfShape :: Type -> Shape -> Type
arrayOfShape :: Type -> Shape -> Type
arrayOfShape Type
t Shape
shape = forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
t Shape
shape NoUniqueness
NoUniqueness

-- | Set the dimensions of an array.  If the given type is not an
-- array, return the type unchanged.
setArrayDims :: TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
setArrayDims :: forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
setArrayDims TypeBase oldshape u
t [SubExp]
dims = TypeBase oldshape u
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape [SubExp]
dims

-- | Replace the size of the outermost dimension of an array.  If the
-- given type is not an array, it is returned unchanged.
setOuterSize ::
  ArrayShape (ShapeBase d) =>
  TypeBase (ShapeBase d) u ->
  d ->
  TypeBase (ShapeBase d) u
setOuterSize :: forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setOuterSize = forall d u.
ArrayShape (ShapeBase d) =>
Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize Int
0

-- | Replace the size of the given dimension of an array.  If the
-- given type is not an array, it is returned unchanged.
setDimSize ::
  ArrayShape (ShapeBase d) =>
  Int ->
  TypeBase (ShapeBase d) u ->
  d ->
  TypeBase (ShapeBase d) u
setDimSize :: forall d u.
ArrayShape (ShapeBase d) =>
Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize Int
i TypeBase (ShapeBase d) u
t d
e = TypeBase (ShapeBase d) u
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
i (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase d) u
t) d
e

-- | Replace the outermost dimension of an array shape.
setOuterDim :: ShapeBase d -> d -> ShapeBase d
setOuterDim :: forall d. ShapeBase d -> d -> ShapeBase d
setOuterDim = forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0

-- | Replace some outermost dimensions of an array shape.
setOuterDims :: ShapeBase d -> Int -> ShapeBase d -> ShapeBase d
setOuterDims :: forall d. ShapeBase d -> Int -> ShapeBase d -> ShapeBase d
setOuterDims ShapeBase d
old Int
k ShapeBase d
new = ShapeBase d
new forall a. Semigroup a => a -> a -> a
<> forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
k ShapeBase d
old

-- | Replace the specified dimension of an array shape.
setDim :: Int -> ShapeBase d -> d -> ShapeBase d
setDim :: forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
i (Shape [d]
ds) d
e = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
i [d]
ds forall a. [a] -> [a] -> [a]
++ d
e forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
drop (Int
i forall a. Num a => a -> a -> a
+ Int
1) [d]
ds

-- | @peelArray n t@ returns the type resulting from peeling the first
-- @n@ array dimensions from @t@.  Returns @Nothing@ if @t@ has less
-- than @n@ dimensions.
peelArray :: Int -> TypeBase Shape u -> Maybe (TypeBase Shape u)
peelArray :: forall u. Int -> TypeBase Shape u -> Maybe (TypeBase Shape u)
peelArray Int
0 TypeBase Shape u
t = forall a. a -> Maybe a
Just TypeBase Shape u
t
peelArray Int
n (Array PrimType
et Shape
shape u
u)
  | forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Eq a => a -> a -> Bool
== Int
n = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
  | forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Ord a => a -> a -> Bool
> Int
n = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
n Shape
shape) u
u
peelArray Int
_ TypeBase Shape u
_ = forall a. Maybe a
Nothing

-- | @stripArray n t@ removes the @n@ outermost layers of the array.
-- Essentially, it is the type of indexing an array of type @t@ with
-- @n@ indexes.
stripArray :: Int -> TypeBase Shape u -> TypeBase Shape u
stripArray :: forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
n (Array PrimType
et Shape
shape u
u)
  | Int
n forall a. Ord a => a -> a -> Bool
< forall a. ArrayShape a => a -> Int
shapeRank Shape
shape = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (forall d. Int -> ShapeBase d -> ShapeBase d
stripDims Int
n Shape
shape) u
u
  | Bool
otherwise = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
stripArray Int
_ TypeBase Shape u
t = TypeBase Shape u
t

-- | Return the size of the given dimension.  If the dimension does
-- not exist, the zero constant is returned.
shapeSize :: Int -> Shape -> SubExp
shapeSize :: Int -> Shape -> SubExp
shapeSize Int
i Shape
shape = case forall a. Int -> [a] -> [a]
drop Int
i forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
  SubExp
e : [SubExp]
_ -> SubExp
e
  [] -> forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)

-- | Return the dimensions of a type - for non-arrays, this is the
-- empty list.
arrayDims :: TypeBase Shape u -> [SubExp]
arrayDims :: forall u. TypeBase Shape u -> [SubExp]
arrayDims = forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the existential dimensions of a type - for non-arrays,
-- this is the empty list.
arrayExtDims :: TypeBase ExtShape u -> [ExtSize]
arrayExtDims :: forall u. TypeBase ExtShape u -> [Ext SubExp]
arrayExtDims = forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the size of the given dimension.  If the dimension does
-- not exist, the zero constant is returned.
arraySize :: Int -> TypeBase Shape u -> SubExp
arraySize :: forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
i = Int -> Shape -> SubExp
shapeSize Int
i forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the size of the given dimension in the first element of
-- the given type list.  If the dimension does not exist, or no types
-- are given, the zero constant is returned.
arraysSize :: Int -> [TypeBase Shape u] -> SubExp
arraysSize :: forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
_ [] = forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)
arraysSize Int
i (TypeBase Shape u
t : [TypeBase Shape u]
_) = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
i TypeBase Shape u
t

-- | Return the immediate row-type of an array.  For @[[int]]@, this
-- would be @[int]@.
rowType :: TypeBase Shape u -> TypeBase Shape u
rowType :: forall u. TypeBase Shape u -> TypeBase Shape u
rowType = forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1

-- | A type is a primitive type if it is not an array or memory block.
primType :: TypeBase shape u -> Bool
primType :: forall shape u. TypeBase shape u -> Bool
primType Prim {} = Bool
True
primType TypeBase shape u
_ = Bool
False

-- | Is this an accumulator?
isAcc :: TypeBase shape u -> Bool
isAcc :: forall shape u. TypeBase shape u -> Bool
isAcc Acc {} = Bool
True
isAcc TypeBase shape u
_ = Bool
False

-- | Returns the bottommost type of an array.  For @[][]i32@, this
-- would be @i32@.  If the given type is not an array, it is returned.
elemType :: TypeBase shape u -> PrimType
elemType :: forall shape u. TypeBase shape u -> PrimType
elemType (Array PrimType
t shape
_ u
_) = PrimType
t
elemType (Prim PrimType
t) = PrimType
t
elemType Acc {} = forall a. HasCallStack => [Char] -> a
error [Char]
"elemType Acc"
elemType Mem {} = forall a. HasCallStack => [Char] -> a
error [Char]
"elemType Mem"

-- | Swap the two outer dimensions of the type.
transposeType :: Type -> Type
transposeType :: Type -> Type
transposeType = [Int] -> Type -> Type
rearrangeType [Int
1, Int
0]

-- | Rearrange the dimensions of the type.  If the length of the
-- permutation does not match the rank of the type, the permutation
-- will be extended with identity.
rearrangeType :: [Int] -> Type -> Type
rearrangeType :: [Int] -> Type -> Type
rearrangeType [Int]
perm Type
t =
  Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape (forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm' forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
  where
    perm' :: [Int]
perm' = [Int]
perm forall a. [a] -> [a] -> [a]
++ [forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t forall a. Num a => a -> a -> a
- Int
1]

-- | Transform any t'SubExp's in the type.
mapOnExtType ::
  Monad m =>
  (SubExp -> m SubExp) ->
  TypeBase ExtShape u ->
  m (TypeBase ExtShape u)
mapOnExtType :: forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase ExtShape u -> m (TypeBase ExtShape u)
mapOnExtType SubExp -> m SubExp
_ (Prim PrimType
bt) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
mapOnExtType SubExp -> m SubExp
f (Acc VName
acc Shape
ispace [Type]
ts u
u) =
  forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m VName
f' VName
acc forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> m SubExp
f Shape
ispace forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType SubExp -> m SubExp
f) [Type]
ts forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
  where
    f' :: VName -> m VName
f' VName
v = do
      SubExp
x <- SubExp -> m SubExp
f forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
x of
        Var VName
v' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'
        Constant {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
mapOnExtType SubExp -> m SubExp
_ (Mem Space
space) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
space
mapOnExtType SubExp -> m SubExp
f (Array PrimType
t ExtShape
shape u
u) =
  forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> m SubExp
f) (forall d. ShapeBase d -> [d]
shapeDims ExtShape
shape))
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

-- | Transform any t'SubExp's in the type.
mapOnType ::
  Monad m =>
  (SubExp -> m SubExp) ->
  TypeBase Shape u ->
  m (TypeBase Shape u)
mapOnType :: forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType SubExp -> m SubExp
_ (Prim PrimType
bt) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
mapOnType SubExp -> m SubExp
f (Acc VName
acc Shape
ispace [Type]
ts u
u) =
  forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m VName
f' VName
acc forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> m SubExp
f Shape
ispace forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType SubExp -> m SubExp
f) [Type]
ts forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
  where
    f' :: VName -> m VName
f' VName
v = do
      SubExp
x <- SubExp -> m SubExp
f forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
x of
        Var VName
v' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'
        Constant {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
mapOnType SubExp -> m SubExp
_ (Mem Space
space) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
space
mapOnType SubExp -> m SubExp
f (Array PrimType
t Shape
shape u
u) =
  forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
f (forall d. ShapeBase d -> [d]
shapeDims Shape
shape))
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

-- | @diet t@ returns a description of how a function parameter of
-- type @t@ might consume its argument.
diet :: TypeBase shape Uniqueness -> Diet
diet :: forall shape. TypeBase shape Uniqueness -> Diet
diet Prim {} = Diet
ObservePrim
diet (Acc VName
_ Shape
_ [Type]
_ Uniqueness
Unique) = Diet
Consume
diet (Acc VName
_ Shape
_ [Type]
_ Uniqueness
Nonunique) = Diet
Observe
diet (Array PrimType
_ shape
_ Uniqueness
Unique) = Diet
Consume
diet (Array PrimType
_ shape
_ Uniqueness
Nonunique) = Diet
Observe
diet Mem {} = Diet
Observe

-- | @x \`subtypeOf\` y@ is true if @x@ is a subtype of @y@ (or equal to
-- @y@), meaning @x@ is valid whenever @y@ is.
subtypeOf ::
  (Ord u, ArrayShape shape) =>
  TypeBase shape u ->
  TypeBase shape u ->
  Bool
subtypeOf :: forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf (Array PrimType
t1 shape
shape1 u
u1) (Array PrimType
t2 shape
shape2 u
u2) =
  u
u2
    forall a. Ord a => a -> a -> Bool
<= u
u1
    Bool -> Bool -> Bool
&& PrimType
t1
    forall a. Eq a => a -> a -> Bool
== PrimType
t2
    Bool -> Bool -> Bool
&& shape
shape1
    forall a. ArrayShape a => a -> a -> Bool
`subShapeOf` shape
shape2
subtypeOf TypeBase shape u
t1 TypeBase shape u
t2 = TypeBase shape u
t1 forall a. Eq a => a -> a -> Bool
== TypeBase shape u
t2

-- | @xs \`subtypesOf\` ys@ is true if @xs@ is the same size as @ys@,
-- and each element in @xs@ is a subtype of the corresponding element
-- in @ys@..
subtypesOf ::
  (Ord u, ArrayShape shape) =>
  [TypeBase shape u] ->
  [TypeBase shape u] ->
  Bool
subtypesOf :: forall u shape.
(Ord u, ArrayShape shape) =>
[TypeBase shape u] -> [TypeBase shape u] -> Bool
subtypesOf [TypeBase shape u]
xs [TypeBase shape u]
ys =
  forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase shape u]
xs forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase shape u]
ys
    Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf [TypeBase shape u]
xs [TypeBase shape u]
ys)

-- | Add the given uniqueness information to the types.
toDecl ::
  TypeBase shape NoUniqueness ->
  Uniqueness ->
  TypeBase shape Uniqueness
toDecl :: forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (Prim PrimType
t) Uniqueness
_ = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
toDecl (Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) Uniqueness
u = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts Uniqueness
u
toDecl (Array PrimType
et shape
shape NoUniqueness
_) Uniqueness
u = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape Uniqueness
u
toDecl (Mem Space
space) Uniqueness
_ = forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | Remove uniqueness information from the type.
fromDecl ::
  TypeBase shape Uniqueness ->
  TypeBase shape NoUniqueness
fromDecl :: forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (Prim PrimType
t) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
fromDecl (Acc VName
acc Shape
ispace [Type]
ts Uniqueness
_) = forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness
fromDecl (Array PrimType
et shape
shape Uniqueness
_) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape NoUniqueness
NoUniqueness
fromDecl (Mem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | If an existential, then return its existential index.
isExt :: Ext a -> Maybe Int
isExt :: forall a. Ext a -> Maybe Int
isExt (Ext Int
i) = forall a. a -> Maybe a
Just Int
i
isExt Ext a
_ = forall a. Maybe a
Nothing

-- | If a known size, then return that size.
isFree :: Ext a -> Maybe a
isFree :: forall a. Ext a -> Maybe a
isFree (Free a
d) = forall a. a -> Maybe a
Just a
d
isFree Ext a
_ = forall a. Maybe a
Nothing

-- | Given the existential return type of a function, and the shapes
-- of the values returned by the function, return the existential
-- shape context.  That is, those sizes that are existential in the
-- return type.
extractShapeContext :: [TypeBase ExtShape u] -> [[a]] -> [a]
extractShapeContext :: forall u a. [TypeBase ExtShape u] -> [[a]] -> [a]
extractShapeContext [TypeBase ExtShape u]
ts [[a]]
shapes =
  forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *} {a} {u} {a}.
(MonadState (Set Int) f, ArrayShape (ShapeBase (Ext a))) =>
TypeBase (ShapeBase (Ext a)) u -> [a] -> f [a]
extract [TypeBase ExtShape u]
ts [[a]]
shapes) forall a. Set a
S.empty
  where
    extract :: TypeBase (ShapeBase (Ext a)) u -> [a] -> f [a]
extract TypeBase (ShapeBase (Ext a)) u
t [a]
shape =
      forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *} {a} {a}.
MonadState (Set Int) m =>
Ext a -> a -> m (Maybe a)
extract' (forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t) [a]
shape
    extract' :: Ext a -> a -> m (Maybe a)
extract' (Ext Int
x) a
v = do
      Bool
seen <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Bool
S.member Int
x
      if Bool
seen
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        else do
          forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
S.insert Int
x
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just a
v
    extract' (Free a
_) a
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

-- | The 'Ext' integers used for existential sizes in the given types.
shapeContext :: [TypeBase ExtShape u] -> S.Set Int
shapeContext :: forall u. [TypeBase ExtShape u] -> Set Int
shapeContext = forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall a. Ext a -> Maybe Int
isExt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape)

-- | If all dimensions of the given 'ExtShape' are statically known,
-- change to the corresponding t'Shape'.
hasStaticShape :: TypeBase ExtShape u -> Maybe (TypeBase Shape u)
hasStaticShape :: forall u. TypeBase ExtShape u -> Maybe (TypeBase Shape u)
hasStaticShape (Prim PrimType
bt) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
hasStaticShape (Acc VName
acc Shape
ispace [Type]
ts u
u) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts u
u
hasStaticShape (Mem Space
space) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. Space -> TypeBase shape u
Mem Space
space
hasStaticShape (Array PrimType
bt (Shape [Ext SubExp]
shape) u
u) =
  forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall d. [d] -> ShapeBase d
Shape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a. Ext a -> Maybe a
isFree [Ext SubExp]
shape) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

-- | Given two lists of 'ExtType's of the same length, return a list
-- of 'ExtType's that is a subtype of the two operands.
generaliseExtTypes ::
  [TypeBase ExtShape u] ->
  [TypeBase ExtShape u] ->
  [TypeBase ExtShape u]
generaliseExtTypes :: forall u.
[TypeBase ExtShape u]
-> [TypeBase ExtShape u] -> [TypeBase ExtShape u]
generaliseExtTypes [TypeBase ExtShape u]
rt1 [TypeBase ExtShape u]
rt2 =
  forall s a. State s a -> s -> a
evalState (forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *} {a} {u} {u}.
(ArrayShape (ShapeBase (Ext a)), MonadState (Int, Map Int Int) f,
 Eq a) =>
TypeBase (ShapeBase (Ext a)) u
-> TypeBase (ShapeBase (Ext a)) u
-> f (TypeBase (ShapeBase (Ext a)) u)
unifyExtShapes [TypeBase ExtShape u]
rt1 [TypeBase ExtShape u]
rt2) (Int
0, forall k a. Map k a
M.empty)
  where
    unifyExtShapes :: TypeBase (ShapeBase (Ext a)) u
-> TypeBase (ShapeBase (Ext a)) u
-> f (TypeBase (ShapeBase (Ext a)) u)
unifyExtShapes TypeBase (ShapeBase (Ext a)) u
t1 TypeBase (ShapeBase (Ext a)) u
t2 =
      forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
setArrayShape TypeBase (ShapeBase (Ext a)) u
t1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [d] -> ShapeBase d
Shape
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
          forall {f :: * -> *} {a}.
(MonadState (Int, Map Int Int) f, Eq a) =>
Ext a -> Ext a -> f (Ext a)
unifyExtDims
          (forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t1)
          (forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t2)
    unifyExtDims :: Ext a -> Ext a -> f (Ext a)
unifyExtDims (Free a
se1) (Free a
se2)
      | a
se1 forall a. Eq a => a -> a -> Bool
== a
se2 = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Ext a
Free a
se1 -- Arbitrary
      | Bool
otherwise = do
          (Int
n, Map Int Int
m) <- forall s (m :: * -> *). MonadState s m => m s
get
          forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int
n forall a. Num a => a -> a -> a
+ Int
1, Map Int Int
m)
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Int -> Ext a
Ext Int
n
    unifyExtDims (Ext Int
x) (Ext Int
y)
      | Int
x forall a. Eq a => a -> a -> Bool
== Int
y = forall a. Int -> Ext a
Ext forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall {m :: * -> *} {b} {k}.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x) forall (f :: * -> *) a. Applicative f => a -> f a
pure forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd))
    unifyExtDims (Ext Int
x) Ext a
_ = forall a. Int -> Ext a
Ext forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {m :: * -> *} {b} {k}.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x
    unifyExtDims Ext a
_ (Ext Int
x) = forall a. Int -> Ext a
Ext forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {m :: * -> *} {b} {k}.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x
    new :: k -> m b
new k
x = do
      (b
n, Map k b
m) <- forall s (m :: * -> *). MonadState s m => m s
get
      forall s (m :: * -> *). MonadState s m => s -> m ()
put (b
n forall a. Num a => a -> a -> a
+ b
1, forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
x b
n Map k b
m)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure b
n

-- | Given a list of 'ExtType's and a list of "forbidden" names,
-- modify the dimensions of the 'ExtType's such that they are 'Ext'
-- where they were previously 'Free' with a variable in the set of
-- forbidden names.
existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes [VName]
inaccessible = forall a b. (a -> b) -> [a] -> [b]
map forall {u}. TypeBase ExtShape u -> TypeBase ExtShape u
makeBoundShapesFree
  where
    makeBoundShapesFree :: TypeBase ExtShape u -> TypeBase ExtShape u
makeBoundShapesFree =
      forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
checkDim
    checkDim :: Ext SubExp -> Ext SubExp
checkDim (Free (Var VName
v))
      | Just Int
i <- VName
v forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
inaccessible =
          forall a. Int -> Ext a
Ext Int
i
    checkDim Ext SubExp
d = Ext SubExp
d

-- | Produce a mapping for the dimensions context.
shapeExtMapping :: [TypeBase ExtShape u] -> [TypeBase Shape u1] -> M.Map Int SubExp
shapeExtMapping :: forall u u1.
[TypeBase ExtShape u] -> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping = forall res t1 dim1 t2 dim2.
Monoid res =>
(t1 -> [dim1])
-> (t2 -> [dim2])
-> (dim1 -> dim2 -> res)
-> (res -> res -> res)
-> [t1]
-> [t2]
-> res
dimMapping forall u. TypeBase ExtShape u -> [Ext SubExp]
arrayExtDims forall u. TypeBase Shape u -> [SubExp]
arrayDims forall {a} {a}. Ext a -> a -> Map Int a
match forall a. Monoid a => a -> a -> a
mappend
  where
    match :: Ext a -> a -> Map Int a
match Free {} a
_ = forall a. Monoid a => a
mempty
    match (Ext Int
i) a
dim = forall k a. k -> a -> Map k a
M.singleton Int
i a
dim

dimMapping ::
  Monoid res =>
  (t1 -> [dim1]) ->
  (t2 -> [dim2]) ->
  (dim1 -> dim2 -> res) ->
  (res -> res -> res) ->
  [t1] ->
  [t2] ->
  res
dimMapping :: forall res t1 dim1 t2 dim2.
Monoid res =>
(t1 -> [dim1])
-> (t2 -> [dim2])
-> (dim1 -> dim2 -> res)
-> (res -> res -> res)
-> [t1]
-> [t2]
-> res
dimMapping t1 -> [dim1]
getDims1 t2 -> [dim2]
getDims2 dim1 -> dim2 -> res
f res -> res -> res
comb [t1]
ts1 [t2]
ts2 =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' res -> res -> res
comb forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith dim1 -> dim2 -> res
f) (forall a b. (a -> b) -> [a] -> [b]
map t1 -> [dim1]
getDims1 [t1]
ts1) (forall a b. (a -> b) -> [a] -> [b]
map t2 -> [dim2]
getDims2 [t2]
ts2)

-- | @IntType Int8@
int8 :: PrimType
int8 :: PrimType
int8 = IntType -> PrimType
IntType IntType
Int8

-- | @IntType Int16@
int16 :: PrimType
int16 :: PrimType
int16 = IntType -> PrimType
IntType IntType
Int16

-- | @IntType Int32@
int32 :: PrimType
int32 :: PrimType
int32 = IntType -> PrimType
IntType IntType
Int32

-- | @IntType Int64@
int64 :: PrimType
int64 :: PrimType
int64 = IntType -> PrimType
IntType IntType
Int64

-- | @FloatType Float32@
float32 :: PrimType
float32 :: PrimType
float32 = FloatType -> PrimType
FloatType FloatType
Float32

-- | @FloatType Float64@
float64 :: PrimType
float64 :: PrimType
float64 = FloatType -> PrimType
FloatType FloatType
Float64

-- | Typeclass for things that contain 'Type's.
class Typed t where
  typeOf :: t -> Type

instance Typed Type where
  typeOf :: Type -> Type
typeOf = forall a. a -> a
id

instance Typed DeclType where
  typeOf :: DeclType -> Type
typeOf = forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl

instance Typed Ident where
  typeOf :: Ident -> Type
typeOf = Ident -> Type
identType

instance Typed dec => Typed (Param dec) where
  typeOf :: Param dec -> Type
typeOf = forall t. Typed t => t -> Type
typeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> dec
paramDec

instance Typed dec => Typed (PatElem dec) where
  typeOf :: PatElem dec -> Type
typeOf = forall t. Typed t => t -> Type
typeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> dec
patElemDec

instance Typed b => Typed (a, b) where
  typeOf :: (a, b) -> Type
typeOf = forall t. Typed t => t -> Type
typeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd

-- | Typeclass for things that contain 'DeclType's.
class DeclTyped t where
  declTypeOf :: t -> DeclType

instance DeclTyped DeclType where
  declTypeOf :: DeclType -> DeclType
declTypeOf = forall a. a -> a
id

instance DeclTyped dec => DeclTyped (Param dec) where
  declTypeOf :: Param dec -> DeclType
declTypeOf = forall t. DeclTyped t => t -> DeclType
declTypeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> dec
paramDec

-- | Typeclass for things that contain 'ExtType's.
class FixExt t => ExtTyped t where
  extTypeOf :: t -> ExtType

instance ExtTyped ExtType where
  extTypeOf :: ExtType -> ExtType
extTypeOf = forall a. a -> a
id

-- | Typeclass for things that contain 'DeclExtType's.
class FixExt t => DeclExtTyped t where
  declExtTypeOf :: t -> DeclExtType

instance DeclExtTyped DeclExtType where
  declExtTypeOf :: DeclExtType -> DeclExtType
declExtTypeOf = forall a. a -> a
id

-- | Typeclass for things whose type can be changed.
class Typed a => SetType a where
  setType :: a -> Type -> a

instance SetType Type where
  setType :: Type -> Type -> Type
setType Type
_ Type
t = Type
t

instance SetType b => SetType (a, b) where
  setType :: (a, b) -> Type -> (a, b)
setType (a
a, b
b) Type
t = (a
a, forall a. SetType a => a -> Type -> a
setType b
b Type
t)

instance SetType dec => SetType (PatElem dec) where
  setType :: PatElem dec -> Type -> PatElem dec
setType (PatElem VName
name dec
dec) Type
t =
    forall dec. VName -> dec -> PatElem dec
PatElem VName
name forall a b. (a -> b) -> a -> b
$ forall a. SetType a => a -> Type -> a
setType dec
dec Type
t

-- | Something with an existential context that can be (partially)
-- fixed.
class FixExt t where
  -- | Fix the given existentional variable to the indicated free
  -- value.
  fixExt :: Int -> SubExp -> t -> t

instance (FixExt shape, ArrayShape shape) => FixExt (TypeBase shape u) where
  fixExt :: Int -> SubExp -> TypeBase shape u -> TypeBase shape u
fixExt Int
i SubExp
se = forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape forall a b. (a -> b) -> a -> b
$ forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se

instance FixExt d => FixExt (ShapeBase d) where
  fixExt :: Int -> SubExp -> ShapeBase d -> ShapeBase d
fixExt Int
i SubExp
se = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se

instance FixExt a => FixExt [a] where
  fixExt :: Int -> SubExp -> [a] -> [a]
fixExt Int
i SubExp
se = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se

instance FixExt ExtSize where
  fixExt :: Int -> SubExp -> Ext SubExp -> Ext SubExp
fixExt Int
i SubExp
se (Ext Int
j)
    | Int
j forall a. Ord a => a -> a -> Bool
> Int
i = forall a. Int -> Ext a
Ext forall a b. (a -> b) -> a -> b
$ Int
j forall a. Num a => a -> a -> a
- Int
1
    | Int
j forall a. Eq a => a -> a -> Bool
== Int
i = forall a. a -> Ext a
Free SubExp
se
    | Bool
otherwise = forall a. Int -> Ext a
Ext Int
j
  fixExt Int
_ SubExp
_ (Free SubExp
x) = forall a. a -> Ext a
Free SubExp
x

instance FixExt () where
  fixExt :: Int -> SubExp -> () -> ()
fixExt Int
_ SubExp
_ () = ()