{-# language GADTs #-} {-# language TypeOperators #-} {-# language PatternSynonyms #-} {-# language ViewPatterns #-} {-# language FlexibleContexts #-} {-# language ScopedTypeVariables #-} module Feldspar.Hardware.Optimize where import Feldspar.Representation import Feldspar.Hardware.Primitive import Feldspar.Hardware.Expression import Data.Struct import Control.Monad.Writer hiding (Any (..)) import Data.Maybe import Data.Constraint (Dict (..)) import Data.Set (Set) import qualified Data.Monoid as Monoid import qualified Data.Set as Set -- syntactic. import Language.Syntactic import Language.Syntactic.Functional import Language.Syntactic.Functional.Tuple import Language.Syntactic.Functional.Sharing -------------------------------------------------------------------------------- -- * Optimize hardware expressions. -------------------------------------------------------------------------------- viewLit :: ASTF HardwareDomain a -> Maybe a viewLit lit | Just (Lit a) <- prj lit = Just a viewLit _ = Nothing witInteger :: ASTF HardwareDomain a -> Maybe (Dict (Integral a, Ord a)) witInteger a = case getDecor a of ValT (Node Int8HT) -> Just Dict ValT (Node Int16HT) -> Just Dict ValT (Node Int32HT) -> Just Dict ValT (Node Int64HT) -> Just Dict ValT (Node Word8HT) -> Just Dict ValT (Node Word16HT) -> Just Dict ValT (Node Word32HT) -> Just Dict ValT (Node Word64HT) -> Just Dict _ -> Nothing isExact :: ASTF HardwareDomain a -> Bool isExact = isJust . witInteger -- | projection with a stronger constraint to allow using it in -- bidirectional patterns. prjBi :: (sub :<: sup) => sup sig -> Maybe (sub sig) prjBi = prj -------------------------------------------------------------------------------- pattern SymP t s <- Sym ((prjBi -> Just s) :&: ValT t) where SymP t s = Sym ((inj s) :&: ValT t) pattern VarP t v <- Sym ((prjBi -> Just (VarT v)) :&: t) where VarP t v = Sym (inj (VarT v) :&: t) pattern LamP t v body <- Sym ((prjBi -> Just (LamT v)) :&: t) :$ body where LamP t v body = Sym (inj (LamT v) :&: t) :$ body -------------------------------------------------------------------------------- pattern LitP :: () => (Eq a, Show a) => HTypeRep a -> a -> ASTF HardwareDomain a pattern AddP :: () => (Num a, HardwarePrimType a) => HTypeRep a -> ASTF HardwareDomain a -> ASTF HardwareDomain a -> ASTF HardwareDomain a pattern SubP :: () => (Num a, HardwarePrimType a) => HTypeRep a -> ASTF HardwareDomain a -> ASTF HardwareDomain a -> ASTF HardwareDomain a pattern MulP :: () => (Num a, HardwarePrimType a) => HTypeRep a -> ASTF HardwareDomain a -> ASTF HardwareDomain a -> ASTF HardwareDomain a pattern NegP :: () => (Num a, HardwarePrimType a) => HTypeRep a -> ASTF HardwareDomain a -> ASTF HardwareDomain a pattern DivP :: () => (Integral a, HardwarePrimType a) => HTypeRep a -> ASTF HardwareDomain a -> ASTF HardwareDomain a -> ASTF HardwareDomain a pattern ModP :: () => (Integral a, HardwarePrimType a) => HTypeRep a -> ASTF HardwareDomain a -> ASTF HardwareDomain a -> ASTF HardwareDomain a -------------------------------------------------------------------------------- pattern NonLitP <- (viewLit -> Nothing) pattern LitP t a <- Sym ((prj -> Just (Lit a)) :&: ValT t) where LitP t a = Sym (inj (Lit a) :&: ValT t) pattern AddP t a b <- SymP t Add :$ a :$ b where AddP t a b = simplifyUp $ SymP t Add :$ a :$ b pattern SubP t a b <- SymP t Sub :$ a :$ b where SubP t a b = simplifyUp $ SymP t Sub :$ a :$ b pattern MulP t a b <- SymP t Mul :$ a :$ b where MulP t a b = simplifyUp $ SymP t Mul :$ a :$ b pattern NegP t a <- SymP t Neg :$ a where NegP t a = simplifyUp $ SymP t Neg :$ a pattern DivP t a b <- SymP t Div :$ a :$ b where DivP t a b = simplifyUp $ SymP t Div :$ a :$ b pattern ModP t a b <- SymP t Mod :$ a :$ b where ModP t a b = simplifyUp $ SymP t Mod :$ a :$ b -------------------------------------------------------------------------------- simplifyUp :: ASTF HardwareDomain a -> ASTF HardwareDomain a -- Addition with zero. simplifyUp (AddP t (LitP _ 0) b) | isExact b = b simplifyUp (AddP t a (LitP _ 0)) | isExact a = a -- Simplify additions with literals. simplifyUp (AddP t (AddP _ a (LitP _ b)) (LitP _ c)) | isExact a = AddP t a (LitP t (b+c)) simplifyUp (AddP t (SubP _ a (LitP _ b)) (LitP _ c)) | isExact a = AddP t a (LitP t (c-b)) simplifyUp (AddP t a (LitP _ b)) | Just Dict <- witInteger a, b < 0 = SubP t a (LitP t (negate b)) -- Subtraction with zero. simplifyUp (SubP t (LitP _ 0) b) | isExact b = NegP t b simplifyUp (SubP t a (LitP _ 0)) | isExact a = a -- Simplify subtractions with literals. simplifyUp (SubP t (AddP _ a (LitP _ b)) (LitP _ c)) | isExact a = AddP t a (LitP t (b-c)) simplifyUp (SubP t (SubP _ a (LitP _ b)) (LitP _ c)) | isExact a = SubP t a (LitP t (b+c)) simplifyUp (SubP t a (LitP _ b)) | Just Dict <- witInteger a, b < 0 = AddP t a (LitP t (negate b)) -- Multiplication with zero. simplifyUp (MulP t (LitP _ 0) b) | isExact b = LitP t 0 simplifyUp (MulP t a (LitP _ 0)) | isExact a = LitP t 0 -- Multiplication with one. simplifyUp (MulP t (LitP _ 1) b) | isExact b = b simplifyUp (MulP t a (LitP _ 1)) | isExact a = a -- Simplify multiplications with literals. simplifyUp (MulP t (MulP _ a (LitP _ b)) (LitP _ c)) | isExact a = MulP t a (LitP t (b*c)) -- Simplify negations. simplifyUp (NegP t (NegP _ a)) | isExact a = a simplifyUp (NegP t (AddP _ a b)) | isExact a = SubP t (NegP t a) b simplifyUp (NegP t (SubP _ a b)) | isExact a = SubP t b a simplifyUp (NegP t (MulP _ a b)) | isExact a = MulP t a (NegP t b) -- Move literals to the right. simplifyUp (AddP t a@(LitP _ _) b@NonLitP) | isExact a = AddP t b a simplifyUp (SubP t a@(LitP _ _) b@NonLitP) | isExact a = AddP t (NegP t b) a simplifyUp (MulP t a@(LitP _ _) b@NonLitP) | isExact a = MulP t b a -- Simplify not-expressions. -- Not sure (yet) why I can't write `simplifyUp (NotP _ (NotP _ a)) = a` simplifyUp (SymP _ Not :$ (SymP _ Not :$ a)) = a simplifyUp (SymP t Not :$ (SymP _ Lt :$ a :$ b)) = simplifyUp $ SymP t Gte :$ a :$ b simplifyUp (SymP t Not :$ (SymP _ Gt :$ a :$ b)) = simplifyUp $ SymP t Lte :$ a :$ b simplifyUp (SymP t Not :$ (SymP _ Lte :$ a :$ b)) = simplifyUp $ SymP t Gt :$ a :$ b simplifyUp (SymP t Not :$ (SymP _ Gte :$ a :$ b)) = simplifyUp $ SymP t Lt :$ a :$ b -- Simplify and-expressions. simplifyUp (SymP _ And :$ LitP t False :$ _) = LitP t False simplifyUp (SymP _ And :$ _ :$ LitP t False) = LitP t False simplifyUp (SymP _ And :$ LitP t True :$ b) = b simplifyUp (SymP _ And :$ a :$ LitP t True) = a -- Simplify or-expressions. simplifyUp (SymP _ Or :$ LitP t False :$ b) = b simplifyUp (SymP _ Or :$ a :$ LitP t False) = a simplifyUp (SymP _ Or :$ LitP t True :$ _) = LitP t True simplifyUp (SymP _ Or :$ _ :$ LitP t True) = LitP t True -- Simplify conditional expressions. simplifyUp (SymP _ Cond :$ LitP _ True :$ t :$ _) = t simplifyUp (SymP _ Cond :$ LitP _ False :$ _ :$ f) = f simplifyUp (SymP _ Cond :$ c :$ t :$ f) | equal t f = t -- Simplify pairs. simplifyUp (SymP t Pair :$ (SymP _ Fst :$ a) :$ (SymP _ Snd :$ b)) | alphaEq a b , ValT t' <- getDecor a , Just Dict <- hardwareTypeEq t t' = a simplifyUp (SymP t Fst :$ (SymP _ Pair :$ a :$ _)) = a simplifyUp (SymP t Snd :$ (SymP _ Pair :$ _ :$ a)) = a -- simplifyUp a = constFold a -- `constFold` here ensures that `simplifyUp` does not produce any expressions -- that can be statically constant folded. This property is needed, e.g. to -- fully simplify the expression `negate (2*x)`. The simplification should go -- as follows: -- -- negate (2*x) -> negate (x*2) -> x * negate 2 -> x*(-2) -- -- There is no explicit rule for the last step; it is done by `constFold`. -- Furthermore, this constant folding would not be performed by `simplifyM` -- since it never sees the sub-expression `negate 2`. (Note that the constant -- folding in `simplifyM` is still needed, because constructs such as -- `ForLoop` cannot be folded by simple literal propagation.) -- -- In order to see that `simplifyUp` doesn't produce any "junk" -- (sub-expressions that can be folded by `constFold`), we reason as follows: -- -- * Assume that the arguments of the top-level node are junk-free -- * `simplifyUp` will either apply an explicit rewrite or apply `constFold` -- * In the latter case, the result will be junk-free -- * In case of an explicit rewrite, the resulting expression is constructed -- by applying `simplifyUp` to each newly introduced node; thus the -- resulting expression must be junk-free -------------------------------------------------------------------------------- -- | Reduce an expression to a literal if the following conditions are met: -- -- * The top-most symbol can be evaluated statically (i.g. not a variable or a -- lifted side-effecting program) -- * All immediate sub-terms are literals -- * The type of the expression is an allowed type for literals (e.g. not a -- function) -- -- Note that this function only folds the top-level node of an expressions (if -- possible). It does not reduce an expression like @1+(2+3)@ where the -- sub-expression @2+3@ is not a literal. constFold :: ASTF HardwareDomain a -> ASTF HardwareDomain a constFold e | constArgs e, canFold e, ValT t@(Node _) <- getDecor e = LitP t (evalClosed e) where canFold :: ASTF HardwareDomain a -> Bool canFold = simpleMatch (\(s :&: ValT _) _ -> go s) where go :: HardwareConstructs sig -> Bool go var | Just (FreeVar _) <- prj var = False go ix | Just (ArrIx _) <- prj ix = False go bid | Just (_ :: BindingT sig) <- prj bid = False go _ = True constFold e = e -- | Check whether all arguments of a symbol are literals constArgs :: AST HardwareDomain sig -> Bool constArgs (Sym _) = True constArgs (s :$ LitP _ _) = constArgs s constArgs _ = False -------------------------------------------------------------------------------- type Opt = Writer (Set Name, Monoid.Any) freeVar :: Name -> Opt () freeVar v = tell (Set.singleton v, mempty) bindVar :: Name -> Opt a -> Opt a bindVar v = censor (\(vars, unsafe) -> (Set.delete v vars, unsafe)) tellUnsafe :: Opt () tellUnsafe = tell (mempty, Monoid.Any True) simplifyM :: ASTF HardwareDomain a -> Opt (ASTF HardwareDomain a) simplifyM exp = case exp of (VarP _ v) -> freeVar v >> return exp (LamP t v body) -> bindVar v $ fmap (LamP t v) $ simplifyM body _ -> simpleMatch (\s@(v :&: t) args -> do (exp', (vars, Monoid.Any unsafe)) <- listen (simplifyUp . appArgs (Sym s) <$> mapArgsM simplifyM args) case () of _ | Just (FreeVar _) <- prj v -> tellUnsafe >> return exp' _ | Just (ArrIx _) <- prj v -> tellUnsafe >> return exp' _ | null vars && not unsafe, ValT t'@(Node _) <- t -> return $ LitP t' $ evalClosed exp' _ -> return exp' ) exp -- Array indexing is actually not unsafe. It's more like -- an expression with a free variable. But setting the -- unsafe flag does the trick. -- Constant fold if expression is closed and does not -- contain unsafe operations. -------------------------------------------------------------------------------- optimize :: ASTF HardwareDomain a -> ASTF HardwareDomain a optimize = fst . runWriter . simplifyM --------------------------------------------------------------------------------