-- | A primitive expression is an expression where the non-leaves are -- primitive operators. Our representation does not guarantee that -- the expression is type-correct. module Futhark.Analysis.PrimExp ( PrimExp (..) , evalPrimExp , primExpType , coerceIntPrimExp , true , false , constFoldPrimExp , module Futhark.Representation.Primitive , (.&&.), (.||.), (.<.), (.<=.), (.>.), (.>=.), (.==.), (.&.), (.|.), (.^.) ) where import Data.Foldable import Data.Traversable import qualified Data.Map as M import Futhark.Representation.AST.Attributes.Names import Futhark.Representation.Primitive import Futhark.Util.IntegralExp import Futhark.Util.Pretty -- | A primitive expression parametrised over the representation of -- free variables. Note that the 'Functor', 'Traversable', and 'Num' -- instances perform automatic (but simple) constant folding. data PrimExp v = LeafExp v PrimType | ValueExp PrimValue | BinOpExp BinOp (PrimExp v) (PrimExp v) | CmpOpExp CmpOp (PrimExp v) (PrimExp v) | UnOpExp UnOp (PrimExp v) | ConvOpExp ConvOp (PrimExp v) | FunExp String [PrimExp v] PrimType deriving (Ord, Show) -- The Eq instance upcoerces all integer constants to their largest -- type before comparing for equality. This is technically not a good -- idea, but solves annoying problems related to the Num instance -- always producing Int64s. instance Eq v => Eq (PrimExp v) where LeafExp x xt == LeafExp y yt = x == y && xt == yt ValueExp (IntValue x) == ValueExp (IntValue y) = intToInt64 x == intToInt64 y ValueExp x == ValueExp y = x == y BinOpExp xop x1 x2 == BinOpExp yop y1 y2 = xop == yop && x1 == y1 && x2 == y2 CmpOpExp xop x1 x2 == CmpOpExp yop y1 y2 = xop == yop && x1 == y1 && x2 == y2 UnOpExp xop x == UnOpExp yop y = xop == yop && x == y ConvOpExp xop x == ConvOpExp yop y = xop == yop && x == y FunExp xf xargs _ == FunExp yf yargs _ = xf == yf && xargs == yargs _ == _ = False instance Functor PrimExp where fmap = fmapDefault instance Foldable PrimExp where foldMap = foldMapDefault instance Traversable PrimExp where traverse f (LeafExp v t) = LeafExp <$> f v <*> pure t traverse _ (ValueExp v) = pure $ ValueExp v traverse f (BinOpExp op x y) = constFoldPrimExp <$> (BinOpExp op <$> traverse f x <*> traverse f y) traverse f (CmpOpExp op x y) = CmpOpExp op <$> traverse f x <*> traverse f y traverse f (ConvOpExp op x) = ConvOpExp op <$> traverse f x traverse f (UnOpExp op x) = UnOpExp op <$> traverse f x traverse f (FunExp h args t) = FunExp h <$> traverse (traverse f) args <*> pure t instance FreeIn v => FreeIn (PrimExp v) where freeIn = foldMap freeIn -- | Perform quick and dirty constant folding on the top level of a -- PrimExp. This is necessary because we want to consider -- e.g. equality modulo constant folding. constFoldPrimExp :: PrimExp v -> PrimExp v constFoldPrimExp (BinOpExp Add{} x y) | zeroIshExp x = y | zeroIshExp y = x constFoldPrimExp (BinOpExp Sub{} x y) | zeroIshExp y = x constFoldPrimExp (BinOpExp Mul{} x y) | oneIshExp x = y | oneIshExp y = x constFoldPrimExp (BinOpExp SDiv{} x y) | oneIshExp y = x constFoldPrimExp (BinOpExp SQuot{} x y) | oneIshExp y = x constFoldPrimExp (BinOpExp UDiv{} x y) | oneIshExp y = x constFoldPrimExp (BinOpExp bop (ValueExp x) (ValueExp y)) | Just z <- doBinOp bop x y = ValueExp z constFoldPrimExp e = e -- The Num instance performs a little bit of magic: whenever an -- expression and a constant is combined with a binary operator, the -- type of the constant may be changed to be the type of the -- expression, if they are not already the same. This permits us to -- write e.g. @x * 4@, where @x@ is an arbitrary PrimExp, and have the -- @4@ converted to the proper primitive type. We also support -- converting integers to floating point values, but not the other way -- around. All numeric instances assume unsigned integers for such -- conversions. -- -- We also perform simple constant folding, in particular to reduce -- expressions to constants so that the above works. However, it is -- still a bit of a hack. instance Pretty v => Num (PrimExp v) where x + y | Just z <- msum [asIntOp Add x y, asFloatOp FAdd x y] = constFoldPrimExp z | otherwise = numBad "+" (x,y) x - y | Just z <- msum [asIntOp Sub x y, asFloatOp FSub x y] = constFoldPrimExp z | otherwise = numBad "-" (x,y) x * y | Just z <- msum [asIntOp Mul x y, asFloatOp FMul x y] = constFoldPrimExp z | otherwise = numBad "*" (x,y) abs x | IntType t <- primExpType x = UnOpExp (Abs t) x | FloatType t <- primExpType x = UnOpExp (FAbs t) x | otherwise = numBad "abs" x signum x | IntType t <- primExpType x = UnOpExp (SSignum t) x | otherwise = numBad "signum" x fromInteger = fromInt32 . fromInteger instance Pretty v => IntegralExp (PrimExp v) where x `div` y | Just z <- msum [asIntOp SDiv x y, asFloatOp FDiv x y] = constFoldPrimExp z | otherwise = numBad "div" (x,y) x `mod` y | Just z <- msum [asIntOp SMod x y] = z | otherwise = numBad "mod" (x,y) x `quot` y | oneIshExp y = x | Just z <- msum [asIntOp SQuot x y] = constFoldPrimExp z | otherwise = numBad "quot" (x,y) x `rem` y | Just z <- msum [asIntOp SRem x y] = constFoldPrimExp z | otherwise = numBad "rem" (x,y) sgn (ValueExp (IntValue i)) = Just $ signum $ valueIntegral i sgn _ = Nothing fromInt8 = ValueExp . IntValue . Int8Value fromInt16 = ValueExp . IntValue . Int16Value fromInt32 = ValueExp . IntValue . Int32Value fromInt64 = ValueExp . IntValue . Int64Value -- | Lifted logical conjunction. (.&&.) :: PrimExp v -> PrimExp v -> PrimExp v x .&&. y = BinOpExp LogAnd x y -- | Lifted logical conjunction. (.||.) :: PrimExp v -> PrimExp v -> PrimExp v x .||. y = BinOpExp LogOr x y -- | Lifted relational operators; assuming signed numbers in case of -- integers. (.<.), (.>.), (.<=.), (.>=.), (.==.) :: PrimExp v -> PrimExp v -> PrimExp v x .<. y = CmpOpExp cmp x y where cmp = case primExpType x of IntType t -> CmpSlt $ t `min` primExpIntType y FloatType t -> FCmpLt t _ -> CmpLlt x .<=. y = CmpOpExp cmp x y where cmp = case primExpType x of IntType t -> CmpSle $ t `min` primExpIntType y FloatType t -> FCmpLe t _ -> CmpLle x .==. y = CmpOpExp (CmpEq $ primExpType x `min` primExpType y) x y x .>. y = y .<. x x .>=. y = y .<=. x -- | Lifted bitwise operators. (.&.), (.|.), (.^.) :: PrimExp v -> PrimExp v -> PrimExp v x .&. y = BinOpExp (And $ primExpIntType x `min` primExpIntType y) x y x .|. y = BinOpExp (Or $ primExpIntType x `min` primExpIntType y) x y x .^. y = BinOpExp (Xor $ primExpIntType x `min` primExpIntType y) x y infix 4 .==., .<., .>., .<=., .>=. infixr 3 .&&. infixr 2 .||. asIntOp :: (IntType -> BinOp) -> PrimExp v -> PrimExp v -> Maybe (PrimExp v) asIntOp f x y | IntType t <- primExpType x, Just y' <- asIntExp t y = Just $ BinOpExp (f t) x y' | IntType t <- primExpType y, Just x' <- asIntExp t x = Just $ BinOpExp (f t) x' y | otherwise = Nothing asIntExp :: IntType -> PrimExp v -> Maybe (PrimExp v) asIntExp t e | primExpType e == IntType t = Just e asIntExp t (ValueExp (IntValue v)) = Just $ ValueExp $ IntValue $ doSExt v t asIntExp _ _ = Nothing asFloatOp :: (FloatType -> BinOp) -> PrimExp v -> PrimExp v -> Maybe (PrimExp v) asFloatOp f x y | FloatType t <- primExpType x, Just y' <- asFloatExp t y = Just $ BinOpExp (f t) x y' | FloatType t <- primExpType y, Just x' <- asFloatExp t x = Just $ BinOpExp (f t) x' y | otherwise = Nothing asFloatExp :: FloatType -> PrimExp v -> Maybe (PrimExp v) asFloatExp t e | primExpType e == FloatType t = Just e asFloatExp t (ValueExp (FloatValue v)) = Just $ ValueExp $ FloatValue $ doFPConv v t asFloatExp t (ValueExp (IntValue v)) = Just $ ValueExp $ FloatValue $ doSIToFP v t asFloatExp _ _ = Nothing numBad :: Pretty a => String -> a -> b numBad s x = error $ "Invalid argument to PrimExp method " ++ s ++ ": " ++ pretty x -- | Evaluate a 'PrimExp' in the given monad. Invokes 'fail' on type -- errors. evalPrimExp :: (Pretty v, Monad m) => (v -> m PrimValue) -> PrimExp v -> m PrimValue evalPrimExp f (LeafExp v _) = f v evalPrimExp _ (ValueExp v) = return v evalPrimExp f (BinOpExp op x y) = do x' <- evalPrimExp f x y' <- evalPrimExp f y maybe (evalBad op (x,y)) return $ doBinOp op x' y' evalPrimExp f (CmpOpExp op x y) = do x' <- evalPrimExp f x y' <- evalPrimExp f y maybe (evalBad op (x,y)) (return . BoolValue) $ doCmpOp op x' y' evalPrimExp f (UnOpExp op x) = do x' <- evalPrimExp f x maybe (evalBad op x) return $ doUnOp op x' evalPrimExp f (ConvOpExp op x) = do x' <- evalPrimExp f x maybe (evalBad op x) return $ doConvOp op x' evalPrimExp f (FunExp h args _) = do args' <- mapM (evalPrimExp f) args maybe (evalBad h args) return $ do (_, _, fun) <- M.lookup h primFuns fun args' evalBad :: (Pretty a, Pretty b, Monad m) => a -> b -> m c evalBad op arg = fail $ "evalPrimExp: Type error when applying " ++ pretty op ++ " to " ++ pretty arg -- | The type of values returned by a 'PrimExp'. This function -- returning does not imply that the 'PrimExp' is type-correct. primExpType :: PrimExp v -> PrimType primExpType (LeafExp _ t) = t primExpType (ValueExp v) = primValueType v primExpType (BinOpExp op _ _) = binOpType op primExpType CmpOpExp{} = Bool primExpType (UnOpExp op _) = unOpType op primExpType (ConvOpExp op _) = snd $ convOpType op primExpType (FunExp _ _ t) = t -- | Is the expression a constant zero of some sort? zeroIshExp :: PrimExp v -> Bool zeroIshExp (ValueExp v) = zeroIsh v zeroIshExp _ = False -- | Is the expression a constant one of some sort? oneIshExp :: PrimExp v -> Bool oneIshExp (ValueExp v) = oneIsh v oneIshExp _ = False -- | If the given 'PrimExp' is a constant of the wrong integer type, -- coerce it to the given integer type. This is a workaround for an -- issue in the 'Num' instance. coerceIntPrimExp :: IntType -> PrimExp v -> PrimExp v coerceIntPrimExp t (ValueExp (IntValue v)) = ValueExp $ IntValue $ doSExt v t coerceIntPrimExp _ e = e primExpIntType :: PrimExp v -> IntType primExpIntType e = case primExpType e of IntType t -> t _ -> Int64 -- | Boolean-valued PrimExps. true, false :: PrimExp v true = ValueExp $ BoolValue True false = ValueExp $ BoolValue False -- Prettyprinting instances instance Pretty v => Pretty (PrimExp v) where ppr (LeafExp v _) = ppr v ppr (ValueExp v) = ppr v ppr (BinOpExp op x y) = ppr op <+> parens (ppr x) <+> parens (ppr y) ppr (CmpOpExp op x y) = ppr op <+> parens (ppr x) <+> parens (ppr y) ppr (ConvOpExp op x) = ppr op <+> parens (ppr x) ppr (UnOpExp op x) = ppr op <+> parens (ppr x) ppr (FunExp h args _) = text h <+> parens (commasep $ map ppr args)