{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.Prop
( module Futhark.IR.Prop.Reshape,
module Futhark.IR.Prop.Rearrange,
module Futhark.IR.Prop.Types,
module Futhark.IR.Prop.Constants,
module Futhark.IR.Prop.TypeOf,
module Futhark.IR.Prop.Patterns,
module Futhark.IR.Prop.Names,
module Futhark.IR.RetType,
isBuiltInFunction,
builtInFunctions,
asBasicOp,
safeExp,
subExpVars,
subExpVar,
commutativeLambda,
defAux,
stmCerts,
certify,
expExtTypesFromPat,
attrsForAssert,
lamIsBinOp,
ASTConstraints,
IsOp (..),
ASTRep (..),
)
where
import Control.Monad
import Data.List (elemIndex, find)
import Data.Map.Strict qualified as M
import Data.Maybe (isJust, mapMaybe)
import Data.Set qualified as S
import Futhark.IR.Pretty
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Names
import Futhark.IR.Prop.Patterns
import Futhark.IR.Prop.Rearrange
import Futhark.IR.Prop.Reshape
import Futhark.IR.Prop.TypeOf
import Futhark.IR.Prop.Types
import Futhark.IR.RetType
import Futhark.IR.Syntax
import Futhark.Transform.Rename (Rename, Renameable)
import Futhark.Transform.Substitute (Substitutable, Substitute)
import Futhark.Util (maybeNth)
isBuiltInFunction :: Name -> Bool
isBuiltInFunction :: Name -> Bool
isBuiltInFunction Name
fnm = Name
fnm forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name (PrimType, [PrimType])
builtInFunctions
builtInFunctions :: M.Map Name (PrimType, [PrimType])
builtInFunctions :: Map Name (PrimType, [PrimType])
builtInFunctions = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {b} {a} {c}. (String, (b, a, c)) -> (Name, (a, b))
namify forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
where
namify :: (String, (b, a, c)) -> (Name, (a, b))
namify (String
k, (b
paramts, a
ret, c
_)) = (String -> Name
nameFromString String
k, (a
ret, b
paramts))
asBasicOp :: Exp rep -> Maybe BasicOp
asBasicOp :: forall {k} (rep :: k). Exp rep -> Maybe BasicOp
asBasicOp (BasicOp BasicOp
op) = forall a. a -> Maybe a
Just BasicOp
op
asBasicOp Exp rep
_ = forall a. Maybe a
Nothing
safeExp :: IsOp (Op rep) => Exp rep -> Bool
safeExp :: forall {k} (rep :: k). IsOp (Op rep) => Exp rep -> Bool
safeExp (BasicOp BasicOp
op) = BasicOp -> Bool
safeBasicOp BasicOp
op
where
safeBasicOp :: BasicOp -> Bool
safeBasicOp (BinOp (SDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SQuot IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SRem IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp SDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SDiv {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SDivUp {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UDiv {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UDivUp {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SMod {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UMod {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SQuot {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SQuot {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SRem {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SRem {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp Pow {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
y
safeBasicOp (BinOp Pow {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp ArrayLit {} = Bool
True
safeBasicOp BinOp {} = Bool
True
safeBasicOp SubExp {} = Bool
True
safeBasicOp UnOp {} = Bool
True
safeBasicOp CmpOp {} = Bool
True
safeBasicOp ConvOp {} = Bool
True
safeBasicOp Scratch {} = Bool
True
safeBasicOp Concat {} = Bool
True
safeBasicOp Reshape {} = Bool
True
safeBasicOp Rearrange {} = Bool
True
safeBasicOp Manifest {} = Bool
True
safeBasicOp Iota {} = Bool
True
safeBasicOp Replicate {} = Bool
True
safeBasicOp Copy {} = Bool
True
safeBasicOp BasicOp
_ = Bool
False
safeExp (DoLoop [(FParam rep, SubExp)]
_ LoopForm rep
_ Body rep
body) = forall {k} (rep :: k). IsOp (Op rep) => Body rep -> Bool
safeBody Body rep
body
safeExp (Apply Name
fname [(SubExp, Diet)]
_ [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) =
Name -> Bool
isBuiltInFunction Name
fname
safeExp (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
def_case MatchDec (BranchType rep)
_) =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp) (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
def_case)
safeExp WithAcc {} = Bool
True
safeExp (Op Op rep
op) = forall op. IsOp op => op -> Bool
safeOp Op rep
op
safeBody :: IsOp (Op rep) => Body rep -> Bool
safeBody :: forall {k} (rep :: k). IsOp (Op rep) => Body rep -> Bool
safeBody = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms
subExpVars :: [SubExp] -> [VName]
subExpVars :: [SubExp] -> [VName]
subExpVars = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
subExpVar
subExpVar :: SubExp -> Maybe VName
subExpVar :: SubExp -> Maybe VName
subExpVar (Var VName
v) = forall a. a -> Maybe a
Just VName
v
subExpVar Constant {} = forall a. Maybe a
Nothing
commutativeLambda :: Lambda rep -> Bool
commutativeLambda :: forall {k} (rep :: k). Lambda rep -> Bool
commutativeLambda Lambda rep
lam =
let body :: Body rep
body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
n2 :: Int
n2 = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) forall a. Integral a => a -> a -> a
`div` Int
2
([Param (LParamInfo rep)]
xps, [Param (LParamInfo rep)]
yps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n2 (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
okComponent :: (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes) -> Bool
okComponent (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
c = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (forall {k} {dec} {dec} {rep :: k}.
(Param dec, Param dec, SubExpRes) -> Stm rep -> Bool
okBinOp (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
c) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
okBinOp :: (Param dec, Param dec, SubExpRes) -> Stm rep -> Bool
okBinOp
(Param dec
xp, Param dec
yp, SubExpRes Certs
_ (Var VName
r))
(Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y)))) =
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall a. Eq a => a -> a -> Bool
== VName
r
Bool -> Bool -> Bool
&& BinOp -> Bool
commutativeBinOp BinOp
op
Bool -> Bool -> Bool
&& ( (VName
x forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
y forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
yp)
Bool -> Bool -> Bool
|| (VName
y forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
x forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
yp)
)
okBinOp (Param dec, Param dec, SubExpRes)
_ Stm rep
_ = Bool
False
in Int
n2 forall a. Num a => a -> a -> a
* Int
2 forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
Bool -> Bool -> Bool
&& Int
n2 forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
body)
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes) -> Bool
okComponent (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param (LParamInfo rep)]
xps [Param (LParamInfo rep)]
yps forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
body)
defAux :: dec -> StmAux dec
defAux :: forall dec. dec -> StmAux dec
defAux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
stmCerts :: Stm rep -> Certs
stmCerts :: forall {k} (rep :: k). Stm rep -> Certs
stmCerts = forall dec. StmAux dec -> Certs
stmAuxCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux
certify :: Certs -> Stm rep -> Stm rep
certify :: forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs1 (Let Pat (LetDec rep)
pat (StmAux Certs
cs2 Attrs
attrs ExpDec rep
dec) Exp rep
e) =
forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux (Certs
cs2 forall a. Semigroup a => a -> a -> a
<> Certs
cs1) Attrs
attrs ExpDec rep
dec) Exp rep
e
type ASTConstraints a =
(Eq a, Ord a, Show a, Rename a, Substitute a, FreeIn a, Pretty a)
class (ASTConstraints op, TypedOp op) => IsOp op where
safeOp :: op -> Bool
cheapOp :: op -> Bool
instance IsOp () where
safeOp :: () -> Bool
safeOp () = Bool
True
cheapOp :: () -> Bool
cheapOp () = Bool
True
class
( RepTypes rep,
PrettyRep rep,
Renameable rep,
Substitutable rep,
FreeDec (ExpDec rep),
FreeIn (LetDec rep),
FreeDec (BodyDec rep),
FreeIn (FParamInfo rep),
FreeIn (LParamInfo rep),
FreeIn (RetType rep),
FreeIn (BranchType rep),
IsOp (Op rep)
) =>
ASTRep rep
where
expTypesFromPat ::
(HasScope rep m, Monad m) =>
Pat (LetDec rep) ->
m [BranchType rep]
expExtTypesFromPat :: Typed dec => Pat dec -> [ExtType]
expExtTypesFromPat :: forall dec. Typed dec => Pat dec -> [ExtType]
expExtTypesFromPat Pat dec
pat =
[VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (forall dec. Pat dec -> [VName]
patNames Pat dec
pat) forall a b. (a -> b) -> a -> b
$
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType forall a b. (a -> b) -> a -> b
$
forall dec. Pat dec -> [PatElem dec]
patElems Pat dec
pat
attrsForAssert :: Attrs -> Attrs
attrsForAssert :: Attrs -> Attrs
attrsForAssert (Attrs Set Attr
attrs) =
Set Attr -> Attrs
Attrs forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> Set a -> Set a
S.filter Attr -> Bool
attrForAssert Set Attr
attrs
where
attrForAssert :: Attr -> Bool
attrForAssert = (forall a. Eq a => a -> a -> Bool
== Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"])
lamIsBinOp :: ASTRep rep => Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp :: forall {k} (rep :: k).
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda rep
lam = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> Maybe (BinOp, PrimType, VName, VName)
splitStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
where
n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam
splitStm :: SubExpRes -> Maybe (BinOp, PrimType, VName, VName)
splitStm (SubExpRes Certs
cs (Var VName
res)) = do
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res] ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
Int
i <- VName -> SubExp
Var VName
res forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
Param (LParamInfo rep)
xp <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
Param (LParamInfo rep)
yp <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
n forall a. Num a => a -> a -> a
+ Int
i) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
xp forall a. Eq a => a -> a -> Bool
== VName
x
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
yp forall a. Eq a => a -> a -> Bool
== VName
y
Prim PrimType
t <- forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec rep)
pe
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BinOp
op, PrimType
t, forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
xp, forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
yp)
splitStm SubExpRes
_ = forall a. Maybe a
Nothing