{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
--
-- Perform general rule-based simplification based on data dependency
-- information.  This module will:
--
--    * Perform common-subexpression elimination (CSE).
--
--    * Hoist expressions out of loops (including lambdas) and
--    branches.  This is done as aggressively as possible.
--
--    * Apply simplification rules (see
--    "Futhark.Optimise.Simplification.Rules").
--
-- If you just want to run the simplifier as simply as possible, you
-- may prefer to use the "Futhark.Optimise.Simplify" module.
module Futhark.Optimise.Simplify.Engine
  ( -- * Monadic interface
    SimpleM,
    runSimpleM,
    SimpleOps (..),
    SimplifyOp,
    bindableSimpleOps,
    Env (envHoistBlockers, envRules),
    emptyEnv,
    HoistBlockers (..),
    neverBlocks,
    noExtraHoistBlockers,
    neverHoist,
    BlockPred,
    orIf,
    hasFree,
    isConsumed,
    isFalse,
    isOp,
    isNotSafe,
    asksEngineEnv,
    askVtable,
    localVtable,

    -- * Building blocks
    SimplifiableRep,
    Simplifiable (..),
    simplifyStms,
    simplifyFun,
    simplifyLambda,
    simplifyLambdaNoHoisting,
    bindLParams,
    simplifyBody,
    SimplifiedBody,
    ST.SymbolTable,
    hoistStms,
    blockIf,
    enterLoop,
    module Futhark.Optimise.Simplify.Rep,
  )
where

import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Either
import Data.List (find, foldl', mapAccumL)
import Data.Maybe
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Util (nubOrd)

data HoistBlockers rep = HoistBlockers
  { -- | Blocker for hoisting out of parallel loops.
    HoistBlockers rep -> BlockPred (Wise rep)
blockHoistPar :: BlockPred (Wise rep),
    -- | Blocker for hoisting out of sequential loops.
    HoistBlockers rep -> BlockPred (Wise rep)
blockHoistSeq :: BlockPred (Wise rep),
    -- | Blocker for hoisting out of branches.
    HoistBlockers rep -> BlockPred (Wise rep)
blockHoistBranch :: BlockPred (Wise rep),
    HoistBlockers rep -> Stm (Wise rep) -> Bool
isAllocation :: Stm (Wise rep) -> Bool
  }

noExtraHoistBlockers :: HoistBlockers rep
noExtraHoistBlockers :: HoistBlockers rep
noExtraHoistBlockers =
  BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
forall rep.
BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
HoistBlockers BlockPred (Wise rep)
forall rep. BlockPred rep
neverBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
neverBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
neverBlocks (Bool -> Stm (Wise rep) -> Bool
forall a b. a -> b -> a
const Bool
False)

neverHoist :: HoistBlockers rep
neverHoist :: HoistBlockers rep
neverHoist =
  BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
forall rep.
BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> BlockPred (Wise rep)
-> (Stm (Wise rep) -> Bool)
-> HoistBlockers rep
HoistBlockers BlockPred (Wise rep)
forall rep. BlockPred rep
alwaysBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
alwaysBlocks BlockPred (Wise rep)
forall rep. BlockPred rep
alwaysBlocks (Bool -> Stm (Wise rep) -> Bool
forall a b. a -> b -> a
const Bool
False)

data Env rep = Env
  { Env rep -> RuleBook (Wise rep)
envRules :: RuleBook (Wise rep),
    Env rep -> HoistBlockers rep
envHoistBlockers :: HoistBlockers rep,
    Env rep -> SymbolTable (Wise rep)
envVtable :: ST.SymbolTable (Wise rep)
  }

emptyEnv :: RuleBook (Wise rep) -> HoistBlockers rep -> Env rep
emptyEnv :: RuleBook (Wise rep) -> HoistBlockers rep -> Env rep
emptyEnv RuleBook (Wise rep)
rules HoistBlockers rep
blockers =
  Env :: forall rep.
RuleBook (Wise rep)
-> HoistBlockers rep -> SymbolTable (Wise rep) -> Env rep
Env
    { envRules :: RuleBook (Wise rep)
envRules = RuleBook (Wise rep)
rules,
      envHoistBlockers :: HoistBlockers rep
envHoistBlockers = HoistBlockers rep
blockers,
      envVtable :: SymbolTable (Wise rep)
envVtable = SymbolTable (Wise rep)
forall a. Monoid a => a
mempty
    }

type Protect m = SubExp -> Pat (Rep m) -> Op (Rep m) -> Maybe (m ())

data SimpleOps rep = SimpleOps
  { SimpleOps rep
-> SymbolTable (Wise rep)
-> Pat (Wise rep)
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
mkExpDecS ::
      ST.SymbolTable (Wise rep) ->
      Pat (Wise rep) ->
      Exp (Wise rep) ->
      SimpleM rep (ExpDec (Wise rep)),
    SimpleOps rep
-> SymbolTable (Wise rep)
-> Stms (Wise rep)
-> Result
-> SimpleM rep (Body (Wise rep))
mkBodyS ::
      ST.SymbolTable (Wise rep) ->
      Stms (Wise rep) ->
      Result ->
      SimpleM rep (Body (Wise rep)),
    -- | Make a hoisted Op safe.  The SubExp is a boolean
    -- that is true when the value of the statement will
    -- actually be used.
    SimpleOps rep -> Protect (Builder (Wise rep))
protectHoistedOpS :: Protect (Builder (Wise rep)),
    SimpleOps rep -> Op (Wise rep) -> UsageTable
opUsageS :: Op (Wise rep) -> UT.UsageTable,
    SimpleOps rep -> SimplifyOp rep (Op rep)
simplifyOpS :: SimplifyOp rep (Op rep)
  }

type SimplifyOp rep op = op -> SimpleM rep (OpWithWisdom op, Stms (Wise rep))

bindableSimpleOps ::
  (SimplifiableRep rep, Buildable rep) =>
  SimplifyOp rep (Op rep) ->
  SimpleOps rep
bindableSimpleOps :: SimplifyOp rep (Op rep) -> SimpleOps rep
bindableSimpleOps =
  (SymbolTable (Wise rep)
 -> Pat (Wise rep)
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op rep)
-> SimpleOps rep
forall rep.
(SymbolTable (Wise rep)
 -> Pat (Wise rep)
 -> Exp (Wise rep)
 -> SimpleM rep (ExpDec (Wise rep)))
-> (SymbolTable (Wise rep)
    -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> Protect (Builder (Wise rep))
-> (Op (Wise rep) -> UsageTable)
-> SimplifyOp rep (Op rep)
-> SimpleOps rep
SimpleOps SymbolTable (Wise rep)
-> Pat (Wise rep)
-> Exp (Wise rep)
-> SimpleM rep (ExpDec (Wise rep))
forall (m :: * -> *) rep p.
(Monad m, Buildable rep) =>
p -> PatT (LetDec rep) -> Exp rep -> m (ExpDec rep)
mkExpDecS' SymbolTable (Wise rep)
-> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall (m :: * -> *) rep p.
(Monad m, Buildable rep) =>
p -> Stms rep -> Result -> m (Body rep)
mkBodyS' Protect (Builder (Wise rep))
forall p p p a. p -> p -> p -> Maybe a
protectHoistedOpS' (UsageTable -> OpWithWisdom (Op rep) -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty)
  where
    mkExpDecS' :: p -> PatT (LetDec rep) -> Exp rep -> m (ExpDec rep)
mkExpDecS' p
_ PatT (LetDec rep)
pat Exp rep
e = ExpDec rep -> m (ExpDec rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpDec rep -> m (ExpDec rep)) -> ExpDec rep -> m (ExpDec rep)
forall a b. (a -> b) -> a -> b
$ PatT (LetDec rep) -> Exp rep -> ExpDec rep
forall rep. Buildable rep => Pat rep -> Exp rep -> ExpDec rep
mkExpDec PatT (LetDec rep)
pat Exp rep
e
    mkBodyS' :: p -> Stms rep -> Result -> m (Body rep)
mkBodyS' p
_ Stms rep
stms Result
res = Body rep -> m (Body rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body rep -> m (Body rep)) -> Body rep -> m (Body rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
stms Result
res
    protectHoistedOpS' :: p -> p -> p -> Maybe a
protectHoistedOpS' p
_ p
_ p
_ = Maybe a
forall a. Maybe a
Nothing

newtype SimpleM rep a
  = SimpleM
      ( ReaderT
          (SimpleOps rep, Env rep)
          (State (VNameSource, Bool, Certs))
          a
      )
  deriving
    ( Functor (SimpleM rep)
a -> SimpleM rep a
Functor (SimpleM rep)
-> (forall a. a -> SimpleM rep a)
-> (forall a b.
    SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b)
-> (forall a b c.
    (a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c)
-> (forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b)
-> (forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep a)
-> Applicative (SimpleM rep)
SimpleM rep a -> SimpleM rep b -> SimpleM rep b
SimpleM rep a -> SimpleM rep b -> SimpleM rep a
SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
forall rep. Functor (SimpleM rep)
forall a. a -> SimpleM rep a
forall rep a. a -> SimpleM rep a
forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep a
forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall a b. SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall rep a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep a
forall rep a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall rep a b.
SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
forall a b c.
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
forall rep a b c.
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SimpleM rep a -> SimpleM rep b -> SimpleM rep a
$c<* :: forall rep a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep a
*> :: SimpleM rep a -> SimpleM rep b -> SimpleM rep b
$c*> :: forall rep a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
liftA2 :: (a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
$cliftA2 :: forall rep a b c.
(a -> b -> c) -> SimpleM rep a -> SimpleM rep b -> SimpleM rep c
<*> :: SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
$c<*> :: forall rep a b.
SimpleM rep (a -> b) -> SimpleM rep a -> SimpleM rep b
pure :: a -> SimpleM rep a
$cpure :: forall rep a. a -> SimpleM rep a
$cp1Applicative :: forall rep. Functor (SimpleM rep)
Applicative,
      a -> SimpleM rep b -> SimpleM rep a
(a -> b) -> SimpleM rep a -> SimpleM rep b
(forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b)
-> (forall a b. a -> SimpleM rep b -> SimpleM rep a)
-> Functor (SimpleM rep)
forall a b. a -> SimpleM rep b -> SimpleM rep a
forall a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall rep a b. a -> SimpleM rep b -> SimpleM rep a
forall rep a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SimpleM rep b -> SimpleM rep a
$c<$ :: forall rep a b. a -> SimpleM rep b -> SimpleM rep a
fmap :: (a -> b) -> SimpleM rep a -> SimpleM rep b
$cfmap :: forall rep a b. (a -> b) -> SimpleM rep a -> SimpleM rep b
Functor,
      Applicative (SimpleM rep)
a -> SimpleM rep a
Applicative (SimpleM rep)
-> (forall a b.
    SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b)
-> (forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b)
-> (forall a. a -> SimpleM rep a)
-> Monad (SimpleM rep)
SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall rep. Applicative (SimpleM rep)
forall a. a -> SimpleM rep a
forall rep a. a -> SimpleM rep a
forall a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall a b. SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
forall rep a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
forall rep a b.
SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SimpleM rep a
$creturn :: forall rep a. a -> SimpleM rep a
>> :: SimpleM rep a -> SimpleM rep b -> SimpleM rep b
$c>> :: forall rep a b. SimpleM rep a -> SimpleM rep b -> SimpleM rep b
>>= :: SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
$c>>= :: forall rep a b.
SimpleM rep a -> (a -> SimpleM rep b) -> SimpleM rep b
$cp1Monad :: forall rep. Applicative (SimpleM rep)
Monad,
      MonadReader (SimpleOps rep, Env rep),
      MonadState (VNameSource, Bool, Certs)
    )

instance MonadFreshNames (SimpleM rep) where
  putNameSource :: VNameSource -> SimpleM rep ()
putNameSource VNameSource
src = ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
 -> SimpleM rep ())
-> ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
_, Bool
b, Certs
c) -> (VNameSource
src, Bool
b, Certs
c)
  getNameSource :: SimpleM rep VNameSource
getNameSource = ((VNameSource, Bool, Certs) -> VNameSource)
-> SimpleM rep VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((VNameSource, Bool, Certs) -> VNameSource)
 -> SimpleM rep VNameSource)
-> ((VNameSource, Bool, Certs) -> VNameSource)
-> SimpleM rep VNameSource
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
_, Certs
_) -> VNameSource
a

instance SimplifiableRep rep => HasScope (Wise rep) (SimpleM rep) where
  askScope :: SimpleM rep (Scope (Wise rep))
askScope = SymbolTable (Wise rep) -> Scope (Wise rep)
forall rep. SymbolTable rep -> Scope rep
ST.toScope (SymbolTable (Wise rep) -> Scope (Wise rep))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Scope (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
  lookupType :: VName -> SimpleM rep Type
lookupType VName
name = do
    SymbolTable (Wise rep)
vtable <- SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
    case VName -> SymbolTable (Wise rep) -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
name SymbolTable (Wise rep)
vtable of
      Just Type
t -> Type -> SimpleM rep Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
      Maybe Type
Nothing ->
        [Char] -> SimpleM rep Type
forall a. HasCallStack => [Char] -> a
error ([Char] -> SimpleM rep Type) -> [Char] -> SimpleM rep Type
forall a b. (a -> b) -> a -> b
$
          [Char]
"SimpleM.lookupType: cannot find variable "
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" in symbol table."

instance
  SimplifiableRep rep =>
  LocalScope (Wise rep) (SimpleM rep)
  where
  localScope :: Scope (Wise rep) -> SimpleM rep a -> SimpleM rep a
localScope Scope (Wise rep)
types = (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Scope (Wise rep) -> SymbolTable (Wise rep)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope (Wise rep)
types)

runSimpleM ::
  SimpleM rep a ->
  SimpleOps rep ->
  Env rep ->
  VNameSource ->
  ((a, Bool), VNameSource)
runSimpleM :: SimpleM rep a
-> SimpleOps rep
-> Env rep
-> VNameSource
-> ((a, Bool), VNameSource)
runSimpleM (SimpleM ReaderT
  (SimpleOps rep, Env rep) (State (VNameSource, Bool, Certs)) a
m) SimpleOps rep
simpl Env rep
env VNameSource
src =
  let (a
x, (VNameSource
src', Bool
b, Certs
_)) = State (VNameSource, Bool, Certs) a
-> (VNameSource, Bool, Certs) -> (a, (VNameSource, Bool, Certs))
forall s a. State s a -> s -> (a, s)
runState (ReaderT
  (SimpleOps rep, Env rep) (State (VNameSource, Bool, Certs)) a
-> (SimpleOps rep, Env rep) -> State (VNameSource, Bool, Certs) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (SimpleOps rep, Env rep) (State (VNameSource, Bool, Certs)) a
m (SimpleOps rep
simpl, Env rep
env)) (VNameSource
src, Bool
False, Certs
forall a. Monoid a => a
mempty)
   in ((a
x, Bool
b), VNameSource
src')

askEngineEnv :: SimpleM rep (Env rep)
askEngineEnv :: SimpleM rep (Env rep)
askEngineEnv = ((SimpleOps rep, Env rep) -> Env rep) -> SimpleM rep (Env rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (SimpleOps rep, Env rep) -> Env rep
forall a b. (a, b) -> b
snd

asksEngineEnv :: (Env rep -> a) -> SimpleM rep a
asksEngineEnv :: (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> a
f = Env rep -> a
f (Env rep -> a) -> SimpleM rep (Env rep) -> SimpleM rep a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (Env rep)
forall rep. SimpleM rep (Env rep)
askEngineEnv

askVtable :: SimpleM rep (ST.SymbolTable (Wise rep))
askVtable :: SimpleM rep (SymbolTable (Wise rep))
askVtable = (Env rep -> SymbolTable (Wise rep))
-> SimpleM rep (SymbolTable (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> SymbolTable (Wise rep)
forall rep. Env rep -> SymbolTable (Wise rep)
envVtable

localVtable ::
  (ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) ->
  SimpleM rep a ->
  SimpleM rep a
localVtable :: (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable SymbolTable (Wise rep) -> SymbolTable (Wise rep)
f = ((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
-> SimpleM rep a -> SimpleM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
 -> SimpleM rep a -> SimpleM rep a)
-> ((SimpleOps rep, Env rep) -> (SimpleOps rep, Env rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ \(SimpleOps rep
ops, Env rep
env) -> (SimpleOps rep
ops, Env rep
env {envVtable :: SymbolTable (Wise rep)
envVtable = SymbolTable (Wise rep) -> SymbolTable (Wise rep)
f (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ Env rep -> SymbolTable (Wise rep)
forall rep. Env rep -> SymbolTable (Wise rep)
envVtable Env rep
env})

collectCerts :: SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts :: SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts SimpleM rep a
m = do
  a
x <- SimpleM rep a
m
  (VNameSource
a, Bool
b, Certs
cs) <- SimpleM rep (VNameSource, Bool, Certs)
forall s (m :: * -> *). MonadState s m => m s
get
  (VNameSource, Bool, Certs) -> SimpleM rep ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (VNameSource
a, Bool
b, Certs
forall a. Monoid a => a
mempty)
  (a, Certs) -> SimpleM rep (a, Certs)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Certs
cs)

-- | Mark that we have changed something and it would be a good idea
-- to re-run the simplifier.
changed :: SimpleM rep ()
changed :: SimpleM rep ()
changed = ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
 -> SimpleM rep ())
-> ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
src, Bool
_, Certs
cs) -> (VNameSource
src, Bool
True, Certs
cs)

usedCerts :: Certs -> SimpleM rep ()
usedCerts :: Certs -> SimpleM rep ()
usedCerts Certs
cs = ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
 -> SimpleM rep ())
-> ((VNameSource, Bool, Certs) -> (VNameSource, Bool, Certs))
-> SimpleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
b, Certs
c) -> (VNameSource
a, Bool
b, Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
c)

-- | Indicate in the symbol table that we have descended into a loop.
enterLoop :: SimpleM rep a -> SimpleM rep a
enterLoop :: SimpleM rep a -> SimpleM rep a
enterLoop = (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. SymbolTable rep -> SymbolTable rep
ST.deepen

bindFParams :: SimplifiableRep rep => [FParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindFParams :: [FParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindFParams [FParam (Wise rep)]
params =
  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ [FParam (Wise rep)]
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
[FParam rep] -> SymbolTable rep -> SymbolTable rep
ST.insertFParams [FParam (Wise rep)]
params

bindLParams :: SimplifiableRep rep => [LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindLParams :: [LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindLParams [LParam (Wise rep)]
params =
  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise rep)
vtable -> (Param (LParamInfo rep)
 -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep)
-> [Param (LParamInfo rep)]
-> SymbolTable (Wise rep)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Param (LParamInfo rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
LParam rep -> SymbolTable rep -> SymbolTable rep
ST.insertLParam SymbolTable (Wise rep)
vtable [Param (LParamInfo rep)]
[LParam (Wise rep)]
params

bindArrayLParams ::
  SimplifiableRep rep =>
  [LParam (Wise rep)] ->
  SimpleM rep a ->
  SimpleM rep a
bindArrayLParams :: [LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindArrayLParams [LParam (Wise rep)]
params =
  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise rep)
vtable -> (SymbolTable (Wise rep)
 -> Param (LParamInfo rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep)
-> [Param (LParamInfo rep)]
-> SymbolTable (Wise rep)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Param (LParamInfo rep)
 -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep)
-> Param (LParamInfo rep)
-> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Param (LParamInfo rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
LParam rep -> SymbolTable rep -> SymbolTable rep
ST.insertLParam) SymbolTable (Wise rep)
vtable [Param (LParamInfo rep)]
[LParam (Wise rep)]
params

bindMerge ::
  SimplifiableRep rep =>
  [(FParam (Wise rep), SubExp, SubExpRes)] ->
  SimpleM rep a ->
  SimpleM rep a
bindMerge :: [(FParam (Wise rep), SubExp, SubExpRes)]
-> SimpleM rep a -> SimpleM rep a
bindMerge = (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> ([(Param (FParamInfo rep), SubExp, SubExpRes)]
    -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> [(Param (FParamInfo rep), SubExp, SubExpRes)]
-> SimpleM rep a
-> SimpleM rep a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Param (FParamInfo rep), SubExp, SubExpRes)]
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
[(FParam rep, SubExp, SubExpRes)]
-> SymbolTable rep -> SymbolTable rep
ST.insertLoopMerge

bindLoopVar :: SimplifiableRep rep => VName -> IntType -> SubExp -> SimpleM rep a -> SimpleM rep a
bindLoopVar :: VName -> IntType -> SubExp -> SimpleM rep a -> SimpleM rep a
bindLoopVar VName
var IntType
it SubExp
bound =
  (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep a -> SimpleM rep a)
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a
-> SimpleM rep a
forall a b. (a -> b) -> a -> b
$ VName
-> IntType
-> SubExp
-> SymbolTable (Wise rep)
-> SymbolTable (Wise rep)
forall rep.
ASTRep rep =>
VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep
ST.insertLoopVar VName
var IntType
it SubExp
bound

-- | We are willing to hoist potentially unsafe statements out of
-- branches, but they most be protected by adding a branch on top of
-- them.  (This means such hoisting is not worth it unless they are in
-- turn hoisted out of a loop somewhere.)
protectIfHoisted ::
  SimplifiableRep rep =>
  -- | Branch condition.
  SubExp ->
  -- | Which side of the branch are we
  -- protecting here?
  Bool ->
  SimpleM rep (a, Stms (Wise rep)) ->
  SimpleM rep (a, Stms (Wise rep))
protectIfHoisted :: SubExp
-> Bool
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
protectIfHoisted SubExp
cond Bool
side SimpleM rep (a, Stms (Wise rep))
m = do
  (a
x, Stms (Wise rep)
stms) <- SimpleM rep (a, Stms (Wise rep))
m
  SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
ops <- ((SimpleOps rep, Env rep)
 -> SubExp
 -> PatT (VarWisdom, LetDec rep)
 -> OpWithWisdom (Op rep)
 -> Maybe (Builder (Wise rep) ()))
-> SimpleM
     rep
     (SubExp
      -> PatT (VarWisdom, LetDec rep)
      -> OpWithWisdom (Op rep)
      -> Maybe (Builder (Wise rep) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps rep, Env rep)
  -> SubExp
  -> PatT (VarWisdom, LetDec rep)
  -> OpWithWisdom (Op rep)
  -> Maybe (Builder (Wise rep) ()))
 -> SimpleM
      rep
      (SubExp
       -> PatT (VarWisdom, LetDec rep)
       -> OpWithWisdom (Op rep)
       -> Maybe (Builder (Wise rep) ())))
-> ((SimpleOps rep, Env rep)
    -> SubExp
    -> PatT (VarWisdom, LetDec rep)
    -> OpWithWisdom (Op rep)
    -> Maybe (Builder (Wise rep) ()))
-> SimpleM
     rep
     (SubExp
      -> PatT (VarWisdom, LetDec rep)
      -> OpWithWisdom (Op rep)
      -> Maybe (Builder (Wise rep) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps rep
-> SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
forall rep. SimpleOps rep -> Protect (Builder (Wise rep))
protectHoistedOpS (SimpleOps rep
 -> SubExp
 -> PatT (VarWisdom, LetDec rep)
 -> OpWithWisdom (Op rep)
 -> Maybe (Builder (Wise rep) ()))
-> ((SimpleOps rep, Env rep) -> SimpleOps rep)
-> (SimpleOps rep, Env rep)
-> SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps rep, Env rep) -> SimpleOps rep
forall a b. (a, b) -> a
fst
  Builder (Wise rep) a -> SimpleM rep (a, Stms (Wise rep))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder (Wise rep) a -> SimpleM rep (a, Stms (Wise rep)))
-> Builder (Wise rep) a -> SimpleM rep (a, Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
    if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise rep) -> Bool) -> Stms (Wise rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise rep) -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp (Wise rep) -> Bool)
-> (Stm (Wise rep) -> Exp (Wise rep)) -> Stm (Wise rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise rep) -> Exp (Wise rep)
forall rep. Stm rep -> Exp rep
stmExp) Stms (Wise rep)
stms
      then do
        SubExp
cond' <-
          if Bool
side
            then SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond
            else [Char]
-> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"cond_neg" (Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
 -> BuilderT (Wise rep) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
        (Stm (Wise rep) -> Builder (Wise rep) ())
-> Stms (Wise rep) -> Builder (Wise rep) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BuilderT (Wise rep) (State VNameSource))
-> (Exp (Rep (BuilderT (Wise rep) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Rep (BuilderT (Wise rep) (State VNameSource)))
-> Builder (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
Protect m -> (Exp (Rep m) -> Bool) -> SubExp -> Stm (Rep m) -> m ()
protectIf SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
Protect (BuilderT (Wise rep) (State VNameSource))
ops Exp (Rep (BuilderT (Wise rep) (State VNameSource))) -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
unsafeOrCostly SubExp
cond') Stms (Wise rep)
stms
      else Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
-> Builder (Wise rep) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
Stms (Wise rep)
stms
    a -> Builder (Wise rep) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  where
    unsafeOrCostly :: Exp rep -> Bool
unsafeOrCostly Exp rep
e = Bool -> Bool
not (Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp Exp rep
e) Bool -> Bool -> Bool
|| Bool -> Bool
not (Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
cheapExp Exp rep
e)

-- | We are willing to hoist potentially unsafe statements out of
-- loops, but they most be protected by adding a branch on top of
-- them.
protectLoopHoisted ::
  SimplifiableRep rep =>
  [(FParam (Wise rep), SubExp)] ->
  LoopForm (Wise rep) ->
  SimpleM rep (a, Stms (Wise rep)) ->
  SimpleM rep (a, Stms (Wise rep))
protectLoopHoisted :: [(FParam (Wise rep), SubExp)]
-> LoopForm (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
protectLoopHoisted [(FParam (Wise rep), SubExp)]
merge LoopForm (Wise rep)
form SimpleM rep (a, Stms (Wise rep))
m = do
  (a
x, Stms (Wise rep)
stms) <- SimpleM rep (a, Stms (Wise rep))
m
  SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
ops <- ((SimpleOps rep, Env rep)
 -> SubExp
 -> PatT (VarWisdom, LetDec rep)
 -> OpWithWisdom (Op rep)
 -> Maybe (Builder (Wise rep) ()))
-> SimpleM
     rep
     (SubExp
      -> PatT (VarWisdom, LetDec rep)
      -> OpWithWisdom (Op rep)
      -> Maybe (Builder (Wise rep) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps rep, Env rep)
  -> SubExp
  -> PatT (VarWisdom, LetDec rep)
  -> OpWithWisdom (Op rep)
  -> Maybe (Builder (Wise rep) ()))
 -> SimpleM
      rep
      (SubExp
       -> PatT (VarWisdom, LetDec rep)
       -> OpWithWisdom (Op rep)
       -> Maybe (Builder (Wise rep) ())))
-> ((SimpleOps rep, Env rep)
    -> SubExp
    -> PatT (VarWisdom, LetDec rep)
    -> OpWithWisdom (Op rep)
    -> Maybe (Builder (Wise rep) ()))
-> SimpleM
     rep
     (SubExp
      -> PatT (VarWisdom, LetDec rep)
      -> OpWithWisdom (Op rep)
      -> Maybe (Builder (Wise rep) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps rep
-> SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
forall rep. SimpleOps rep -> Protect (Builder (Wise rep))
protectHoistedOpS (SimpleOps rep
 -> SubExp
 -> PatT (VarWisdom, LetDec rep)
 -> OpWithWisdom (Op rep)
 -> Maybe (Builder (Wise rep) ()))
-> ((SimpleOps rep, Env rep) -> SimpleOps rep)
-> (SimpleOps rep, Env rep)
-> SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps rep, Env rep) -> SimpleOps rep
forall a b. (a, b) -> a
fst
  Builder (Wise rep) a -> SimpleM rep (a, Stms (Wise rep))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder (Wise rep) a -> SimpleM rep (a, Stms (Wise rep)))
-> Builder (Wise rep) a -> SimpleM rep (a, Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
    if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise rep) -> Bool) -> Stms (Wise rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise rep) -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp (Wise rep) -> Bool)
-> (Stm (Wise rep) -> Exp (Wise rep)) -> Stm (Wise rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise rep) -> Exp (Wise rep)
forall rep. Stm rep -> Exp rep
stmExp) Stms (Wise rep)
stms
      then do
        SubExp
is_nonempty <- BuilderT (Wise rep) (State VNameSource) SubExp
checkIfNonEmpty
        (Stm (Wise rep) -> Builder (Wise rep) ())
-> Stms (Wise rep) -> Builder (Wise rep) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BuilderT (Wise rep) (State VNameSource))
-> (Exp (Rep (BuilderT (Wise rep) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Rep (BuilderT (Wise rep) (State VNameSource)))
-> Builder (Wise rep) ()
forall (m :: * -> *).
MonadBuilder m =>
Protect m -> (Exp (Rep m) -> Bool) -> SubExp -> Stm (Rep m) -> m ()
protectIf SubExp
-> PatT (VarWisdom, LetDec rep)
-> OpWithWisdom (Op rep)
-> Maybe (Builder (Wise rep) ())
Protect (BuilderT (Wise rep) (State VNameSource))
ops (Bool -> Bool
not (Bool -> Bool)
-> (Exp (Wise rep) -> Bool) -> Exp (Wise rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Wise rep) -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp) SubExp
is_nonempty) Stms (Wise rep)
stms
      else Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
-> Builder (Wise rep) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
Stms (Wise rep)
stms
    a -> Builder (Wise rep) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  where
    checkIfNonEmpty :: BuilderT (Wise rep) (State VNameSource) SubExp
checkIfNonEmpty =
      case LoopForm (Wise rep)
form of
        WhileLoop VName
cond
          | Just (Param (FParamInfo rep)
_, SubExp
cond_init) <-
              ((Param (FParamInfo rep), SubExp) -> Bool)
-> [(Param (FParamInfo rep), SubExp)]
-> Maybe (Param (FParamInfo rep), SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
cond) (VName -> Bool)
-> ((Param (FParamInfo rep), SubExp) -> VName)
-> (Param (FParamInfo rep), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo rep) -> VName)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
[(FParam (Wise rep), SubExp)]
merge ->
            SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond_init
          | Bool
otherwise -> SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp)
-> SubExp -> BuilderT (Wise rep) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True -- infinite loop
        ForLoop VName
_ IntType
it SubExp
bound [(LParam (Wise rep), VName)]
_ ->
          [Char]
-> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"loop_nonempty" (Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
 -> BuilderT (Wise rep) (State VNameSource) SubExp)
-> Exp (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) (IntType -> Integer -> SubExp
intConst IntType
it Integer
0) SubExp
bound

protectIf ::
  MonadBuilder m =>
  Protect m ->
  (Exp (Rep m) -> Bool) ->
  SubExp ->
  Stm (Rep m) ->
  m ()
protectIf :: Protect m -> (Exp (Rep m) -> Bool) -> SubExp -> Stm (Rep m) -> m ()
protectIf Protect m
_ Exp (Rep m) -> Bool
_ SubExp
taken (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux (If SubExp
cond BodyT (Rep m)
taken_body BodyT (Rep m)
untaken_body (IfDec [BranchType (Rep m)]
if_ts IfSort
IfFallback))) = do
  SubExp
cond' <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"protect_cond_conj" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
taken SubExp
cond
  StmAux (ExpDec (Rep m)) -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Rep m))
aux (m () -> m ()) -> (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (Rep m) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
    SubExp
-> BodyT (Rep m)
-> BodyT (Rep m)
-> IfDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond' BodyT (Rep m)
taken_body BodyT (Rep m)
untaken_body (IfDec (BranchType (Rep m)) -> Exp (Rep m))
-> IfDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [BranchType (Rep m)] -> IfSort -> IfDec (BranchType (Rep m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType (Rep m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Rep m) -> Bool
_ SubExp
taken (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux (BasicOp (Assert SubExp
cond ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc))) = do
  SubExp
not_taken <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"loop_not_taken" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
taken
  SubExp
cond' <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"protect_assert_disj" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
not_taken SubExp
cond
  StmAux (ExpDec (Rep m)) -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Rep m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep m) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp
Assert SubExp
cond' ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc
protectIf Protect m
protect Exp (Rep m) -> Bool
_ SubExp
taken (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux (Op Op (Rep m)
op))
  | Just m ()
m <- Protect m
protect SubExp
taken Pat (Rep m)
pat Op (Rep m)
op =
    StmAux (ExpDec (Rep m)) -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Rep m))
aux m ()
m
protectIf Protect m
_ Exp (Rep m) -> Bool
f SubExp
taken (Let Pat (Rep m)
pat StmAux (ExpDec (Rep m))
aux Exp (Rep m)
e)
  | Exp (Rep m) -> Bool
f Exp (Rep m)
e =
    case Exp (Rep m) -> Maybe (Exp (Rep m))
forall rep. Exp rep -> Maybe (Exp rep)
makeSafe Exp (Rep m)
e of
      Just Exp (Rep m)
e' ->
        StmAux (ExpDec (Rep m)) -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Rep m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep m) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat Exp (Rep m)
e'
      Maybe (Exp (Rep m))
Nothing -> do
        BodyT (Rep m)
taken_body <- [m (Exp (Rep m))] -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Rep m)
e]
        BodyT (Rep m)
untaken_body <-
          [m (Exp (Rep m))] -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([m (Exp (Rep m))] -> m (BodyT (Rep m)))
-> [m (Exp (Rep m))] -> m (BodyT (Rep m))
forall a b. (a -> b) -> a -> b
$ (Type -> m (Exp (Rep m))) -> [Type] -> [m (Exp (Rep m))]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Type -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Type -> m (Exp (Rep m))
emptyOfType ([VName] -> Type -> m (Exp (Rep m)))
-> [VName] -> Type -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ Pat (Rep m) -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat (Rep m)
pat) (Pat (Rep m) -> [Type]
forall dec. Typed dec => PatT dec -> [Type]
patTypes Pat (Rep m)
pat)
        [BranchType (Rep m)]
if_ts <- Pat (Rep m) -> m [BranchType (Rep m)]
forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pat rep -> m [BranchType rep]
expTypesFromPat Pat (Rep m)
pat
        StmAux (ExpDec (Rep m)) -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Rep m))
