{-# LANGUAGE TypeFamilies, FlexibleContexts, FlexibleInstances, ConstraintKinds #-} -- | This module provides various simple ways to query and manipulate -- fundamental Futhark terms, such as types and values. The intent is to -- keep "Futhark.Reprsentation.AST.Syntax" simple, and put whatever -- embellishments we need here. This is an internal, desugared -- representation. module Futhark.Representation.AST.Attributes ( module Futhark.Representation.AST.Attributes.Reshape , module Futhark.Representation.AST.Attributes.Rearrange , module Futhark.Representation.AST.Attributes.Types , module Futhark.Representation.AST.Attributes.Constants , module Futhark.Representation.AST.Attributes.TypeOf , module Futhark.Representation.AST.Attributes.Patterns , module Futhark.Representation.AST.Attributes.Names , module Futhark.Representation.AST.RetType -- * Built-in functions , isBuiltInFunction , builtInFunctions -- * Extra tools , funDefByName , asBasicOp , safeExp , subExpVars , subExpVar , shapeVars , commutativeLambda , entryPointSize , defAux , stmCerts , certify , expExtTypesFromPattern , patternFromParams , IsOp (..) , Attributes (..) ) where import Data.List import Data.Maybe (mapMaybe, isJust) import Data.Monoid ((<>)) import qualified Data.Map.Strict as M import Futhark.Representation.AST.Attributes.Reshape import Futhark.Representation.AST.Attributes.Rearrange import Futhark.Representation.AST.Attributes.Types import Futhark.Representation.AST.Attributes.Constants import Futhark.Representation.AST.Attributes.Patterns import Futhark.Representation.AST.Attributes.Names import Futhark.Representation.AST.Attributes.TypeOf import Futhark.Representation.AST.RetType import Futhark.Representation.AST.Syntax import Futhark.Representation.AST.Pretty import Futhark.Transform.Rename (Rename, Renameable) import Futhark.Transform.Substitute (Substitute, Substitutable) import Futhark.Util.Pretty -- | @isBuiltInFunction k@ is 'True' if @k@ is an element of 'builtInFunctions'. isBuiltInFunction :: Name -> Bool isBuiltInFunction fnm = fnm `M.member` builtInFunctions -- | A map of all built-in functions and their types. builtInFunctions :: M.Map Name (PrimType,[PrimType]) builtInFunctions = M.fromList $ map namify $ M.toList primFuns where namify (k,(paramts,ret,_)) = (nameFromString k, (ret, paramts)) -- | Find the function of the given name in the Futhark program. funDefByName :: Name -> Prog lore -> Maybe (FunDef lore) funDefByName fname = find ((fname ==) . funDefName) . progFunctions -- | If the expression is a 'BasicOp', return that 'BasicOp', otherwise 'Nothing'. asBasicOp :: Exp lore -> Maybe (BasicOp lore) asBasicOp (BasicOp op) = Just op asBasicOp _ = Nothing -- | An expression is safe if it is always well-defined (assuming that -- any required certificates have been checked) in any context. For -- example, array indexing is not safe, as the index may be out of -- bounds. On the other hand, adding two numbers cannot fail. safeExp :: IsOp (Op lore) => Exp lore -> Bool safeExp (BasicOp op) = safeBasicOp op where safeBasicOp (BinOp SDiv{} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SDiv{} _ _) = False safeBasicOp (BinOp UDiv{} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp UDiv{} _ _) = False safeBasicOp (BinOp SMod{} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SMod{} _ _) = False safeBasicOp (BinOp UMod{} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp UMod{} _ _) = False safeBasicOp (BinOp SQuot{} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SQuot{} _ _) = False safeBasicOp (BinOp SRem{} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SRem{} _ _) = False safeBasicOp (BinOp Pow{} _ (Constant y)) = not $ negativeIsh y safeBasicOp (BinOp Pow{} _ _) = False safeBasicOp ArrayLit{} = True safeBasicOp BinOp{} = True safeBasicOp SubExp{} = True safeBasicOp UnOp{} = True safeBasicOp CmpOp{} = True safeBasicOp ConvOp{} = True safeBasicOp Scratch{} = True safeBasicOp Concat{} = True safeBasicOp Reshape{} = True safeBasicOp Manifest{} = True safeBasicOp Iota{} = True safeBasicOp Replicate{} = True safeBasicOp Copy{} = True safeBasicOp _ = False safeExp (DoLoop _ _ _ body) = safeBody body safeExp (Apply fname _ _ _) = isBuiltInFunction fname safeExp (If _ tbranch fbranch _) = all (safeExp . stmExp) (bodyStms tbranch) && all (safeExp . stmExp) (bodyStms fbranch) safeExp (Op op) = safeOp op safeBody :: IsOp (Op lore) => Body lore -> Bool safeBody = all (safeExp . stmExp) . bodyStms -- | Return the variable names used in 'Var' subexpressions. May contain -- duplicates. subExpVars :: [SubExp] -> [VName] subExpVars = mapMaybe subExpVar -- | If the 'SubExp' is a 'Var' return the variable name. subExpVar :: SubExp -> Maybe VName subExpVar (Var v) = Just v subExpVar Constant{} = Nothing -- | Return the variable dimension sizes. May contain -- duplicates. shapeVars :: Shape -> [VName] shapeVars = subExpVars . shapeDims -- | Does the given lambda represent a known commutative function? -- Based on pattern matching and checking whether the lambda -- represents a known arithmetic operator; don't expect anything -- clever here. commutativeLambda :: Lambda lore -> Bool commutativeLambda lam = let body = lambdaBody lam n2 = length (lambdaParams lam) `div` 2 (xps,yps) = splitAt n2 (lambdaParams lam) okComponent c = isJust $ find (okBinOp c) $ bodyStms body okBinOp (xp,yp,Var r) (Let (Pattern [] [pe]) _ (BasicOp (BinOp op (Var x) (Var y)))) = patElemName pe == r && commutativeBinOp op && ((x == paramName xp && y == paramName yp) || (y == paramName xp && x == paramName yp)) okBinOp _ _ = False in n2 * 2 == length (lambdaParams lam) && n2 == length (bodyResult body) && all okComponent (zip3 xps yps $ bodyResult body) -- | How many value parameters are accepted by this entry point? This -- is used to determine which of the function parameters correspond to -- the parameters of the original function (they must all come at the -- end). entryPointSize :: EntryPointType -> Int entryPointSize (TypeOpaque _ x) = x entryPointSize TypeUnsigned = 1 entryPointSize TypeDirect = 1 -- | A 'StmAux' with empty 'Certificates'. defAux :: attr -> StmAux attr defAux = StmAux mempty -- | The certificates associated with a statement. stmCerts :: Stm lore -> Certificates stmCerts = stmAuxCerts . stmAux -- | Add certificates to a statement. certify :: Certificates -> Stm lore -> Stm lore certify cs1 (Let pat (StmAux cs2 attr) e) = Let pat (StmAux (cs2<>cs1) attr) e -- | A type class for operations. class (Eq op, Ord op, Show op, TypedOp op, Rename op, Substitute op, FreeIn op, Pretty op) => IsOp op where -- | Like 'safeExp', but for arbitrary ops. safeOp :: op -> Bool -- | Should we try to hoist this out of branches? cheapOp :: op -> Bool instance IsOp () where safeOp () = True cheapOp () = True -- | Lore-specific attributes; also means the lore supports some basic -- facilities. class (Annotations lore, PrettyLore lore, Renameable lore, Substitutable lore, FreeAttr (ExpAttr lore), FreeIn (LetAttr lore), FreeAttr (BodyAttr lore), FreeIn (FParamAttr lore), FreeIn (LParamAttr lore), FreeIn (RetType lore), FreeIn (BranchType lore), IsOp (Op lore)) => Attributes lore where -- | Given a pattern, construct the type of a body that would match -- it. An implementation for many lores would be -- 'expExtTypesFromPattern'. expTypesFromPattern :: (HasScope lore m, Monad m) => Pattern lore -> m [BranchType lore] -- | Construct the type of an expression that would match the pattern. expExtTypesFromPattern :: Typed attr => PatternT attr -> [ExtType] expExtTypesFromPattern pat = existentialiseExtTypes (patternContextNames pat) $ staticShapes $ map patElemType $ patternValueElements pat -- | Create a pattern corresponding to some parameters. patternFromParams :: [Param attr] -> PatternT attr patternFromParams = Pattern [] . map toPatElem where toPatElem p = PatElem (paramName p) $ paramAttr p