{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
module Futhark.IR.Prop.Types
( rankShaped,
arrayRank,
arrayShape,
setArrayShape,
existential,
uniqueness,
unique,
staticShapes,
staticShapes1,
primType,
arrayOf,
arrayOfRow,
arrayOfShape,
setOuterSize,
setDimSize,
setOuterDim,
setDim,
setArrayDims,
peelArray,
stripArray,
arrayDims,
arrayExtDims,
shapeSize,
arraySize,
arraysSize,
rowType,
elemType,
transposeType,
rearrangeType,
mapOnExtType,
mapOnType,
diet,
subtypeOf,
subtypesOf,
toDecl,
fromDecl,
isExt,
isFree,
extractShapeContext,
shapeContext,
hasStaticShape,
generaliseExtTypes,
existentialiseExtTypes,
shapeExtMapping,
int8,
int16,
int32,
int64,
float32,
float64,
Typed (..),
DeclTyped (..),
ExtTyped (..),
DeclExtTyped (..),
SetType (..),
FixExt (..),
)
where
import Control.Monad.State
import Data.List (elemIndex, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Rearrange
import Futhark.IR.Syntax.Core
rankShaped :: ArrayShape shape => TypeBase shape u -> TypeBase Rank u
rankShaped :: TypeBase shape u -> TypeBase Rank u
rankShaped (Array PrimType
et shape
sz u
u) = PrimType -> Rank -> u -> TypeBase Rank u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Int -> Rank
Rank (Int -> Rank) -> Int -> Rank
forall a b. (a -> b) -> a -> b
$ shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
sz) u
u
rankShaped (Prim PrimType
et) = PrimType -> TypeBase Rank u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
rankShaped (Mem Space
space) = Space -> TypeBase Rank u
forall shape u. Space -> TypeBase shape u
Mem Space
space
arrayRank :: ArrayShape shape => TypeBase shape u -> Int
arrayRank :: TypeBase shape u -> Int
arrayRank = shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (shape -> Int)
-> (TypeBase shape u -> shape) -> TypeBase shape u -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape u -> shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape
arrayShape :: ArrayShape shape => TypeBase shape u -> shape
arrayShape :: TypeBase shape u -> shape
arrayShape (Array PrimType
_ shape
ds u
_) = shape
ds
arrayShape TypeBase shape u
_ = shape
forall a. Monoid a => a
mempty
modifyArrayShape ::
ArrayShape newshape =>
(oldshape -> newshape) ->
TypeBase oldshape u ->
TypeBase newshape u
modifyArrayShape :: (oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape oldshape -> newshape
f (Array PrimType
t oldshape
ds u
u)
| newshape -> Int
forall a. ArrayShape a => a -> Int
shapeRank newshape
ds' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PrimType -> TypeBase newshape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
| Bool
otherwise = PrimType -> newshape -> u -> TypeBase newshape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t (oldshape -> newshape
f oldshape
ds) u
u
where
ds' :: newshape
ds' = oldshape -> newshape
f oldshape
ds
modifyArrayShape oldshape -> newshape
_ (Prim PrimType
t) = PrimType -> TypeBase newshape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
modifyArrayShape oldshape -> newshape
_ (Mem Space
space) = Space -> TypeBase newshape u
forall shape u. Space -> TypeBase shape u
Mem Space
space
setArrayShape ::
ArrayShape newshape =>
TypeBase oldshape u ->
newshape ->
TypeBase newshape u
setArrayShape :: TypeBase oldshape u -> newshape -> TypeBase newshape u
setArrayShape TypeBase oldshape u
t newshape
ds = (oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape (newshape -> oldshape -> newshape
forall a b. a -> b -> a
const newshape
ds) TypeBase oldshape u
t
existential :: ExtType -> Bool
existential :: ExtType -> Bool
existential = (Ext SubExp -> Bool) -> [Ext SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Ext SubExp -> Bool
forall a. Ext a -> Bool
ext ([Ext SubExp] -> Bool)
-> (ExtType -> [Ext SubExp]) -> ExtType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> (ExtType -> ShapeBase (Ext SubExp)) -> ExtType -> [Ext SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtType -> ShapeBase (Ext SubExp)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape
where
ext :: Ext a -> Bool
ext (Ext Int
_) = Bool
True
ext (Free a
_) = Bool
False
uniqueness :: TypeBase shape Uniqueness -> Uniqueness
uniqueness :: TypeBase shape Uniqueness -> Uniqueness
uniqueness (Array PrimType
_ shape
_ Uniqueness
u) = Uniqueness
u
uniqueness TypeBase shape Uniqueness
_ = Uniqueness
Nonunique
unique :: TypeBase shape Uniqueness -> Bool
unique :: TypeBase shape Uniqueness -> Bool
unique = (Uniqueness -> Uniqueness -> Bool
forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique) (Uniqueness -> Bool)
-> (TypeBase shape Uniqueness -> Uniqueness)
-> TypeBase shape Uniqueness
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape Uniqueness -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness
staticShapes :: [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes :: [TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes = (TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u)
-> [TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u
forall u. TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u
staticShapes1
staticShapes1 :: TypeBase Shape u -> TypeBase ExtShape u
staticShapes1 :: TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u
staticShapes1 (Prim PrimType
bt) =
PrimType -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
staticShapes1 (Array PrimType
bt (Shape [SubExp]
shape) u
u) =
PrimType
-> ShapeBase (Ext SubExp)
-> u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (SubExp -> Ext SubExp) -> [SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
forall a. a -> Ext a
Free [SubExp]
shape) u
u
staticShapes1 (Mem Space
space) =
Space -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. Space -> TypeBase shape u
Mem Space
space
arrayOf ::
ArrayShape shape =>
TypeBase shape u_unused ->
shape ->
u ->
TypeBase shape u
arrayOf :: TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Array PrimType
et shape
size1 u_unused
_) shape
size2 u
u =
PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (shape
size2 shape -> shape -> shape
forall a. Semigroup a => a -> a -> a
<> shape
size1) u
u
arrayOf (Prim PrimType
et) shape
s u
_
| Int
0 <- shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
s = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
arrayOf (Prim PrimType
et) shape
size u
u =
PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
size u
u
arrayOf Mem {} shape
_ u
_ =
[Char] -> TypeBase shape u
forall a. HasCallStack => [Char] -> a
error [Char]
"arrayOf Mem"
arrayOfRow ::
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness ->
d ->
TypeBase (ShapeBase d) NoUniqueness
arrayOfRow :: TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow TypeBase (ShapeBase d) NoUniqueness
t d
size = TypeBase (ShapeBase d) NoUniqueness
-> ShapeBase d
-> NoUniqueness
-> TypeBase (ShapeBase d) NoUniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase (ShapeBase d) NoUniqueness
t ([d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape [d
size]) NoUniqueness
NoUniqueness
arrayOfShape :: Type -> Shape -> Type
arrayOfShape :: Type -> Shape -> Type
arrayOfShape Type
t Shape
shape = Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
t Shape
shape NoUniqueness
NoUniqueness
setArrayDims :: TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
setArrayDims :: TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
setArrayDims TypeBase oldshape u
t [SubExp]
dims = TypeBase oldshape u
t TypeBase oldshape u -> Shape -> TypeBase Shape u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims
setOuterSize ::
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u ->
d ->
TypeBase (ShapeBase d) u
setOuterSize :: TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setOuterSize = Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
forall d u.
ArrayShape (ShapeBase d) =>
Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize Int
0
setDimSize ::
ArrayShape (ShapeBase d) =>
Int ->
TypeBase (ShapeBase d) u ->
d ->
TypeBase (ShapeBase d) u
setDimSize :: Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize Int
i TypeBase (ShapeBase d) u
t d
e = TypeBase (ShapeBase d) u
t TypeBase (ShapeBase d) u -> ShapeBase d -> TypeBase (ShapeBase d) u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Int -> ShapeBase d -> d -> ShapeBase d
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
i (TypeBase (ShapeBase d) u -> ShapeBase d
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase d) u
t) d
e
setOuterDim :: ShapeBase d -> d -> ShapeBase d
setOuterDim :: ShapeBase d -> d -> ShapeBase d
setOuterDim = Int -> ShapeBase d -> d -> ShapeBase d
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0
setDim :: Int -> ShapeBase d -> d -> ShapeBase d
setDim :: Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
i (Shape [d]
ds) d
e = [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape ([d] -> ShapeBase d) -> [d] -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
take Int
i [d]
ds [d] -> [d] -> [d]
forall a. [a] -> [a] -> [a]
++ d
e d -> [d] -> [d]
forall a. a -> [a] -> [a]
: Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [d]
ds
peelArray ::
ArrayShape shape =>
Int ->
TypeBase shape u ->
Maybe (TypeBase shape u)
peelArray :: Int -> TypeBase shape u -> Maybe (TypeBase shape u)
peelArray Int
0 TypeBase shape u
t = TypeBase shape u -> Maybe (TypeBase shape u)
forall a. a -> Maybe a
Just TypeBase shape u
t
peelArray Int
n (Array PrimType
et shape
shape u
u)
| shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = TypeBase shape u -> Maybe (TypeBase shape u)
forall a. a -> Maybe a
Just (TypeBase shape u -> Maybe (TypeBase shape u))
-> TypeBase shape u -> Maybe (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
| shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
shape Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n = TypeBase shape u -> Maybe (TypeBase shape u)
forall a. a -> Maybe a
Just (TypeBase shape u -> Maybe (TypeBase shape u))
-> TypeBase shape u -> Maybe (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Int -> shape -> shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
n shape
shape) u
u
peelArray Int
_ TypeBase shape u
_ = Maybe (TypeBase shape u)
forall a. Maybe a
Nothing
stripArray :: ArrayShape shape => Int -> TypeBase shape u -> TypeBase shape u
stripArray :: Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
n (Array PrimType
et shape
shape u
u)
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
shape = PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Int -> shape -> shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
n shape
shape) u
u
| Bool
otherwise = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
stripArray Int
_ TypeBase shape u
t = TypeBase shape u
t
shapeSize :: Int -> Shape -> SubExp
shapeSize :: Int -> Shape -> SubExp
shapeSize Int
i Shape
shape = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
SubExp
e : [SubExp]
_ -> SubExp
e
[] -> Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)
arrayDims :: TypeBase Shape u -> [SubExp]
arrayDims :: TypeBase Shape u -> [SubExp]
arrayDims = Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp])
-> (TypeBase Shape u -> Shape) -> TypeBase Shape u -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape u -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape
arrayExtDims :: TypeBase ExtShape u -> [ExtSize]
arrayExtDims :: TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp]
arrayExtDims = ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> (TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> [Ext SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape
arraySize :: Int -> TypeBase Shape u -> SubExp
arraySize :: Int -> TypeBase Shape u -> SubExp
arraySize Int
i = Int -> Shape -> SubExp
shapeSize Int
i (Shape -> SubExp)
-> (TypeBase Shape u -> Shape) -> TypeBase Shape u -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape u -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape
arraysSize :: Int -> [TypeBase Shape u] -> SubExp
arraysSize :: Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
_ [] = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)
arraysSize Int
i (TypeBase Shape u
t : [TypeBase Shape u]
_) = Int -> TypeBase Shape u -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
i TypeBase Shape u
t
rowType :: ArrayShape shape => TypeBase shape u -> TypeBase shape u
rowType :: TypeBase shape u -> TypeBase shape u
rowType = Int -> TypeBase shape u -> TypeBase shape u
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1
primType :: TypeBase shape u -> Bool
primType :: TypeBase shape u -> Bool
primType Array {} = Bool
False
primType Mem {} = Bool
False
primType TypeBase shape u
_ = Bool
True
elemType :: TypeBase shape u -> PrimType
elemType :: TypeBase shape u -> PrimType
elemType (Array PrimType
t shape
_ u
_) = PrimType
t
elemType (Prim PrimType
t) = PrimType
t
elemType Mem {} = [Char] -> PrimType
forall a. HasCallStack => [Char] -> a
error [Char]
"elemType Mem"
transposeType :: Type -> Type
transposeType :: Type -> Type
transposeType = [Int] -> Type -> Type
rearrangeType [Int
1, Int
0]
rearrangeType :: [Int] -> Type -> Type
rearrangeType :: [Int] -> Type -> Type
rearrangeType [Int]
perm Type
t =
Type
t Type -> Shape -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm' ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
where
perm' :: [Int]
perm' = [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
mapOnExtType ::
Monad m =>
(SubExp -> m SubExp) ->
TypeBase ExtShape u ->
m (TypeBase ExtShape u)
mapOnExtType :: (SubExp -> m SubExp)
-> TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
mapOnExtType SubExp -> m SubExp
_ (Prim PrimType
bt) =
TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
mapOnExtType SubExp -> m SubExp
_ (Mem Space
space) =
TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall a b. (a -> b) -> a -> b
$ Space -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. Space -> TypeBase shape u
Mem Space
space
mapOnExtType SubExp -> m SubExp
f (Array PrimType
t ShapeBase (Ext SubExp)
shape u
u) =
PrimType
-> ShapeBase (Ext SubExp)
-> u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t (ShapeBase (Ext SubExp)
-> u -> TypeBase (ShapeBase (Ext SubExp)) u)
-> m (ShapeBase (Ext SubExp))
-> m (u -> TypeBase (ShapeBase (Ext SubExp)) u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> m [Ext SubExp] -> m (ShapeBase (Ext SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext SubExp -> m (Ext SubExp)) -> [Ext SubExp] -> m [Ext SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Ext SubExp -> m (Ext SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> m SubExp
f) (ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape)) m (u -> TypeBase (ShapeBase (Ext SubExp)) u)
-> m u -> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> m u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
mapOnType ::
Monad m =>
(SubExp -> m SubExp) ->
TypeBase Shape u ->
m (TypeBase Shape u)
mapOnType :: (SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType SubExp -> m SubExp
_ (Prim PrimType
bt) = TypeBase Shape u -> m (TypeBase Shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase Shape u -> m (TypeBase Shape u))
-> TypeBase Shape u -> m (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
mapOnType SubExp -> m SubExp
_ (Mem Space
space) = TypeBase Shape u -> m (TypeBase Shape u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape u -> m (TypeBase Shape u))
-> TypeBase Shape u -> m (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ Space -> TypeBase Shape u
forall shape u. Space -> TypeBase shape u
Mem Space
space
mapOnType SubExp -> m SubExp
f (Array PrimType
t Shape
shape u
u) =
PrimType -> Shape -> u -> TypeBase Shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t (Shape -> u -> TypeBase Shape u)
-> m Shape -> m (u -> TypeBase Shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
f (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)) m (u -> TypeBase Shape u) -> m u -> m (TypeBase Shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> m u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
diet :: TypeBase shape Uniqueness -> Diet
diet :: TypeBase shape Uniqueness -> Diet
diet (Prim PrimType
_) = Diet
ObservePrim
diet (Array PrimType
_ shape
_ Uniqueness
Unique) = Diet
Consume
diet (Array PrimType
_ shape
_ Uniqueness
Nonunique) = Diet
Observe
diet Mem {} = Diet
Observe
subtypeOf ::
(Ord u, ArrayShape shape) =>
TypeBase shape u ->
TypeBase shape u ->
Bool
subtypeOf :: TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf (Array PrimType
t1 shape
shape1 u
u1) (Array PrimType
t2 shape
shape2 u
u2) =
u
u2 u -> u -> Bool
forall a. Ord a => a -> a -> Bool
<= u
u1
Bool -> Bool -> Bool
&& PrimType
t1 PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
t2
Bool -> Bool -> Bool
&& shape
shape1 shape -> shape -> Bool
forall a. ArrayShape a => a -> a -> Bool
`subShapeOf` shape
shape2
subtypeOf (Prim PrimType
t1) (Prim PrimType
t2) = PrimType
t1 PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
t2
subtypeOf (Mem Space
space1) (Mem Space
space2) = Space
space1 Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
space2
subtypeOf TypeBase shape u
_ TypeBase shape u
_ = Bool
False
subtypesOf ::
(Ord u, ArrayShape shape) =>
[TypeBase shape u] ->
[TypeBase shape u] ->
Bool
subtypesOf :: [TypeBase shape u] -> [TypeBase shape u] -> Bool
subtypesOf [TypeBase shape u]
xs [TypeBase shape u]
ys =
[TypeBase shape u] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase shape u]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TypeBase shape u] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase shape u]
ys
Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((TypeBase shape u -> TypeBase shape u -> Bool)
-> [TypeBase shape u] -> [TypeBase shape u] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase shape u -> TypeBase shape u -> Bool
forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf [TypeBase shape u]
xs [TypeBase shape u]
ys)
toDecl ::
TypeBase shape NoUniqueness ->
Uniqueness ->
TypeBase shape Uniqueness
toDecl :: TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (Prim PrimType
bt) Uniqueness
_ = PrimType -> TypeBase shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
toDecl (Array PrimType
et shape
shape NoUniqueness
_) Uniqueness
u = PrimType -> shape -> Uniqueness -> TypeBase shape Uniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape Uniqueness
u
toDecl (Mem Space
space) Uniqueness
_ = Space -> TypeBase shape Uniqueness
forall shape u. Space -> TypeBase shape u
Mem Space
space
fromDecl ::
TypeBase shape Uniqueness ->
TypeBase shape NoUniqueness
fromDecl :: TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (Prim PrimType
bt) = PrimType -> TypeBase shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
fromDecl (Array PrimType
et shape
shape Uniqueness
_) = PrimType -> shape -> NoUniqueness -> TypeBase shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape NoUniqueness
NoUniqueness
fromDecl (Mem Space
space) = Space -> TypeBase shape NoUniqueness
forall shape u. Space -> TypeBase shape u
Mem Space
space
isExt :: Ext a -> Maybe Int
isExt :: Ext a -> Maybe Int
isExt (Ext Int
i) = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i
isExt Ext a
_ = Maybe Int
forall a. Maybe a
Nothing
isFree :: Ext a -> Maybe a
isFree :: Ext a -> Maybe a
isFree (Free a
d) = a -> Maybe a
forall a. a -> Maybe a
Just a
d
isFree Ext a
_ = Maybe a
forall a. Maybe a
Nothing
extractShapeContext :: [TypeBase ExtShape u] -> [[a]] -> [a]
[TypeBase (ShapeBase (Ext SubExp)) u]
ts [[a]]
shapes =
State (Set Int) [a] -> Set Int -> [a]
forall s a. State s a -> s -> a
evalState ([[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a])
-> StateT (Set Int) Identity [[a]] -> State (Set Int) [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TypeBase (ShapeBase (Ext SubExp)) u -> [a] -> State (Set Int) [a])
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [[a]]
-> StateT (Set Int) Identity [[a]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM TypeBase (ShapeBase (Ext SubExp)) u -> [a] -> State (Set Int) [a]
forall (f :: * -> *) a u a.
(MonadState (Set Int) f, ArrayShape (ShapeBase (Ext a))) =>
TypeBase (ShapeBase (Ext a)) u -> [a] -> f [a]
extract [TypeBase (ShapeBase (Ext SubExp)) u]
ts [[a]]
shapes) Set Int
forall a. Set a
S.empty
where
extract :: TypeBase (ShapeBase (Ext a)) u -> [a] -> f [a]
extract TypeBase (ShapeBase (Ext a)) u
t [a]
shape =
[Maybe a] -> [a]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe a] -> [a]) -> f [Maybe a] -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext a -> a -> f (Maybe a)) -> [Ext a] -> [a] -> f [Maybe a]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Ext a -> a -> f (Maybe a)
forall (m :: * -> *) a a.
MonadState (Set Int) m =>
Ext a -> a -> m (Maybe a)
extract' (ShapeBase (Ext a) -> [Ext a]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext a) -> [Ext a]) -> ShapeBase (Ext a) -> [Ext a]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext a)) u -> ShapeBase (Ext a)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t) [a]
shape
extract' :: Ext a -> a -> m (Maybe a)
extract' (Ext Int
x) a
v = do
Bool
seen <- (Set Int -> Bool) -> m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Set Int -> Bool) -> m Bool) -> (Set Int -> Bool) -> m Bool
forall a b. (a -> b) -> a -> b
$ Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Int
x
if Bool
seen
then Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
else do
(Set Int -> Set Int) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set Int -> Set Int) -> m ()) -> (Set Int -> Set Int) -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> Set Int -> Set Int
forall a. Ord a => a -> Set a -> Set a
S.insert Int
x
Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ a -> Maybe a
forall a. a -> Maybe a
Just a
v
extract' (Free a
_) a
_ = Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
shapeContext :: [TypeBase ExtShape u] -> S.Set Int
shapeContext :: [TypeBase (ShapeBase (Ext SubExp)) u] -> Set Int
shapeContext =
[Int] -> Set Int
forall a. Ord a => [a] -> Set a
S.fromList
([Int] -> Set Int)
-> ([TypeBase (ShapeBase (Ext SubExp)) u] -> [Int])
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> Set Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeBase (ShapeBase (Ext SubExp)) u -> [Int])
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Ext SubExp -> Maybe Int) -> [Ext SubExp] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ext SubExp -> Maybe Int
forall a. Ext a -> Maybe Int
ext ([Ext SubExp] -> [Int])
-> (TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp])
-> TypeBase (ShapeBase (Ext SubExp)) u
-> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> (TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> [Ext SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape)
where
ext :: Ext a -> Maybe Int
ext (Ext Int
x) = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
x
ext (Free a
_) = Maybe Int
forall a. Maybe a
Nothing
hasStaticShape :: TypeBase ExtShape u -> Maybe (TypeBase Shape u)
hasStaticShape :: TypeBase (ShapeBase (Ext SubExp)) u -> Maybe (TypeBase Shape u)
hasStaticShape (Prim PrimType
bt) = TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a. a -> Maybe a
Just (TypeBase Shape u -> Maybe (TypeBase Shape u))
-> TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
hasStaticShape (Mem Space
space) = TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a. a -> Maybe a
Just (TypeBase Shape u -> Maybe (TypeBase Shape u))
-> TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ Space -> TypeBase Shape u
forall shape u. Space -> TypeBase shape u
Mem Space
space
hasStaticShape (Array PrimType
bt (Shape [Ext SubExp]
shape) u
u) =
PrimType -> Shape -> u -> TypeBase Shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt (Shape -> u -> TypeBase Shape u)
-> Maybe Shape -> Maybe (u -> TypeBase Shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> Maybe [SubExp] -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext SubExp -> Maybe SubExp) -> [Ext SubExp] -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp -> Maybe SubExp
forall a. Ext a -> Maybe a
isFree [Ext SubExp]
shape) Maybe (u -> TypeBase Shape u)
-> Maybe u -> Maybe (TypeBase Shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> Maybe u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
generaliseExtTypes ::
[TypeBase ExtShape u] ->
[TypeBase ExtShape u] ->
[TypeBase ExtShape u]
generaliseExtTypes :: [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
generaliseExtTypes [TypeBase (ShapeBase (Ext SubExp)) u]
rt1 [TypeBase (ShapeBase (Ext SubExp)) u]
rt2 =
State (Int, Map Int Int) [TypeBase (ShapeBase (Ext SubExp)) u]
-> (Int, Map Int Int) -> [TypeBase (ShapeBase (Ext SubExp)) u]
forall s a. State s a -> s -> a
evalState ((TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
-> StateT
(Int, Map Int Int) Identity (TypeBase (ShapeBase (Ext SubExp)) u))
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> State (Int, Map Int Int) [TypeBase (ShapeBase (Ext SubExp)) u]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
-> StateT
(Int, Map Int Int) Identity (TypeBase (ShapeBase (Ext SubExp)) u)
forall (f :: * -> *) a u u.
(ArrayShape (ShapeBase (Ext a)), MonadState (Int, Map Int Int) f,
Eq a) =>
TypeBase (ShapeBase (Ext a)) u
-> TypeBase (ShapeBase (Ext a)) u
-> f (TypeBase (ShapeBase (Ext a)) u)
unifyExtShapes [TypeBase (ShapeBase (Ext SubExp)) u]
rt1 [TypeBase (ShapeBase (Ext SubExp)) u]
rt2) (Int
0, Map Int Int
forall k a. Map k a
M.empty)
where
unifyExtShapes :: TypeBase (ShapeBase (Ext a)) u
-> TypeBase (ShapeBase (Ext a)) u
-> f (TypeBase (ShapeBase (Ext a)) u)
unifyExtShapes TypeBase (ShapeBase (Ext a)) u
t1 TypeBase (ShapeBase (Ext a)) u
t2 =
TypeBase (ShapeBase (Ext a)) u
-> ShapeBase (Ext a) -> TypeBase (ShapeBase (Ext a)) u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
setArrayShape TypeBase (ShapeBase (Ext a)) u
t1 (ShapeBase (Ext a) -> TypeBase (ShapeBase (Ext a)) u)
-> ([Ext a] -> ShapeBase (Ext a))
-> [Ext a]
-> TypeBase (ShapeBase (Ext a)) u
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ext a] -> ShapeBase (Ext a)
forall d. [d] -> ShapeBase d
Shape
([Ext a] -> TypeBase (ShapeBase (Ext a)) u)
-> f [Ext a] -> f (TypeBase (ShapeBase (Ext a)) u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext a -> Ext a -> f (Ext a)) -> [Ext a] -> [Ext a] -> f [Ext a]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
Ext a -> Ext a -> f (Ext a)
forall (m :: * -> *) a.
(MonadState (Int, Map Int Int) m, Eq a) =>
Ext a -> Ext a -> m (Ext a)
unifyExtDims
(ShapeBase (Ext a) -> [Ext a]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext a) -> [Ext a]) -> ShapeBase (Ext a) -> [Ext a]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext a)) u -> ShapeBase (Ext a)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t1)
(ShapeBase (Ext a) -> [Ext a]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext a) -> [Ext a]) -> ShapeBase (Ext a) -> [Ext a]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext a)) u -> ShapeBase (Ext a)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t2)
unifyExtDims :: Ext a -> Ext a -> m (Ext a)
unifyExtDims (Free a
se1) (Free a
se2)
| a
se1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
se2 = Ext a -> m (Ext a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ext a -> m (Ext a)) -> Ext a -> m (Ext a)
forall a b. (a -> b) -> a -> b
$ a -> Ext a
forall a. a -> Ext a
Free a
se1
| Bool
otherwise = do
(Int
n, Map Int Int
m) <- m (Int, Map Int Int)
forall s (m :: * -> *). MonadState s m => m s
get
(Int, Map Int Int) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Map Int Int
m)
Ext a -> m (Ext a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ext a -> m (Ext a)) -> Ext a -> m (Ext a)
forall a b. (a -> b) -> a -> b
$ Int -> Ext a
forall a. Int -> Ext a
Ext Int
n
unifyExtDims (Ext Int
x) (Ext Int
y)
| Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y =
Int -> Ext a
forall a. Int -> Ext a
Ext
(Int -> Ext a) -> m Int -> m (Ext a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( m Int -> (Int -> m Int) -> Maybe Int -> m Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> m Int
forall (m :: * -> *) b k.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x) Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return
(Maybe Int -> m Int) -> m (Maybe Int) -> m Int
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((Int, Map Int Int) -> Maybe Int) -> m (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Int -> Map Int Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x (Map Int Int -> Maybe Int)
-> ((Int, Map Int Int) -> Map Int Int)
-> (Int, Map Int Int)
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Map Int Int) -> Map Int Int
forall a b. (a, b) -> b
snd)
)
unifyExtDims (Ext Int
x) Ext a
_ = Int -> Ext a
forall a. Int -> Ext a
Ext (Int -> Ext a) -> m Int -> m (Ext a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m Int
forall (m :: * -> *) b k.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x
unifyExtDims Ext a
_ (Ext Int
x) = Int -> Ext a
forall a. Int -> Ext a
Ext (Int -> Ext a) -> m Int -> m (Ext a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m Int
forall (m :: * -> *) b k.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x
new :: k -> m b
new k
x = do
(b
n, Map k b
m) <- m (b, Map k b)
forall s (m :: * -> *). MonadState s m => m s
get
(b, Map k b) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (b
n b -> b -> b
forall a. Num a => a -> a -> a
+ b
1, k -> b -> Map k b -> Map k b
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
x b
n Map k b
m)
b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
n
existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes [VName]
inaccessible = (ExtType -> ExtType) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map ExtType -> ExtType
forall u.
TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
makeBoundShapesFree
where
makeBoundShapesFree :: TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
makeBoundShapesFree =
(ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape ((ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u)
-> (ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> Ext SubExp)
-> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
checkDim
checkDim :: Ext SubExp -> Ext SubExp
checkDim (Free (Var VName
v))
| Just Int
i <- VName
v VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
inaccessible =
Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
i
checkDim Ext SubExp
d = Ext SubExp
d
shapeExtMapping :: [TypeBase ExtShape u] -> [TypeBase Shape u1] -> M.Map Int SubExp
shapeExtMapping :: [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping = (TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp])
-> (TypeBase Shape u1 -> [SubExp])
-> (Ext SubExp -> SubExp -> Map Int SubExp)
-> (Map Int SubExp -> Map Int SubExp -> Map Int SubExp)
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase Shape u1]
-> Map Int SubExp
forall res t1 dim1 t2 dim2.
Monoid res =>
(t1 -> [dim1])
-> (t2 -> [dim2])
-> (dim1 -> dim2 -> res)
-> (res -> res -> res)
-> [t1]
-> [t2]
-> res
dimMapping TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp]
forall u. TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp]
arrayExtDims TypeBase Shape u1 -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Ext SubExp -> SubExp -> Map Int SubExp
forall a a. Ext a -> a -> Map Int a
match Map Int SubExp -> Map Int SubExp -> Map Int SubExp
forall a. Monoid a => a -> a -> a
mappend
where
match :: Ext a -> a -> Map Int a
match Free {} a
_ = Map Int a
forall a. Monoid a => a
mempty
match (Ext Int
i) a
dim = Int -> a -> Map Int a
forall k a. k -> a -> Map k a
M.singleton Int
i a
dim
dimMapping ::
Monoid res =>
(t1 -> [dim1]) ->
(t2 -> [dim2]) ->
(dim1 -> dim2 -> res) ->
(res -> res -> res) ->
[t1] ->
[t2] ->
res
dimMapping :: (t1 -> [dim1])
-> (t2 -> [dim2])
-> (dim1 -> dim2 -> res)
-> (res -> res -> res)
-> [t1]
-> [t2]
-> res
dimMapping t1 -> [dim1]
getDims1 t2 -> [dim2]
getDims2 dim1 -> dim2 -> res
f res -> res -> res
comb [t1]
ts1 [t2]
ts2 =
(res -> res -> res) -> res -> [res] -> res
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' res -> res -> res
comb res
forall a. Monoid a => a
mempty ([res] -> res) -> [res] -> res
forall a b. (a -> b) -> a -> b
$ [[res]] -> [res]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[res]] -> [res]) -> [[res]] -> [res]
forall a b. (a -> b) -> a -> b
$ ([dim1] -> [dim2] -> [res]) -> [[dim1]] -> [[dim2]] -> [[res]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((dim1 -> dim2 -> res) -> [dim1] -> [dim2] -> [res]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith dim1 -> dim2 -> res
f) ((t1 -> [dim1]) -> [t1] -> [[dim1]]
forall a b. (a -> b) -> [a] -> [b]
map t1 -> [dim1]
getDims1 [t1]
ts1) ((t2 -> [dim2]) -> [t2] -> [[dim2]]
forall a b. (a -> b) -> [a] -> [b]
map t2 -> [dim2]
getDims2 [t2]
ts2)
int8 :: PrimType
int8 :: PrimType
int8 = IntType -> PrimType
IntType IntType
Int8
int16 :: PrimType
int16 :: PrimType
int16 = IntType -> PrimType
IntType IntType
Int16
int32 :: PrimType
int32 :: PrimType
int32 = IntType -> PrimType
IntType IntType
Int32
int64 :: PrimType
int64 :: PrimType
int64 = IntType -> PrimType
IntType IntType
Int64
float32 :: PrimType
float32 :: PrimType
float32 = FloatType -> PrimType
FloatType FloatType
Float32
float64 :: PrimType
float64 :: PrimType
float64 = FloatType -> PrimType
FloatType FloatType
Float64
class Typed t where
typeOf :: t -> Type
instance Typed Type where
typeOf :: Type -> Type
typeOf = Type -> Type
forall a. a -> a
id
instance Typed DeclType where
typeOf :: DeclType -> Type
typeOf = DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl
instance Typed Ident where
typeOf :: Ident -> Type
typeOf = Ident -> Type
identType
instance Typed dec => Typed (Param dec) where
typeOf :: Param dec -> Type
typeOf = dec -> Type
forall t. Typed t => t -> Type
typeOf (dec -> Type) -> (Param dec -> dec) -> Param dec -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> dec
forall dec. Param dec -> dec
paramDec
instance Typed dec => Typed (PatElemT dec) where
typeOf :: PatElemT dec -> Type
typeOf = dec -> Type
forall t. Typed t => t -> Type
typeOf (dec -> Type) -> (PatElemT dec -> dec) -> PatElemT dec -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> dec
forall dec. PatElemT dec -> dec
patElemDec
instance Typed b => Typed (a, b) where
typeOf :: (a, b) -> Type
typeOf = b -> Type
forall t. Typed t => t -> Type
typeOf (b -> Type) -> ((a, b) -> b) -> (a, b) -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, b) -> b
forall a b. (a, b) -> b
snd
class DeclTyped t where
declTypeOf :: t -> DeclType
instance DeclTyped DeclType where
declTypeOf :: DeclType -> DeclType
declTypeOf = DeclType -> DeclType
forall a. a -> a
id
instance DeclTyped dec => DeclTyped (Param dec) where
declTypeOf :: Param dec -> DeclType
declTypeOf = dec -> DeclType
forall t. DeclTyped t => t -> DeclType
declTypeOf (dec -> DeclType) -> (Param dec -> dec) -> Param dec -> DeclType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> dec
forall dec. Param dec -> dec
paramDec
class FixExt t => ExtTyped t where
extTypeOf :: t -> ExtType
instance ExtTyped ExtType where
extTypeOf :: ExtType -> ExtType
extTypeOf = ExtType -> ExtType
forall a. a -> a
id
class FixExt t => DeclExtTyped t where
declExtTypeOf :: t -> DeclExtType
instance DeclExtTyped DeclExtType where
declExtTypeOf :: DeclExtType -> DeclExtType
declExtTypeOf = DeclExtType -> DeclExtType
forall a. a -> a
id
class Typed a => SetType a where
setType :: a -> Type -> a
instance SetType Type where
setType :: Type -> Type -> Type
setType Type
_ Type
t = Type
t
instance SetType b => SetType (a, b) where
setType :: (a, b) -> Type -> (a, b)
setType (a
a, b
b) Type
t = (a
a, b -> Type -> b
forall a. SetType a => a -> Type -> a
setType b
b Type
t)
instance SetType dec => SetType (PatElemT dec) where
setType :: PatElemT dec -> Type -> PatElemT dec
setType (PatElem VName
name dec
dec) Type
t =
VName -> dec -> PatElemT dec
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (dec -> PatElemT dec) -> dec -> PatElemT dec
forall a b. (a -> b) -> a -> b
$ dec -> Type -> dec
forall a. SetType a => a -> Type -> a
setType dec
dec Type
t
class FixExt t where
fixExt :: Int -> SubExp -> t -> t
instance (FixExt shape, ArrayShape shape) => FixExt (TypeBase shape u) where
fixExt :: Int -> SubExp -> TypeBase shape u -> TypeBase shape u
fixExt Int
i SubExp
se = (shape -> shape) -> TypeBase shape u -> TypeBase shape u
forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape ((shape -> shape) -> TypeBase shape u -> TypeBase shape u)
-> (shape -> shape) -> TypeBase shape u -> TypeBase shape u
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> shape -> shape
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se
instance FixExt d => FixExt (ShapeBase d) where
fixExt :: Int -> SubExp -> ShapeBase d -> ShapeBase d
fixExt Int
i SubExp
se = (d -> d) -> ShapeBase d -> ShapeBase d
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((d -> d) -> ShapeBase d -> ShapeBase d)
-> (d -> d) -> ShapeBase d -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> d -> d
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se
instance FixExt a => FixExt [a] where
fixExt :: Int -> SubExp -> [a] -> [a]
fixExt Int
i SubExp
se = (a -> a) -> [a] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> a) -> [a] -> [a]) -> (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> a -> a
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se
instance FixExt ExtSize where
fixExt :: Int -> SubExp -> Ext SubExp -> Ext SubExp
fixExt Int
i SubExp
se (Ext Int
j)
| Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i = Int -> Ext SubExp
forall a. Int -> Ext a
Ext (Int -> Ext SubExp) -> Int -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
| Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
se
| Bool
otherwise = Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
j
fixExt Int
_ SubExp
_ (Free SubExp
x) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
x
instance FixExt () where
fixExt :: Int -> SubExp -> () -> ()
fixExt Int
_ SubExp
_ () = ()