{-# LANGUAGE Strict #-}

module Futhark.Optimise.Simplify
  ( simplifyProg,
    simplifySomething,
    simplifyFun,
    simplifyLambda,
    simplifyStms,
    Engine.SimpleOps (..),
    Engine.SimpleM,
    Engine.SimplifyOp,
    Engine.bindableSimpleOps,
    Engine.noExtraHoistBlockers,
    Engine.neverHoist,
    Engine.SimplifiableRep,
    Engine.HoistBlockers,
    RuleBook,
  )
where

import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Pass

-- | Simplify the given program.  Even if the output differs from the
-- output, meaningful simplification may not have taken place - the
-- order of bindings may simply have been rearranged.
simplifyProg ::
  Engine.SimplifiableRep rep =>
  Engine.SimpleOps rep ->
  RuleBook (Engine.Wise rep) ->
  Engine.HoistBlockers rep ->
  Prog rep ->
  PassM (Prog rep)
simplifyProg :: forall {k} (rep :: k).
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
simplifyProg SimpleOps rep
simpl RuleBook (Wise rep)
rules HoistBlockers rep
blockers Prog rep
prog = do
  let consts :: Stms rep
consts = forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog rep
prog
      funs :: [FunDef rep]
funs = forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog rep
prog
  (SymbolTable (Wise rep)
consts_vtable, Stms (Wise rep)
consts') <-
    forall {m :: * -> *}.
MonadFreshNames m =>
UsageTable
-> (SymbolTable (Wise rep), Stms (Wise rep))
-> m (SymbolTable (Wise rep), Stms (Wise rep))
simplifyConsts (Names -> UsageTable
UT.usages forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn [FunDef rep]
funs) (forall a. Monoid a => a
mempty, forall {k} (rep :: k). Informing rep => Stms rep -> Stms (Wise rep)
informStms Stms rep
consts)

  -- We deepen the vtable so it will look like the constants are in an
  -- "outer loop"; this communicates useful information to some
  -- simplification rules (e.g. see issue #1302).
  [FunDef (Wise rep)]
funs' <- forall a b. (a -> PassM b) -> [a] -> PassM [b]
parPass (forall {m :: * -> *}.
MonadFreshNames m =>
SymbolTable (Wise rep)
-> FunDef (Wise rep) -> m (FunDef (Wise rep))
simplifyFun' (forall {k} (rep :: k). SymbolTable rep -> SymbolTable rep
ST.deepen SymbolTable (Wise rep)
consts_vtable) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Informing rep =>
FunDef rep -> FunDef (Wise rep)
informFunDef) [FunDef rep]
funs
  let funs_uses :: UsageTable
funs_uses = Names -> UsageTable
UT.usages forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn [FunDef (Wise rep)]
funs'

  (SymbolTable (Wise rep)
_, Stms (Wise rep)
consts'') <- forall {m :: * -> *}.
MonadFreshNames m =>
UsageTable
-> (SymbolTable (Wise rep), Stms (Wise rep))
-> m (SymbolTable (Wise rep), Stms (Wise rep))
simplifyConsts UsageTable
funs_uses (forall a. Monoid a => a
mempty, Stms (Wise rep)
consts')

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    Prog rep
prog
      { progConsts :: Stms rep
progConsts = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k).
CanBeWise (Op rep) =>
Stm (Wise rep) -> Stm rep
removeStmWisdom Stms (Wise rep)
consts'',
        progFuns :: [FunDef rep]
progFuns = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k).
CanBeWise (Op rep) =>
FunDef (Wise rep) -> FunDef rep
removeFunDefWisdom [FunDef (Wise rep)]
funs'
      }
  where
    simplifyFun' :: SymbolTable (Wise rep)
-> FunDef (Wise rep) -> m (FunDef (Wise rep))
simplifyFun' SymbolTable (Wise rep)
consts_vtable =
      forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
(a -> SimpleM rep b)
-> (b -> a)
-> SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> a
-> m a
simplifySomething
        (forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
consts_vtable <>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
SimplifiableRep rep =>
FunDef (Wise rep) -> SimpleM rep (FunDef (Wise rep))
Engine.simplifyFun)
        forall a. a -> a
id
        SimpleOps rep
simpl
        RuleBook (Wise rep)
rules
        HoistBlockers rep
blockers
        forall a. Monoid a => a
mempty

    simplifyConsts :: UsageTable
-> (SymbolTable (Wise rep), Stms (Wise rep))
-> m (SymbolTable (Wise rep), Stms (Wise rep))
simplifyConsts UsageTable
uses =
      forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
(a -> SimpleM rep b)
-> (b -> a)
-> SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> a
-> m a
simplifySomething
        (forall {k} {rep :: k}.
(ASTRep rep, Simplifiable (LetDec rep),
 Simplifiable (FParamInfo rep), Simplifiable (LParamInfo rep),
 Simplifiable (RetType rep), Simplifiable (BranchType rep),
 TraverseOpStms (Wise rep), CanBeWise (Op rep),
 IndexOp (OpWithWisdom (Op rep)), BuilderOps (Wise rep)) =>
UsageTable
-> Stms (Wise rep)
-> SimpleM rep (SymbolTable (Wise rep), Stms (Wise rep))
onConsts UsageTable
uses forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
        forall a. a -> a
id
        SimpleOps rep
simpl
        RuleBook (Wise rep)
rules
        HoistBlockers rep
blockers
        forall a. Monoid a => a
mempty

    onConsts :: UsageTable
-> Stms (Wise rep)
-> SimpleM rep (SymbolTable (Wise rep), Stms (Wise rep))
onConsts UsageTable
uses Stms (Wise rep)
consts' = do
      Stms (Wise rep)
consts'' <- forall {k} (rep :: k).
SimplifiableRep rep =>
UsageTable -> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
Engine.simplifyStmsWithUsage UsageTable
uses Stms (Wise rep)
consts'
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
(ASTRep rep, IndexOp (Op rep), Aliased rep) =>
Stms rep -> SymbolTable rep -> SymbolTable rep
ST.insertStms Stms (Wise rep)
consts'' forall a. Monoid a => a
mempty, Stms (Wise rep)
consts'')

-- | Run a simplification operation to convergence.
simplifySomething ::
  (MonadFreshNames m, Engine.SimplifiableRep rep) =>
  (a -> Engine.SimpleM rep b) ->
  (b -> a) ->
  Engine.SimpleOps rep ->
  RuleBook (Wise rep) ->
  Engine.HoistBlockers rep ->
  ST.SymbolTable (Wise rep) ->
  a ->
  m a
simplifySomething :: forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
(a -> SimpleM rep b)
-> (b -> a)
-> SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> a
-> m a
simplifySomething a -> SimpleM rep b
f b -> a
g SimpleOps rep
simpl RuleBook (Wise rep)
rules HoistBlockers rep
blockers SymbolTable (Wise rep)
vtable a
x = do
  let f' :: a -> SimpleM rep b
f' a
x' = forall {k} (rep :: k) a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
Engine.localVtable (SymbolTable (Wise rep)
vtable <>) forall a b. (a -> b) -> a -> b
$ a -> SimpleM rep b
f a
x'
  forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
Env rep
-> SimpleOps rep -> (a -> SimpleM rep b) -> (b -> a) -> a -> m a
loopUntilConvergence Env rep
env SimpleOps rep
simpl a -> SimpleM rep b
f' b -> a
g a
x
  where
    env :: Env rep
env = forall {k} (rep :: k).
RuleBook (Wise rep) -> HoistBlockers rep -> Env rep
Engine.emptyEnv RuleBook (Wise rep)
rules HoistBlockers rep
blockers

-- | Simplify the given function.  Even if the output differs from the
-- output, meaningful simplification may not have taken place - the
-- order of bindings may simply have been rearranged.  Runs in a loop
-- until convergence.
simplifyFun ::
  (MonadFreshNames m, Engine.SimplifiableRep rep) =>
  Engine.SimpleOps rep ->
  RuleBook (Engine.Wise rep) ->
  Engine.HoistBlockers rep ->
  ST.SymbolTable (Wise rep) ->
  FunDef rep ->
  m (FunDef rep)
simplifyFun :: forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> FunDef rep
-> m (FunDef rep)
simplifyFun SimpleOps rep
simpl RuleBook (Wise rep)
rules HoistBlockers rep
blockers SymbolTable (Wise rep)
vtable FunDef rep
fd =
  forall {k} (rep :: k).
CanBeWise (Op rep) =>
FunDef (Wise rep) -> FunDef rep
removeFunDefWisdom
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
(a -> SimpleM rep b)
-> (b -> a)
-> SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> a
-> m a
simplifySomething
      forall {k} (rep :: k).
SimplifiableRep rep =>
FunDef (Wise rep) -> SimpleM rep (FunDef (Wise rep))
Engine.simplifyFun
      forall a. a -> a
id
      SimpleOps rep
simpl
      RuleBook (Wise rep)
rules
      HoistBlockers rep
blockers
      SymbolTable (Wise rep)
vtable
      (forall {k} (rep :: k).
Informing rep =>
FunDef rep -> FunDef (Wise rep)
informFunDef FunDef rep
fd)

-- | Simplify just a single t'Lambda'.
simplifyLambda ::
  ( MonadFreshNames m,
    HasScope rep m,
    Engine.SimplifiableRep rep
  ) =>
  Engine.SimpleOps rep ->
  RuleBook (Engine.Wise rep) ->
  Engine.HoistBlockers rep ->
  Lambda rep ->
  m (Lambda rep)
simplifyLambda :: forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
simplifyLambda SimpleOps rep
simpl RuleBook (Wise rep)
rules HoistBlockers rep
blockers Lambda rep
orig_lam = do
  SymbolTable (Wise rep)
vtable <- forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Scope rep -> Scope (Wise rep)
addScopeWisdom forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  forall {k} (rep :: k).
CanBeWise (Op rep) =>
Lambda (Wise rep) -> Lambda rep
removeLambdaWisdom
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
(a -> SimpleM rep b)
-> (b -> a)
-> SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> a
-> m a
simplifySomething
      forall {k} (rep :: k).
SimplifiableRep rep =>
Lambda (Wise rep) -> SimpleM rep (Lambda (Wise rep))
Engine.simplifyLambdaNoHoisting
      forall a. a -> a
id
      SimpleOps rep
simpl
      RuleBook (Wise rep)
rules
      HoistBlockers rep
blockers
      SymbolTable (Wise rep)
vtable
      (forall {k} (rep :: k).
Informing rep =>
Lambda rep -> Lambda (Wise rep)
informLambda Lambda rep
orig_lam)

-- | Simplify a sequence of 'Stm's.
simplifyStms ::
  (MonadFreshNames m, Engine.SimplifiableRep rep) =>
  Engine.SimpleOps rep ->
  RuleBook (Engine.Wise rep) ->
  Engine.HoistBlockers rep ->
  Scope rep ->
  Stms rep ->
  m (Stms rep)
simplifyStms :: forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Scope rep
-> Stms rep
-> m (Stms rep)
simplifyStms SimpleOps rep
simpl RuleBook (Wise rep)
rules HoistBlockers rep
blockers Scope rep
scope =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (rep :: k).
CanBeWise (Op rep) =>
Stm (Wise rep) -> Stm rep
removeStmWisdom)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
(a -> SimpleM rep b)
-> (b -> a)
-> SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> SymbolTable (Wise rep)
-> a
-> m a
simplifySomething forall {k} (rep :: k).
SimplifiableRep rep =>
Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
Engine.simplifyStms forall a. a -> a
id SimpleOps rep
simpl RuleBook (Wise rep)
rules HoistBlockers rep
blockers SymbolTable (Wise rep)
vtable
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Informing rep => Stms rep -> Stms (Wise rep)
informStms
  where
    vtable :: SymbolTable (Wise rep)