aux (m () -> m ()) -> (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (Rep m) -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat (Rep m)
pat (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
          SubExp
-> BodyT (Rep m)
-> BodyT (Rep m)
-> IfDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
taken BodyT (Rep m)
taken_body BodyT (Rep m)
untaken_body (IfDec (BranchType (Rep m)) -> Exp (Rep m))
-> IfDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [BranchType (Rep m)] -> IfSort -> IfDec (BranchType (Rep m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType (Rep m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Rep m) -> Bool
_ SubExp
_ Stm (Rep m)
stm =
  Stm (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm (Rep m)
stm

makeSafe :: Exp rep -> Maybe (Exp rep)
makeSafe :: Exp rep -> Maybe (Exp rep)
makeSafe (BasicOp (BinOp (SDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SQuot IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SMod IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SRem IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UMod IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just (Exp rep -> Maybe (Exp rep)) -> Exp rep -> Maybe (Exp rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe Exp rep
_ =
  Maybe (Exp rep)
forall a. Maybe a
Nothing

emptyOfType :: MonadBuilder m => [VName] -> Type -> m (Exp (Rep m))
emptyOfType :: [VName] -> Type -> m (Exp (Rep m))
emptyOfType [VName]
_ Mem {} =
  [Char] -> m (Exp (Rep m))
forall a. HasCallStack => [Char] -> a
error [Char]
"emptyOfType: Cannot hoist non-existential memory."
emptyOfType [VName]
_ Acc {} =
  [Char] -> m (Exp (Rep m))
forall a. HasCallStack => [Char] -> a
error [Char]
"emptyOfType: Cannot hoist accumulator."
emptyOfType [VName]
_ (Prim PrimType
pt) =
  Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
emptyOfType [VName]
ctx_names (Array PrimType
et Shape
shape NoUniqueness
_) = do
  let dims :: [SubExp]
dims = (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
zeroIfContext ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
et [SubExp]
dims
  where
    zeroIfContext :: SubExp -> SubExp
zeroIfContext (Var VName
v) | VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ctx_names = IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
    zeroIfContext SubExp
se = SubExp
se

-- | Statements that are not worth hoisting out of loops, because they
-- are unsafe, and added safety (by 'protectLoopHoisted') may inhibit
-- further optimisation..
notWorthHoisting :: ASTRep rep => BlockPred rep
notWorthHoisting :: BlockPred rep
notWorthHoisting SymbolTable rep
_ UsageTable
_ (Let Pat rep
pat StmAux (ExpDec rep)
_ Exp rep
e) =
  Bool -> Bool
not (Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp Exp rep
e) Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Int -> Bool) -> (Type -> Int) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) (Pat rep -> [Type]
forall dec. Typed dec => PatT dec -> [Type]
patTypes Pat rep
pat)

hoistStms ::
  SimplifiableRep rep =>
  RuleBook (Wise rep) ->
  BlockPred (Wise rep) ->
  ST.SymbolTable (Wise rep) ->
  UT.UsageTable ->
  Stms (Wise rep) ->
  SimpleM
    rep
    ( Stms (Wise rep),
      Stms (Wise rep)
    )
hoistStms :: RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
hoistStms RuleBook (Wise rep)
rules BlockPred (Wise rep)
block SymbolTable (Wise rep)
vtable UsageTable
uses Stms (Wise rep)
orig_stms = do
  ([Stm (Wise rep)]
blocked, [Stm (Wise rep)]
hoisted) <- SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep ([Stm (Wise rep)], [Stm (Wise rep)])
simplifyStmsBottomUp SymbolTable (Wise rep)
vtable UsageTable
uses Stms (Wise rep)
orig_stms
  Bool -> SimpleM rep () -> SimpleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Stm (Wise rep)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Stm (Wise rep)]
hoisted) SimpleM rep ()
forall rep. SimpleM rep ()
changed
  (Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise rep)] -> Stms (Wise rep)
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm (Wise rep)]
blocked, [Stm (Wise rep)] -> Stms (Wise rep)
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm (Wise rep)]
hoisted)
  where
    simplifyStmsBottomUp :: SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep ([Stm (Wise rep)], [Stm (Wise rep)])
simplifyStmsBottomUp SymbolTable (Wise rep)
vtable' UsageTable
uses' Stms (Wise rep)
stms = do
      (UsageTable
_, [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms') <- SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
simplifyStmsBottomUp' SymbolTable (Wise rep)
vtable' UsageTable
uses' Stms (Wise rep)
stms
      -- We need to do a final pass to ensure that nothing is
      -- hoisted past something that it depends on.
      let ([Stm (Wise rep)]
blocked, [Stm (Wise rep)]
hoisted) = [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> ([Stm (Wise rep)], [Stm (Wise rep)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Stm (Wise rep)) (Stm (Wise rep))]
 -> ([Stm (Wise rep)], [Stm (Wise rep)]))
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> ([Stm (Wise rep)], [Stm (Wise rep)])
forall a b. (a -> b) -> a -> b
$ [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
forall rep.
ASTRep rep =>
[Either (Stm rep) (Stm rep)] -> [Either (Stm rep) (Stm rep)]
blockUnhoistedDeps [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms'
      ([Stm (Wise rep)], [Stm (Wise rep)])
-> SimpleM rep ([Stm (Wise rep)], [Stm (Wise rep)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise rep)]
blocked, [Stm (Wise rep)]
hoisted)

    simplifyStmsBottomUp' :: SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
simplifyStmsBottomUp' SymbolTable (Wise rep)
vtable' UsageTable
uses' Stms (Wise rep)
stms = do
      OpWithWisdom (Op rep) -> UsageTable
opUsage <- ((SimpleOps rep, Env rep) -> OpWithWisdom (Op rep) -> UsageTable)
-> SimpleM rep (OpWithWisdom (Op rep) -> UsageTable)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps rep, Env rep) -> OpWithWisdom (Op rep) -> UsageTable)
 -> SimpleM rep (OpWithWisdom (Op rep) -> UsageTable))
-> ((SimpleOps rep, Env rep)
    -> OpWithWisdom (Op rep) -> UsageTable)
-> SimpleM rep (OpWithWisdom (Op rep) -> UsageTable)
forall a b. (a -> b) -> a -> b
$ SimpleOps rep -> OpWithWisdom (Op rep) -> UsageTable
forall rep. SimpleOps rep -> Op (Wise rep) -> UsageTable
opUsageS (SimpleOps rep -> OpWithWisdom (Op rep) -> UsageTable)
-> ((SimpleOps rep, Env rep) -> SimpleOps rep)
-> (SimpleOps rep, Env rep)
-> OpWithWisdom (Op rep)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps rep, Env rep) -> SimpleOps rep
forall a b. (a, b) -> a
fst
      let usageInStm :: Stm (Wise rep) -> UsageTable
usageInStm Stm (Wise rep)
stm =
            Stm (Wise rep) -> UsageTable
forall rep. (ASTRep rep, Aliased rep) => Stm rep -> UsageTable
UT.usageInStm Stm (Wise rep)
stm
              UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> case Stm (Wise rep) -> Exp (Wise rep)
forall rep. Stm rep -> Exp rep
stmExp Stm (Wise rep)
stm of
                Op Op (Wise rep)
op -> OpWithWisdom (Op rep) -> UsageTable
opUsage Op (Wise rep)
OpWithWisdom (Op rep)
op
                Exp (Wise rep)
_ -> UsageTable
forall a. Monoid a => a
mempty
      ((UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
 -> (Stm (Wise rep), SymbolTable (Wise rep))
 -> SimpleM
      rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))]))
