{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | This module defines the concept of a simplification rule for
-- bindings.  The intent is that you pass some context (such as symbol
-- table) and a binding, and is given back a sequence of bindings that
-- compute the same result, but are "better" in some sense.
--
-- These rewrite rules are "local", in that they do not maintain any
-- state or look at the program as a whole.  Compare this to the
-- fusion algorithm in @Futhark.Optimise.Fusion.Fusion@, which must be implemented
-- as its own pass.
module Futhark.Optimise.Simplify.Rule
  ( -- * The rule monad
    RuleM,
    cannotSimplify,
    liftMaybe,

    -- * Rule definition
    Rule (..),
    SimplificationRule (..),
    RuleGeneric,
    RuleBasicOp,
    RuleIf,
    RuleDoLoop,

    -- * Top-down rules
    TopDown,
    TopDownRule,
    TopDownRuleGeneric,
    TopDownRuleBasicOp,
    TopDownRuleIf,
    TopDownRuleDoLoop,
    TopDownRuleOp,

    -- * Bottom-up rules
    BottomUp,
    BottomUpRule,
    BottomUpRuleGeneric,
    BottomUpRuleBasicOp,
    BottomUpRuleIf,
    BottomUpRuleDoLoop,
    BottomUpRuleOp,

    -- * Assembling rules
    RuleBook,
    ruleBook,

    -- * Applying rules
    topDownSimplifyStm,
    bottomUpSimplifyStm,
  )
where

import Control.Monad.State
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Builder
import Futhark.IR

-- | The monad in which simplification rules are evaluated.
newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource Maybe) a)
  deriving
    ( a -> RuleM rep b -> RuleM rep a
(a -> b) -> RuleM rep a -> RuleM rep b
(forall a b. (a -> b) -> RuleM rep a -> RuleM rep b)
-> (forall a b. a -> RuleM rep b -> RuleM rep a)
-> Functor (RuleM rep)
forall a b. a -> RuleM rep b -> RuleM rep a
forall a b. (a -> b) -> RuleM rep a -> RuleM rep b
forall rep a b. a -> RuleM rep b -> RuleM rep a
forall rep a b. (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> RuleM rep b -> RuleM rep a
$c<$ :: forall rep a b. a -> RuleM rep b -> RuleM rep a
fmap :: (a -> b) -> RuleM rep a -> RuleM rep b
$cfmap :: forall rep a b. (a -> b) -> RuleM rep a -> RuleM rep b
Functor,
      Functor (RuleM rep)
a -> RuleM rep a
Functor (RuleM rep)
-> (forall a. a -> RuleM rep a)
-> (forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b)
-> (forall a b c.
    (a -> b -> c) -> RuleM rep a -> RuleM rep b -> RuleM rep c)
-> (forall a b. RuleM rep a -> RuleM rep b -> RuleM rep b)
-> (forall a b. RuleM rep a -> RuleM rep b -> RuleM rep a)
-> Applicative (RuleM rep)
RuleM rep a -> RuleM rep b -> RuleM rep b
RuleM rep a -> RuleM rep b -> RuleM rep a
RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
(a -> b -> c) -> RuleM rep a -> RuleM rep b -> RuleM rep c
forall rep. Functor (RuleM rep)
forall a. a -> RuleM rep a
forall rep a. a -> RuleM rep a
forall a b. RuleM rep a -> RuleM rep b -> RuleM rep a
forall a b. RuleM rep a -> RuleM rep b -> RuleM rep b
forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall rep a b. RuleM rep a -> RuleM rep b -> RuleM rep a
forall rep a b. RuleM rep a -> RuleM rep b -> RuleM rep b
forall rep a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall a b c.
(a -> b -> c) -> RuleM rep a -> RuleM rep b -> RuleM rep c
forall rep a b c.
(a -> b -> c) -> RuleM rep a -> RuleM rep b -> RuleM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: RuleM rep a -> RuleM rep b -> RuleM rep a
$c<* :: forall rep a b. RuleM rep a -> RuleM rep b -> RuleM rep a
*> :: RuleM rep a -> RuleM rep b -> RuleM rep b
$c*> :: forall rep a b. RuleM rep a -> RuleM rep b -> RuleM rep b
liftA2 :: (a -> b -> c) -> RuleM rep a -> RuleM rep b -> RuleM rep c
$cliftA2 :: forall rep a b c.
(a -> b -> c) -> RuleM rep a -> RuleM rep b -> RuleM rep c
<*> :: RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
$c<*> :: forall rep a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
pure :: a -> RuleM rep a
$cpure :: forall rep a. a -> RuleM rep a
$cp1Applicative :: forall rep. Functor (RuleM rep)
Applicative,
      Applicative (RuleM rep)
a -> RuleM rep a
Applicative (RuleM rep)
-> (forall a b. RuleM rep a -> (a -> RuleM rep b) -> RuleM rep b)
-> (forall a b. RuleM rep a -> RuleM rep b -> RuleM rep b)
-> (forall a. a -> RuleM rep a)
-> Monad (RuleM rep)
RuleM rep a -> (a -> RuleM rep b) -> RuleM rep b
RuleM rep a -> RuleM rep b -> RuleM rep b
forall rep. Applicative (RuleM rep)
forall a. a -> RuleM rep a
forall rep a. a -> RuleM rep a
forall a b. RuleM rep a -> RuleM rep b -> RuleM rep b
forall a b. RuleM rep a -> (a -> RuleM rep b) -> RuleM rep b
forall rep a b. RuleM rep a -> RuleM rep b -> RuleM rep b
forall rep a b. RuleM rep a -> (a -> RuleM rep b) -> RuleM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> RuleM rep a
$creturn :: forall rep a. a -> RuleM rep a
>> :: RuleM rep a -> RuleM rep b -> RuleM rep b
$c>> :: forall rep a b. RuleM rep a -> RuleM rep b -> RuleM rep b
>>= :: RuleM rep a -> (a -> RuleM rep b) -> RuleM rep b
$c>>= :: forall rep a b. RuleM rep a -> (a -> RuleM rep b) -> RuleM rep b
$cp1Monad :: forall rep. Applicative (RuleM rep)
Monad,
      Monad (RuleM rep)
Applicative (RuleM rep)
RuleM rep VNameSource
Applicative (RuleM rep)
-> Monad (RuleM rep)
-> RuleM rep VNameSource
-> (VNameSource -> RuleM rep ())
-> MonadFreshNames (RuleM rep)
VNameSource -> RuleM rep ()
forall rep. Monad (RuleM rep)
forall rep. Applicative (RuleM rep)
forall rep. RuleM rep VNameSource
forall rep. VNameSource -> RuleM rep ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> RuleM rep ()
$cputNameSource :: forall rep. VNameSource -> RuleM rep ()
getNameSource :: RuleM rep VNameSource
$cgetNameSource :: forall rep. RuleM rep VNameSource
$cp2MonadFreshNames :: forall rep. Monad (RuleM rep)
$cp1MonadFreshNames :: forall rep. Applicative (RuleM rep)
MonadFreshNames,
      HasScope rep,
      LocalScope rep
    )

instance (ASTRep rep, BuilderOps rep) => MonadBuilder (RuleM rep) where
  type Rep (RuleM rep) = rep
  mkExpDecM :: Pat (Rep (RuleM rep))
-> Exp (Rep (RuleM rep)) -> RuleM rep (ExpDec (Rep (RuleM rep)))
mkExpDecM Pat (Rep (RuleM rep))
pat Exp (Rep (RuleM rep))
e = BuilderT rep (StateT VNameSource Maybe) (ExpDec rep)
-> RuleM rep (ExpDec rep)
forall rep a.
BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
RuleM (BuilderT rep (StateT VNameSource Maybe) (ExpDec rep)
 -> RuleM rep (ExpDec rep))
-> BuilderT rep (StateT VNameSource Maybe) (ExpDec rep)
-> RuleM rep (ExpDec rep)
forall a b. (a -> b) -> a -> b
$ Pat (Rep (BuilderT rep (StateT VNameSource Maybe)))
-> Exp (Rep (BuilderT rep (StateT VNameSource Maybe)))
-> BuilderT
     rep
     (StateT VNameSource Maybe)
     (ExpDec (Rep (BuilderT rep (StateT VNameSource Maybe))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (Rep (BuilderT rep (StateT VNameSource Maybe)))
Pat (Rep (RuleM rep))
pat Exp (Rep (BuilderT rep (StateT VNameSource Maybe)))
Exp (Rep (RuleM rep))
e
  mkBodyM :: Stms (Rep (RuleM rep))
-> Result -> RuleM rep (Body (Rep (RuleM rep)))
mkBodyM Stms (Rep (RuleM rep))
stms Result
res = BuilderT rep (StateT VNameSource Maybe) (Body rep)
-> RuleM rep (Body rep)
forall rep a.
BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
RuleM (BuilderT rep (StateT VNameSource Maybe) (Body rep)
 -> RuleM rep (Body rep))
-> BuilderT rep (StateT VNameSource Maybe) (Body rep)
-> RuleM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT rep (StateT VNameSource Maybe)))
-> Result
-> BuilderT
     rep
     (StateT VNameSource Maybe)
     (Body (Rep (BuilderT rep (StateT VNameSource Maybe))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT rep (StateT VNameSource Maybe)))
Stms (Rep (RuleM rep))
stms Result
res
  mkLetNamesM :: [VName]
-> Exp (Rep (RuleM rep)) -> RuleM rep (Stm (Rep (RuleM rep)))
mkLetNamesM [VName]
pat Exp (Rep (RuleM rep))
e = BuilderT rep (StateT VNameSource Maybe) (Stm rep)
-> RuleM rep (Stm rep)
forall rep a.
BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
RuleM (BuilderT rep (StateT VNameSource Maybe) (Stm rep)
 -> RuleM rep (Stm rep))
-> BuilderT rep (StateT VNameSource Maybe) (Stm rep)
-> RuleM rep (Stm rep)
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT rep (StateT VNameSource Maybe)))
-> BuilderT
     rep
     (StateT VNameSource Maybe)
     (Stm (Rep (BuilderT rep (StateT VNameSource Maybe))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT rep (StateT VNameSource Maybe)))
