{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Simplify.Rule
(
RuleM
, cannotSimplify
, liftMaybe
, Rule(..)
, SimplificationRule(..)
, RuleGeneric
, RuleBasicOp
, RuleIf
, RuleDoLoop
, TopDown
, TopDownRule
, TopDownRuleGeneric
, TopDownRuleBasicOp
, TopDownRuleIf
, TopDownRuleDoLoop
, TopDownRuleOp
, BottomUp
, BottomUpRule
, BottomUpRuleGeneric
, BottomUpRuleBasicOp
, BottomUpRuleIf
, BottomUpRuleDoLoop
, BottomUpRuleOp
, RuleBook
, ruleBook
, topDownSimplifyStm
, bottomUpSimplifyStm
) where
import Control.Monad.State
import qualified Control.Monad.Fail as Fail
import Control.Monad.Except
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Representation.AST
import Futhark.Binder
data RuleError = CannotSimplify
| OtherError String
newtype RuleM lore a = RuleM (BinderT lore (StateT VNameSource (Except RuleError)) a)
deriving (Functor, Applicative, Monad,
MonadFreshNames, HasScope lore, LocalScope lore,
MonadError RuleError)
instance Fail.MonadFail (RuleM lore) where
fail = throwError . OtherError
instance (Attributes lore, BinderOps lore) => MonadBinder (RuleM lore) where
type Lore (RuleM lore) = lore
mkExpAttrM pat e = RuleM $ mkExpAttrM pat e
mkBodyM bnds res = RuleM $ mkBodyM bnds res
mkLetNamesM pat e = RuleM $ mkLetNamesM pat e
addStms = RuleM . addStms
collectStms (RuleM m) = RuleM $ collectStms m
certifying cs (RuleM m) = RuleM $ certifying cs m
simplify :: Scope lore -> VNameSource -> Rule lore
-> Maybe (Stms lore, VNameSource)
simplify _ _ Skip = Nothing
simplify scope src (Simplify (RuleM m)) =
case runExcept $ runStateT (runBinderT m scope) src of
Left CannotSimplify -> Nothing
Left (OtherError err) -> error $ "simplify: " ++ err
Right (((), x), src') -> Just (x, src')
cannotSimplify :: RuleM lore a
cannotSimplify = throwError CannotSimplify
liftMaybe :: Maybe a -> RuleM lore a
liftMaybe Nothing = cannotSimplify
liftMaybe (Just x) = return x
data Rule lore = Simplify (RuleM lore ())
| Skip
type RuleGeneric lore a = a -> Stm lore -> Rule lore
type RuleBasicOp lore a = (a -> Pattern lore -> StmAux (ExpAttr lore) ->
BasicOp lore -> Rule lore)
type RuleIf lore a = a -> Pattern lore -> StmAux (ExpAttr lore) ->
(SubExp, BodyT lore, BodyT lore,
IfAttr (BranchType lore)) ->
Rule lore
type RuleDoLoop lore a = a -> Pattern lore -> StmAux (ExpAttr lore) ->
([(FParam lore, SubExp)], [(FParam lore, SubExp)],
LoopForm lore, BodyT lore) ->
Rule lore
type RuleOp lore a = a -> Pattern lore -> StmAux (ExpAttr lore) ->
Op lore -> Rule lore
data SimplificationRule lore a = RuleGeneric (RuleGeneric lore a)
| RuleBasicOp (RuleBasicOp lore a)
| RuleIf (RuleIf lore a)
| RuleDoLoop (RuleDoLoop lore a)
| RuleOp (RuleOp lore a)
data Rules lore a = Rules { rulesAny :: [SimplificationRule lore a]
, rulesBasicOp :: [SimplificationRule lore a]
, rulesIf :: [SimplificationRule lore a]
, rulesDoLoop :: [SimplificationRule lore a]
, rulesOp :: [SimplificationRule lore a]
}
instance Semigroup (Rules lore a) where
Rules as1 bs1 cs1 ds1 es1 <> Rules as2 bs2 cs2 ds2 es2 =
Rules (as1<>as2) (bs1<>bs2) (cs1<>cs2) (ds1<>ds2) (es1<>es2)
instance Monoid (Rules lore a) where
mempty = Rules mempty mempty mempty mempty mempty
type TopDown lore = ST.SymbolTable lore
type TopDownRuleGeneric lore = RuleGeneric lore (TopDown lore)
type TopDownRuleBasicOp lore = RuleBasicOp lore (TopDown lore)
type TopDownRuleIf lore = RuleIf lore (TopDown lore)
type TopDownRuleDoLoop lore = RuleDoLoop lore (TopDown lore)
type TopDownRuleOp lore = RuleOp lore (TopDown lore)
type TopDownRule lore = SimplificationRule lore (TopDown lore)
type BottomUp lore = (ST.SymbolTable lore, UT.UsageTable)
type BottomUpRuleGeneric lore = RuleGeneric lore (BottomUp lore)
type BottomUpRuleBasicOp lore = RuleBasicOp lore (BottomUp lore)
type BottomUpRuleIf lore = RuleIf lore (BottomUp lore)
type BottomUpRuleDoLoop lore = RuleDoLoop lore (BottomUp lore)
type BottomUpRuleOp lore = RuleOp lore (BottomUp lore)
type BottomUpRule lore = SimplificationRule lore (BottomUp lore)
type TopDownRules lore = Rules lore (TopDown lore)
type BottomUpRules lore = Rules lore (BottomUp lore)
data RuleBook lore = RuleBook { bookTopDownRules :: TopDownRules lore
, bookBottomUpRules :: BottomUpRules lore
}
instance Semigroup (RuleBook lore) where
RuleBook ts1 bs1 <> RuleBook ts2 bs2 = RuleBook (ts1<>ts2) (bs1<>bs2)
instance Monoid (RuleBook lore) where
mempty = RuleBook mempty mempty
ruleBook :: [TopDownRule m]
-> [BottomUpRule m]
-> RuleBook m
ruleBook topdowns bottomups =
RuleBook (groupRules topdowns) (groupRules bottomups)
where groupRules :: [SimplificationRule m a] -> Rules m a
groupRules rs = Rules rs
(filter forBasicOp rs)
(filter forIf rs)
(filter forDoLoop rs)
(filter forOp rs)
forBasicOp RuleBasicOp{} = True
forBasicOp RuleGeneric{} = True
forBasicOp _ = False
forIf RuleIf{} = True
forIf RuleGeneric{} = True
forIf _ = False
forDoLoop RuleDoLoop{} = True
forDoLoop RuleGeneric{} = True
forDoLoop _ = False
forOp RuleOp{} = True
forOp RuleGeneric{} = True
forOp _ = False
topDownSimplifyStm :: (MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> ST.SymbolTable lore
-> Stm lore
-> m (Maybe (Stms lore))
topDownSimplifyStm = applyRules . bookTopDownRules
bottomUpSimplifyStm :: (MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> (ST.SymbolTable lore, UT.UsageTable)
-> Stm lore
-> m (Maybe (Stms lore))
bottomUpSimplifyStm = applyRules . bookBottomUpRules
rulesForStm :: Stm lore -> Rules lore a -> [SimplificationRule lore a]
rulesForStm stm = case stmExp stm of BasicOp{} -> rulesBasicOp
DoLoop{} -> rulesDoLoop
Op{} -> rulesOp
If{} -> rulesIf
_ -> rulesAny
applyRule :: SimplificationRule lore a -> a -> Stm lore -> Rule lore
applyRule (RuleGeneric f) a stm = f a stm
applyRule (RuleBasicOp f) a (Let pat aux (BasicOp e)) = f a pat aux e
applyRule (RuleDoLoop f) a (Let pat aux (DoLoop ctx val form body)) =
f a pat aux (ctx, val, form, body)
applyRule (RuleIf f) a (Let pat aux (If cond tbody fbody ifsort)) =
f a pat aux (cond, tbody, fbody, ifsort)
applyRule (RuleOp f) a (Let pat aux (Op op)) =
f a pat aux op
applyRule _ _ _ =
Skip
applyRules :: (MonadFreshNames m, HasScope lore m) =>
Rules lore a -> a -> Stm lore
-> m (Maybe (Stms lore))
applyRules all_rules context stm = do
scope <- askScope
modifyNameSource $ \src ->
let applyRules' [] = Nothing
applyRules' (rule:rules) =
case simplify scope src (applyRule rule context stm) of
Just x -> Just x
Nothing -> applyRules' rules
in case applyRules' $ rulesForStm stm all_rules of
Just (stms, src') -> (Just stms, src')
Nothing -> (Nothing, src)