-> (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> [(Stm (Wise rep), SymbolTable (Wise rep))]
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((Stm (Wise rep) -> UsageTable)
-> (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> (Stm (Wise rep), SymbolTable (Wise rep))
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
hoistable Stm (Wise rep) -> UsageTable
usageInStm) (UsageTable
uses', []) ([(Stm (Wise rep), SymbolTable (Wise rep))]
 -> SimpleM
      rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))]))
-> [(Stm (Wise rep), SymbolTable (Wise rep))]
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall a b. (a -> b) -> a -> b
$ [(Stm (Wise rep), SymbolTable (Wise rep))]
-> [(Stm (Wise rep), SymbolTable (Wise rep))]
forall a. [a] -> [a]
reverse ([(Stm (Wise rep), SymbolTable (Wise rep))]
 -> [(Stm (Wise rep), SymbolTable (Wise rep))])
-> [(Stm (Wise rep), SymbolTable (Wise rep))]
-> [(Stm (Wise rep), SymbolTable (Wise rep))]
forall a b. (a -> b) -> a -> b
$ [Stm (Wise rep)]
-> [SymbolTable (Wise rep)]
-> [(Stm (Wise rep), SymbolTable (Wise rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Stms (Wise rep) -> [Stm (Wise rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Wise rep)
stms) [SymbolTable (Wise rep)]
vtables
      where
        vtables :: [SymbolTable (Wise rep)]
vtables = (SymbolTable (Wise rep)
 -> Stm (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep)
-> [Stm (Wise rep)]
-> [SymbolTable (Wise rep)]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl ((Stm (Wise rep)
 -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep)
-> Stm (Wise rep)
-> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stm (Wise rep) -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
(ASTRep rep, IndexOp (Op rep), Aliased rep) =>
Stm rep -> SymbolTable rep -> SymbolTable rep
ST.insertStm) SymbolTable (Wise rep)
vtable' ([Stm (Wise rep)] -> [SymbolTable (Wise rep)])
-> [Stm (Wise rep)] -> [SymbolTable (Wise rep)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise rep) -> [Stm (Wise rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms (Wise rep)
stms

    hoistable :: (Stm (Wise rep) -> UsageTable)
-> (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> (Stm (Wise rep), SymbolTable (Wise rep))
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
hoistable Stm (Wise rep) -> UsageTable
usageInStm (UsageTable
uses', [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms) (Stm (Wise rep)
stm, SymbolTable (Wise rep)
vtable')
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
uses') ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Stm (Wise rep) -> [VName]
forall rep. Stm rep -> [VName]
provides Stm (Wise rep)
stm -- Dead statement.
        =
        (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses', [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms)
      | Bool
otherwise = do
        Maybe (Stms (Wise rep))
res <-
          (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (Maybe (Stms (Wise rep)))
-> SimpleM rep (Maybe (Stms (Wise rep)))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable (SymbolTable (Wise rep)
-> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. a -> b -> a
const SymbolTable (Wise rep)
vtable') (SimpleM rep (Maybe (Stms (Wise rep)))
 -> SimpleM rep (Maybe (Stms (Wise rep))))
-> SimpleM rep (Maybe (Stms (Wise rep)))
-> SimpleM rep (Maybe (Stms (Wise rep)))
forall a b. (a -> b) -> a -> b
$
            RuleBook (Wise rep)
-> (SymbolTable (Wise rep), UsageTable)
-> Stm (Wise rep)
-> SimpleM rep (Maybe (Stms (Wise rep)))
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m) =>
RuleBook rep
-> (SymbolTable rep, UsageTable) -> Stm rep -> m (Maybe (Stms rep))
bottomUpSimplifyStm RuleBook (Wise rep)
rules (SymbolTable (Wise rep)
vtable', UsageTable
uses') Stm (Wise rep)
stm
        case Maybe (Stms (Wise rep))
res of
          Maybe (Stms (Wise rep))
Nothing -- Nothing to optimise - see if hoistable.
            | BlockPred (Wise rep)
block SymbolTable (Wise rep)
vtable' UsageTable
uses' Stm (Wise rep)
stm ->
              (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall (m :: * -> *) a. Monad m => a -> m a
return
                ( (Stm (Wise rep) -> UsageTable)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stm (Wise rep)
-> UsageTable
forall rep.
(ASTRep rep, Aliased rep) =>
(Stm rep -> UsageTable)
-> SymbolTable rep -> UsageTable -> Stm rep -> UsageTable
expandUsage Stm (Wise rep) -> UsageTable
usageInStm SymbolTable (Wise rep)
vtable' UsageTable
uses' Stm (Wise rep)
stm
                    UsageTable -> [VName] -> UsageTable
`UT.without` Stm (Wise rep) -> [VName]
forall rep. Stm rep -> [VName]
provides Stm (Wise rep)
stm,
                  Stm (Wise rep) -> Either (Stm (Wise rep)) (Stm (Wise rep))
forall a b. a -> Either a b
Left Stm (Wise rep)
stm Either (Stm (Wise rep)) (Stm (Wise rep))
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms
                )
            | Bool
otherwise ->
              (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall (m :: * -> *) a. Monad m => a -> m a
return
                ( (Stm (Wise rep) -> UsageTable)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stm (Wise rep)
-> UsageTable
forall rep.
(ASTRep rep, Aliased rep) =>
(Stm rep -> UsageTable)
-> SymbolTable rep -> UsageTable -> Stm rep -> UsageTable
expandUsage Stm (Wise rep) -> UsageTable
usageInStm SymbolTable (Wise rep)
vtable' UsageTable
uses' Stm (Wise rep)
stm,
                  Stm (Wise rep) -> Either (Stm (Wise rep)) (Stm (Wise rep))
forall a b. b -> Either a b
Right Stm (Wise rep)
stm Either (Stm (Wise rep)) (Stm (Wise rep))
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms
                )
          Just Stms (Wise rep)
optimstms -> do
            SimpleM rep ()
forall rep. SimpleM rep ()
changed
            (UsageTable
uses'', [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms') <- SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
simplifyStmsBottomUp' SymbolTable (Wise rep)
vtable' UsageTable
uses' Stms (Wise rep)
optimstms
            (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
-> SimpleM
     rep (UsageTable, [Either (Stm (Wise rep)) (Stm (Wise rep))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses'', [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms' [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
-> [Either (Stm (Wise rep)) (Stm (Wise rep))]
forall a. [a] -> [a] -> [a]
++ [Either (Stm (Wise rep)) (Stm (Wise rep))]
stms)

blockUnhoistedDeps ::
  ASTRep rep =>
  [Either (Stm rep) (Stm rep)] ->
  [Either (Stm rep) (Stm rep)]
blockUnhoistedDeps :: [Either (Stm rep) (Stm rep)] -> [Either (Stm rep) (Stm rep)]
blockUnhoistedDeps = (Names, [Either (Stm rep) (Stm rep)])
-> [Either (Stm rep) (Stm rep)]
forall a b. (a, b) -> b
snd ((Names, [Either (Stm rep) (Stm rep)])
 -> [Either (Stm rep) (Stm rep)])
-> ([Either (Stm rep) (Stm rep)]
    -> (Names, [Either (Stm rep) (Stm rep)]))
-> [Either (Stm rep) (Stm rep)]
-> [Either (Stm rep) (Stm rep)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Names
 -> Either (Stm rep) (Stm rep)
 -> (Names, Either (Stm rep) (Stm rep)))
-> Names
-> [Either (Stm rep) (Stm rep)]
-> (Names, [Either (Stm rep) (Stm rep)])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Names
-> Either (Stm rep) (Stm rep)
-> (Names, Either (Stm rep) (Stm rep))
forall rep.
(FreeDec (ExpDec rep), FreeDec (BodyDec rep),
 FreeIn (FParamInfo rep), FreeIn (LParamInfo rep),
 FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep),
 FreeIn (Op rep)) =>
Names
-> Either (Stm rep) (Stm rep)
-> (Names, Either (Stm rep) (Stm rep))
block Names
forall a. Monoid a => a
mempty
  where
    block :: Names
-> Either (Stm rep) (Stm rep)
-> (Names, Either (Stm rep) (Stm rep))
block Names
blocked (Left Stm rep
need) =
      (Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm rep -> [VName]
forall rep. Stm rep -> [VName]
provides Stm rep
need), Stm rep -> Either (Stm rep) (Stm rep)
forall a b. a -> Either a b
Left Stm rep
need)
    block Names
blocked (Right Stm rep
need)
      | Names
blocked Names -> Names -> Bool
`namesIntersect` Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
need =
        (Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm rep -> [VName]
forall rep. Stm rep -> [VName]
provides Stm rep
need), Stm rep -> Either (Stm rep) (Stm rep)
forall a b. a -> Either a b
Left Stm rep
need)
      | Bool
otherwise =
        (Names
blocked, Stm rep -> Either (Stm rep) (Stm rep)
forall a b. b -> Either a b
Right Stm rep
need)

provides :: Stm rep -> [VName]
provides :: Stm rep -> [VName]
provides = PatT (LetDec rep) -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT (LetDec rep) -> [VName])
-> (Stm rep -> PatT (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> PatT (LetDec rep)
forall rep. Stm rep -> Pat rep
stmPat

expandUsage ::
  (ASTRep rep, Aliased rep) =>
  (Stm rep -> UT.UsageTable) ->
  ST.SymbolTable rep ->
  UT.UsageTable ->
  Stm rep ->
  UT.UsageTable
expandUsage :: (Stm rep -> UsageTable)
-> SymbolTable rep -> UsageTable -> Stm rep -> UsageTable
expandUsage Stm rep -> UsageTable
usageInStm SymbolTable rep
vtable UsageTable
utable stm :: Stm rep
stm@(Let Pat rep
pat StmAux (ExpDec rep)
_ Exp rep
e) =
  (VName -> Names) -> UsageTable -> UsageTable
UT.expand (VName -> SymbolTable rep -> Names
forall rep. VName -> SymbolTable rep -> Names
`ST.lookupAliases` SymbolTable rep
vtable) (Stm rep -> UsageTable
usageInStm Stm rep
stm UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
usageThroughAliases)
    UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> ( if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
utable) (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat)
           then Names -> UsageTable
UT.sizeUsages (Exp rep -> Names
forall a. FreeIn a => a -> Names
freeIn Exp rep
e)
           else UsageTable
forall a. Monoid a => a
mempty
       )
    UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
utable
  where
    usageThroughAliases :: UsageTable
usageThroughAliases =
      [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$
        ((VName, Names) -> Maybe UsageTable)
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases ([(VName, Names)] -> [UsageTable])
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$
          [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat) (Pat rep -> [Names]
forall dec. AliasesOf dec => PatT dec -> [Names]
patAliases Pat rep
pat)
    usageThroughBindeeAliases :: (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases (VName
name, Names
aliases) = do
      Usages
uses <- VName -> UsageTable -> Maybe Usages
UT.lookup VName
name UsageTable
utable
      UsageTable -> Maybe UsageTable
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable -> Maybe UsageTable) -> UsageTable -> Maybe UsageTable
forall a b. (a -> b) -> a -> b
$ [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Usages -> UsageTable
`UT.usage` Usages
uses) ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases

type BlockPred rep = ST.SymbolTable rep -> UT.UsageTable -> Stm rep -> Bool

neverBlocks :: BlockPred rep
neverBlocks :: BlockPred rep
neverBlocks SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

alwaysBlocks :: BlockPred rep
alwaysBlocks :: BlockPred rep
alwaysBlocks SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
True

isFalse :: Bool -> BlockPred rep
isFalse :: Bool -> BlockPred rep
isFalse Bool
b SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool -> Bool
not Bool
b

orIf :: BlockPred rep -> BlockPred rep -> BlockPred rep
orIf :: BlockPred rep -> BlockPred rep -> BlockPred rep
orIf BlockPred rep
p1 BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need = BlockPred rep
p1 SymbolTable rep
body UsageTable
vtable Stm rep
need Bool -> Bool -> Bool
|| BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need

andAlso :: BlockPred rep -> BlockPred rep -> BlockPred rep
andAlso :: BlockPred rep -> BlockPred rep -> BlockPred rep
andAlso BlockPred rep
p1 BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need = BlockPred rep
p1 SymbolTable rep
body UsageTable
vtable Stm rep
need Bool -> Bool -> Bool
&& BlockPred rep
p2 SymbolTable rep
body UsageTable
vtable Stm rep
need

isConsumed :: BlockPred rep
isConsumed :: BlockPred rep
isConsumed SymbolTable rep
_ UsageTable
utable = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
utable) ([VName] -> Bool) -> (Stm rep -> [VName]) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatT (LetDec rep) -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT (LetDec rep) -> [VName])
-> (Stm rep -> PatT (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> PatT (LetDec rep)
forall rep. Stm rep -> Pat rep
stmPat

isOp :: BlockPred rep
isOp :: BlockPred rep
isOp SymbolTable rep
_ UsageTable
_ (Let Pat rep
_ StmAux (ExpDec rep)
_ Op {}) = Bool
True
isOp SymbolTable rep
_ UsageTable
_ Stm rep
_ = Bool
False

constructBody ::
  SimplifiableRep rep =>
  Stms (Wise rep) ->
  Result ->
  SimpleM rep (Body (Wise rep))
constructBody :: Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
constructBody Stms (Wise rep)
stms Result
res =
  ((Body (Wise rep), Stms (Wise rep)) -> Body (Wise rep))
-> SimpleM rep (Body (Wise rep), Stms (Wise rep))
-> SimpleM rep (Body (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body (Wise rep), Stms (Wise rep)) -> Body (Wise rep)
forall a b. (a, b) -> a
fst (SimpleM rep (Body (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Body (Wise rep)))
-> (BuilderT (Wise rep) (State VNameSource) Result
    -> SimpleM rep (Body (Wise rep), Stms (Wise rep)))
-> BuilderT (Wise rep) (State VNameSource) Result
-> SimpleM rep (Body (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder (Wise rep) (Body (Wise rep))
-> SimpleM rep (Body (Wise rep), Stms (Wise rep))
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder (Wise rep) (Body (Wise rep))
 -> SimpleM rep (Body (Wise rep), Stms (Wise rep)))
-> (BuilderT (Wise rep) (State VNameSource) Result
    -> Builder (Wise rep) (Body (Wise rep)))
-> BuilderT (Wise rep) (State VNameSource) Result
-> SimpleM rep (Body (Wise rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BuilderT (Wise rep) (State VNameSource) Result
-> Builder (Wise rep) (Body (Wise rep))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (BuilderT (Wise rep) (State VNameSource) Result
 -> SimpleM rep (Body (Wise rep)))
-> BuilderT (Wise rep) (State VNameSource) Result
-> SimpleM rep (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
-> BuilderT (Wise rep) (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT (Wise rep) (State VNameSource)))
Stms (Wise rep)
stms
    Result -> BuilderT (Wise rep) (State VNameSource) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

type SimplifiedBody rep a = ((a, UT.UsageTable), Stms (Wise rep))

blockIf ::
  SimplifiableRep rep =>
  BlockPred (Wise rep) ->
  SimpleM rep (SimplifiedBody rep a) ->
  SimpleM rep ((Stms (Wise rep), a), Stms (Wise rep))
blockIf :: BlockPred (Wise rep)
-> SimpleM rep (SimplifiedBody rep a)
-> SimpleM rep ((Stms (Wise rep), a), Stms (Wise rep))
blockIf BlockPred (Wise rep)
block SimpleM rep (SimplifiedBody rep a)
m = do
  ((a
x, UsageTable
usages), Stms (Wise rep)
stms) <- SimpleM rep (SimplifiedBody rep a)
m
  SymbolTable (Wise rep)
vtable <- SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
  RuleBook (Wise rep)
rules <- (Env rep -> RuleBook (Wise rep))
-> SimpleM rep (RuleBook (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> RuleBook (Wise rep)
forall rep. Env rep -> RuleBook (Wise rep)
envRules
  (Stms (Wise rep)
blocked, Stms (Wise rep)
hoisted) <- RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
hoistStms RuleBook (Wise rep)
rules BlockPred (Wise rep)
block SymbolTable (Wise rep)
vtable UsageTable
usages Stms (Wise rep)
stms
  ((Stms (Wise rep), a), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), a), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Stms (Wise rep)
blocked, a
x), Stms (Wise rep)
hoisted)

hasFree :: ASTRep rep => Names -> BlockPred rep
hasFree :: Names -> BlockPred rep
hasFree Names
ks SymbolTable rep
_ UsageTable
_ Stm rep
need = Names
ks Names -> Names -> Bool
`namesIntersect` Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
need

isNotSafe :: ASTRep rep => BlockPred rep
isNotSafe :: BlockPred rep
isNotSafe SymbolTable rep
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm rep -> Bool) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp

isInPlaceBound :: BlockPred m
isInPlaceBound :: BlockPred m
isInPlaceBound SymbolTable m
_ UsageTable
_ = ExpT m -> Bool
forall rep. ExpT rep -> Bool
isUpdate (ExpT m -> Bool) -> (Stm m -> ExpT m) -> Stm m -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm m -> ExpT m
forall rep. Stm rep -> Exp rep
stmExp
  where
    isUpdate :: ExpT rep -> Bool
isUpdate (BasicOp Update {}) = Bool
True
    isUpdate ExpT rep
_ = Bool
False

isNotCheap :: ASTRep rep => BlockPred rep
isNotCheap :: BlockPred rep
isNotCheap SymbolTable rep
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm rep -> Bool) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Bool
forall rep. ASTRep rep => Stm rep -> Bool
cheapStm

cheapStm :: ASTRep rep => Stm rep -> Bool
cheapStm :: Stm rep -> Bool
cheapStm = Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
cheapExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp

cheapExp :: ASTRep rep => Exp rep -> Bool
cheapExp :: Exp rep -> Bool
cheapExp (BasicOp BinOp {}) = Bool
True
cheapExp (BasicOp SubExp {}) = Bool
True
cheapExp (BasicOp UnOp {}) = Bool
True
cheapExp (BasicOp CmpOp {}) = Bool
True
cheapExp (BasicOp ConvOp {}) = Bool
True
cheapExp (BasicOp Copy {}) = Bool
False
cheapExp (BasicOp Replicate {}) = Bool
False
cheapExp (BasicOp Manifest {}) = Bool
False
cheapExp DoLoop {} = Bool
False
cheapExp (If SubExp
_ BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
_) =
  (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm rep -> Bool
forall rep. ASTRep rep => Stm rep -> Bool
cheapStm (BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tbranch)
    Bool -> Bool -> Bool
&& (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm rep -> Bool
forall rep. ASTRep rep => Stm rep -> Bool
cheapStm (BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
fbranch)
cheapExp (Op Op rep
op) = Op rep -> Bool
forall op. IsOp op => op -> Bool
cheapOp Op rep
op
cheapExp Exp rep
_ = Bool
True -- Used to be False, but
-- let's try it out.

stmIs :: (Stm rep -> Bool) -> BlockPred rep
stmIs :: (Stm rep -> Bool) -> BlockPred rep
stmIs Stm rep -> Bool
f SymbolTable rep
_ UsageTable
_ = Stm rep -> Bool
f

loopInvariantStm :: ASTRep rep => ST.SymbolTable rep -> Stm rep -> Bool
loopInvariantStm :: SymbolTable rep -> Stm rep -> Bool
loopInvariantStm SymbolTable rep
vtable =
  (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable rep -> Names
forall rep. SymbolTable rep -> Names
ST.availableAtClosestLoop SymbolTable rep
vtable) ([VName] -> Bool) -> (Stm rep -> [VName]) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (Stm rep -> Names) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn

hoistCommon ::
  SimplifiableRep rep =>
  SubExp ->
  IfSort ->
  SimplifiedBody rep Result ->
  SimplifiedBody rep Result ->
  SimpleM
    rep
    ( Body (Wise rep),
      Body (Wise rep),
      Stms (Wise rep)
    )
hoistCommon :: SubExp
-> IfSort
-> SimplifiedBody rep Result
-> SimplifiedBody rep Result
-> SimpleM rep (Body (Wise rep), Body (Wise rep), Stms (Wise rep))
hoistCommon SubExp
cond IfSort
ifsort ((Result
res1, UsageTable
usages1), Stms (Wise rep)
stms1) ((Result
res2, UsageTable
usages2), Stms (Wise rep)
stms2) = do
  Stm (Wise rep) -> Bool
is_alloc_fun <- (Env rep -> Stm (Wise rep) -> Bool)
-> SimpleM rep (Stm (Wise rep) -> Bool)
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv ((Env rep -> Stm (Wise rep) -> Bool)
 -> SimpleM rep (Stm (Wise rep) -> Bool))
-> (Env rep -> Stm (Wise rep) -> Bool)
-> SimpleM rep (Stm (Wise rep) -> Bool)
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> Stm (Wise rep) -> Bool
forall rep. HoistBlockers rep -> Stm (Wise rep) -> Bool
isAllocation (HoistBlockers rep -> Stm (Wise rep) -> Bool)
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> Stm (Wise rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep -> HoistBlockers rep
forall rep. Env rep -> HoistBlockers rep
envHoistBlockers
  BlockPred (Wise rep)
branch_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv ((Env rep -> BlockPred (Wise rep))
 -> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall rep. HoistBlockers rep -> BlockPred (Wise rep)
blockHoistBranch (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep -> HoistBlockers rep
forall rep. Env rep -> HoistBlockers rep
envHoistBlockers
  SymbolTable (Wise rep)
vtable <- SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
  let -- We are unwilling to hoist things that are unsafe or costly,

      -- because in that case they will also be hoisted past that
      -- loop.
      --
      -- We also try very hard to hoist allocations or anything that
      -- contributes to memory or array size, because that will allow
      -- allocations to be hoisted.
      cond_loop_invariant :: Bool
cond_loop_invariant =
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable (Wise rep) -> Names
forall rep. SymbolTable rep -> Names
ST.availableAtClosestLoop SymbolTable (Wise rep)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond

      desirableToHoist :: Stm (Wise rep) -> Bool
desirableToHoist Stm (Wise rep)
stm =
        Stm (Wise rep) -> Bool
is_alloc_fun Stm (Wise rep)
stm
          Bool -> Bool -> Bool
|| ( SymbolTable (Wise rep) -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable (Wise rep)
vtable Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                 Bool -> Bool -> Bool
&& Bool
cond_loop_invariant
                 Bool -> Bool -> Bool
&& IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback
                 Bool -> Bool -> Bool
&& SymbolTable (Wise rep) -> Stm (Wise rep) -> Bool
forall rep. ASTRep rep => SymbolTable rep -> Stm rep -> Bool
loopInvariantStm SymbolTable (Wise rep)
vtable Stm (Wise rep)
stm
             )

      -- No matter what, we always want to hoist constants as much as
      -- possible.
      isNotHoistableBnd :: BlockPred (Wise rep)
isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ (Let Pat (Wise rep)
_ StmAux (ExpDec (Wise rep))
_ (BasicOp ArrayLit {})) = Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ (Let Pat (Wise rep)
_ StmAux (ExpDec (Wise rep))
_ (BasicOp SubExp {})) = Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
usages (Let Pat (Wise rep)
pat StmAux (ExpDec (Wise rep))
_ ExpT (Wise rep)
_)
        | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
usages) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatT (VarWisdom, LetDec rep) -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT (VarWisdom, LetDec rep)
Pat (Wise rep)
pat =
          Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ Stm (Wise rep)
stm
        | Stm (Wise rep) -> Bool
is_alloc_fun Stm (Wise rep)
stm = Bool
False
      isNotHoistableBnd SymbolTable (Wise rep)
_ UsageTable
_ Stm (Wise rep)
_ =
        -- Hoist aggressively out of versioning branches.
        IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfEquiv

      block :: BlockPred (Wise rep)
block =
        BlockPred (Wise rep)
branch_blocker
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` ((BlockPred (Wise rep)
forall rep. ASTRep rep => BlockPred rep
isNotSafe BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. ASTRep rep => BlockPred rep
isNotCheap) BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`andAlso` (Stm (Wise rep) -> Bool) -> BlockPred (Wise rep)
forall rep. (Stm rep -> Bool) -> BlockPred rep
stmIs (Bool -> Bool
not (Bool -> Bool)
-> (Stm (Wise rep) -> Bool) -> Stm (Wise rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise rep) -> Bool
desirableToHoist))
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
isInPlaceBound
          BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
isNotHoistableBnd

  RuleBook (Wise rep)
rules <- (Env rep -> RuleBook (Wise rep))
-> SimpleM rep (RuleBook (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> RuleBook (Wise rep)
forall rep. Env rep -> RuleBook (Wise rep)
envRules
  (Stms (Wise rep)
body1_stms', Stms (Wise rep)
safe1) <-
    SubExp
-> Bool
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
SubExp
-> Bool
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
protectIfHoisted SubExp
cond Bool
True (SimpleM rep (Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
      RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
hoistStms RuleBook (Wise rep)
rules BlockPred (Wise rep)
block SymbolTable (Wise rep)
vtable UsageTable
usages1 Stms (Wise rep)
stms1
  (Stms (Wise rep)
body2_stms', Stms (Wise rep)
safe2) <-
    SubExp
-> Bool
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
SubExp
-> Bool
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
protectIfHoisted SubExp
cond Bool
False (SimpleM rep (Stms (Wise rep), Stms (Wise rep))
 -> SimpleM rep (Stms (Wise rep), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
      RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
RuleBook (Wise rep)
-> BlockPred (Wise rep)
-> SymbolTable (Wise rep)
-> UsageTable
-> Stms (Wise rep)
-> SimpleM rep (Stms (Wise rep), Stms (Wise rep))
hoistStms RuleBook (Wise rep)
rules BlockPred (Wise rep)
block SymbolTable (Wise rep)
vtable UsageTable
usages2 Stms (Wise rep)
stms2
  let hoistable :: Stms (Wise rep)
hoistable = Stms (Wise rep)
safe1 Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
safe2
  Body (Wise rep)
body1' <- Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall rep.
SimplifiableRep rep =>
Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
constructBody Stms (Wise rep)
body1_stms' Result
res1
  Body (Wise rep)
body2' <- Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall rep.
SimplifiableRep rep =>
Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
constructBody Stms (Wise rep)
body2_stms' Result
res2
  (Body (Wise rep), Body (Wise rep), Stms (Wise rep))
-> SimpleM rep (Body (Wise rep), Body (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise rep)
body1', Body (Wise rep)
body2', Stms (Wise rep)
hoistable)

-- | Simplify a single body.  The @[Diet]@ only covers the value
-- elements, because the context cannot be consumed.
simplifyBody ::
  SimplifiableRep rep =>
  [Diet] ->
  Body rep ->
  SimpleM rep (SimplifiedBody rep Result)
simplifyBody :: [Diet] -> Body rep -> SimpleM rep (SimplifiedBody rep Result)
simplifyBody [Diet]
ds (Body BodyDec rep
_ Stms rep
stms Result
res) =
  Stms rep
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep (SimplifiedBody rep Result)
forall rep a.
SimplifiableRep rep =>
Stms rep
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
simplifyStms Stms rep
stms (SimpleM rep (SimplifiedBody rep Result)
 -> SimpleM rep (SimplifiedBody rep Result))
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep (SimplifiedBody rep Result)
forall a b. (a -> b) -> a -> b
$ do
    (Result, UsageTable)
res' <- [Diet] -> Result -> SimpleM rep (Result, UsageTable)
forall rep.
SimplifiableRep rep =>
[Diet] -> Result -> SimpleM rep (Result, UsageTable)
simplifyResult [Diet]
ds Result
res
    SimplifiedBody rep Result
-> SimpleM rep (SimplifiedBody rep Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result, UsageTable)
res', Stms (Wise rep)
forall a. Monoid a => a
mempty)

-- | Simplify a single 'Result'.  The @[Diet]@ only covers the value
-- elements, because the context cannot be consumed.
simplifyResult ::
  SimplifiableRep rep => [Diet] -> Result -> SimpleM rep (Result, UT.UsageTable)
simplifyResult :: [Diet] -> Result -> SimpleM rep (Result, UsageTable)
simplifyResult [Diet]
ds Result
res = do
  Result
res' <- (SubExpRes -> SimpleM rep SubExpRes)
-> Result -> SimpleM rep Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> SimpleM rep SubExpRes
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Result
res
  SymbolTable (Wise rep)
vtable <- SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
  let consumption :: UsageTable
consumption = SymbolTable (Wise rep) -> [(Diet, SubExpRes)] -> UsageTable
forall rep. SymbolTable rep -> [(Diet, SubExpRes)] -> UsageTable
consumeResult SymbolTable (Wise rep)
vtable ([(Diet, SubExpRes)] -> UsageTable)
-> [(Diet, SubExpRes)] -> UsageTable
forall a b. (a -> b) -> a -> b
$ [Diet] -> Result -> [(Diet, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Diet]
ds Result
res'
  (Result, UsageTable) -> SimpleM rep (Result, UsageTable)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
res', Names -> UsageTable
UT.usages (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res') UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
consumption)

isDoLoopResult :: Result -> UT.UsageTable
isDoLoopResult :: Result -> UsageTable
isDoLoopResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> (Result -> [UsageTable]) -> Result -> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> UsageTable) -> Result -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> UsageTable
checkForVar
  where
    checkForVar :: SubExpRes -> UsageTable
checkForVar (SubExpRes Certs
_ (Var VName
ident)) = VName -> UsageTable
UT.inResultUsage VName
ident
    checkForVar SubExpRes
_ = UsageTable
forall a. Monoid a => a
mempty

simplifyStms ::
  SimplifiableRep rep =>
  Stms rep ->
  SimpleM rep (a, Stms (Wise rep)) ->
  SimpleM rep (a, Stms (Wise rep))
simplifyStms :: Stms rep
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
simplifyStms Stms rep
stms SimpleM rep (a, Stms (Wise rep))
m =
  case Stms rep -> Maybe (Stm rep, Stms rep)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms of
    Maybe (Stm rep, Stms rep)
Nothing -> Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStms Stms (Wise rep)
forall a. Monoid a => a
mempty SimpleM rep (a, Stms (Wise rep))
m
    Just (Let Pat rep
pat (StmAux Certs
stm_cs Attrs
attrs ExpDec rep
dec) Exp rep
e, Stms rep
stms') -> do
      Certs
stm_cs' <- Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Certs
stm_cs
      ((Exp (Wise rep)
e', Stms (Wise rep)
e_stms), Certs
e_cs) <- SimpleM rep (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep ((Exp (Wise rep), Stms (Wise rep)), Certs)
forall rep a. SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts (SimpleM rep (Exp (Wise rep), Stms (Wise rep))
 -> SimpleM rep ((Exp (Wise rep), Stms (Wise rep)), Certs))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep ((Exp (Wise rep), Stms (Wise rep)), Certs)
forall a b. (a -> b) -> a -> b
$ Exp rep -> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Exp rep -> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
simplifyExp Exp rep
e
      (Pat rep
pat', Certs
pat_cs) <- SimpleM rep (Pat rep) -> SimpleM rep (Pat rep, Certs)
forall rep a. SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts (SimpleM rep (Pat rep) -> SimpleM rep (Pat rep, Certs))
-> SimpleM rep (Pat rep) -> SimpleM rep (Pat rep, Certs)
forall a b. (a -> b) -> a -> b
$ Pat rep -> SimpleM rep (Pat rep)
forall rep dec.
(SimplifiableRep rep, Simplifiable dec) =>
PatT dec -> SimpleM rep (PatT dec)
simplifyPat Pat rep
pat
      let cs :: Certs
cs = Certs
stm_cs' Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
e_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
pat_cs
      Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStms Stms (Wise rep)
e_stms (SimpleM rep (a, Stms (Wise rep))
 -> SimpleM rep (a, Stms (Wise rep)))
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
        Stm (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
Stm (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStm (Pat rep -> StmAux (ExpDec rep) -> Exp (Wise rep) -> Stm (Wise rep)
forall rep.
(ASTRep rep, CanBeWise (Op rep)) =>
Pat rep -> StmAux (ExpDec rep) -> Exp (Wise rep) -> Stm (Wise rep)
mkWiseLetStm Pat rep
pat' (Certs -> Attrs -> ExpDec rep -> StmAux (ExpDec rep)
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ExpDec rep
dec) Exp (Wise rep)
e') (SimpleM rep (a, Stms (Wise rep))
 -> SimpleM rep (a, Stms (Wise rep)))
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
          Stms rep
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
Stms rep
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
simplifyStms Stms rep
stms' SimpleM rep (a, Stms (Wise rep))
m

inspectStm ::
  SimplifiableRep rep =>
  Stm (Wise rep) ->
  SimpleM rep (a, Stms (Wise rep)) ->
  SimpleM rep (a, Stms (Wise rep))
inspectStm :: Stm (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStm = Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStms (Stms (Wise rep)
 -> SimpleM rep (a, Stms (Wise rep))
 -> SimpleM rep (a, Stms (Wise rep)))
-> (Stm (Wise rep) -> Stms (Wise rep))
-> Stm (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise rep) -> Stms (Wise rep)
forall rep. Stm rep -> Stms rep
oneStm

inspectStms ::
  SimplifiableRep rep =>
  Stms (Wise rep) ->
  SimpleM rep (a, Stms (Wise rep)) ->
  SimpleM rep (a, Stms (Wise rep))
inspectStms :: Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStms Stms (Wise rep)
stms SimpleM rep (a, Stms (Wise rep))
m =
  case Stms (Wise rep) -> Maybe (Stm (Wise rep), Stms (Wise rep))
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms (Wise rep)
stms of
    Maybe (Stm (Wise rep), Stms (Wise rep))
Nothing -> SimpleM rep (a, Stms (Wise rep))
m
    Just (Stm (Wise rep)
stm, Stms (Wise rep)
stms') -> do
      SymbolTable (Wise rep)
vtable <- SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
      RuleBook (Wise rep)
rules <- (Env rep -> RuleBook (Wise rep))
-> SimpleM rep (RuleBook (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv Env rep -> RuleBook (Wise rep)
forall rep. Env rep -> RuleBook (Wise rep)
envRules
      Maybe (Stms (Wise rep))
simplified <- RuleBook (Wise rep)
-> SymbolTable (Wise rep)
-> Stm (Wise rep)
-> SimpleM rep (Maybe (Stms (Wise rep)))
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m) =>
RuleBook rep -> SymbolTable rep -> Stm rep -> m (Maybe (Stms rep))
topDownSimplifyStm RuleBook (Wise rep)
rules SymbolTable (Wise rep)
vtable Stm (Wise rep)
stm
      case Maybe (Stms (Wise rep))
simplified of
        Just Stms (Wise rep)
newstms -> SimpleM rep ()
forall rep. SimpleM rep ()
changed SimpleM rep ()
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStms (Stms (Wise rep)
newstms Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
stms') SimpleM rep (a, Stms (Wise rep))
m
        Maybe (Stms (Wise rep))
Nothing -> do
          (a
x, Stms (Wise rep)
stms'') <- (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable (Stm (Wise rep) -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep.
(ASTRep rep, IndexOp (Op rep), Aliased rep) =>
Stm rep -> SymbolTable rep -> SymbolTable rep
ST.insertStm Stm (Wise rep)
stm) (SimpleM rep (a, Stms (Wise rep))
 -> SimpleM rep (a, Stms (Wise rep)))
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
Stms (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
inspectStms Stms (Wise rep)
stms' SimpleM rep (a, Stms (Wise rep))
m
          (a, Stms (Wise rep)) -> SimpleM rep (a, Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stm (Wise rep) -> Stms (Wise rep)
forall rep. Stm rep -> Stms rep
oneStm Stm (Wise rep)
stm Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
stms'')

simplifyOp :: Op rep -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
simplifyOp :: Op rep -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
simplifyOp Op rep
op = do
  Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep))
f <- ((SimpleOps rep, Env rep)
 -> Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep)))
-> SimpleM
     rep
     (Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps rep, Env rep)
  -> Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep)))
 -> SimpleM
      rep
      (Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep))))
-> ((SimpleOps rep, Env rep)
    -> Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep)))
-> SimpleM
     rep
     (Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep)))
forall a b. (a -> b) -> a -> b
$ SimpleOps rep
-> Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep))
forall rep. SimpleOps rep -> SimplifyOp rep (Op rep)
simplifyOpS (SimpleOps rep
 -> Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep)))
-> ((SimpleOps rep, Env rep) -> SimpleOps rep)
-> (SimpleOps rep, Env rep)
-> Op rep
-> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps rep, Env rep) -> SimpleOps rep
forall a b. (a, b) -> a
fst
  Op rep -> SimpleM rep (OpWithWisdom (Op rep), Stms (Wise rep))
f Op rep
op

simplifyExp ::
  SimplifiableRep rep =>
  Exp rep ->
  SimpleM rep (Exp (Wise rep), Stms (Wise rep))
simplifyExp :: Exp rep -> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
simplifyExp (If SubExp
cond BodyT rep
tbranch BodyT rep
fbranch (IfDec [BranchType rep]
ts IfSort
ifsort)) = do
  -- Here, we have to check whether 'cond' puts a bound on some free
  -- variable, and if so, chomp it.  We should also try to do CSE
  -- across branches.
  SubExp
cond' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
cond
  [BranchType rep]
ts' <- (BranchType rep -> SimpleM rep (BranchType rep))
-> [BranchType rep] -> SimpleM rep [BranchType rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BranchType rep -> SimpleM rep (BranchType rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [BranchType rep]
ts
  -- FIXME: we have to be conservative about the diet here, because we
  -- lack proper ifnormation.  Something is wrong with the order in
  -- which the simplifier does things - it should be purely bottom-up
  -- (or else, If expressions should indicate explicitly the diet of
  -- their return types).
  let ds :: [Diet]
ds = (BranchType rep -> Diet) -> [BranchType rep] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> BranchType rep -> Diet
forall a b. a -> b -> a
const Diet
Consume) [BranchType rep]
ts
  SimplifiedBody rep Result
tbranch' <- [Diet] -> BodyT rep -> SimpleM rep (SimplifiedBody rep Result)
forall rep.
SimplifiableRep rep =>
[Diet] -> Body rep -> SimpleM rep (SimplifiedBody rep Result)
simplifyBody [Diet]
ds BodyT rep
tbranch
  SimplifiedBody rep Result
fbranch' <- [Diet] -> BodyT rep -> SimpleM rep (SimplifiedBody rep Result)
forall rep.
SimplifiableRep rep =>
[Diet] -> Body rep -> SimpleM rep (SimplifiedBody rep Result)
simplifyBody [Diet]
ds BodyT rep
fbranch
  (Body (Wise rep)
tbranch'', Body (Wise rep)
fbranch'', Stms (Wise rep)
hoisted) <- SubExp
-> IfSort
-> SimplifiedBody rep Result
-> SimplifiedBody rep Result
-> SimpleM rep (Body (Wise rep), Body (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
SubExp
-> IfSort
-> SimplifiedBody rep Result
-> SimplifiedBody rep Result
-> SimpleM rep (Body (Wise rep), Body (Wise rep), Stms (Wise rep))
hoistCommon SubExp
cond' IfSort
ifsort SimplifiedBody rep Result
tbranch' SimplifiedBody rep Result
fbranch'
  (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> Body (Wise rep)
-> Body (Wise rep)
-> IfDec (BranchType (Wise rep))
-> Exp (Wise rep)
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond' Body (Wise rep)
tbranch'' Body (Wise rep)
fbranch'' (IfDec (BranchType (Wise rep)) -> Exp (Wise rep))
-> IfDec (BranchType (Wise rep)) -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ [BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
ts' IfSort
ifsort, Stms (Wise rep)
hoisted)
simplifyExp (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form BodyT rep
loopbody) = do
  let ([FParam rep]
params, [SubExp]
args) = [(FParam rep, SubExp)] -> ([FParam rep], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam rep, SubExp)]
merge
  [FParam rep]
params' <- (FParam rep -> SimpleM rep (FParam rep))
-> [FParam rep] -> SimpleM rep [FParam rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo rep -> SimpleM rep (FParamInfo rep))
-> FParam rep -> SimpleM rep (FParam rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo rep -> SimpleM rep (FParamInfo rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify) [FParam rep]
params
  [SubExp]
args' <- (SubExp -> SimpleM rep SubExp) -> [SubExp] -> SimpleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [SubExp]
args
  let merge' :: [(FParam rep, SubExp)]
merge' = [FParam rep] -> [SubExp] -> [(FParam rep, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam rep]
params' [SubExp]
args'
      diets :: [Diet]
diets = (FParam rep -> Diet) -> [FParam rep] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase Shape Uniqueness -> Diet)
-> (FParam rep -> TypeBase Shape Uniqueness) -> FParam rep -> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam rep -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType) [FParam rep]
params'
  (LoopForm (Wise rep)
form', Names
boundnames, SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
wrapbody) <- case LoopForm rep
form of
    ForLoop VName
loopvar IntType
it SubExp
boundexp [(LParam rep, VName)]
loopvars -> do
      SubExp
boundexp' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
boundexp
      let ([LParam rep]
loop_params, [VName]
loop_arrs) = [(LParam rep, VName)] -> ([LParam rep], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(LParam rep, VName)]
loopvars
      [LParam rep]
loop_params' <- (LParam rep -> SimpleM rep (LParam rep))
-> [LParam rep] -> SimpleM rep [LParam rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamInfo rep -> SimpleM rep (LParamInfo rep))
-> LParam rep -> SimpleM rep (LParam rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LParamInfo rep -> SimpleM rep (LParamInfo rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify) [LParam rep]
loop_params
      [VName]
loop_arrs' <- (VName -> SimpleM rep VName) -> [VName] -> SimpleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [VName]
loop_arrs
      let form' :: LoopForm (Wise rep)
form' = VName
-> IntType
-> SubExp
-> [(LParam (Wise rep), VName)]
-> LoopForm (Wise rep)
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
loopvar IntType
it SubExp
boundexp' ([LParam rep] -> [VName] -> [(LParam rep, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam rep]
loop_params' [VName]
loop_arrs')
      (LoopForm (Wise rep), Names,
 SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM
     rep
     (LoopForm (Wise rep), Names,
      SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
      -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( LoopForm (Wise rep)
form',
          [VName] -> Names
namesFromList (VName
loopvar VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: (LParam rep -> VName) -> [LParam rep] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam rep -> VName
forall dec. Param dec -> VName
paramName [LParam rep]
loop_params') Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
fparamnames,
          VName
-> IntType
-> SubExp
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
VName -> IntType -> SubExp -> SimpleM rep a -> SimpleM rep a
bindLoopVar VName
loopvar IntType
it SubExp
boundexp'
            (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
    -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FParam (Wise rep), SubExp)]
-> LoopForm (Wise rep)
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
[(FParam (Wise rep), SubExp)]
-> LoopForm (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
protectLoopHoisted [(FParam rep, SubExp)]
[(FParam (Wise rep), SubExp)]
merge' LoopForm (Wise rep)
form'
            (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
    -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [LParam (Wise rep)]
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
[LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindArrayLParams [LParam rep]
[LParam (Wise rep)]
loop_params'
        )
    WhileLoop VName
cond -> do
      VName
cond' <- VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify VName
cond
      (LoopForm (Wise rep), Names,
 SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM
     rep
     (LoopForm (Wise rep), Names,
      SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
      -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( VName -> LoopForm (Wise rep)
forall rep. VName -> LoopForm rep
WhileLoop VName
cond',
          Names
fparamnames,
          [(FParam (Wise rep), SubExp)]
-> LoopForm (Wise rep)
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
[(FParam (Wise rep), SubExp)]
-> LoopForm (Wise rep)
-> SimpleM rep (a, Stms (Wise rep))
-> SimpleM rep (a, Stms (Wise rep))
protectLoopHoisted [(FParam rep, SubExp)]
[(FParam (Wise rep), SubExp)]
merge' (VName -> LoopForm (Wise rep)
forall rep. VName -> LoopForm rep
WhileLoop VName
cond')
        )
  BlockPred (Wise rep)
seq_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv ((Env rep -> BlockPred (Wise rep))
 -> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall rep. HoistBlockers rep -> BlockPred (Wise rep)
blockHoistSeq (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep -> HoistBlockers rep
forall rep. Env rep -> HoistBlockers rep
envHoistBlockers
  ((Stms (Wise rep)
loopstms, Result
loopres), Stms (Wise rep)
hoisted) <-
    SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a. SimpleM rep a -> SimpleM rep a
enterLoop (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
    -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
consumeMerge (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
      [(FParam (Wise rep), SubExp, SubExpRes)]
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
[(FParam (Wise rep), SubExp, SubExpRes)]
-> SimpleM rep a -> SimpleM rep a
bindMerge (((FParam rep, SubExp)
 -> SubExpRes -> (FParam rep, SubExp, SubExpRes))
-> [(FParam rep, SubExp)]
-> Result
-> [(FParam rep, SubExp, SubExpRes)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (FParam rep, SubExp)
-> SubExpRes -> (FParam rep, SubExp, SubExpRes)
forall a b c. (a, b) -> c -> (a, b, c)
withRes [(FParam rep, SubExp)]
merge' (BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
loopbody)) (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
        SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
wrapbody (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
          BlockPred (Wise rep)
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> SimpleM rep (SimplifiedBody rep a)
-> SimpleM rep ((Stms (Wise rep), a), Stms (Wise rep))
blockIf
            ( Names -> BlockPred (Wise rep)
forall rep. ASTRep rep => Names -> BlockPred rep
hasFree Names
boundnames BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
isConsumed
                BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
seq_blocker
                BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. ASTRep rep => BlockPred rep
notWorthHoisting
            )
            (SimpleM rep (SimplifiedBody rep Result)
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ do
              ((Result
res, UsageTable
uses), Stms (Wise rep)
stms) <- [Diet] -> BodyT rep -> SimpleM rep (SimplifiedBody rep Result)
forall rep.
SimplifiableRep rep =>
[Diet] -> Body rep -> SimpleM rep (SimplifiedBody rep Result)
simplifyBody [Diet]
diets BodyT rep
loopbody
              SimplifiedBody rep Result
-> SimpleM rep (SimplifiedBody rep Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result
res, UsageTable
uses UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> Result -> UsageTable
isDoLoopResult Result
res), Stms (Wise rep)
stms)
  Body (Wise rep)
loopbody' <- Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall rep.
SimplifiableRep rep =>
Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
constructBody Stms (Wise rep)
loopstms Result
loopres
  (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return ([(FParam (Wise rep), SubExp)]
-> LoopForm (Wise rep) -> Body (Wise rep) -> Exp (Wise rep)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam rep, SubExp)]
[(FParam (Wise rep), SubExp)]
merge' LoopForm (Wise rep)
form' Body (Wise rep)
loopbody', Stms (Wise rep)
hoisted)
  where
    fparamnames :: Names
fparamnames =
      [VName] -> Names
namesFromList (((FParam rep, SubExp) -> VName)
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep -> VName
forall dec. Param dec -> VName
paramName (FParam rep -> VName)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge)
    consumeMerge :: SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
consumeMerge =
      (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
(SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep a -> SimpleM rep a
localVtable ((SymbolTable (Wise rep) -> SymbolTable (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> (SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$ (SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> [VName] -> SymbolTable (Wise rep)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> SymbolTable (Wise rep) -> VName -> SymbolTable (Wise rep)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall rep. VName -> SymbolTable rep -> SymbolTable rep
ST.consume)) ([VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep))
-> [VName] -> SymbolTable (Wise rep) -> SymbolTable (Wise rep)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
consumed_by_merge
    consumed_by_merge :: Names
consumed_by_merge =
      [SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn ([SubExp] -> Names) -> [SubExp] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExp) -> SubExp)
-> [(FParam rep, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(FParam rep, SubExp)] -> [SubExp])
-> [(FParam rep, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExp) -> Bool)
-> [(FParam rep, SubExp)] -> [(FParam rep, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> ((FParam rep, SubExp) -> TypeBase Shape Uniqueness)
-> (FParam rep, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam rep -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType (FParam rep -> TypeBase Shape Uniqueness)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> TypeBase Shape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge
    withRes :: (a, b) -> c -> (a, b, c)
withRes (a
p, b
x) c
y = (a
p, b
x, c
y)
simplifyExp (Op Op rep
op) = do
  (OpWithWisdom (Op rep)
op', Stms (Wise rep)
stms) <- Op rep -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
forall rep. Op rep -> SimpleM rep (Op (Wise rep), Stms (Wise rep))
simplifyOp Op rep
op
  (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Op (Wise rep) -> Exp (Wise rep)
forall rep. Op rep -> ExpT rep
Op Op (Wise rep)
OpWithWisdom (Op rep)
op', Stms (Wise rep)
stms)
simplifyExp (WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs Lambda rep
lam) = do
  ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))]
inputs', [Stms (Wise rep)]
inputs_stms) <- ([((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
   Stms (Wise rep))]
 -> ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))],
     [Stms (Wise rep)]))
-> SimpleM
     rep
     [((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
       Stms (Wise rep))]
-> SimpleM
     rep
     ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))],
      [Stms (Wise rep)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
  Stms (Wise rep))]
-> ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))],
    [Stms (Wise rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip (SimpleM
   rep
   [((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
     Stms (Wise rep))]
 -> SimpleM
      rep
      ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))],
       [Stms (Wise rep)]))
-> (((Shape, [VName], Maybe (Lambda rep, [SubExp]))
     -> SimpleM
          rep
          ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
           Stms (Wise rep)))
    -> SimpleM
         rep
         [((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
           Stms (Wise rep))])
-> ((Shape, [VName], Maybe (Lambda rep, [SubExp]))
    -> SimpleM
         rep
         ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
          Stms (Wise rep)))
-> SimpleM
     rep
     ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))],
      [Stms (Wise rep)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> ((Shape, [VName], Maybe (Lambda rep, [SubExp]))
    -> SimpleM
         rep
         ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
          Stms (Wise rep)))
-> SimpleM
     rep
     [((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
       Stms (Wise rep))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs (((Shape, [VName], Maybe (Lambda rep, [SubExp]))
  -> SimpleM
       rep
       ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
        Stms (Wise rep)))
 -> SimpleM
      rep
      ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))],
       [Stms (Wise rep)]))
-> ((Shape, [VName], Maybe (Lambda rep, [SubExp]))
    -> SimpleM
         rep
         ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
          Stms (Wise rep)))
-> SimpleM
     rep
     ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))],
      [Stms (Wise rep)])
forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda rep, [SubExp])
op) -> do
    (Maybe (Lambda (Wise rep), [SubExp])
op', Stms (Wise rep)
op_stms) <- case Maybe (Lambda rep, [SubExp])
op of
      Maybe (Lambda rep, [SubExp])
Nothing ->
        (Maybe (Lambda (Wise rep), [SubExp]), Stms (Wise rep))
-> SimpleM
     rep (Maybe (Lambda (Wise rep), [SubExp]), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Lambda (Wise rep), [SubExp])
forall a. Maybe a
Nothing, Stms (Wise rep)
forall a. Monoid a => a
mempty)
      Just (Lambda rep
op_lam, [SubExp]
nes) -> do
        (Lambda (Wise rep)
op_lam', Stms (Wise rep)
op_lam_stms) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda rep
op_lam
        [SubExp]
nes' <- [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [SubExp]
nes
        (Maybe (Lambda (Wise rep), [SubExp]), Stms (Wise rep))
-> SimpleM
     rep (Maybe (Lambda (Wise rep), [SubExp]), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Lambda (Wise rep), [SubExp])
-> Maybe (Lambda (Wise rep), [SubExp])
forall a. a -> Maybe a
Just (Lambda (Wise rep)
op_lam', [SubExp]
nes'), Stms (Wise rep)
op_lam_stms)
    (,Stms (Wise rep)
op_stms) ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))
 -> ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
     Stms (Wise rep)))
-> SimpleM
     rep (Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))
-> SimpleM
     rep
     ((Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])),
      Stms (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((,,Maybe (Lambda (Wise rep), [SubExp])
op') (Shape
 -> [VName]
 -> (Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])))
-> SimpleM rep Shape
-> SimpleM
     rep
     ([VName] -> (Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Shape
shape SimpleM
  rep
  ([VName] -> (Shape, [VName], Maybe (Lambda (Wise rep), [SubExp])))
-> SimpleM rep [VName]
-> SimpleM
     rep (Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> SimpleM rep [VName]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [VName]
arrs)
  (Lambda (Wise rep)
lam', Stms (Wise rep)
lam_stms) <- Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda rep
lam
  (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))]
-> Lambda (Wise rep) -> Exp (Wise rep)
forall rep.
[(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda (Wise rep), [SubExp]))]
inputs' Lambda (Wise rep)
lam', [Stms (Wise rep)] -> Stms (Wise rep)
forall a. Monoid a => [a] -> a
mconcat [Stms (Wise rep)]
inputs_stms Stms (Wise rep) -> Stms (Wise rep) -> Stms (Wise rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise rep)
lam_stms)

-- Special case for simplification of commutative BinOps where we
-- arrange the operands in sorted order.  This can make expressions
-- more identical, which helps CSE.
simplifyExp (BasicOp (BinOp BinOp
op SubExp
x SubExp
y))
  | BinOp -> Bool
commutativeBinOp BinOp
op = do
    SubExp
x' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
x
    SubExp
y' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
y
    (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp -> Exp (Wise rep)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Wise rep)) -> BasicOp -> Exp (Wise rep)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
min SubExp
x' SubExp
y') (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
max SubExp
x' SubExp
y'), Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyExp Exp rep
e = do
  Exp (Wise rep)
e' <- Exp rep -> SimpleM rep (Exp (Wise rep))
forall rep.
SimplifiableRep rep =>
Exp rep -> SimpleM rep (Exp (Wise rep))
simplifyExpBase Exp rep
e
  (Exp (Wise rep), Stms (Wise rep))
-> SimpleM rep (Exp (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Wise rep)
e', Stms (Wise rep)
forall a. Monoid a => a
mempty)

simplifyExpBase ::
  SimplifiableRep rep =>
  Exp rep ->
  SimpleM rep (Exp (Wise rep))
simplifyExpBase :: Exp rep -> SimpleM rep (Exp (Wise rep))
simplifyExpBase = Mapper rep (Wise rep) (SimpleM rep)
-> Exp rep -> SimpleM rep (Exp (Wise rep))
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep (Wise rep) (SimpleM rep)
hoist
  where
    hoist :: Mapper rep (Wise rep) (SimpleM rep)
hoist =
      Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
        { -- Bodies are handled explicitly because we need to
          -- provide their result diet.
          mapOnBody :: Scope (Wise rep) -> Body rep -> SimpleM rep (Body (Wise rep))
mapOnBody =
            [Char]
-> Scope (Wise rep) -> Body rep -> SimpleM rep (Body (Wise rep))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled body in simplification engine.",
          mapOnSubExp :: SubExp -> SimpleM rep SubExp
mapOnSubExp = SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify,
          -- Lambdas are handled explicitly because we need to
          -- bind their parameters.
          mapOnVName :: VName -> SimpleM rep VName
mapOnVName = VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify,
          mapOnRetType :: RetType rep -> SimpleM rep (RetType (Wise rep))
mapOnRetType = RetType rep -> SimpleM rep (RetType (Wise rep))
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify,
          mapOnBranchType :: BranchType rep -> SimpleM rep (BranchType (Wise rep))
mapOnBranchType = BranchType rep -> SimpleM rep (BranchType (Wise rep))
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify,
          mapOnFParam :: FParam rep -> SimpleM rep (FParam (Wise rep))
mapOnFParam =
            [Char] -> FParam rep -> SimpleM rep (FParam (Wise rep))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled FParam in simplification engine.",
          mapOnLParam :: LParam rep -> SimpleM rep (LParam (Wise rep))
mapOnLParam =
            [Char] -> LParam rep -> SimpleM rep (LParam (Wise rep))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled LParam in simplification engine.",
          mapOnOp :: Op rep -> SimpleM rep (Op (Wise rep))
mapOnOp =
            [Char] -> Op rep -> SimpleM rep (Op (Wise rep))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in simplification engine."
        }

type SimplifiableRep rep =
  ( ASTRep rep,
    Simplifiable (LetDec rep),
    Simplifiable (FParamInfo rep),
    Simplifiable (LParamInfo rep),
    Simplifiable (RetType rep),
    Simplifiable (BranchType rep),
    CanBeWise (Op rep),
    ST.IndexOp (OpWithWisdom (Op rep)),
    BuilderOps (Wise rep),
    IsOp (Op rep)
  )

class Simplifiable e where
  simplify :: SimplifiableRep rep => e -> SimpleM rep e

instance (Simplifiable a, Simplifiable b) => Simplifiable (a, b) where
  simplify :: (a, b) -> SimpleM rep (a, b)
simplify (a
x, b
y) = (,) (a -> b -> (a, b)) -> SimpleM rep a -> SimpleM rep (b -> (a, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify a
x SimpleM rep (b -> (a, b)) -> SimpleM rep b -> SimpleM rep (a, b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM rep b
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify b
y

instance
  (Simplifiable a, Simplifiable b, Simplifiable c) =>
  Simplifiable (a, b, c)
  where
  simplify :: (a, b, c) -> SimpleM rep (a, b, c)
simplify (a
x, b
y, c
z) = (,,) (a -> b -> c -> (a, b, c))
-> SimpleM rep a -> SimpleM rep (b -> c -> (a, b, c))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify a
x SimpleM rep (b -> c -> (a, b, c))
-> SimpleM rep b -> SimpleM rep (c -> (a, b, c))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM rep b
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify b
y SimpleM rep (c -> (a, b, c))
-> SimpleM rep c -> SimpleM rep (a, b, c)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> c -> SimpleM rep c
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify c
z

-- Convenient for Scatter.
instance Simplifiable Int where
  simplify :: Int -> SimpleM rep Int
simplify = Int -> SimpleM rep Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance Simplifiable a => Simplifiable (Maybe a) where
  simplify :: Maybe a -> SimpleM rep (Maybe a)
simplify Maybe a
Nothing = Maybe a -> SimpleM rep (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
  simplify (Just a
x) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> SimpleM rep a -> SimpleM rep (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify a
x

instance Simplifiable a => Simplifiable [a] where
  simplify :: [a] -> SimpleM rep [a]
simplify = (a -> SimpleM rep a) -> [a] -> SimpleM rep [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> SimpleM rep a
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify

instance Simplifiable SubExp where
  simplify :: SubExp -> SimpleM rep SubExp
simplify (Var VName
name) = do
    Maybe (SubExp, Certs)
stm <- VName -> SymbolTable (Wise rep) -> Maybe (SubExp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, Certs)
ST.lookupSubExp VName
name (SymbolTable (Wise rep) -> Maybe (SubExp, Certs))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Maybe (SubExp, Certs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
    case Maybe (SubExp, Certs)
stm of
      Just (Constant PrimValue
v, Certs
cs) -> do
        SimpleM rep ()
forall rep. SimpleM rep ()
changed
        Certs -> SimpleM rep ()
forall rep. Certs -> SimpleM rep ()
usedCerts Certs
cs
        SubExp -> SimpleM rep SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
      Just (Var VName
id', Certs
cs) -> do
        SimpleM rep ()
forall rep. SimpleM rep ()
changed
        Certs -> SimpleM rep ()
forall rep. Certs -> SimpleM rep ()
usedCerts Certs
cs
        SubExp -> SimpleM rep SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
      Maybe (SubExp, Certs)
_ -> SubExp -> SimpleM rep SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
  simplify (Constant PrimValue
v) =
    SubExp -> SimpleM rep SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v

instance Simplifiable SubExpRes where
  simplify :: SubExpRes -> SimpleM rep SubExpRes
simplify (SubExpRes Certs
cs SubExp
se) = do
    Certs
cs' <- Certs -> SimpleM rep Certs
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Certs
cs
    (SubExp
se', Certs
se_cs) <- SimpleM rep SubExp -> SimpleM rep (SubExp, Certs)
forall rep a. SimpleM rep a -> SimpleM rep (a, Certs)
collectCerts (SimpleM rep SubExp -> SimpleM rep (SubExp, Certs))
-> SimpleM rep SubExp -> SimpleM rep (SubExp, Certs)
forall a b. (a -> b) -> a -> b
$ SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
se
    SubExpRes -> SimpleM rep SubExpRes
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> SimpleM rep SubExpRes)
-> SubExpRes -> SimpleM rep SubExpRes
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> SubExpRes
SubExpRes (Certs
se_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs') SubExp
se'

simplifyPat ::
  (SimplifiableRep rep, Simplifiable dec) =>
  PatT dec ->
  SimpleM rep (PatT dec)
simplifyPat :: PatT dec -> SimpleM rep (PatT dec)
simplifyPat (Pat [PatElemT dec]
xs) =
  [PatElemT dec] -> PatT dec
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT dec] -> PatT dec)
-> SimpleM rep [PatElemT dec] -> SimpleM rep (PatT dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT dec -> SimpleM rep (PatElemT dec))
-> [PatElemT dec] -> SimpleM rep [PatElemT dec]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT dec -> SimpleM rep (PatElemT dec)
forall rep dec.
(ASTRep rep, Simplifiable dec, Simplifiable (LetDec rep),
 Simplifiable (FParamInfo rep), Simplifiable (LParamInfo rep),
 Simplifiable (RetType rep), Simplifiable (BranchType rep),
 CanBeWise (Op rep), IndexOp (OpWithWisdom (Op rep)),
 BuilderOps (Wise rep)) =>
PatElemT dec -> SimpleM rep (PatElemT dec)
inspect [PatElemT dec]
xs
  where
    inspect :: PatElemT dec -> SimpleM rep (PatElemT dec)
inspect (PatElem VName
name dec
rep) = VName -> dec -> PatElemT dec
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (dec -> PatElemT dec)
-> SimpleM rep dec -> SimpleM rep (PatElemT dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> dec -> SimpleM rep dec
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify dec
rep

instance Simplifiable () where
  simplify :: () -> SimpleM rep ()
simplify = () -> SimpleM rep ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance Simplifiable VName where
  simplify :: VName -> SimpleM rep VName
simplify VName
v = do
    Maybe (SubExp, Certs)
se <- VName -> SymbolTable (Wise rep) -> Maybe (SubExp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, Certs)
ST.lookupSubExp VName
v (SymbolTable (Wise rep) -> Maybe (SubExp, Certs))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Maybe (SubExp, Certs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
    case Maybe (SubExp, Certs)
se of
      Just (Var VName
v', Certs
cs) -> do
        SimpleM rep ()
forall rep. SimpleM rep ()
changed
        Certs -> SimpleM rep ()
forall rep. Certs -> SimpleM rep ()
usedCerts Certs
cs
        VName -> SimpleM rep VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v'
      Maybe (SubExp, Certs)
_ -> VName -> SimpleM rep VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v

instance Simplifiable d => Simplifiable (ShapeBase d) where
  simplify :: ShapeBase d -> SimpleM rep (ShapeBase d)
simplify = ([d] -> ShapeBase d)
-> SimpleM rep [d] -> SimpleM rep (ShapeBase d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape (SimpleM rep [d] -> SimpleM rep (ShapeBase d))
-> (ShapeBase d -> SimpleM rep [d])
-> ShapeBase d
-> SimpleM rep (ShapeBase d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [d] -> SimpleM rep [d]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify ([d] -> SimpleM rep [d])
-> (ShapeBase d -> [d]) -> ShapeBase d -> SimpleM rep [d]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims

instance Simplifiable ExtSize where
  simplify :: ExtSize -> SimpleM rep ExtSize
simplify (Free SubExp
se) = SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize) -> SimpleM rep SubExp -> SimpleM rep ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify SubExp
se
  simplify (Ext Int
x) = ExtSize -> SimpleM rep ExtSize
forall (m :: * -> *) a. Monad m => a -> m a
return (ExtSize -> SimpleM rep ExtSize) -> ExtSize -> SimpleM rep ExtSize
forall a b. (a -> b) -> a -> b
$ Int -> ExtSize
forall a. Int -> Ext a
Ext Int
x

instance Simplifiable Space where
  simplify :: Space -> SimpleM rep Space
simplify (ScalarSpace [SubExp]
ds PrimType
t) = [SubExp] -> PrimType -> Space
ScalarSpace ([SubExp] -> PrimType -> Space)
-> SimpleM rep [SubExp] -> SimpleM rep (PrimType -> Space)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> SimpleM rep [SubExp]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [SubExp]
ds SimpleM rep (PrimType -> Space)
-> SimpleM rep PrimType -> SimpleM rep Space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM rep PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
  simplify Space
s = Space -> SimpleM rep Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
s

instance Simplifiable PrimType where
  simplify :: PrimType -> SimpleM rep PrimType
simplify = PrimType -> SimpleM rep PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance Simplifiable shape => Simplifiable (TypeBase shape u) where
  simplify :: TypeBase shape u -> SimpleM rep (TypeBase shape u)
simplify (Array PrimType
et shape
shape u
u) =
    PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (PrimType -> shape -> u -> TypeBase shape u)
-> SimpleM rep PrimType
-> SimpleM rep (shape -> u -> TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> SimpleM rep PrimType
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify PrimType
et SimpleM rep (shape -> u -> TypeBase shape u)
-> SimpleM rep shape -> SimpleM rep (u -> TypeBase shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> shape -> SimpleM rep shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify shape
shape SimpleM rep (u -> TypeBase shape u)
-> SimpleM rep u -> SimpleM rep (TypeBase shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
  simplify (Acc VName
acc Shape
ispace [Type]
ts u
u) =
    VName -> Shape -> [Type] -> u -> TypeBase shape u
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc (VName -> Shape -> [Type] -> u -> TypeBase shape u)
-> SimpleM rep VName
-> SimpleM rep (Shape -> [Type] -> u -> TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SimpleM rep VName
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify VName
acc SimpleM rep (Shape -> [Type] -> u -> TypeBase shape u)
-> SimpleM rep Shape
-> SimpleM rep ([Type] -> u -> TypeBase shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape -> SimpleM rep Shape
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Shape
ispace SimpleM rep ([Type] -> u -> TypeBase shape u)
-> SimpleM rep [Type] -> SimpleM rep (u -> TypeBase shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> SimpleM rep [Type]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [Type]
ts SimpleM rep (u -> TypeBase shape u)
-> SimpleM rep u -> SimpleM rep (TypeBase shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> SimpleM rep u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u
  simplify (Mem Space
space) =
    Space -> TypeBase shape u
forall shape u. Space -> TypeBase shape u
Mem (Space -> TypeBase shape u)
-> SimpleM rep Space -> SimpleM rep (TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space -> SimpleM rep Space
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify Space
space
  simplify (Prim PrimType
bt) =
    TypeBase shape u -> SimpleM rep (TypeBase shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase shape u -> SimpleM rep (TypeBase shape u))
-> TypeBase shape u -> SimpleM rep (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt

instance Simplifiable d => Simplifiable (DimIndex d) where
  simplify :: DimIndex d -> SimpleM rep (DimIndex d)
simplify (DimFix d
i) = d -> DimIndex d
forall d. d -> DimIndex d
DimFix (d -> DimIndex d) -> SimpleM rep d -> SimpleM rep (DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
i
  simplify (DimSlice d
i d
n d
s) = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice (d -> d -> d -> DimIndex d)
-> SimpleM rep d -> SimpleM rep (d -> d -> DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
i SimpleM rep (d -> d -> DimIndex d)
-> SimpleM rep d -> SimpleM rep (d -> DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
n SimpleM rep (d -> DimIndex d)
-> SimpleM rep d -> SimpleM rep (DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify d
s

instance Simplifiable d => Simplifiable (Slice d) where
  simplify :: Slice d -> SimpleM rep (Slice d)
simplify = (d -> SimpleM rep d) -> Slice d -> SimpleM rep (Slice d)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse d -> SimpleM rep d
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify

simplifyLambda ::
  SimplifiableRep rep =>
  Lambda rep ->
  SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda :: Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambda Lambda rep
lam = do
  BlockPred (Wise rep)
par_blocker <- (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall rep a. (Env rep -> a) -> SimpleM rep a
asksEngineEnv ((Env rep -> BlockPred (Wise rep))
 -> SimpleM rep (BlockPred (Wise rep)))
-> (Env rep -> BlockPred (Wise rep))
-> SimpleM rep (BlockPred (Wise rep))
forall a b. (a -> b) -> a -> b
$ HoistBlockers rep -> BlockPred (Wise rep)
forall rep. HoistBlockers rep -> BlockPred (Wise rep)
blockHoistPar (HoistBlockers rep -> BlockPred (Wise rep))
-> (Env rep -> HoistBlockers rep)
-> Env rep
-> BlockPred (Wise rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep -> HoistBlockers rep
forall rep. Env rep -> HoistBlockers rep
envHoistBlockers
  BlockPred (Wise rep)
-> Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaMaybeHoist BlockPred (Wise rep)
par_blocker Lambda rep
lam

simplifyLambdaNoHoisting ::
  SimplifiableRep rep =>
  Lambda rep ->
  SimpleM rep (Lambda (Wise rep))
simplifyLambdaNoHoisting :: Lambda rep -> SimpleM rep (Lambda (Wise rep))
simplifyLambdaNoHoisting Lambda rep
lam =
  (Lambda (Wise rep), Stms (Wise rep)) -> Lambda (Wise rep)
forall a b. (a, b) -> a
fst ((Lambda (Wise rep), Stms (Wise rep)) -> Lambda (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BlockPred (Wise rep)
-> Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaMaybeHoist (Bool -> BlockPred (Wise rep)
forall rep. Bool -> BlockPred rep
isFalse Bool
False) Lambda rep
lam

simplifyLambdaMaybeHoist ::
  SimplifiableRep rep =>
  BlockPred (Wise rep) ->
  Lambda rep ->
  SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaMaybeHoist :: BlockPred (Wise rep)
-> Lambda rep -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
simplifyLambdaMaybeHoist BlockPred (Wise rep)
blocked lam :: Lambda rep
lam@(Lambda [LParam rep]
params BodyT rep
body [Type]
rettype) = do
  [LParam rep]
params' <- (LParam rep -> SimpleM rep (LParam rep))
-> [LParam rep] -> SimpleM rep [LParam rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamInfo rep -> SimpleM rep (LParamInfo rep))
-> LParam rep -> SimpleM rep (LParam rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LParamInfo rep -> SimpleM rep (LParamInfo rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify) [LParam rep]
params
  let paramnames :: Names
paramnames = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [VName]
forall rep. Lambda rep -> [VName]
boundByLambda Lambda rep
lam
  ((Stms (Wise rep)
lamstms, Result
lamres), Stms (Wise rep)
hoisted) <-
    SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a. SimpleM rep a -> SimpleM rep a
enterLoop (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
      [LParam (Wise rep)]
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
[LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindLParams [LParam rep]
[LParam (Wise rep)]
params' (SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
        BlockPred (Wise rep)
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> SimpleM rep (SimplifiedBody rep a)
-> SimpleM rep ((Stms (Wise rep), a), Stms (Wise rep))
blockIf (BlockPred (Wise rep)
blocked BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` Names -> BlockPred (Wise rep)
forall rep. ASTRep rep => Names -> BlockPred rep
hasFree Names
paramnames BlockPred (Wise rep)
-> BlockPred (Wise rep) -> BlockPred (Wise rep)
forall rep. BlockPred rep -> BlockPred rep -> BlockPred rep
`orIf` BlockPred (Wise rep)
forall rep. BlockPred rep
isConsumed) (SimpleM rep (SimplifiedBody rep Result)
 -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall a b. (a -> b) -> a -> b
$
          [Diet] -> BodyT rep -> SimpleM rep (SimplifiedBody rep Result)
forall rep.
SimplifiableRep rep =>
[Diet] -> Body rep -> SimpleM rep (SimplifiedBody rep Result)
simplifyBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) [Type]
rettype) BodyT rep
body
  Body (Wise rep)
body' <- Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall rep.
SimplifiableRep rep =>
Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
constructBody Stms (Wise rep)
lamstms Result
lamres
  [Type]
rettype' <- [Type] -> SimpleM rep [Type]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [Type]
rettype
  (Lambda (Wise rep), Stms (Wise rep))
-> SimpleM rep (Lambda (Wise rep), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return ([LParam (Wise rep)]
-> Body (Wise rep) -> [Type] -> Lambda (Wise rep)
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam rep]
[LParam (Wise rep)]
params' Body (Wise rep)
body' [Type]
rettype', Stms (Wise rep)
hoisted)

consumeResult :: ST.SymbolTable rep -> [(Diet, SubExpRes)] -> UT.UsageTable
consumeResult :: SymbolTable rep -> [(Diet, SubExpRes)] -> UsageTable
consumeResult SymbolTable rep
vtable = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> ([(Diet, SubExpRes)] -> [UsageTable])
-> [(Diet, SubExpRes)]
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Diet, SubExpRes) -> UsageTable)
-> [(Diet, SubExpRes)] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (Diet, SubExpRes) -> UsageTable
inspect
  where
    inspect :: (Diet, SubExpRes) -> UsageTable
inspect (Diet
Consume, SubExpRes Certs
_ (Var VName
v)) =
      [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map VName -> UsageTable
UT.consumedUsage ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: Names -> [VName]
namesToList (VName -> SymbolTable rep -> Names
forall rep. VName -> SymbolTable rep -> Names
ST.lookupAliases VName
v SymbolTable rep
vtable)
    inspect (Diet, SubExpRes)
_ = UsageTable
forall a. Monoid a => a
mempty

instance Simplifiable Certs where
  simplify :: Certs -> SimpleM rep Certs
simplify (Certs [VName]
ocs) = [VName] -> Certs
Certs ([VName] -> Certs) -> ([[VName]] -> [VName]) -> [[VName]] -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [VName]
forall a. Ord a => [a] -> [a]
nubOrd ([VName] -> [VName])
-> ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> Certs) -> SimpleM rep [[VName]] -> SimpleM rep Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> SimpleM rep [VName]) -> [VName] -> SimpleM rep [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM rep [VName]
forall rep. VName -> SimpleM rep [VName]
check [VName]
ocs
    where
      check :: VName -> SimpleM rep [VName]
check VName
idd = do
        Maybe (SubExp, Certs)
vv <- VName -> SymbolTable (Wise rep) -> Maybe (SubExp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, Certs)
ST.lookupSubExp VName
idd (SymbolTable (Wise rep) -> Maybe (SubExp, Certs))
-> SimpleM rep (SymbolTable (Wise rep))
-> SimpleM rep (Maybe (SubExp, Certs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM rep (SymbolTable (Wise rep))
forall rep. SimpleM rep (SymbolTable (Wise rep))
askVtable
        case Maybe (SubExp, Certs)
vv of
          Just (Constant PrimValue
_, Certs [VName]
cs) -> [VName] -> SimpleM rep [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
cs
          Just (Var VName
idd', Certs
_) -> [VName] -> SimpleM rep [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd']
          Maybe (SubExp, Certs)
_ -> [VName] -> SimpleM rep [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd]

insertAllStms ::
  SimplifiableRep rep =>
  SimpleM rep (SimplifiedBody rep Result) ->
  SimpleM rep (Body (Wise rep))
insertAllStms :: SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep (Body (Wise rep))
insertAllStms = (Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)))
-> (Stms (Wise rep), Result) -> SimpleM rep (Body (Wise rep))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
forall rep.
SimplifiableRep rep =>
Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep))
constructBody ((Stms (Wise rep), Result) -> SimpleM rep (Body (Wise rep)))
-> (((Stms (Wise rep), Result), Stms (Wise rep))
    -> (Stms (Wise rep), Result))
-> ((Stms (Wise rep), Result), Stms (Wise rep))
-> SimpleM rep (Body (Wise rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Stms (Wise rep), Result), Stms (Wise rep))
-> (Stms (Wise rep), Result)
forall a b. (a, b) -> a
fst (((Stms (Wise rep), Result), Stms (Wise rep))
 -> SimpleM rep (Body (Wise rep)))
-> (SimpleM rep (SimplifiedBody rep Result)
    -> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep)))
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep (Body (Wise rep))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BlockPred (Wise rep)
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep ((Stms (Wise rep), Result), Stms (Wise rep))
forall rep a.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> SimpleM rep (SimplifiedBody rep a)
-> SimpleM rep ((Stms (Wise rep), a), Stms (Wise rep))
blockIf (Bool -> BlockPred (Wise rep)
forall rep. Bool -> BlockPred rep
isFalse Bool
False)

simplifyFun ::
  SimplifiableRep rep =>
  FunDef rep ->
  SimpleM rep (FunDef (Wise rep))
simplifyFun :: FunDef rep -> SimpleM rep (FunDef (Wise rep))
simplifyFun (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType rep]
rettype [FParam rep]
params BodyT rep
body) = do
  [RetType rep]
rettype' <- [RetType rep] -> SimpleM rep [RetType rep]
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify [RetType rep]
rettype
  [FParam rep]
params' <- (FParam rep -> SimpleM rep (FParam rep))
-> [FParam rep] -> SimpleM rep [FParam rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo rep -> SimpleM rep (FParamInfo rep))
-> FParam rep -> SimpleM rep (FParam rep)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo rep -> SimpleM rep (FParamInfo rep)
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify) [FParam rep]
params
  let ds :: [Diet]
ds = (RetType rep -> Diet) -> [RetType rep] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase ExtShape Uniqueness -> Diet)
-> (RetType rep -> TypeBase ExtShape Uniqueness)
-> RetType rep
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetType rep -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf) [RetType rep]
rettype'
  Body (Wise rep)
body' <- [FParam (Wise rep)]
-> SimpleM rep (Body (Wise rep)) -> SimpleM rep (Body (Wise rep))
forall rep a.
SimplifiableRep rep =>
[FParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a
bindFParams [FParam rep]
[FParam (Wise rep)]
params (SimpleM rep (Body (Wise rep)) -> SimpleM rep (Body (Wise rep)))
-> SimpleM rep (Body (Wise rep)) -> SimpleM rep (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep (Body (Wise rep))
forall rep.
SimplifiableRep rep =>
SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep (Body (Wise rep))
insertAllStms (SimpleM rep (SimplifiedBody rep Result)
 -> SimpleM rep (Body (Wise rep)))
-> SimpleM rep (SimplifiedBody rep Result)
-> SimpleM rep (Body (Wise rep))
forall a b. (a -> b) -> a -> b
$ [Diet] -> BodyT rep -> SimpleM rep (SimplifiedBody rep Result)
forall rep.
SimplifiableRep rep =>
[Diet] -> Body rep -> SimpleM rep (SimplifiedBody rep Result)
simplifyBody [Diet]
ds BodyT rep
body
  FunDef (Wise rep) -> SimpleM rep (FunDef (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef (Wise rep) -> SimpleM rep (FunDef (Wise rep)))
-> FunDef (Wise rep) -> SimpleM rep (FunDef (Wise rep))
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType (Wise rep)]
-> [FParam (Wise rep)]
-> Body (Wise rep)
-> FunDef (Wise rep)
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> BodyT rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType rep]
[RetType (Wise rep)]
rettype' [FParam rep]
[FParam (Wise rep)]
params' Body (Wise rep)
body'