{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.Prop
( module Futhark.IR.Prop.Reshape,
module Futhark.IR.Prop.Rearrange,
module Futhark.IR.Prop.Types,
module Futhark.IR.Prop.Constants,
module Futhark.IR.Prop.TypeOf,
module Futhark.IR.Prop.Patterns,
module Futhark.IR.Prop.Names,
module Futhark.IR.RetType,
isBuiltInFunction,
builtInFunctions,
asBasicOp,
safeExp,
subExpVars,
subExpVar,
commutativeLambda,
entryPointSize,
defAux,
stmCerts,
certify,
expExtTypesFromPat,
attrsForAssert,
lamIsBinOp,
ASTConstraints,
IsOp (..),
ASTRep (..),
)
where
import Control.Monad
import Data.List (elemIndex, find)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust, mapMaybe)
import qualified Data.Set as S
import Futhark.IR.Pretty
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Names
import Futhark.IR.Prop.Patterns
import Futhark.IR.Prop.Rearrange
import Futhark.IR.Prop.Reshape
import Futhark.IR.Prop.TypeOf
import Futhark.IR.Prop.Types
import Futhark.IR.RetType
import Futhark.IR.Syntax
import Futhark.Transform.Rename (Rename, Renameable)
import Futhark.Transform.Substitute (Substitutable, Substitute)
import Futhark.Util (maybeNth)
import Futhark.Util.Pretty
isBuiltInFunction :: Name -> Bool
isBuiltInFunction :: Name -> Bool
isBuiltInFunction Name
fnm = Name
fnm Name -> Map Name (PrimType, [PrimType]) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name (PrimType, [PrimType])
builtInFunctions
builtInFunctions :: M.Map Name (PrimType, [PrimType])
builtInFunctions :: Map Name (PrimType, [PrimType])
builtInFunctions = [(Name, (PrimType, [PrimType]))] -> Map Name (PrimType, [PrimType])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, (PrimType, [PrimType]))]
-> Map Name (PrimType, [PrimType]))
-> [(Name, (PrimType, [PrimType]))]
-> Map Name (PrimType, [PrimType])
forall a b. (a -> b) -> a -> b
$ ((String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> (Name, (PrimType, [PrimType])))
-> [(String,
([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))]
forall a b. (a -> b) -> [a] -> [b]
map (String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> (Name, (PrimType, [PrimType]))
forall b a c. (String, (b, a, c)) -> (Name, (a, b))
namify ([(String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))])
-> [(String,
([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))]
forall a b. (a -> b) -> a -> b
$ Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> [(String,
([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
forall k a. Map k a -> [(k, a)]
M.toList Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
where
namify :: (String, (b, a, c)) -> (Name, (a, b))
namify (String
k, (b
paramts, a
ret, c
_)) = (String -> Name
nameFromString String
k, (a
ret, b
paramts))
asBasicOp :: Exp rep -> Maybe BasicOp
asBasicOp :: Exp rep -> Maybe BasicOp
asBasicOp (BasicOp BasicOp
op) = BasicOp -> Maybe BasicOp
forall a. a -> Maybe a
Just BasicOp
op
asBasicOp Exp rep
_ = Maybe BasicOp
forall a. Maybe a
Nothing
safeExp :: IsOp (Op rep) => Exp rep -> Bool
safeExp :: Exp rep -> Bool
safeExp (BasicOp BasicOp
op) = BasicOp -> Bool
safeBasicOp BasicOp
op
where
safeBasicOp :: BasicOp -> Bool
safeBasicOp (BinOp (SDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SQuot IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (SRem IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp (UMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
safeBasicOp (BinOp SDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SDiv {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SDivUp {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UDiv {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UDivUp {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SMod {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp UMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp UMod {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SQuot {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SQuot {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp SRem {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
safeBasicOp (BinOp SRem {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp (BinOp Pow {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
y
safeBasicOp (BinOp Pow {} SubExp
_ SubExp
_) = Bool
False
safeBasicOp ArrayLit {} = Bool
True
safeBasicOp BinOp {} = Bool
True
safeBasicOp SubExp {} = Bool
True
safeBasicOp UnOp {} = Bool
True
safeBasicOp CmpOp {} = Bool
True
safeBasicOp ConvOp {} = Bool
True
safeBasicOp Scratch {} = Bool
True
safeBasicOp Concat {} = Bool
True
safeBasicOp Reshape {} = Bool
True
safeBasicOp Rearrange {} = Bool
True
safeBasicOp Manifest {} = Bool
True
safeBasicOp Iota {} = Bool
True
safeBasicOp Replicate {} = Bool
True
safeBasicOp Copy {} = Bool
True
safeBasicOp BasicOp
_ = Bool
False
safeExp (DoLoop [(FParam rep, SubExp)]
_ LoopForm rep
_ Body rep
body) = Body rep -> Bool
forall rep. IsOp (Op rep) => Body rep -> Bool
safeBody Body rep
body
safeExp (Apply Name
fname [(SubExp, Diet)]
_ [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) =
Name -> Bool
isBuiltInFunction Name
fname
safeExp (If SubExp
_ Body rep
tbranch Body rep
fbranch IfDec (BranchType rep)
_) =
(Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
tbranch)
Bool -> Bool -> Bool
&& (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
fbranch)
safeExp WithAcc {} = Bool
True
safeExp (Op Op rep
op) = Op rep -> Bool
forall op. IsOp op => op -> Bool
safeOp Op rep
op
safeBody :: IsOp (Op rep) => Body rep -> Bool
safeBody :: Body rep -> Bool
safeBody = (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Seq (Stm rep) -> Bool)
-> (Body rep -> Seq (Stm rep)) -> Body rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms
subExpVars :: [SubExp] -> [VName]
subExpVars :: [SubExp] -> [VName]
subExpVars = (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
subExpVar
subExpVar :: SubExp -> Maybe VName
subExpVar :: SubExp -> Maybe VName
subExpVar (Var VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
subExpVar Constant {} = Maybe VName
forall a. Maybe a
Nothing
commutativeLambda :: Lambda rep -> Bool
commutativeLambda :: Lambda rep -> Bool
commutativeLambda Lambda rep
lam =
let body :: Body rep
body = Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
n2 :: Int
n2 = [Param (LParamInfo rep)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
([Param (LParamInfo rep)]
xps, [Param (LParamInfo rep)]
yps) = Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n2 (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
okComponent :: (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes) -> Bool
okComponent (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
c = Maybe (Stm rep) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Stm rep) -> Bool) -> Maybe (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm rep -> Bool) -> Seq (Stm rep) -> Maybe (Stm rep)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
-> Stm rep -> Bool
forall dec dec rep.
(Param dec, Param dec, SubExpRes) -> Stm rep -> Bool
okBinOp (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
c) (Seq (Stm rep) -> Maybe (Stm rep))
-> Seq (Stm rep) -> Maybe (Stm rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
okBinOp :: (Param dec, Param dec, SubExpRes) -> Stm rep -> Bool
okBinOp
(Param dec
xp, Param dec
yp, SubExpRes Certs
_ (Var VName
r))
(Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y)))) =
PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
r
Bool -> Bool -> Bool
&& BinOp -> Bool
commutativeBinOp BinOp
op
Bool -> Bool -> Bool
&& ( (VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
yp)
Bool -> Bool -> Bool
|| (VName
y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
yp)
)
okBinOp (Param dec, Param dec, SubExpRes)
_ Stm rep
_ = Bool
False
in Int
n2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Param (LParamInfo rep)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
Bool -> Bool -> Bool
&& Int
n2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExpRes] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
body)
Bool -> Bool -> Bool
&& ((Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
-> Bool)
-> [(Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes) -> Bool
okComponent ([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [SubExpRes]
-> [(Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param (LParamInfo rep)]
xps [Param (LParamInfo rep)]
yps ([SubExpRes]
-> [(Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)])
-> [SubExpRes]
-> [(Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
body)
entryPointSize :: EntryPointType -> Int
entryPointSize :: EntryPointType -> Int
entryPointSize (TypeOpaque Uniqueness
_ String
_ Int
x) = Int
x
entryPointSize (TypeUnsigned Uniqueness
_) = Int
1
entryPointSize (TypeDirect Uniqueness
_) = Int
1
defAux :: dec -> StmAux dec
defAux :: dec -> StmAux dec
defAux = Certs -> Attrs -> dec -> StmAux dec
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty
stmCerts :: Stm rep -> Certs
stmCerts :: Stm rep -> Certs
stmCerts = StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts (StmAux (ExpDec rep) -> Certs)
-> (Stm rep -> StmAux (ExpDec rep)) -> Stm rep -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> StmAux (ExpDec rep)
forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux
certify :: Certs -> Stm rep -> Stm rep
certify :: Certs -> Stm rep -> Stm rep
certify Certs
cs1 (Let Pat (LetDec rep)
pat (StmAux Certs
cs2 Attrs
attrs ExpDec rep
dec) Exp rep
e) =
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux (Certs
cs2 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs1) Attrs
attrs ExpDec rep
dec) Exp rep
e
type ASTConstraints a =
(Eq a, Ord a, Show a, Rename a, Substitute a, FreeIn a, Pretty a)
class (ASTConstraints op, TypedOp op) => IsOp op where
safeOp :: op -> Bool
cheapOp :: op -> Bool
instance IsOp () where
safeOp :: () -> Bool
safeOp () = Bool
True
cheapOp :: () -> Bool
cheapOp () = Bool
True
class
( RepTypes rep,
PrettyRep rep,
Renameable rep,
Substitutable rep,
FreeDec (ExpDec rep),
FreeIn (LetDec rep),
FreeDec (BodyDec rep),
FreeIn (FParamInfo rep),
FreeIn (LParamInfo rep),
FreeIn (RetType rep),
FreeIn (BranchType rep),
IsOp (Op rep)
) =>
ASTRep rep
where
expTypesFromPat ::
(HasScope rep m, Monad m) =>
Pat (LetDec rep) ->
m [BranchType rep]
expExtTypesFromPat :: Typed dec => Pat dec -> [ExtType]
expExtTypesFromPat :: Pat dec -> [ExtType]
expExtTypesFromPat Pat dec
pat =
[VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (Pat dec -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat dec
pat) ([ExtType] -> [ExtType]) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> a -> b
$
[TypeBase Shape NoUniqueness] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([TypeBase Shape NoUniqueness] -> [ExtType])
-> [TypeBase Shape NoUniqueness] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ (PatElem dec -> TypeBase Shape NoUniqueness)
-> [PatElem dec] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PatElem dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType ([PatElem dec] -> [TypeBase Shape NoUniqueness])
-> [PatElem dec] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Pat dec -> [PatElem dec]
forall dec. Pat dec -> [PatElem dec]
patElems Pat dec
pat
attrsForAssert :: Attrs -> Attrs
attrsForAssert :: Attrs -> Attrs
attrsForAssert (Attrs Set Attr
attrs) =
Set Attr -> Attrs
Attrs (Set Attr -> Attrs) -> Set Attr -> Attrs
forall a b. (a -> b) -> a -> b
$ (Attr -> Bool) -> Set Attr -> Set Attr
forall a. (a -> Bool) -> Set a -> Set a
S.filter Attr -> Bool
attrForAssert Set Attr
attrs
where
attrForAssert :: Attr -> Bool
attrForAssert = (Attr -> Attr -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"])
lamIsBinOp :: ASTRep rep => Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp :: Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda rep
lam = (SubExpRes -> Maybe (BinOp, PrimType, VName, VName))
-> [SubExpRes] -> Maybe [(BinOp, PrimType, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> Maybe (BinOp, PrimType, VName, VName)
splitStm ([SubExpRes] -> Maybe [(BinOp, PrimType, VName, VName)])
-> [SubExpRes] -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body rep -> [SubExpRes]) -> Body rep -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
where
n :: Int
n = [TypeBase Shape NoUniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TypeBase Shape NoUniqueness] -> Int)
-> [TypeBase Shape NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam
splitStm :: SubExpRes -> Maybe (BinOp, PrimType, VName, VName)
splitStm (SubExpRes Certs
cs (Var VName
res)) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Certs
cs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty
Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
(Stm rep -> Bool) -> [Stm rep] -> Maybe (Stm rep)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res] [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
==) ([VName] -> Bool) -> (Stm rep -> [VName]) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) ([Stm rep] -> Maybe (Stm rep)) -> [Stm rep] -> Maybe (Stm rep)
forall a b. (a -> b) -> a -> b
$
Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
Int
i <- VName -> SubExp
Var VName
res SubExp -> [SubExp] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
Param (LParamInfo rep)
xp <- Int -> [Param (LParamInfo rep)] -> Maybe (Param (LParamInfo rep))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([Param (LParamInfo rep)] -> Maybe (Param (LParamInfo rep)))
-> [Param (LParamInfo rep)] -> Maybe (Param (LParamInfo rep))
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
Param (LParamInfo rep)
yp <- Int -> [Param (LParamInfo rep)] -> Maybe (Param (LParamInfo rep))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) ([Param (LParamInfo rep)] -> Maybe (Param (LParamInfo rep)))
-> [Param (LParamInfo rep)] -> Maybe (Param (LParamInfo rep))
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
xp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
yp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
Prim PrimType
t <- TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec rep)
pe
(BinOp, PrimType, VName, VName)
-> Maybe (BinOp, PrimType, VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinOp
op, PrimType
t, Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
xp, Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
yp)
splitStm SubExpRes
_ = Maybe (BinOp, PrimType, VName, VName)
forall a. Maybe a
Nothing