{-# LANGUAGE TypeFamilies, FlexibleContexts, FlexibleInstances, ConstraintKinds #-}
-- | This module provides various simple ways to query and manipulate
-- fundamental Futhark terms, such as types and values.  The intent is to
-- keep "Futhark.Reprsentation.AST.Syntax" simple, and put whatever
-- embellishments we need here.  This is an internal, desugared
-- representation.
module Futhark.Representation.AST.Attributes
  ( module Futhark.Representation.AST.Attributes.Reshape
  , module Futhark.Representation.AST.Attributes.Rearrange
  , module Futhark.Representation.AST.Attributes.Types
  , module Futhark.Representation.AST.Attributes.Constants
  , module Futhark.Representation.AST.Attributes.TypeOf
  , module Futhark.Representation.AST.Attributes.Patterns
  , module Futhark.Representation.AST.Attributes.Names
  , module Futhark.Representation.AST.RetType

  -- * Built-in functions
  , isBuiltInFunction
  , builtInFunctions

  -- * Extra tools
  , asBasicOp
  , safeExp
  , subExpVars
  , subExpVar
  , shapeVars
  , commutativeLambda
  , entryPointSize
  , defAux
  , stmCerts
  , certify
  , expExtTypesFromPattern

  , ASTConstraints
  , IsOp (..)
  , Attributes (..)
  )
  where

import Data.List (find)
import Data.Maybe (mapMaybe, isJust)
import qualified Data.Map.Strict as M

import Futhark.Representation.AST.Attributes.Reshape
import Futhark.Representation.AST.Attributes.Rearrange
import Futhark.Representation.AST.Attributes.Types
import Futhark.Representation.AST.Attributes.Constants
import Futhark.Representation.AST.Attributes.Patterns
import Futhark.Representation.AST.Attributes.Names
import Futhark.Representation.AST.Attributes.TypeOf
import Futhark.Representation.AST.RetType
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Pretty
import Futhark.Transform.Rename (Rename, Renameable)
import Futhark.Transform.Substitute (Substitute, Substitutable)
import Futhark.Util.Pretty

-- | @isBuiltInFunction k@ is 'True' if @k@ is an element of 'builtInFunctions'.
isBuiltInFunction :: Name -> Bool
isBuiltInFunction :: Name -> Bool
isBuiltInFunction Name
fnm = Name
fnm Name -> Map Name (PrimType, [PrimType]) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Name (PrimType, [PrimType])
builtInFunctions

-- | A map of all built-in functions and their types.
builtInFunctions :: M.Map Name (PrimType,[PrimType])
builtInFunctions :: Map Name (PrimType, [PrimType])
builtInFunctions = [(Name, (PrimType, [PrimType]))] -> Map Name (PrimType, [PrimType])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, (PrimType, [PrimType]))]
 -> Map Name (PrimType, [PrimType]))
-> [(Name, (PrimType, [PrimType]))]
-> Map Name (PrimType, [PrimType])
forall a b. (a -> b) -> a -> b
$ ((String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
 -> (Name, (PrimType, [PrimType])))
-> [(String,
     ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))]