Exp (Rep (RuleM rep))
e

  addStms :: Stms (Rep (RuleM rep)) -> RuleM rep ()
addStms = BuilderT rep (StateT VNameSource Maybe) () -> RuleM rep ()
forall rep a.
BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
RuleM (BuilderT rep (StateT VNameSource Maybe) () -> RuleM rep ())
-> (Stms rep -> BuilderT rep (StateT VNameSource Maybe) ())
-> Stms rep
-> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms rep -> BuilderT rep (StateT VNameSource Maybe) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: RuleM rep a -> RuleM rep (a, Stms (Rep (RuleM rep)))
collectStms (RuleM BuilderT rep (StateT VNameSource Maybe) a
m) = BuilderT rep (StateT VNameSource Maybe) (a, Stms rep)
-> RuleM rep (a, Stms rep)
forall rep a.
BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
RuleM (BuilderT rep (StateT VNameSource Maybe) (a, Stms rep)
 -> RuleM rep (a, Stms rep))
-> BuilderT rep (StateT VNameSource Maybe) (a, Stms rep)
-> RuleM rep (a, Stms rep)
forall a b. (a -> b) -> a -> b
$ BuilderT rep (StateT VNameSource Maybe) a
-> BuilderT
     rep
     (StateT VNameSource Maybe)
     (a, Stms (Rep (BuilderT rep (StateT VNameSource Maybe))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT rep (StateT VNameSource Maybe) a
m

-- | Execute a 'RuleM' action.  If succesful, returns the result and a
-- list of new bindings.
simplify ::
  Scope rep ->
  VNameSource ->
  Rule rep ->
  Maybe (Stms rep, VNameSource)
simplify :: Scope rep
-> VNameSource -> Rule rep -> Maybe (Stms rep, VNameSource)
simplify Scope rep
_ VNameSource
_ Rule rep
Skip = Maybe (Stms rep, VNameSource)
forall a. Maybe a
Nothing
simplify Scope rep
scope VNameSource
src (Simplify (RuleM BuilderT rep (StateT VNameSource Maybe) ()
m)) =
  StateT VNameSource Maybe (Stms rep)
-> VNameSource -> Maybe (Stms rep, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (BuilderT rep (StateT VNameSource Maybe) ()
-> Scope rep -> StateT VNameSource Maybe (Stms rep)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ BuilderT rep (StateT VNameSource Maybe) ()
m Scope rep
scope) VNameSource
src

cannotSimplify :: RuleM rep a
cannotSimplify :: RuleM rep a
cannotSimplify = BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
forall rep a.
BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
RuleM (BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a)
-> BuilderT rep (StateT VNameSource Maybe) a -> RuleM rep a
forall a b. (a -> b) -> a -> b
$ StateT VNameSource Maybe a
-> BuilderT rep (StateT VNameSource Maybe) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT VNameSource Maybe a
 -> BuilderT rep (StateT VNameSource Maybe) a)
-> StateT VNameSource Maybe a
-> BuilderT rep (StateT VNameSource Maybe) a
forall a b. (a -> b) -> a -> b
$ Maybe a -> StateT VNameSource Maybe a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe a
forall a. Maybe a
Nothing

liftMaybe :: Maybe a -> RuleM rep a
liftMaybe :: Maybe a -> RuleM rep a
liftMaybe Maybe a
Nothing = RuleM rep a
forall rep a. RuleM rep a
cannotSimplify
liftMaybe (Just a
x) = a -> RuleM rep a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | An efficient way of encoding whether a simplification rule should even be attempted.
data Rule rep
  = -- | Give it a shot.
    Simplify (RuleM rep ())
  | -- | Don't bother.
    Skip

type RuleGeneric rep a = a -> Stm rep -> Rule rep

type RuleBasicOp rep a =
  ( a ->
    Pat rep ->
    StmAux (ExpDec rep) ->
    BasicOp ->
    Rule rep
  )

type RuleIf rep a =
  a ->
  Pat rep ->
  StmAux (ExpDec rep) ->
  ( SubExp,
    BodyT rep,
    BodyT rep,
    IfDec (BranchType rep)
  ) ->
  Rule rep

type RuleDoLoop rep a =
  a ->
  Pat rep ->
  StmAux (ExpDec rep) ->
  ( [(FParam rep, SubExp)],
    LoopForm rep,
    BodyT rep
  ) ->
  Rule rep

type RuleOp rep a =
  a ->
  Pat rep ->
  StmAux (ExpDec rep) ->
  Op rep ->
  Rule rep

-- | A simplification rule takes some argument and a statement, and
-- tries to simplify the statement.
data SimplificationRule rep a
  = RuleGeneric (RuleGeneric rep a)
  | RuleBasicOp (RuleBasicOp rep a)
  | RuleIf (RuleIf rep a)
  | RuleDoLoop (RuleDoLoop rep a)
  | RuleOp (RuleOp rep a)

-- | A collection of rules grouped by which forms of statements they
-- may apply to.
data Rules rep a = Rules
  { Rules rep a -> [SimplificationRule rep a]
rulesAny :: [SimplificationRule rep a],
    Rules rep a -> [SimplificationRule rep a]
rulesBasicOp :: [SimplificationRule rep a],
    Rules rep a -> [SimplificationRule rep a]
rulesIf :: [SimplificationRule rep a],
    Rules rep a -> [SimplificationRule rep a]
rulesDoLoop :: [SimplificationRule rep a],
    Rules rep a -> [SimplificationRule rep a]
rulesOp :: [SimplificationRule rep a]
  }

instance Semigroup (Rules rep a) where
  Rules [SimplificationRule rep a]
as1 [SimplificationRule rep a]
bs1 [SimplificationRule rep a]
cs1 [SimplificationRule rep a]
ds1 [SimplificationRule rep a]
es1 <> :: Rules rep a -> Rules rep a -> Rules rep a
<> Rules [SimplificationRule rep a]
as2 [SimplificationRule rep a]
bs2 [SimplificationRule rep a]
cs2 [SimplificationRule rep a]
ds2 [SimplificationRule rep a]
es2 =
    [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> Rules rep a
forall rep a.
[SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> Rules rep a
Rules ([SimplificationRule rep a]
as1 [SimplificationRule rep a]
-> [SimplificationRule rep a] -> [SimplificationRule rep a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule rep a]
as2) ([SimplificationRule rep a]
bs1 [SimplificationRule rep a]
-> [SimplificationRule rep a] -> [SimplificationRule rep a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule rep a]
bs2) ([SimplificationRule rep a]
cs1 [SimplificationRule rep a]
-> [SimplificationRule rep a] -> [SimplificationRule rep a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule rep a]
cs2) ([SimplificationRule rep a]
ds1 [SimplificationRule rep a]
-> [SimplificationRule rep a] -> [SimplificationRule rep a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule rep a]
ds2) ([SimplificationRule rep a]
es1 [SimplificationRule rep a]
-> [SimplificationRule rep a] -> [SimplificationRule rep a]
forall a. Semigroup a => a -> a -> a
<> [SimplificationRule rep a]
es2)

instance Monoid (Rules rep a) where
  mempty :: Rules rep a
mempty = [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> Rules rep a
forall rep a.
[SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> Rules rep a
Rules [SimplificationRule rep a]
forall a. Monoid a => a
mempty [SimplificationRule rep a]
forall a. Monoid a => a
mempty [SimplificationRule rep a]
forall a. Monoid a => a
mempty [SimplificationRule rep a]
forall a. Monoid a => a
mempty [SimplificationRule rep a]
forall a. Monoid a => a
mempty

-- | Context for a rule applied during top-down traversal of the
-- program.  Takes a symbol table as argument.
type TopDown rep = ST.SymbolTable rep

type TopDownRuleGeneric rep = RuleGeneric rep (TopDown rep)

type TopDownRuleBasicOp rep = RuleBasicOp rep (TopDown rep)

type TopDownRuleIf rep = RuleIf rep (TopDown rep)

type TopDownRuleDoLoop rep = RuleDoLoop rep (TopDown rep)

type TopDownRuleOp rep = RuleOp rep (TopDown rep)

type TopDownRule rep = SimplificationRule rep (TopDown rep)

-- | Context for a rule applied during bottom-up traversal of the
-- program.  Takes a symbol table and usage table as arguments.
type BottomUp rep = (ST.SymbolTable rep, UT.UsageTable)

type BottomUpRuleGeneric rep = RuleGeneric rep (BottomUp rep)

type BottomUpRuleBasicOp rep = RuleBasicOp rep (BottomUp rep)

type BottomUpRuleIf rep = RuleIf rep (BottomUp rep)

type BottomUpRuleDoLoop rep = RuleDoLoop rep (BottomUp rep)

type BottomUpRuleOp rep = RuleOp rep (BottomUp rep)

type BottomUpRule rep = SimplificationRule rep (BottomUp rep)

-- | A collection of top-down rules.
type TopDownRules rep = Rules rep (TopDown rep)

-- | A collection of bottom-up rules.
type BottomUpRules rep = Rules rep (BottomUp rep)

-- | A collection of both top-down and bottom-up rules.
data RuleBook rep = RuleBook
  { RuleBook rep -> TopDownRules rep
bookTopDownRules :: TopDownRules rep,
    RuleBook rep -> BottomUpRules rep
bookBottomUpRules :: BottomUpRules rep
  }

instance Semigroup (RuleBook rep) where
  RuleBook TopDownRules rep
ts1 BottomUpRules rep
bs1 <> :: RuleBook rep -> RuleBook rep -> RuleBook rep
<> RuleBook TopDownRules rep
ts2 BottomUpRules rep
bs2 = TopDownRules rep -> BottomUpRules rep -> RuleBook rep
forall rep. TopDownRules rep -> BottomUpRules rep -> RuleBook rep
RuleBook (TopDownRules rep
ts1 TopDownRules rep -> TopDownRules rep -> TopDownRules rep
forall a. Semigroup a => a -> a -> a
<> TopDownRules rep
ts2) (BottomUpRules rep
bs1 BottomUpRules rep -> BottomUpRules rep -> BottomUpRules rep
forall a. Semigroup a => a -> a -> a
<> BottomUpRules rep
bs2)

instance Monoid (RuleBook rep) where
  mempty :: RuleBook rep
mempty = TopDownRules rep -> BottomUpRules rep -> RuleBook rep
forall rep. TopDownRules rep -> BottomUpRules rep -> RuleBook rep
RuleBook TopDownRules rep
forall a. Monoid a => a
mempty BottomUpRules rep
forall a. Monoid a => a
mempty

-- | Construct a rule book from a collection of rules.
ruleBook ::
  [TopDownRule m] ->
  [BottomUpRule m] ->
  RuleBook m
ruleBook :: [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule m]
topdowns [BottomUpRule m]
bottomups =
  TopDownRules m -> BottomUpRules m -> RuleBook m
forall rep. TopDownRules rep -> BottomUpRules rep -> RuleBook rep
RuleBook ([TopDownRule m] -> TopDownRules m
forall m a. [SimplificationRule m a] -> Rules m a
groupRules [TopDownRule m]
topdowns) ([BottomUpRule m] -> BottomUpRules m
forall m a. [SimplificationRule m a] -> Rules m a
groupRules [BottomUpRule m]
bottomups)
  where
    groupRules :: [SimplificationRule m a] -> Rules m a
    groupRules :: [SimplificationRule m a] -> Rules m a
groupRules [SimplificationRule m a]
rs =
      [SimplificationRule m a]
-> [SimplificationRule m a]
-> [SimplificationRule m a]
-> [SimplificationRule m a]
-> [SimplificationRule m a]
-> Rules m a
forall rep a.
[SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> [SimplificationRule rep a]
-> Rules rep a
Rules
        [SimplificationRule m a]
rs
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall rep a. SimplificationRule rep a -> Bool
forBasicOp [SimplificationRule m a]
rs)
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall rep a. SimplificationRule rep a -> Bool
forIf [SimplificationRule m a]
rs)
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall rep a. SimplificationRule rep a -> Bool
forDoLoop [SimplificationRule m a]
rs)
        ((SimplificationRule m a -> Bool)
-> [SimplificationRule m a] -> [SimplificationRule m a]
forall a. (a -> Bool) -> [a] -> [a]
filter SimplificationRule m a -> Bool
forall rep a. SimplificationRule rep a -> Bool
forOp [SimplificationRule m a]
rs)

    forBasicOp :: SimplificationRule rep a -> Bool
forBasicOp RuleBasicOp {} = Bool
True
    forBasicOp RuleGeneric {} = Bool
True
    forBasicOp SimplificationRule rep a
_ = Bool
False

    forIf :: SimplificationRule rep a -> Bool
forIf RuleIf {} = Bool
True
    forIf RuleGeneric {} = Bool
True
    forIf SimplificationRule rep a
_ = Bool
False

    forDoLoop :: SimplificationRule rep a -> Bool
forDoLoop RuleDoLoop {} = Bool
True
    forDoLoop RuleGeneric {} = Bool
True
    forDoLoop SimplificationRule rep a
_ = Bool
False

    forOp :: SimplificationRule rep a -> Bool
forOp RuleOp {} = Bool
True
    forOp RuleGeneric {} = Bool
True
    forOp SimplificationRule rep a
_ = Bool
False

-- | @simplifyStm lookup stm@ performs simplification of the
-- binding @stm@.  If simplification is possible, a replacement list
-- of bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
topDownSimplifyStm ::
  (MonadFreshNames m, HasScope rep m) =>
  RuleBook rep ->
  ST.SymbolTable rep ->
  Stm rep ->
  m (Maybe (Stms rep))
topDownSimplifyStm :: RuleBook rep -> SymbolTable rep -> Stm rep -> m (Maybe (Stms rep))
topDownSimplifyStm = Rules rep (SymbolTable rep)
-> SymbolTable rep -> Stm rep -> m (Maybe (Stms rep))
forall (m :: * -> *) rep a.
(MonadFreshNames m, HasScope rep m) =>
Rules rep a -> a -> Stm rep -> m (Maybe (Stms rep))
applyRules (Rules rep (SymbolTable rep)
 -> SymbolTable rep -> Stm rep -> m (Maybe (Stms rep)))
-> (RuleBook rep -> Rules rep (SymbolTable rep))
-> RuleBook rep
-> SymbolTable rep
-> Stm rep
-> m (Maybe (Stms rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RuleBook rep -> Rules rep (SymbolTable rep)
forall rep. RuleBook rep -> TopDownRules rep
bookTopDownRules

-- | @simplifyStm uses stm@ performs simplification of the binding
-- @stm@.  If simplification is possible, a replacement list of
-- bindings is returned, that bind at least the same names as the
-- original binding (and possibly more, for intermediate results).
-- The first argument is the set of names used after this binding.
bottomUpSimplifyStm ::
  (MonadFreshNames m, HasScope rep m) =>
  RuleBook rep ->
  (ST.SymbolTable rep, UT.UsageTable) ->
  Stm rep ->
  m (Maybe (Stms rep))
bottomUpSimplifyStm :: RuleBook rep
-> (SymbolTable rep, UsageTable) -> Stm rep -> m (Maybe (Stms rep))
bottomUpSimplifyStm = Rules rep (SymbolTable rep, UsageTable)
-> (SymbolTable rep, UsageTable) -> Stm rep -> m (Maybe (Stms rep))
forall (m :: * -> *) rep a.
(MonadFreshNames m, HasScope rep m) =>
Rules rep a -> a -> Stm rep -> m (Maybe (Stms rep))
applyRules (Rules rep (SymbolTable rep, UsageTable)
 -> (SymbolTable rep, UsageTable)
 -> Stm rep
 -> m (Maybe (Stms rep)))
-> (RuleBook rep -> Rules rep (SymbolTable rep, UsageTable))
-> RuleBook rep
-> (SymbolTable rep, UsageTable)
-> Stm rep
-> m (Maybe (Stms rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RuleBook rep -> Rules rep (SymbolTable rep, UsageTable)
forall rep. RuleBook rep -> BottomUpRules rep
bookBottomUpRules

rulesForStm :: Stm rep -> Rules rep a -> [SimplificationRule rep a]
rulesForStm :: Stm rep -> Rules rep a -> [SimplificationRule rep a]
rulesForStm Stm rep
stm = case Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm of
  BasicOp {} -> Rules rep a -> [SimplificationRule rep a]
forall rep a. Rules rep a -> [SimplificationRule rep a]
rulesBasicOp
  DoLoop {} -> Rules rep a -> [SimplificationRule rep a]
forall rep a. Rules rep a -> [SimplificationRule rep a]
rulesDoLoop
  Op {} -> Rules rep a -> [SimplificationRule rep a]
forall rep a. Rules rep a -> [SimplificationRule rep a]
rulesOp
  If {} -> Rules rep a -> [SimplificationRule rep a]
forall rep a. Rules rep a -> [SimplificationRule rep a]
rulesIf
  Exp rep
_ -> Rules rep a -> [SimplificationRule rep a]
forall rep a. Rules rep a -> [SimplificationRule rep a]
rulesAny

applyRule :: SimplificationRule rep a -> a -> Stm rep -> Rule rep
applyRule :: SimplificationRule rep a -> a -> Stm rep -> Rule rep
applyRule (RuleGeneric a -> Stm rep -> Rule rep
f) a
a Stm rep
stm = a -> Stm rep -> Rule rep
f a
a Stm rep
stm
applyRule (RuleBasicOp RuleBasicOp rep a
f) a
a (Let Pat rep
pat StmAux (ExpDec rep)
aux (BasicOp BasicOp
e)) = RuleBasicOp rep a
f a
a Pat rep
pat StmAux (ExpDec rep)
aux BasicOp
e
applyRule (RuleDoLoop RuleDoLoop rep a
f) a
a (Let Pat rep
pat StmAux (ExpDec rep)
aux (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form BodyT rep
body)) =
  RuleDoLoop rep a
f a
a Pat rep
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
merge, LoopForm rep
form, BodyT rep
body)
applyRule (RuleIf RuleIf rep a
f) a
a (Let Pat rep
pat StmAux (ExpDec rep)
aux (If SubExp
cond BodyT rep
tbody BodyT rep
fbody IfDec (BranchType rep)
ifsort)) =
  RuleIf rep a
f a
a Pat rep
pat StmAux (ExpDec rep)
aux (SubExp
cond, BodyT rep
tbody, BodyT rep
fbody, IfDec (BranchType rep)
ifsort)
applyRule (RuleOp RuleOp rep a
f) a
a (Let Pat rep
pat StmAux (ExpDec rep)
aux (Op Op rep
op)) =
  RuleOp rep a
f a
a Pat rep
pat StmAux (ExpDec rep)
aux Op rep
op
applyRule SimplificationRule rep a
_ a
_ Stm rep
_ =
  Rule rep
forall rep. Rule rep
Skip

applyRules ::
  (MonadFreshNames m, HasScope rep m) =>
  Rules rep a ->
  a ->
  Stm rep ->
  m (Maybe (Stms rep))
applyRules :: Rules rep a -> a -> Stm rep -> m (Maybe (Stms rep))
applyRules Rules rep a
all_rules a
context Stm rep
stm = do
  Scope rep
scope <- m (Scope rep)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope

  (VNameSource -> (Maybe (Stms rep), VNameSource))
-> m (Maybe (Stms rep))
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe (Stms rep), VNameSource))
 -> m (Maybe (Stms rep)))
-> (VNameSource -> (Maybe (Stms rep), VNameSource))
-> m (Maybe (Stms rep))
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let applyRules' :: [SimplificationRule rep a] -> Maybe (Stms rep, VNameSource)
applyRules' [] = Maybe (Stms rep, VNameSource)
forall a. Maybe a
Nothing
        applyRules' (SimplificationRule rep a
rule : [SimplificationRule rep a]
rules) =
          case Scope rep
-> VNameSource -> Rule rep -> Maybe (Stms rep, VNameSource)
forall rep.
Scope rep
-> VNameSource -> Rule rep -> Maybe (Stms rep, VNameSource)
simplify Scope rep
scope VNameSource
src (SimplificationRule rep a -> a -> Stm rep -> Rule rep
forall rep a. SimplificationRule rep a -> a -> Stm rep -> Rule rep
applyRule SimplificationRule rep a
rule a
context Stm rep
stm) of
            Just (Stms rep, VNameSource)
x -> (Stms rep, VNameSource) -> Maybe (Stms rep, VNameSource)
forall a. a -> Maybe a
Just (Stms rep, VNameSource)
x
            Maybe (Stms rep, VNameSource)
Nothing -> [SimplificationRule rep a] -> Maybe (Stms rep, VNameSource)
applyRules' [SimplificationRule rep a]
rules
     in case [SimplificationRule rep a] -> Maybe (Stms rep, VNameSource)
applyRules' ([SimplificationRule rep a] -> Maybe (Stms rep, VNameSource))
-> [SimplificationRule rep a] -> Maybe (Stms rep, VNameSource)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Rules rep a -> [SimplificationRule rep a]
forall rep a. Stm rep -> Rules rep a -> [SimplificationRule rep a]
rulesForStm Stm rep
stm Rules rep a
all_rules of
          Just (Stms rep
stms, VNameSource
src') -> (Stms rep -> Maybe (Stms rep)
forall a. a -> Maybe a
Just Stms rep
stms, VNameSource
src')
          Maybe (Stms rep, VNameSource)
Nothing -> (Maybe (Stms rep)
forall a. Maybe a
Nothing, VNameSource
src)