{-# LANGUAGE FlexibleContexts #-} module Futhark.Analysis.ScalExp ( RelOp0(..) , ScalExp(..) , scalExpType , scalExpSize , subExpToScalExp , toScalExp , expandScalExp , LookupVar , module Futhark.Representation.Primitive ) where import Data.List import qualified Data.Set as S import Data.Maybe import Data.Monoid ((<>)) import Futhark.Representation.Primitive hiding (SQuot, SRem, SDiv, SMod, SSignum) import Futhark.Representation.AST hiding (SQuot, SRem, SDiv, SMod, SSignum) import qualified Futhark.Representation.AST as AST import Futhark.Transform.Substitute import Futhark.Transform.Rename import Futhark.Util.Pretty hiding (pretty) ----------------------------------------------------------------- -- BINARY OPERATORS for Numbers -- -- Note that MOD, BAND, XOR, BOR, SHIFTR, SHIFTL not supported -- -- `a SHIFTL/SHIFTR p' can be translated if desired as as -- -- `a * 2^p' or `a / 2^p -- ----------------------------------------------------------------- -- | Relational operators. data RelOp0 = LTH0 | LEQ0 deriving (Eq, Ord, Enum, Bounded, Show) -- | Representation of a scalar expression, which is: -- -- (i) an algebraic expression, e.g., min(a+b, a*b), -- -- (ii) a relational expression: a+b < 5, -- -- (iii) a logical expression: e1 and (not (a+b>5) data ScalExp= Val PrimValue | Id VName PrimType | SNeg ScalExp | SNot ScalExp | SAbs ScalExp | SSignum ScalExp | SPlus ScalExp ScalExp | SMinus ScalExp ScalExp | STimes ScalExp ScalExp | SPow ScalExp ScalExp | SDiv ScalExp ScalExp | SMod ScalExp ScalExp | SQuot ScalExp ScalExp | SRem ScalExp ScalExp | MaxMin Bool [ScalExp] | RelExp RelOp0 ScalExp | SLogAnd ScalExp ScalExp | SLogOr ScalExp ScalExp deriving (Eq, Ord, Show) instance Num ScalExp where 0 + y = y x + 0 = x x + y = SPlus x y x - 0 = x x - y = SMinus x y 0 * _ = 0 _ * 0 = 0 1 * y = y y * 1 = y x * y = STimes x y abs = SAbs signum = SSignum fromInteger = Val . IntValue . Int32Value . fromInteger -- probably not OK negate = SNeg instance Pretty ScalExp where pprPrec _ (Val val) = ppr val pprPrec _ (Id v _) = ppr v pprPrec _ (SNeg e) = text "-" <> pprPrec 9 e pprPrec _ (SNot e) = text "not" <+> pprPrec 9 e pprPrec _ (SAbs e) = text "abs" <+> pprPrec 9 e pprPrec _ (SSignum e) = text "signum" <+> pprPrec 9 e pprPrec prec (SPlus x y) = ppBinOp prec "+" 4 4 x y pprPrec prec (SMinus x y) = ppBinOp prec "-" 4 10 x y pprPrec prec (SPow x y) = ppBinOp prec "^" 6 6 x y pprPrec prec (STimes x y) = ppBinOp prec "*" 5 5 x y pprPrec prec (SDiv x y) = ppBinOp prec "/" 5 10 x y pprPrec prec (SMod x y) = ppBinOp prec "%" 5 10 x y pprPrec prec (SQuot x y) = ppBinOp prec "//" 5 10 x y pprPrec prec (SRem x y) = ppBinOp prec "%%" 5 10 x y pprPrec prec (SLogOr x y) = ppBinOp prec "||" 0 0 x y pprPrec prec (SLogAnd x y) = ppBinOp prec "&&" 1 1 x y pprPrec prec (RelExp LTH0 e) = ppBinOp prec "<" 2 2 e (0::Int) pprPrec prec (RelExp LEQ0 e) = ppBinOp prec "<=" 2 2 e (0::Int) pprPrec _ (MaxMin True es) = text "min" <> parens (commasep $ map ppr es) pprPrec _ (MaxMin False es) = text "max" <> parens (commasep $ map ppr es) ppBinOp :: (Pretty a, Pretty b) => Int -> String -> Int -> Int -> a -> b -> Doc ppBinOp p bop precedence rprecedence x y = parensIf (p > precedence) $ pprPrec precedence x <+/> text bop <+> pprPrec rprecedence y instance Substitute ScalExp where substituteNames subst e = case e of Id v t -> Id (substituteNames subst v) t Val v -> Val v SNeg x -> SNeg $ substituteNames subst x SNot x -> SNot $ substituteNames subst x SAbs x -> SAbs $ substituteNames subst x SSignum x -> SSignum $ substituteNames subst x SPlus x y -> substituteNames subst x `SPlus` substituteNames subst y SMinus x y -> substituteNames subst x `SMinus` substituteNames subst y SPow x y -> substituteNames subst x `SPow` substituteNames subst y STimes x y -> substituteNames subst x `STimes` substituteNames subst y SDiv x y -> substituteNames subst x `SDiv` substituteNames subst y SMod x y -> substituteNames subst x `SMod` substituteNames subst y SQuot x y -> substituteNames subst x `SDiv` substituteNames subst y SRem x y -> substituteNames subst x `SRem` substituteNames subst y MaxMin m es -> MaxMin m $ map (substituteNames subst) es RelExp r x -> RelExp r $ substituteNames subst x SLogAnd x y -> substituteNames subst x `SLogAnd` substituteNames subst y SLogOr x y -> substituteNames subst x `SLogOr` substituteNames subst y instance Rename ScalExp where rename = substituteRename scalExpType :: ScalExp -> PrimType scalExpType (Val v) = primValueType v scalExpType (Id _ t) = t scalExpType (SNeg e) = scalExpType e scalExpType (SNot _) = Bool scalExpType (SAbs e) = scalExpType e scalExpType (SSignum e) = scalExpType e scalExpType (SPlus e _) = scalExpType e scalExpType (SMinus e _) = scalExpType e scalExpType (STimes e _) = scalExpType e scalExpType (SDiv e _) = scalExpType e scalExpType (SMod e _) = scalExpType e scalExpType (SPow e _) = scalExpType e scalExpType (SQuot e _) = scalExpType e scalExpType (SRem e _) = scalExpType e scalExpType (SLogAnd _ _) = Bool scalExpType (SLogOr _ _) = Bool scalExpType (RelExp _ _) = Bool scalExpType (MaxMin _ []) = IntType Int32 -- arbitrary and probably wrong. scalExpType (MaxMin _ (e:_)) = scalExpType e -- | Number of nodes in the scalar expression. scalExpSize :: ScalExp -> Int scalExpSize Val{} = 1 scalExpSize Id{} = 1 scalExpSize (SNeg e) = scalExpSize e scalExpSize (SNot e) = scalExpSize e scalExpSize (SAbs e) = scalExpSize e scalExpSize (SSignum e) = scalExpSize e scalExpSize (SPlus x y) = scalExpSize x + scalExpSize y scalExpSize (SMinus x y) = scalExpSize x + scalExpSize y scalExpSize (STimes x y) = scalExpSize x + scalExpSize y scalExpSize (SDiv x y) = scalExpSize x + scalExpSize y scalExpSize (SMod x y) = scalExpSize x + scalExpSize y scalExpSize (SPow x y) = scalExpSize x + scalExpSize y scalExpSize (SQuot x y) = scalExpSize x + scalExpSize y scalExpSize (SRem x y) = scalExpSize x + scalExpSize y scalExpSize (SLogAnd x y) = scalExpSize x + scalExpSize y scalExpSize (SLogOr x y) = scalExpSize x + scalExpSize y scalExpSize (RelExp _ x) = scalExpSize x scalExpSize (MaxMin _ []) = 0 scalExpSize (MaxMin _ es) = sum $ map scalExpSize es -- | A function that checks whether a variable name corresponds to a -- scalar expression. type LookupVar = VName -> Maybe ScalExp -- | Non-recursively convert a subexpression to a 'ScalExp'. The -- (scalar) type of the subexpression must be given in advance. subExpToScalExp :: SubExp -> PrimType -> ScalExp subExpToScalExp (Var v) t = Id v t subExpToScalExp (Constant val) _ = Val val toScalExp :: (HasScope t f, Monad f) => LookupVar -> Exp lore -> f (Maybe ScalExp) toScalExp look (BasicOp (SubExp (Var v))) | Just se <- look v = return $ Just se | otherwise = do t <- lookupType v case t of Prim bt | typeIsOK bt -> return $ Just $ Id v bt _ -> return Nothing toScalExp _ (BasicOp (SubExp (Constant val))) | typeIsOK $ primValueType val = return $ Just $ Val val toScalExp look (BasicOp (CmpOp (CmpSlt _) x y)) = Just . RelExp LTH0 <$> (sminus <$> subExpToScalExp' look x <*> subExpToScalExp' look y) toScalExp look (BasicOp (CmpOp (CmpSle _) x y)) = Just . RelExp LEQ0 <$> (sminus <$> subExpToScalExp' look x <*> subExpToScalExp' look y) toScalExp look (BasicOp (CmpOp (CmpEq t) x y)) | typeIsOK t = do x' <- subExpToScalExp' look x y' <- subExpToScalExp' look y return $ Just $ case t of Bool -> SLogAnd x' y' `SLogOr` SLogAnd (SNot x') (SNot y') _ -> RelExp LEQ0 (x' `sminus` y') `SLogAnd` RelExp LEQ0 (y' `sminus` x') toScalExp look (BasicOp (BinOp (Sub t) (Constant x) y)) | typeIsOK $ IntType t, zeroIsh x = Just . SNeg <$> subExpToScalExp' look y toScalExp look (BasicOp (UnOp AST.Not e)) = Just . SNot <$> subExpToScalExp' look e toScalExp look (BasicOp (BinOp bop x y)) | Just f <- binOpScalExp bop = Just <$> (f <$> subExpToScalExp' look x <*> subExpToScalExp' look y) toScalExp _ _ = return Nothing typeIsOK :: PrimType -> Bool typeIsOK = (`elem` Bool : map IntType allIntTypes) subExpToScalExp' :: HasScope t f => LookupVar -> SubExp -> f ScalExp subExpToScalExp' look (Var v) | Just se <- look v = pure se | otherwise = withType <$> lookupType v where withType (Prim t) = subExpToScalExp (Var v) t withType t = error $ "Cannot create ScalExp from variable " ++ pretty v ++ " of type " ++ pretty t subExpToScalExp' _ (Constant val) = pure $ Val val -- | If you have a scalar expression that has been created with -- incomplete symbol table information, you can use this function to -- grow its 'Id' leaves. expandScalExp :: LookupVar -> ScalExp -> ScalExp expandScalExp _ (Val v) = Val v expandScalExp look (Id v t) = fromMaybe (Id v t) $ look v expandScalExp look (SNeg se) = SNeg $ expandScalExp look se expandScalExp look (SNot se) = SNot $ expandScalExp look se expandScalExp look (SAbs se) = SAbs $ expandScalExp look se expandScalExp look (SSignum se) = SSignum $ expandScalExp look se expandScalExp look (MaxMin b ses) = MaxMin b $ map (expandScalExp look) ses expandScalExp look (SPlus x y) = SPlus (expandScalExp look x) (expandScalExp look y) expandScalExp look (SMinus x y) = SMinus (expandScalExp look x) (expandScalExp look y) expandScalExp look (STimes x y) = STimes (expandScalExp look x) (expandScalExp look y) expandScalExp look (SDiv x y) = SDiv (expandScalExp look x) (expandScalExp look y) expandScalExp look (SMod x y) = SMod (expandScalExp look x) (expandScalExp look y) expandScalExp look (SQuot x y) = SQuot (expandScalExp look x) (expandScalExp look y) expandScalExp look (SRem x y) = SRem (expandScalExp look x) (expandScalExp look y) expandScalExp look (SPow x y) = SPow (expandScalExp look x) (expandScalExp look y) expandScalExp look (SLogAnd x y) = SLogAnd (expandScalExp look x) (expandScalExp look y) expandScalExp look (SLogOr x y) = SLogOr (expandScalExp look x) (expandScalExp look y) expandScalExp look (RelExp relop x) = RelExp relop $ expandScalExp look x -- | "Smart constructor" that checks whether we are subtracting zero, -- and if so just returns the first argument. sminus :: ScalExp -> ScalExp -> ScalExp sminus x (Val v) | zeroIsh v = x sminus x y = x `SMinus` y -- XXX: Only integers and booleans, OK? binOpScalExp :: BinOp -> Maybe (ScalExp -> ScalExp -> ScalExp) binOpScalExp bop = fmap snd . find ((==bop) . fst) $ concatMap intOps allIntTypes ++ [ (LogAnd, SLogAnd), (LogOr, SLogOr) ] where intOps t = [ (Add t, SPlus) , (Sub t, SMinus) , (Mul t, STimes) , (AST.SDiv t, SDiv) , (AST.Pow t, SPow) ] instance FreeIn ScalExp where freeIn (Val _) = mempty freeIn (Id i _) = S.singleton i freeIn (SNeg e) = freeIn e freeIn (SNot e) = freeIn e freeIn (SAbs e) = freeIn e freeIn (SSignum e) = freeIn e freeIn (SPlus x y) = freeIn x <> freeIn y freeIn (SMinus x y) = freeIn x <> freeIn y freeIn (SPow x y) = freeIn x <> freeIn y freeIn (STimes x y) = freeIn x <> freeIn y freeIn (SDiv x y) = freeIn x <> freeIn y freeIn (SMod x y) = freeIn x <> freeIn y freeIn (SQuot x y) = freeIn x <> freeIn y freeIn (SRem x y) = freeIn x <> freeIn y freeIn (SLogOr x y) = freeIn x <> freeIn y freeIn (SLogAnd x y) = freeIn x <> freeIn y freeIn (RelExp LTH0 e) = freeIn e freeIn (RelExp LEQ0 e) = freeIn e freeIn (MaxMin _ es) = mconcat $ map freeIn es