{-# LANGUAGE TypeFamilies #-}

-- | This module provides various simple ways to query and manipulate
-- fundamental Futhark terms, such as types and values.  The intent is
-- to keep "Futhark.IR.Syntax" simple, and put whatever embellishments
-- we need here.  This is an internal, desugared representation.
module Futhark.IR.Prop
  ( module Futhark.IR.Prop.Reshape,
    module Futhark.IR.Prop.Rearrange,
    module Futhark.IR.Prop.Types,
    module Futhark.IR.Prop.Constants,
    module Futhark.IR.Prop.TypeOf,
    module Futhark.IR.Prop.Patterns,
    module Futhark.IR.Prop.Names,
    module Futhark.IR.RetType,
    module Futhark.IR.Rephrase,

    -- * Built-in functions
    isBuiltInFunction,
    builtInFunctions,

    -- * Extra tools
    asBasicOp,
    safeExp,
    subExpVars,
    subExpVar,
    commutativeLambda,
    defAux,
    stmCerts,
    certify,
    expExtTypesFromPat,
    attrsForAssert,
    lamIsBinOp,
    ASTConstraints,
    IsOp (..),
    ASTRep (..),
  )
where

import Control.Monad
import Data.List (elemIndex, find)
import Data.Map.Strict qualified as M
import Data.Maybe (isJust, mapMaybe)
import Data.Set qualified as S
import Futhark.IR.Pretty
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Names
import Futhark.IR.Prop.Patterns
import Futhark.IR.Prop.Rearrange
import Futhark.IR.Prop.Reshape
import Futhark.IR.Prop.TypeOf
import Futhark.IR.Prop.Types
import Futhark.IR.Rephrase
import Futhark.IR.RetType
import Futhark.IR.Syntax
import Futhark.Transform.Rename (Rename, Renameable)
import Futhark.Transform.Substitute (Substitutable, Substitute)
import Futhark.Util (maybeNth)

-- | @isBuiltInFunction k@ is 'True' if @k@ is an element of 'builtInFunctions'.
isBuiltInFunction :: Name -> Bool
isBuiltInFunction :: Name -> Bool
isBuiltInFunction Name
fnm = Name
fnm forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name (PrimType, [PrimType])
builtInFunctions

-- | A map of all built-in functions and their types.
builtInFunctions :: M.Map Name (PrimType, [PrimType])
builtInFunctions :: Map Name (PrimType, [PrimType])
builtInFunctions = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {b} {a} {c}. (String, (b, a, c)) -> (Name, (a, b))
namify forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
  where
    namify :: (String, (b, a, c)) -> (Name, (a, b))
namify (String
k, (b
paramts, a
ret, c
_)) = (String -> Name
nameFromString String
k, (a
ret, b
paramts))

-- | If the expression is a t'BasicOp', return it, otherwise 'Nothing'.
asBasicOp :: Exp rep -> Maybe BasicOp
asBasicOp :: forall rep. Exp rep -> Maybe BasicOp
asBasicOp (BasicOp BasicOp
op) = forall a. a -> Maybe a
Just BasicOp
op
asBasicOp Exp rep
_ = forall a. Maybe a
Nothing

-- | An expression is safe if it is always well-defined (assuming that
-- any required certificates have been checked) in any context.  For
-- example, array indexing is not safe, as the index may be out of
-- bounds.  On the other hand, adding two numbers cannot fail.
safeExp :: IsOp (Op rep) => Exp rep -> Bool
safeExp :: forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (BasicOp BasicOp
op) = BasicOp -> Bool
safeBasicOp BasicOp
op
  where
    safeBasicOp :: BasicOp -> Bool
safeBasicOp (BinOp (SDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp (SDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp (SQuot IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp (UDiv IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp (UDivUp IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp (SMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp (SRem IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp (UMod IntType
_ Safety
Safe) SubExp
_ SubExp
_) = Bool
True
    safeBasicOp (BinOp SDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp SDiv {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp SDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp SDivUp {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp UDiv {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp UDiv {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp UDivUp {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp UDivUp {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp SMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp SMod {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp UMod {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp UMod {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp SQuot {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp SQuot {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp SRem {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
    safeBasicOp (BinOp SRem {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp (BinOp Pow {} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
y
    safeBasicOp (BinOp Pow {} SubExp
_ SubExp
_) = Bool
False
    safeBasicOp ArrayLit {} = Bool
True
    safeBasicOp BinOp {} = Bool
True
    safeBasicOp SubExp {} = Bool
True
    safeBasicOp UnOp {} = Bool
True
    safeBasicOp CmpOp {} = Bool
True
    safeBasicOp ConvOp {} = Bool
True
    safeBasicOp Scratch {} = Bool
True
    safeBasicOp Concat {} = Bool
True
    safeBasicOp Reshape {} = Bool
True
    safeBasicOp Rearrange {} = Bool
True
    safeBasicOp Manifest {} = Bool
True
    safeBasicOp Iota {} = Bool
True
    safeBasicOp Replicate {} = Bool
True
    safeBasicOp Copy {} = Bool
True
    safeBasicOp BasicOp
_ = Bool
False
safeExp (DoLoop [(FParam rep, SubExp)]
_ LoopForm rep
_ Body rep
body) = forall rep. IsOp (Op rep) => Body rep -> Bool
safeBody Body rep
body
safeExp (Apply Name
fname [(SubExp, Diet)]
_ [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) =
  Name -> Bool
isBuiltInFunction Name
fname
safeExp (Match [SubExp]
_ [Case (Body rep)]
cases Body rep
def_case MatchDec (BranchType rep)
_) =
  forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
    Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp) (forall rep. Body rep -> Stms rep
bodyStms Body rep
def_case)
safeExp WithAcc {} = Bool
True -- Although unlikely to matter.
safeExp (Op Op rep
op) = forall op. IsOp op => op -> Bool
safeOp Op rep
op

safeBody :: IsOp (Op rep) => Body rep -> Bool
safeBody :: forall rep. IsOp (Op rep) => Body rep -> Bool
safeBody = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms

-- | Return the variable names used in 'Var' subexpressions.  May contain
-- duplicates.
subExpVars :: [SubExp] -> [VName]
subExpVars :: [SubExp] -> [VName]
subExpVars = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
subExpVar

-- | If the t'SubExp' is a 'Var' return the variable name.
subExpVar :: SubExp -> Maybe VName
subExpVar :: SubExp -> Maybe VName
subExpVar (Var VName
v) = forall a. a -> Maybe a
Just VName
v
subExpVar Constant {} = forall a. Maybe a
Nothing

-- | Does the given lambda represent a known commutative function?
-- Based on pattern matching and checking whether the lambda
-- represents a known arithmetic operator; don't expect anything
-- clever here.
commutativeLambda :: Lambda rep -> Bool
commutativeLambda :: forall rep. Lambda rep -> Bool
commutativeLambda Lambda rep
lam =
  let body :: Body rep
body = forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
      n2 :: Int
n2 = forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) forall a. Integral a => a -> a -> a
`div` Int
2
      ([Param (LParamInfo rep)]
xps, [Param (LParamInfo rep)]
yps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n2 (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)

      okComponent :: (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes) -> Bool
okComponent (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
c = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (forall {dec} {dec} {rep}.
(Param dec, Param dec, SubExpRes) -> Stm rep -> Bool
okBinOp (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes)
c) forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
body
      okBinOp :: (Param dec, Param dec, SubExpRes) -> Stm rep -> Bool
okBinOp
        (Param dec
xp, Param dec
yp, SubExpRes Certs
_ (Var VName
r))
        (Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y)))) =
          forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall a. Eq a => a -> a -> Bool
== VName
r
            Bool -> Bool -> Bool
&& BinOp -> Bool
commutativeBinOp BinOp
op
            Bool -> Bool -> Bool
&& ( (VName
x forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
y forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
yp)
                   Bool -> Bool -> Bool
|| (VName
y forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
xp Bool -> Bool -> Bool
&& VName
x forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
yp)
               )
      okBinOp (Param dec, Param dec, SubExpRes)
_ Stm rep
_ = Bool
False
   in Int
n2 forall a. Num a => a -> a -> a
* Int
2 forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam)
        Bool -> Bool -> Bool
&& Int
n2 forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Body rep -> Result
bodyResult Body rep
body)
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (LParamInfo rep), Param (LParamInfo rep), SubExpRes) -> Bool
okComponent (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param (LParamInfo rep)]
xps [Param (LParamInfo rep)]
yps forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body rep
body)

-- | A 'StmAux' with empty 'Certs'.
defAux :: dec -> StmAux dec
defAux :: forall dec. dec -> StmAux dec
defAux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

-- | The certificates associated with a statement.
stmCerts :: Stm rep -> Certs
stmCerts :: forall rep. Stm rep -> Certs
stmCerts = forall dec. StmAux dec -> Certs
stmAuxCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> StmAux (ExpDec rep)
stmAux

-- | Add certificates to a statement.
certify :: Certs -> Stm rep -> Stm rep
certify :: forall rep. Certs -> Stm rep -> Stm rep
certify Certs
cs1 (Let Pat (LetDec rep)
pat (StmAux Certs
cs2 Attrs
attrs ExpDec rep
dec) Exp rep
e) =
  forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux (Certs
cs2 forall a. Semigroup a => a -> a -> a
<> Certs
cs1) Attrs
attrs ExpDec rep
dec) Exp rep
e

-- | A handy shorthand for properties that we usually want for things
-- we stuff into ASTs.
type ASTConstraints a =
  (Eq a, Ord a, Show a, Rename a, Substitute a, FreeIn a, Pretty a)

-- | A type class for operations.
class (ASTConstraints op, TypedOp op) => IsOp op where
  -- | Like 'safeExp', but for arbitrary ops.
  safeOp :: op -> Bool

  -- | Should we try to hoist this out of branches?
  cheapOp :: op -> Bool

instance IsOp (NoOp rep) where
  safeOp :: NoOp rep -> Bool
safeOp NoOp rep
NoOp = Bool
True
  cheapOp :: NoOp rep -> Bool
cheapOp NoOp rep
NoOp = Bool
True

-- | Representation-specific attributes; also means the rep supports
-- some basic facilities.
class
  ( RepTypes rep,
    PrettyRep rep,
    Renameable rep,
    Substitutable rep,
    FreeDec (ExpDec rep),
    FreeIn (LetDec rep),
    FreeDec (BodyDec rep),
    FreeIn (FParamInfo rep),
    FreeIn (LParamInfo rep),
    FreeIn (RetType rep),
    FreeIn (BranchType rep),
    IsOp (Op rep),
    RephraseOp (OpC rep)
  ) =>
  ASTRep rep
  where
  -- | Given a pattern, construct the type of a body that would match
  -- it.  An implementation for many representations would be
  -- 'expExtTypesFromPat'.
  expTypesFromPat ::
    (HasScope rep m, Monad m) =>
    Pat (LetDec rep) ->
    m [BranchType rep]

-- | Construct the type of an expression that would match the pattern.
expExtTypesFromPat :: Typed dec => Pat dec -> [ExtType]
expExtTypesFromPat :: forall dec. Typed dec => Pat dec -> [ExtType]
expExtTypesFromPat Pat dec
pat =
  [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (forall dec. Pat dec -> [VName]
patNames Pat dec
pat) forall a b. (a -> b) -> a -> b
$
    forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType forall a b. (a -> b) -> a -> b
$
        forall dec. Pat dec -> [PatElem dec]
patElems Pat dec
pat

-- | Keep only those attributes that are relevant for 'Assert'
-- expressions.
attrsForAssert :: Attrs -> Attrs
attrsForAssert :: Attrs -> Attrs
attrsForAssert (Attrs Set Attr
attrs) =
  Set Attr -> Attrs
Attrs forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> Set a -> Set a
S.filter Attr -> Bool
attrForAssert Set Attr
attrs
  where
    attrForAssert :: Attr -> Bool
attrForAssert = (forall a. Eq a => a -> a -> Bool
== Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"])

-- | Horizontally fission a lambda that models a binary operator.
lamIsBinOp :: ASTRep rep => Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp :: forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda rep
lam = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> Maybe (BinOp, PrimType, VName, VName)
splitStm forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  where
    n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda rep
lam
    splitStm :: SubExpRes -> Maybe (BinOp, PrimType, VName, VName)
splitStm (SubExpRes Certs
cs (Var VName
res)) = do
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty
      Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res] ==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat) forall a b. (a -> b) -> a -> b
$
          forall rep. Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
            forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$
              forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
      Int
i <- VName -> SubExp
Var VName
res forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
      Param (LParamInfo rep)
xp <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
      Param (LParamInfo rep)
yp <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
n forall a. Num a => a -> a -> a
+ Int
i) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
xp forall a. Eq a => a -> a -> Bool
== VName
x
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
yp forall a. Eq a => a -> a -> Bool
== VName
y
      Prim PrimType
t <- forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec rep)
pe
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (BinOp
op, PrimType
t, forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
xp, forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
yp)
    splitStm SubExpRes
_ = forall a. Maybe a
Nothing