forall a b. (a -> b) -> [a] -> [b]
map (String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> (Name, (PrimType, [PrimType]))
forall b a c. (String, (b, a, c)) -> (Name, (a, b))
namify ([(String, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
 -> [(Name, (PrimType, [PrimType]))])
-> [(String,
     ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Name, (PrimType, [PrimType]))]
forall a b. (a -> b) -> a -> b
$ Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> [(String,
     ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
forall k a. Map k a -> [(k, a)]
M.toList Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns
  where namify :: (String, (b, a, c)) -> (Name, (a, b))
namify (String
k,(b
paramts,a
ret,c
_)) = (String -> Name
nameFromString String
k, (a
ret, b
paramts))

-- | If the expression is a 'BasicOp', return that 'BasicOp', otherwise 'Nothing'.
asBasicOp :: Exp lore -> Maybe BasicOp
asBasicOp :: Exp lore -> Maybe BasicOp
asBasicOp (BasicOp BasicOp
op) = BasicOp -> Maybe BasicOp
forall a. a -> Maybe a
Just BasicOp
op
asBasicOp Exp lore
_            = Maybe BasicOp
forall a. Maybe a
Nothing

-- | An expression is safe if it is always well-defined (assuming that
-- any required certificates have been checked) in any context.  For
-- example, array indexing is not safe, as the index may be out of
-- bounds.  On the other hand, adding two numbers cannot fail.
safeExp :: IsOp (Op lore) => Exp lore -> Bool
safeExp :: Exp lore -> Bool
safeExp (BasicOp BasicOp
op) = BasicOp -> Bool
safeBasicOp BasicOp
op
  where safeBasicOp :: BasicOp -> Bool
safeBasicOp (BinOp SDiv{} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
        safeBasicOp (BinOp SDiv{} SubExp
_ SubExp
_) = Bool
False
        safeBasicOp (BinOp UDiv{} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
        safeBasicOp (BinOp UDiv{} SubExp
_ SubExp
_) = Bool
False
        safeBasicOp (BinOp SMod{} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
        safeBasicOp (BinOp SMod{} SubExp
_ SubExp
_) = Bool
False
        safeBasicOp (BinOp UMod{} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
        safeBasicOp (BinOp UMod{} SubExp
_ SubExp
_) = Bool
False

        safeBasicOp (BinOp SQuot{} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
        safeBasicOp (BinOp SQuot{} SubExp
_ SubExp
_) = Bool
False
        safeBasicOp (BinOp SRem{} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
y
        safeBasicOp (BinOp SRem{} SubExp
_ SubExp
_) = Bool
False

        safeBasicOp (BinOp Pow{} SubExp
_ (Constant PrimValue
y)) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
y
        safeBasicOp (BinOp Pow{} SubExp
_ SubExp
_) = Bool
False
        safeBasicOp ArrayLit{} = Bool
True
        safeBasicOp BinOp{} = Bool
True
        safeBasicOp SubExp{} = Bool
True
        safeBasicOp UnOp{} = Bool
True
        safeBasicOp CmpOp{} = Bool
True
        safeBasicOp ConvOp{} = Bool
True
        safeBasicOp Scratch{} = Bool
True
        safeBasicOp Concat{} = Bool
True
        safeBasicOp Reshape{} = Bool
True
        safeBasicOp Rearrange{} = Bool
True
        safeBasicOp Manifest{} = Bool
True
        safeBasicOp Iota{} = Bool
True
        safeBasicOp Replicate{} = Bool
True
        safeBasicOp Copy{} = Bool
True
        safeBasicOp BasicOp
_ = Bool
False

safeExp (DoLoop [(FParam lore, SubExp)]
_ [(FParam lore, SubExp)]
_ LoopForm lore
_ BodyT lore
body) = BodyT lore -> Bool
forall lore. IsOp (Op lore) => Body lore -> Bool
safeBody BodyT lore
body
safeExp (Apply Name
fname [(SubExp, Diet)]
_ [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) =
  Name -> Bool
isBuiltInFunction Name
fname
safeExp (If SubExp
_ BodyT lore
tbranch BodyT lore
fbranch IfAttr (BranchType lore)
_) =
  (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch) Bool -> Bool -> Bool
&&
  (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fbranch)
safeExp (Op Op lore
op) = Op lore -> Bool
forall op. IsOp op => op -> Bool
safeOp Op lore
op

safeBody :: IsOp (Op lore) => Body lore -> Bool
safeBody :: Body lore -> Bool
safeBody = (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) (Seq (Stm lore) -> Bool)
-> (Body lore -> Seq (Stm lore)) -> Body lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms

-- | Return the variable names used in 'Var' subexpressions.  May contain
-- duplicates.
subExpVars :: [SubExp] -> [VName]
subExpVars :: [SubExp] -> [VName]
subExpVars = (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
subExpVar

-- | If the 'SubExp' is a 'Var' return the variable name.
subExpVar :: SubExp -> Maybe VName
subExpVar :: SubExp -> Maybe VName
subExpVar (Var VName
v)    = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
subExpVar Constant{} = Maybe VName
forall a. Maybe a
Nothing

-- | Return the variable dimension sizes.  May contain
-- duplicates.
shapeVars :: Shape -> [VName]
shapeVars :: Shape -> [VName]
shapeVars = [SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> (Shape -> [SubExp]) -> Shape -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims

-- | Does the given lambda represent a known commutative function?
-- Based on pattern matching and checking whether the lambda
-- represents a known arithmetic operator; don't expect anything
-- clever here.
commutativeLambda :: Lambda lore -> Bool
commutativeLambda :: Lambda lore -> Bool
commutativeLambda Lambda lore
lam =
  let body :: BodyT lore
body = Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
      n2 :: Int
n2 = [Param (LParamAttr lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
      ([Param (LParamAttr lore)]
xps,[Param (LParamAttr lore)]
yps) = Int
-> [Param (LParamAttr lore)]
-> ([Param (LParamAttr lore)], [Param (LParamAttr lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n2 (Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam)

      okComponent :: (Param (LParamAttr lore), Param (LParamAttr lore), SubExp) -> Bool
okComponent (Param (LParamAttr lore), Param (LParamAttr lore), SubExp)
c = Maybe (Stm lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Stm lore) -> Bool) -> Maybe (Stm lore) -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Seq (Stm lore) -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Param (LParamAttr lore), Param (LParamAttr lore), SubExp)
-> Stm lore -> Bool
forall attr attr lore.
(Param attr, Param attr, SubExp) -> Stm lore -> Bool
okBinOp (Param (LParamAttr lore), Param (LParamAttr lore), SubExp)
c) (Seq (Stm lore) -> Maybe (Stm lore))
-> Seq (Stm lore) -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
body
      okBinOp :: (Param attr, Param attr, SubExp) -> Stm lore -> Bool
okBinOp (Param attr
xp,Param attr
yp,Var VName
r) (Let (Pattern [] [PatElemT (LetAttr lore)
pe]) StmAux (ExpAttr lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y)))) =
        PatElemT (LetAttr lore) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
pe VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
r Bool -> Bool -> Bool
&&
        BinOp -> Bool
commutativeBinOp BinOp
op Bool -> Bool -> Bool
&&
        ((VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
xp Bool -> Bool -> Bool
&& VName
y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
yp) Bool -> Bool -> Bool
||
         (VName
y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
xp Bool -> Bool -> Bool
&& VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
yp))
      okBinOp (Param attr, Param attr, SubExp)
_ Stm lore
_ = Bool
False

  in Int
n2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Param (LParamAttr lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) Bool -> Bool -> Bool
&&
     Int
n2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body) Bool -> Bool -> Bool
&&
     ((Param (LParamAttr lore), Param (LParamAttr lore), SubExp)
 -> Bool)
-> [(Param (LParamAttr lore), Param (LParamAttr lore), SubExp)]
-> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (LParamAttr lore), Param (LParamAttr lore), SubExp) -> Bool
okComponent ([Param (LParamAttr lore)]
-> [Param (LParamAttr lore)]
-> [SubExp]
-> [(Param (LParamAttr lore), Param (LParamAttr lore), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param (LParamAttr lore)]
xps [Param (LParamAttr lore)]
yps ([SubExp]
 -> [(Param (LParamAttr lore), Param (LParamAttr lore), SubExp)])
-> [SubExp]
-> [(Param (LParamAttr lore), Param (LParamAttr lore), SubExp)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body)

-- | How many value parameters are accepted by this entry point?  This
-- is used to determine which of the function parameters correspond to
-- the parameters of the original function (they must all come at the
-- end).
entryPointSize :: EntryPointType -> Int
entryPointSize :: EntryPointType -> Int
entryPointSize (TypeOpaque String
_ Int
x) = Int
x
entryPointSize EntryPointType
TypeUnsigned = Int
1
entryPointSize EntryPointType
TypeDirect = Int
1

-- | A 'StmAux' with empty 'Certificates'.
defAux :: attr -> StmAux attr
defAux :: attr -> StmAux attr
defAux = Certificates -> attr -> StmAux attr
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
forall a. Monoid a => a
mempty

-- | The certificates associated with a statement.
stmCerts :: Stm lore -> Certificates
stmCerts :: Stm lore -> Certificates
stmCerts = StmAux (ExpAttr lore) -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts (StmAux (ExpAttr lore) -> Certificates)
-> (Stm lore -> StmAux (ExpAttr lore)) -> Stm lore -> Certificates
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> StmAux (ExpAttr lore)
forall lore. Stm lore -> StmAux (ExpAttr lore)
stmAux

-- | Add certificates to a statement.
certify :: Certificates -> Stm lore -> Stm lore
certify :: Certificates -> Stm lore -> Stm lore
certify Certificates
cs1 (Let Pattern lore
pat (StmAux Certificates
cs2 ExpAttr lore
attr) Exp lore
e) = Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern lore
pat (Certificates -> ExpAttr lore -> StmAux (ExpAttr lore)
forall attr. Certificates -> attr -> StmAux attr
StmAux (Certificates
cs2Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
cs1) ExpAttr lore
attr) Exp lore
e

-- | A handy shorthand for properties that we usually want to things
-- we stuff into ASTs.
type ASTConstraints a =
  (Eq a, Ord a, Show a, Rename a, Substitute a, FreeIn a, Pretty a)

-- | A type class for operations.
class (ASTConstraints op, TypedOp op) => IsOp op where
  -- | Like 'safeExp', but for arbitrary ops.
  safeOp :: op -> Bool
  -- | Should we try to hoist this out of branches?
  cheapOp :: op -> Bool

instance IsOp () where
  safeOp :: () -> Bool
safeOp () = Bool
True
  cheapOp :: () -> Bool
cheapOp () = Bool
True

-- | Lore-specific attributes; also means the lore supports some basic
-- facilities.
class (Annotations lore,

       PrettyLore lore,

       Renameable lore, Substitutable lore,
       FreeAttr (ExpAttr lore),
       FreeIn (LetAttr lore),
       FreeAttr (BodyAttr lore),
       FreeIn (FParamAttr lore),
       FreeIn (LParamAttr lore),
       FreeIn (RetType lore),
       FreeIn (BranchType lore),

       IsOp (Op lore)) => Attributes lore where
  -- | Given a pattern, construct the type of a body that would match
  -- it.  An implementation for many lores would be
  -- 'expExtTypesFromPattern'.
  expTypesFromPattern :: (HasScope lore m, Monad m) =>
                         Pattern lore -> m [BranchType lore]

-- | Construct the type of an expression that would match the pattern.
expExtTypesFromPattern :: Typed attr => PatternT attr -> [ExtType]
expExtTypesFromPattern :: PatternT attr -> [ExtType]
expExtTypesFromPattern PatternT attr
pat =
  [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (PatternT attr -> [VName]
forall attr. PatternT attr -> [VName]
patternContextNames PatternT attr
pat) ([ExtType] -> [ExtType]) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> a -> b
$
  [TypeBase Shape NoUniqueness] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([TypeBase Shape NoUniqueness] -> [ExtType])
-> [TypeBase Shape NoUniqueness] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ (PatElemT attr -> TypeBase Shape NoUniqueness)
-> [PatElemT attr] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT attr -> TypeBase Shape NoUniqueness
forall attr.
Typed attr =>
PatElemT attr -> TypeBase Shape NoUniqueness
patElemType ([PatElemT attr] -> [TypeBase Shape NoUniqueness])
-> [PatElemT attr] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PatternT attr -> [PatElemT attr]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT attr
pat