vtable = forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scope rep -> Scope (Wise rep)
addScopeWisdom Scope rep
scope

loopUntilConvergence ::
  (MonadFreshNames m, Engine.SimplifiableRep rep) =>
  Engine.Env rep ->
  Engine.SimpleOps rep ->
  (a -> Engine.SimpleM rep b) ->
  (b -> a) ->
  a ->
  m a
loopUntilConvergence :: forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
Env rep
-> SimpleOps rep -> (a -> SimpleM rep b) -> (b -> a) -> a -> m a
loopUntilConvergence Env rep
env SimpleOps rep
simpl a -> SimpleM rep b
f b -> a
g a
x = do
  (b
x', Bool
changed) <- forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a.
SimpleM rep a
-> SimpleOps rep
-> Env rep
-> VNameSource
-> ((a, Bool), VNameSource)
Engine.runSimpleM (a -> SimpleM rep b
f a
x) SimpleOps rep
simpl Env rep
env
  if Bool
changed then forall {k} (m :: * -> *) (rep :: k) a b.
(MonadFreshNames m, SimplifiableRep rep) =>
Env rep
-> SimpleOps rep -> (a -> SimpleM rep b) -> (b -> a) -> a -> m a
loopUntilConvergence Env rep
env SimpleOps rep
simpl a -> SimpleM rep b
f b -> a
g (b -> a
g b
x') else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ b -> a
g b
x'