{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TypeFamilies, FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Strict #-}
-- |
--
-- Perform general rule-based simplification based on data dependency
-- information.  This module will:
--
--    * Perform common-subexpression elimination (CSE).
--
--    * Hoist expressions out of loops (including lambdas) and
--    branches.  This is done as aggressively as possible.
--
--    * Apply simplification rules (see
--    "Futhark.Optimise.Simplification.Rules").
--
-- If you just want to run the simplifier as simply as possible, you
-- may prefer to use the "Futhark.Optimise.Simplify" module.
--
module Futhark.Optimise.Simplify.Engine
       ( -- * Monadic interface
         SimpleM
       , runSimpleM
       , SimpleOps (..)
       , SimplifyOp
       , bindableSimpleOps

       , Env (envHoistBlockers, envRules)
       , emptyEnv
       , HoistBlockers(..)
       , neverBlocks
       , noExtraHoistBlockers
       , neverHoist
       , BlockPred
       , orIf
       , hasFree
       , isConsumed
       , isFalse
       , isOp
       , isNotSafe
       , asksEngineEnv
       , askVtable
       , localVtable

         -- * Building blocks
       , SimplifiableLore
       , Simplifiable (..)
       , simplifyStms
       , simplifyFun
       , simplifyLambda
       , simplifyLambdaNoHoisting
       , bindLParams
       , simplifyBody
       , SimplifiedBody
       , ST.SymbolTable

       , hoistStms
       , blockIf

       , module Futhark.Optimise.Simplify.Lore
       ) where

import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Either
import Data.List (find, foldl', nub, mapAccumL)
import Data.Maybe

import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.Optimise.Simplify.Lore
import Futhark.Util (splitFromEnd)

data HoistBlockers lore = HoistBlockers
                          { HoistBlockers lore -> BlockPred (Wise lore)
blockHoistPar :: BlockPred (Wise lore)
                            -- ^ Blocker for hoisting out of parallel loops.
                          , HoistBlockers lore -> BlockPred (Wise lore)
blockHoistSeq :: BlockPred (Wise lore)
                            -- ^ Blocker for hoisting out of sequential loops.
                          , HoistBlockers lore -> BlockPred (Wise lore)
blockHoistBranch :: BlockPred (Wise lore)
                            -- ^ Blocker for hoisting out of branches.
                          , HoistBlockers lore -> Stm (Wise lore) -> Bool
isAllocation  :: Stm (Wise lore) -> Bool
                          }

noExtraHoistBlockers :: HoistBlockers lore
noExtraHoistBlockers :: HoistBlockers lore
noExtraHoistBlockers =
  BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
forall lore.
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
HoistBlockers BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks (Bool -> Stm (Wise lore) -> Bool
forall a b. a -> b -> a
const Bool
False)

neverHoist :: HoistBlockers lore
neverHoist :: HoistBlockers lore
neverHoist =
  BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
forall lore.
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
HoistBlockers BlockPred (Wise lore)
forall lore. BlockPred lore
alwaysBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
alwaysBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
alwaysBlocks (Bool -> Stm (Wise lore) -> Bool
forall a b. a -> b -> a
const Bool
False)

data Env lore = Env { Env lore -> RuleBook (Wise lore)
envRules         :: RuleBook (Wise lore)
                    , Env lore -> HoistBlockers lore
envHoistBlockers :: HoistBlockers lore
                    , Env lore -> SymbolTable (Wise lore)
envVtable        :: ST.SymbolTable (Wise lore)
                    }

emptyEnv :: RuleBook (Wise lore) -> HoistBlockers lore -> Env lore
emptyEnv :: RuleBook (Wise lore) -> HoistBlockers lore -> Env lore
emptyEnv RuleBook (Wise lore)
rules HoistBlockers lore
blockers =
  Env :: forall lore.
RuleBook (Wise lore)
-> HoistBlockers lore -> SymbolTable (Wise lore) -> Env lore
Env { envRules :: RuleBook (Wise lore)
envRules = RuleBook (Wise lore)
rules
      , envHoistBlockers :: HoistBlockers lore
envHoistBlockers = HoistBlockers lore
blockers
      , envVtable :: SymbolTable (Wise lore)
envVtable = SymbolTable (Wise lore)
forall a. Monoid a => a
mempty
      }

type Protect m = SubExp -> Pattern (Lore m) -> Op (Lore m) -> Maybe (m ())

data SimpleOps lore =
  SimpleOps { SimpleOps lore
-> SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpDec (Wise lore))
mkExpDecS :: ST.SymbolTable (Wise lore)
                         -> Pattern (Wise lore) -> Exp (Wise lore)
                         -> SimpleM lore (ExpDec (Wise lore))
            , SimpleOps lore
-> SymbolTable (Wise lore)
-> Stms (Wise lore)
-> Result
-> SimpleM lore (Body (Wise lore))
mkBodyS :: ST.SymbolTable (Wise lore)
                      -> Stms (Wise lore) -> Result
                      -> SimpleM lore (Body (Wise lore))
            , SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS :: Protect (Binder (Wise lore))
              -- ^ Make a hoisted Op safe.  The SubExp is a boolean
              -- that is true when the value of the statement will
              -- actually be used.
            , SimpleOps lore -> Op (Wise lore) -> UsageTable
opUsageS :: Op (Wise lore) -> UT.UsageTable
            , SimpleOps lore -> SimplifyOp lore (Op lore)
simplifyOpS :: SimplifyOp lore (Op lore)
            }

type SimplifyOp lore op = op -> SimpleM lore (OpWithWisdom op, Stms (Wise lore))

bindableSimpleOps :: (SimplifiableLore lore, Bindable lore) =>
                     SimplifyOp lore (Op lore) -> SimpleOps lore
bindableSimpleOps :: SimplifyOp lore (Op lore) -> SimpleOps lore
bindableSimpleOps =
  (SymbolTable (Wise lore)
 -> Pattern (Wise lore)
 -> Exp (Wise lore)
 -> SimpleM lore (ExpDec (Wise lore)))
-> (SymbolTable (Wise lore)
    -> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> Protect (Binder (Wise lore))
-> (Op (Wise lore) -> UsageTable)
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
forall lore.
(SymbolTable (Wise lore)
 -> Pattern (Wise lore)
 -> Exp (Wise lore)
 -> SimpleM lore (ExpDec (Wise lore)))
-> (SymbolTable (Wise lore)
    -> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> Protect (Binder (Wise lore))
-> (Op (Wise lore) -> UsageTable)
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
SimpleOps SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpDec (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Bindable lore) =>
p -> PatternT (LetDec lore) -> Exp lore -> m (ExpDec lore)
mkExpDecS' SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Bindable lore) =>
p -> Stms lore -> Result -> m (Body lore)
mkBodyS' Protect (Binder (Wise lore))
forall p p p a. p -> p -> p -> Maybe a
protectHoistedOpS' (UsageTable -> OpWithWisdom (Op lore) -> UsageTable
forall a b. a -> b -> a
const UsageTable
forall a. Monoid a => a
mempty)
  where mkExpDecS' :: p -> PatternT (LetDec lore) -> Exp lore -> m (ExpDec lore)
mkExpDecS' p
_ PatternT (LetDec lore)
pat Exp lore
e = ExpDec lore -> m (ExpDec lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpDec lore -> m (ExpDec lore)) -> ExpDec lore -> m (ExpDec lore)
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec lore) -> Exp lore -> ExpDec lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpDec lore
mkExpDec PatternT (LetDec lore)
pat Exp lore
e
        mkBodyS' :: p -> Stms lore -> Result -> m (Body lore)
mkBodyS' p
_ Stms lore
bnds Result
res = Body lore -> m (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> m (Body lore)) -> Body lore -> m (Body lore)
forall a b. (a -> b) -> a -> b
$ Stms lore -> Result -> Body lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms lore
bnds Result
res
        protectHoistedOpS' :: p -> p -> p -> Maybe a
protectHoistedOpS' p
_ p
_ p
_ = Maybe a
forall a. Maybe a
Nothing

newtype SimpleM lore a =
  SimpleM (ReaderT (SimpleOps lore, Env lore)
           (State (VNameSource, Bool, Certificates)) a)
  deriving (Functor (SimpleM lore)
a -> SimpleM lore a
Functor (SimpleM lore)
-> (forall a. a -> SimpleM lore a)
-> (forall a b.
    SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b)
-> (forall a b c.
    (a -> b -> c)
    -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a)
-> Applicative (SimpleM lore)
SimpleM lore a -> SimpleM lore b -> SimpleM lore b
SimpleM lore a -> SimpleM lore b -> SimpleM lore a
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall lore. Functor (SimpleM lore)
forall a. a -> SimpleM lore a
forall lore a. a -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
forall a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall lore a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SimpleM lore a -> SimpleM lore b -> SimpleM lore a
$c<* :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
*> :: SimpleM lore a -> SimpleM lore b -> SimpleM lore b
$c*> :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
liftA2 :: (a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
<*> :: SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
$c<*> :: forall lore a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
pure :: a -> SimpleM lore a
$cpure :: forall lore a. a -> SimpleM lore a
$cp1Applicative :: forall lore. Functor (SimpleM lore)
Applicative, a -> SimpleM lore b -> SimpleM lore a
(a -> b) -> SimpleM lore a -> SimpleM lore b
(forall a b. (a -> b) -> SimpleM lore a -> SimpleM lore b)
-> (forall a b. a -> SimpleM lore b -> SimpleM lore a)
-> Functor (SimpleM lore)
forall a b. a -> SimpleM lore b -> SimpleM lore a
forall a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
forall lore a b. a -> SimpleM lore b -> SimpleM lore a
forall lore a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SimpleM lore b -> SimpleM lore a
$c<$ :: forall lore a b. a -> SimpleM lore b -> SimpleM lore a
fmap :: (a -> b) -> SimpleM lore a -> SimpleM lore b
$cfmap :: forall lore a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
Functor, Applicative (SimpleM lore)
a -> SimpleM lore a
Applicative (SimpleM lore)
-> (forall a b.
    SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b)
-> (forall a. a -> SimpleM lore a)
-> Monad (SimpleM lore)
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore. Applicative (SimpleM lore)
forall a. a -> SimpleM lore a
forall lore a. a -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SimpleM lore a
$creturn :: forall lore a. a -> SimpleM lore a
>> :: SimpleM lore a -> SimpleM lore b -> SimpleM lore b
$c>> :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
>>= :: SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
$c>>= :: forall lore a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
$cp1Monad :: forall lore. Applicative (SimpleM lore)
Monad,
            MonadReader (SimpleOps lore, Env lore),
            MonadState (VNameSource, Bool, Certificates))

instance MonadFreshNames (SimpleM lore) where
  putNameSource :: VNameSource -> SimpleM lore ()
putNameSource VNameSource
src = ((VNameSource, Bool, Certificates)
 -> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
  -> (VNameSource, Bool, Certificates))
 -> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
    -> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
_, Bool
b, Certificates
c) -> (VNameSource
src, Bool
b, Certificates
c)
  getNameSource :: SimpleM lore VNameSource
getNameSource = ((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((VNameSource, Bool, Certificates) -> VNameSource)
 -> SimpleM lore VNameSource)
-> ((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
_, Certificates
_) -> VNameSource
a

instance SimplifiableLore lore => HasScope (Wise lore) (SimpleM lore) where
  askScope :: SimpleM lore (Scope (Wise lore))
askScope = SymbolTable (Wise lore) -> Scope (Wise lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope (SymbolTable (Wise lore) -> Scope (Wise lore))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Scope (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
  lookupType :: VName -> SimpleM lore Type
lookupType VName
name = do
    SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
    case VName -> SymbolTable (Wise lore) -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
name SymbolTable (Wise lore)
vtable of
      Just Type
t -> Type -> SimpleM lore Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
      Maybe Type
Nothing -> [Char] -> SimpleM lore Type
forall a. HasCallStack => [Char] -> a
error ([Char] -> SimpleM lore Type) -> [Char] -> SimpleM lore Type
forall a b. (a -> b) -> a -> b
$
                 [Char]
"SimpleM.lookupType: cannot find variable " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
                 VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" in symbol table."

instance SimplifiableLore lore =>
         LocalScope (Wise lore) (SimpleM lore) where
  localScope :: Scope (Wise lore) -> SimpleM lore a -> SimpleM lore a
localScope Scope (Wise lore)
types = (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
types)

runSimpleM :: SimpleM lore a
           -> SimpleOps lore
           -> Env lore
           -> VNameSource
           -> ((a, Bool), VNameSource)
runSimpleM :: SimpleM lore a
-> SimpleOps lore
-> Env lore
-> VNameSource
-> ((a, Bool), VNameSource)
runSimpleM (SimpleM ReaderT
  (SimpleOps lore, Env lore)
  (State (VNameSource, Bool, Certificates))
  a
m) SimpleOps lore
simpl Env lore
env VNameSource
src =
  let (a
x, (VNameSource
src', Bool
b, Certificates
_)) = State (VNameSource, Bool, Certificates) a
-> (VNameSource, Bool, Certificates)
-> (a, (VNameSource, Bool, Certificates))
forall s a. State s a -> s -> (a, s)
runState (ReaderT
  (SimpleOps lore, Env lore)
  (State (VNameSource, Bool, Certificates))
  a
-> (SimpleOps lore, Env lore)
-> State (VNameSource, Bool, Certificates) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (SimpleOps lore, Env lore)
  (State (VNameSource, Bool, Certificates))
  a
m (SimpleOps lore
simpl, Env lore
env)) (VNameSource
src, Bool
False, Certificates
forall a. Monoid a => a
mempty)
  in ((a
x, Bool
b), VNameSource
src')

askEngineEnv :: SimpleM lore (Env lore)
askEngineEnv :: SimpleM lore (Env lore)
askEngineEnv = ((SimpleOps lore, Env lore) -> Env lore) -> SimpleM lore (Env lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (SimpleOps lore, Env lore) -> Env lore
forall a b. (a, b) -> b
snd

asksEngineEnv :: (Env lore -> a) -> SimpleM lore a
asksEngineEnv :: (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> a
f = Env lore -> a
f (Env lore -> a) -> SimpleM lore (Env lore) -> SimpleM lore a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (Env lore)
forall lore. SimpleM lore (Env lore)
askEngineEnv

askVtable :: SimpleM lore (ST.SymbolTable (Wise lore))
askVtable :: SimpleM lore (SymbolTable (Wise lore))
askVtable = (Env lore -> SymbolTable (Wise lore))
-> SimpleM lore (SymbolTable (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> SymbolTable (Wise lore)
forall lore. Env lore -> SymbolTable (Wise lore)
envVtable

localVtable :: (ST.SymbolTable (Wise lore) -> ST.SymbolTable (Wise lore))
            -> SimpleM lore a -> SimpleM lore a
localVtable :: (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable SymbolTable (Wise lore) -> SymbolTable (Wise lore)
f = ((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a -> SimpleM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
 -> SimpleM lore a -> SimpleM lore a)
-> ((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \(SimpleOps lore
ops, Env lore
env) -> (SimpleOps lore
ops, Env lore
env { envVtable :: SymbolTable (Wise lore)
envVtable = SymbolTable (Wise lore) -> SymbolTable (Wise lore)
f (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ Env lore -> SymbolTable (Wise lore)
forall lore. Env lore -> SymbolTable (Wise lore)
envVtable Env lore
env })

collectCerts :: SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts :: SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts SimpleM lore a
m = do a
x <- SimpleM lore a
m
                    (VNameSource
a, Bool
b, Certificates
cs) <- SimpleM lore (VNameSource, Bool, Certificates)
forall s (m :: * -> *). MonadState s m => m s
get
                    (VNameSource, Bool, Certificates) -> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (VNameSource
a, Bool
b, Certificates
forall a. Monoid a => a
mempty)
                    (a, Certificates) -> SimpleM lore (a, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Certificates
cs)

-- | Mark that we have changed something and it would be a good idea
-- to re-run the simplifier.
changed :: SimpleM lore ()
changed :: SimpleM lore ()
changed = ((VNameSource, Bool, Certificates)
 -> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
  -> (VNameSource, Bool, Certificates))
 -> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
    -> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
src, Bool
_, Certificates
cs) -> (VNameSource
src, Bool
True, Certificates
cs)

usedCerts :: Certificates -> SimpleM lore ()
usedCerts :: Certificates -> SimpleM lore ()
usedCerts Certificates
cs = ((VNameSource, Bool, Certificates)
 -> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
  -> (VNameSource, Bool, Certificates))
 -> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
    -> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
b, Certificates
c) -> (VNameSource
a, Bool
b, Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
c)

enterLoop :: SimpleM lore a -> SimpleM lore a
enterLoop :: SimpleM lore a -> SimpleM lore a
enterLoop = (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. SymbolTable lore -> SymbolTable lore
ST.deepen

bindFParams :: SimplifiableLore lore => [FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams :: [FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams [FParam (Wise lore)]
params =
  (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
 -> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ [FParam (Wise lore)]
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
[FParam lore] -> SymbolTable lore -> SymbolTable lore
ST.insertFParams [FParam (Wise lore)]
params

bindLParams :: SimplifiableLore lore => [LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams :: [LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams [LParam (Wise lore)]
params =
  (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
 -> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise lore)
vtable -> (Param (LParamInfo lore)
 -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Param (LParamInfo lore)]
-> SymbolTable (Wise lore)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Param (LParamInfo lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
LParam lore -> SymbolTable lore -> SymbolTable lore
ST.insertLParam SymbolTable (Wise lore)
vtable [Param (LParamInfo lore)]
[LParam (Wise lore)]
params

bindArrayLParams :: SimplifiableLore lore => [LParam (Wise lore)] -> SimpleM lore a
                 -> SimpleM lore a
bindArrayLParams :: [LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindArrayLParams [LParam (Wise lore)]
params =
  (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
 -> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise lore)
vtable -> (SymbolTable (Wise lore)
 -> Param (LParamInfo lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Param (LParamInfo lore)]
-> SymbolTable (Wise lore)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Param (LParamInfo lore)
 -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> Param (LParamInfo lore)
-> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Param (LParamInfo lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
LParam lore -> SymbolTable lore -> SymbolTable lore
ST.insertLParam) SymbolTable (Wise lore)
vtable [Param (LParamInfo lore)]
[LParam (Wise lore)]
params

bindLoopVar :: SimplifiableLore lore => VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar :: VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar VName
var IntType
it SubExp
bound =
  (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
 -> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ VName
-> IntType
-> SubExp
-> SymbolTable (Wise lore)
-> SymbolTable (Wise lore)
forall lore.
ASTLore lore =>
VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.insertLoopVar VName
var IntType
it SubExp
bound

-- | We are willing to hoist potentially unsafe statements out of
-- branches, but they most be protected by adding a branch on top of
-- them.  (This means such hoisting is not worth it unless they are in
-- turn hoisted out of a loop somewhere.)
protectIfHoisted :: SimplifiableLore lore =>
                    SubExp -- ^ Branch condition.
                 -> Bool -- ^ Which side of the branch are we
                         -- protecting here?
                 -> SimpleM lore (a, Stms (Wise lore))
                 -> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted :: SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
side SimpleM lore (a, Stms (Wise lore))
m = do
  (a
x, Stms (Wise lore)
stms) <- SimpleM lore (a, Stms (Wise lore))
m
  SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
ops <- ((SimpleOps lore, Env lore)
 -> SubExp
 -> PatternT (VarWisdom, LetDec lore)
 -> OpWithWisdom (Op lore)
 -> Maybe (Binder (Wise lore) ()))
-> SimpleM
     lore
     (SubExp
      -> PatternT (VarWisdom, LetDec lore)
      -> OpWithWisdom (Op lore)
      -> Maybe (Binder (Wise lore) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
  -> SubExp
  -> PatternT (VarWisdom, LetDec lore)
  -> OpWithWisdom (Op lore)
  -> Maybe (Binder (Wise lore) ()))
 -> SimpleM
      lore
      (SubExp
       -> PatternT (VarWisdom, LetDec lore)
       -> OpWithWisdom (Op lore)
       -> Maybe (Binder (Wise lore) ())))
-> ((SimpleOps lore, Env lore)
    -> SubExp
    -> PatternT (VarWisdom, LetDec lore)
    -> OpWithWisdom (Op lore)
    -> Maybe (Binder (Wise lore) ()))
-> SimpleM
     lore
     (SubExp
      -> PatternT (VarWisdom, LetDec lore)
      -> OpWithWisdom (Op lore)
      -> Maybe (Binder (Wise lore) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall lore. SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS (SimpleOps lore
 -> SubExp
 -> PatternT (VarWisdom, LetDec lore)
 -> OpWithWisdom (Op lore)
 -> Maybe (Binder (Wise lore) ()))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
  Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore)))
-> Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
    if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise lore) -> Bool) -> Stms (Wise lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp (Wise lore) -> Bool)
-> (Stm (Wise lore) -> Exp (Wise lore)) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Wise lore)
stms
      then do SubExp
cond' <- if Bool
side then SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond
                       else [Char]
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"cond_neg" (Exp (Lore (BinderT (Wise lore) (State VNameSource)))
 -> BinderT (Wise lore) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
              (Stm (Wise lore) -> Binder (Wise lore) ())
-> Stms (Wise lore) -> Binder (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BinderT (Wise lore) (State VNameSource))
-> (Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
Protect (BinderT (Wise lore) (State VNameSource))
ops Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool
forall lore. ASTLore lore => Exp lore -> Bool
unsafeOrCostly SubExp
cond') Stms (Wise lore)
stms
      else Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
    a -> Binder (Wise lore) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  where unsafeOrCostly :: Exp lore -> Bool
unsafeOrCostly Exp lore
e = Bool -> Bool
not (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp Exp lore
e) Bool -> Bool -> Bool
|| Bool -> Bool
not (Exp lore -> Bool
forall lore. ASTLore lore => Exp lore -> Bool
cheapExp Exp lore
e)

-- | We are willing to hoist potentially unsafe statements out of
-- loops, but they most be protected by adding a branch on top of
-- them.
protectLoopHoisted :: SimplifiableLore lore =>
                      [(FParam (Wise lore),SubExp)]
                   -> [(FParam (Wise lore),SubExp)]
                   -> LoopForm (Wise lore)
                   -> SimpleM lore (a, Stms (Wise lore))
                   -> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted :: [(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam (Wise lore), SubExp)]
ctx [(FParam (Wise lore), SubExp)]
val LoopForm (Wise lore)
form SimpleM lore (a, Stms (Wise lore))
m = do
  (a
x, Stms (Wise lore)
stms) <- SimpleM lore (a, Stms (Wise lore))
m
  SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
ops <- ((SimpleOps lore, Env lore)
 -> SubExp
 -> PatternT (VarWisdom, LetDec lore)
 -> OpWithWisdom (Op lore)
 -> Maybe (Binder (Wise lore) ()))
-> SimpleM
     lore
     (SubExp
      -> PatternT (VarWisdom, LetDec lore)
      -> OpWithWisdom (Op lore)
      -> Maybe (Binder (Wise lore) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
  -> SubExp
  -> PatternT (VarWisdom, LetDec lore)
  -> OpWithWisdom (Op lore)
  -> Maybe (Binder (Wise lore) ()))
 -> SimpleM
      lore
      (SubExp
       -> PatternT (VarWisdom, LetDec lore)
       -> OpWithWisdom (Op lore)
       -> Maybe (Binder (Wise lore) ())))
-> ((SimpleOps lore, Env lore)
    -> SubExp
    -> PatternT (VarWisdom, LetDec lore)
    -> OpWithWisdom (Op lore)
    -> Maybe (Binder (Wise lore) ()))
-> SimpleM
     lore
     (SubExp
      -> PatternT (VarWisdom, LetDec lore)
      -> OpWithWisdom (Op lore)
      -> Maybe (Binder (Wise lore) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall lore. SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS (SimpleOps lore
 -> SubExp
 -> PatternT (VarWisdom, LetDec lore)
 -> OpWithWisdom (Op lore)
 -> Maybe (Binder (Wise lore) ()))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
  Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore)))
-> Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
    if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise lore) -> Bool) -> Stms (Wise lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp (Wise lore) -> Bool)
-> (Stm (Wise lore) -> Exp (Wise lore)) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Wise lore)
stms
      then do SubExp
is_nonempty <- BinderT (Wise lore) (State VNameSource) SubExp
checkIfNonEmpty
              (Stm (Wise lore) -> Binder (Wise lore) ())
-> Stms (Wise lore) -> Binder (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BinderT (Wise lore) (State VNameSource))
-> (Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf SubExp
-> PatternT (VarWisdom, LetDec lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
Protect (BinderT (Wise lore) (State VNameSource))
ops (Bool -> Bool
not (Bool -> Bool)
-> (Exp (Wise lore) -> Bool) -> Exp (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp) SubExp
is_nonempty) Stms (Wise lore)
stms
      else Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
    a -> Binder (Wise lore) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  where checkIfNonEmpty :: BinderT (Wise lore) (State VNameSource) SubExp
checkIfNonEmpty =
          case LoopForm (Wise lore)
form of
            WhileLoop VName
cond
              | Just (Param (FParamInfo lore)
_, SubExp
cond_init) <-
                  ((Param (FParamInfo lore), SubExp) -> Bool)
-> [(Param (FParamInfo lore), SubExp)]
-> Maybe (Param (FParamInfo lore), SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
cond) (VName -> Bool)
-> ((Param (FParamInfo lore), SubExp) -> VName)
-> (Param (FParamInfo lore), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo lore) -> VName)
-> ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> (Param (FParamInfo lore), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst) ([(Param (FParamInfo lore), SubExp)]
 -> Maybe (Param (FParamInfo lore), SubExp))
-> [(Param (FParamInfo lore), SubExp)]
-> Maybe (Param (FParamInfo lore), SubExp)
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
[(FParam (Wise lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
[(FParam (Wise lore), SubExp)]
val ->
                    SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond_init
              | Bool
otherwise -> SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> BinderT (Wise lore) (State VNameSource) SubExp)
-> SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True -- infinite loop
            ForLoop VName
_ IntType
it SubExp
bound [(LParam (Wise lore), VName)]
_ ->
              [Char]
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"loop_nonempty" (Exp (Lore (BinderT (Wise lore) (State VNameSource)))
 -> BinderT (Wise lore) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) (IntType -> Integer -> SubExp
intConst IntType
it Integer
0) SubExp
bound

protectIf :: MonadBinder m =>
             Protect m
          -> (Exp (Lore m) -> Bool)
          -> SubExp -> Stm (Lore m) -> m ()
protectIf :: Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux
                     (If SubExp
cond BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfDec [BranchType (Lore m)]
if_ts IfSort
IfFallback))) = do
  SubExp
cond' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"protect_cond_conj" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
taken SubExp
cond
  StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond' BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
    [BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType (Lore m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux (BasicOp (Assert SubExp
cond ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc))) = do
  SubExp
not_taken <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"loop_not_taken" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
taken
  SubExp
cond' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"protect_assert_disj" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
not_taken SubExp
cond
  StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp
Assert SubExp
cond' ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc
protectIf Protect m
protect Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux (Op Op (Lore m)
op))
  | Just m ()
m <- Protect m
protect SubExp
taken Pattern (Lore m)
pat Op (Lore m)
op =
      StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux m ()
m
protectIf Protect m
_ Exp (Lore m) -> Bool
f SubExp
taken (Let Pattern (Lore m)
pat StmAux (ExpDec (Lore m))
aux Exp (Lore m)
e)
  | Exp (Lore m) -> Bool
f Exp (Lore m)
e =
      case Exp (Lore m) -> Maybe (Exp (Lore m))
forall lore. Exp lore -> Maybe (Exp lore)
makeSafe Exp (Lore m)
e of
        Just Exp (Lore m)
e' ->
          StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat Exp (Lore m)
e'
        Maybe (Exp (Lore m))
Nothing -> do
          BodyT (Lore m)
taken_body <- [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
e]
          BodyT (Lore m)
untaken_body <- [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([m (Exp (Lore m))] -> m (BodyT (Lore m)))
-> [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ (Type -> m (Exp (Lore m))) -> [Type] -> [m (Exp (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Type -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Type -> m (Exp (Lore m))
emptyOfType ([VName] -> Type -> m (Exp (Lore m)))
-> [VName] -> Type -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern (Lore m)
pat)
                                      (Pattern (Lore m) -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternValueTypes Pattern (Lore m)
pat)
          [BranchType (Lore m)]
if_ts <- Pattern (Lore m) -> m [BranchType (Lore m)]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern Pattern (Lore m)
pat
          StmAux (ExpDec (Lore m)) -> m () -> m ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec (Lore m))
aux (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
            Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
taken BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
            [BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType (Lore m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
_ Stm (Lore m)
stm =
  Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore m)
stm

makeSafe :: Exp lore -> Maybe (Exp lore)
makeSafe :: Exp lore -> Maybe (Exp lore)
makeSafe (BasicOp (BinOp (SDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SQuot IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDiv IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDiv IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UDivUp IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UDivUp IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SMod IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (SRem IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SRem IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe (BasicOp (BinOp (UMod IntType
t Safety
_) SubExp
x SubExp
y)) =
  Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just (Exp lore -> Maybe (Exp lore)) -> Exp lore -> Maybe (Exp lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
UMod IntType
t Safety
Safe) SubExp
x SubExp
y)
makeSafe Exp lore
_ =
  Maybe (Exp lore)
forall a. Maybe a
Nothing

emptyOfType :: MonadBinder m => [VName] -> Type -> m (Exp (Lore m))
emptyOfType :: [VName] -> Type -> m (Exp (Lore m))
emptyOfType [VName]
_ Mem{} =
  [Char] -> m (Exp (Lore m))
forall a. HasCallStack => [Char] -> a
error [Char]
"emptyOfType: Cannot hoist non-existential memory."
emptyOfType [VName]
_ (Prim PrimType
pt) =
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
emptyOfType [VName]
ctx_names (Array PrimType
pt Shape
shape NoUniqueness
_) = do
  let dims :: Result
dims = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
zeroIfContext (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch PrimType
pt Result
dims
  where zeroIfContext :: SubExp -> SubExp
zeroIfContext (Var VName
v) | VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ctx_names = IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
        zeroIfContext SubExp
se = SubExp
se

-- | Statements that are not worth hoisting out of loops, because they
-- are unsafe, and added safety (by 'protectLoopHoisted') may inhibit
-- further optimisation..
notWorthHoisting :: ASTLore lore => BlockPred lore
notWorthHoisting :: BlockPred lore
notWorthHoisting SymbolTable lore
_ UsageTable
_ (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
  Bool -> Bool
not (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp Exp lore
e) Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) (Int -> Bool) -> (Type -> Int) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) (Pattern lore -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes Pattern lore
pat)

hoistStms :: SimplifiableLore lore =>
             RuleBook (Wise lore) -> BlockPred (Wise lore)
          -> ST.SymbolTable (Wise lore) -> UT.UsageTable
          -> Stms (Wise lore)
          -> SimpleM lore (Stms (Wise lore),
                           Stms (Wise lore))
hoistStms :: RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
uses Stms (Wise lore)
orig_stms = do
  ([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted) <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
simplifyStmsBottomUp SymbolTable (Wise lore)
vtable UsageTable
uses Stms (Wise lore)
orig_stms
  Bool -> SimpleM lore () -> SimpleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Stm (Wise lore)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Stm (Wise lore)]
hoisted) SimpleM lore ()
forall lore. SimpleM lore ()
changed
  (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise lore)] -> Stms (Wise lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Wise lore)]
blocked, [Stm (Wise lore)] -> Stms (Wise lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Wise lore)]
hoisted)
  where simplifyStmsBottomUp :: SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
simplifyStmsBottomUp SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms = do
          (UsageTable
_, [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms') <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms
          -- We need to do a final pass to ensure that nothing is
          -- hoisted past something that it depends on.
          let ([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted) = [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Stm (Wise lore)) (Stm (Wise lore))]
 -> ([Stm (Wise lore)], [Stm (Wise lore)]))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)])
forall a b. (a -> b) -> a -> b
$ [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall lore.
ASTLore lore =>
[Either (Stm lore) (Stm lore)] -> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms'
          ([Stm (Wise lore)], [Stm (Wise lore)])
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted)

        simplifyStmsBottomUp' :: SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms = do
          OpWithWisdom (Op lore) -> UsageTable
opUsage <- ((SimpleOps lore, Env lore)
 -> OpWithWisdom (Op lore) -> UsageTable)
-> SimpleM lore (OpWithWisdom (Op lore) -> UsageTable)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
  -> OpWithWisdom (Op lore) -> UsageTable)
 -> SimpleM lore (OpWithWisdom (Op lore) -> UsageTable))
-> ((SimpleOps lore, Env lore)
    -> OpWithWisdom (Op lore) -> UsageTable)
-> SimpleM lore (OpWithWisdom (Op lore) -> UsageTable)
forall a b. (a -> b) -> a -> b
$ SimpleOps lore -> OpWithWisdom (Op lore) -> UsageTable
forall lore. SimpleOps lore -> Op (Wise lore) -> UsageTable
opUsageS (SimpleOps lore -> OpWithWisdom (Op lore) -> UsageTable)
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> OpWithWisdom (Op lore)
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
          let usageInStm :: Stm (Wise lore) -> UsageTable
usageInStm Stm (Wise lore)
stm =
                Stm (Wise lore) -> UsageTable
forall lore. (ASTLore lore, Aliased lore) => Stm lore -> UsageTable
UT.usageInStm Stm (Wise lore)
stm UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<>
                case Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp Stm (Wise lore)
stm of
                  Op Op (Wise lore)
op -> OpWithWisdom (Op lore) -> UsageTable
opUsage Op (Wise lore)
OpWithWisdom (Op lore)
op
                  Exp (Wise lore)
_ -> UsageTable
forall a. Monoid a => a
mempty
          ((UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
 -> (Stm (Wise lore), SymbolTable (Wise lore))
 -> SimpleM
      lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))]))
-> (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((Stm (Wise lore) -> UsageTable)
-> (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
hoistable Stm (Wise lore) -> UsageTable
usageInStm) (UsageTable
uses',[]) ([(Stm (Wise lore), SymbolTable (Wise lore))]
 -> SimpleM
      lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))]))
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall a b. (a -> b) -> a -> b
$ [(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a. [a] -> [a]
reverse ([(Stm (Wise lore), SymbolTable (Wise lore))]
 -> [(Stm (Wise lore), SymbolTable (Wise lore))])
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a b. (a -> b) -> a -> b
$ [Stm (Wise lore)]
-> [SymbolTable (Wise lore)]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Stms (Wise lore) -> [Stm (Wise lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Wise lore)
stms) [SymbolTable (Wise lore)]
vtables
            where vtables :: [SymbolTable (Wise lore)]
vtables = (SymbolTable (Wise lore)
 -> Stm (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Stm (Wise lore)]
-> [SymbolTable (Wise lore)]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl ((Stm (Wise lore)
 -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> Stm (Wise lore)
-> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm) SymbolTable (Wise lore)
vtable' ([Stm (Wise lore)] -> [SymbolTable (Wise lore)])
-> [Stm (Wise lore)] -> [SymbolTable (Wise lore)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise lore) -> [Stm (Wise lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Wise lore)
stms

        hoistable :: (Stm (Wise lore) -> UsageTable)
-> (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
hoistable Stm (Wise lore) -> UsageTable
usageInStm (UsageTable
uses',[Either (Stm (Wise lore)) (Stm (Wise lore))]
stms) (Stm (Wise lore)
stm, SymbolTable (Wise lore)
vtable')
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
uses') ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Stm (Wise lore) -> [VName]
forall lore. Stm lore -> [VName]
provides Stm (Wise lore)
stm = -- Dead statement.
            (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
          | Bool
otherwise = do
            Maybe (Stms (Wise lore))
res <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. a -> b -> a
const SymbolTable (Wise lore)
vtable') (SimpleM lore (Maybe (Stms (Wise lore)))
 -> SimpleM lore (Maybe (Stms (Wise lore))))
-> SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall a b. (a -> b) -> a -> b
$
                   RuleBook (Wise lore)
-> (SymbolTable (Wise lore), UsageTable)
-> Stm (Wise lore)
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> (SymbolTable lore, UsageTable)
-> Stm lore
-> m (Maybe (Stms lore))
bottomUpSimplifyStm RuleBook (Wise lore)
rules (SymbolTable (Wise lore)
vtable', UsageTable
uses') Stm (Wise lore)
stm
            case Maybe (Stms (Wise lore))
res of
              Maybe (Stms (Wise lore))
Nothing -- Nothing to optimise - see if hoistable.
                | BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm ->
                    (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return ((Stm (Wise lore) -> UsageTable)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stm (Wise lore)
-> UsageTable
forall lore.
(ASTLore lore, Aliased lore) =>
(Stm lore -> UsageTable)
-> SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage Stm (Wise lore) -> UsageTable
usageInStm SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm
                            UsageTable -> [VName] -> UsageTable
`UT.without` Stm (Wise lore) -> [VName]
forall lore. Stm lore -> [VName]
provides Stm (Wise lore)
stm,
                            Stm (Wise lore) -> Either (Stm (Wise lore)) (Stm (Wise lore))
forall a b. a -> Either a b
Left Stm (Wise lore)
stm Either (Stm (Wise lore)) (Stm (Wise lore))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
                | Bool
otherwise ->
                    (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return ((Stm (Wise lore) -> UsageTable)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stm (Wise lore)
-> UsageTable
forall lore.
(ASTLore lore, Aliased lore) =>
(Stm lore -> UsageTable)
-> SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage Stm (Wise lore) -> UsageTable
usageInStm SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm,
                            Stm (Wise lore) -> Either (Stm (Wise lore)) (Stm (Wise lore))
forall a b. b -> Either a b
Right Stm (Wise lore)
stm Either (Stm (Wise lore)) (Stm (Wise lore))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
              Just Stms (Wise lore)
optimstms -> do
                SimpleM lore ()
forall lore. SimpleM lore ()
changed
                (UsageTable
uses'',[Either (Stm (Wise lore)) (Stm (Wise lore))]
stms') <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
optimstms
                (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
     lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses'', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms'[Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. [a] -> [a] -> [a]
++[Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)

blockUnhoistedDeps :: ASTLore lore =>
                      [Either (Stm lore) (Stm lore)]
                   -> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps :: [Either (Stm lore) (Stm lore)] -> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps = (Names, [Either (Stm lore) (Stm lore)])
-> [Either (Stm lore) (Stm lore)]
forall a b. (a, b) -> b
snd ((Names, [Either (Stm lore) (Stm lore)])
 -> [Either (Stm lore) (Stm lore)])
-> ([Either (Stm lore) (Stm lore)]
    -> (Names, [Either (Stm lore) (Stm lore)]))
-> [Either (Stm lore) (Stm lore)]
-> [Either (Stm lore) (Stm lore)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Names
 -> Either (Stm lore) (Stm lore)
 -> (Names, Either (Stm lore) (Stm lore)))
-> Names
-> [Either (Stm lore) (Stm lore)]
-> (Names, [Either (Stm lore) (Stm lore)])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
forall lore.
(FreeDec (ExpDec lore), FreeDec (BodyDec lore),
 FreeIn (FParamInfo lore), FreeIn (LParamInfo lore),
 FreeIn (LetDec lore), FreeIn (Op lore)) =>
Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
block Names
forall a. Monoid a => a
mempty
  where block :: Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
block Names
blocked (Left Stm lore
need) =
          (Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm lore -> [VName]
forall lore. Stm lore -> [VName]
provides Stm lore
need), Stm lore -> Either (Stm lore) (Stm lore)
forall a b. a -> Either a b
Left Stm lore
need)
        block Names
blocked (Right Stm lore
need)
          | Names
blocked Names -> Names -> Bool
`namesIntersect` Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm lore
need =
            (Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm lore -> [VName]
forall lore. Stm lore -> [VName]
provides Stm lore
need), Stm lore -> Either (Stm lore) (Stm lore)
forall a b. a -> Either a b
Left Stm lore
need)
          | Bool
otherwise =
            (Names
blocked, Stm lore -> Either (Stm lore) (Stm lore)
forall a b. b -> Either a b
Right Stm lore
need)

provides :: Stm lore -> [VName]
provides :: Stm lore -> [VName]
provides = PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern

expandUsage :: (ASTLore lore, Aliased lore) =>
               (Stm lore -> UT.UsageTable) -> ST.SymbolTable lore -> UT.UsageTable
            -> Stm lore -> UT.UsageTable
expandUsage :: (Stm lore -> UsageTable)
-> SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage Stm lore -> UsageTable
usageInStm SymbolTable lore
vtable UsageTable
utable stm :: Stm lore
stm@(Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
  (VName -> Names) -> UsageTable -> UsageTable
UT.expand (VName -> SymbolTable lore -> Names
forall lore. VName -> SymbolTable lore -> Names
`ST.lookupAliases` SymbolTable lore
vtable) (Stm lore -> UsageTable
usageInStm Stm lore
stm UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
usageThroughAliases) UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<>
  (if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
utable) (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat)
   then Names -> UsageTable
UT.sizeUsages (Exp lore -> Names
forall a. FreeIn a => a -> Names
freeIn Exp lore
e)
   else UsageTable
forall a. Monoid a => a
mempty) UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<>
  UsageTable
utable
  where usageThroughAliases :: UsageTable
usageThroughAliases =
          [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ ((VName, Names) -> Maybe UsageTable)
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases ([(VName, Names)] -> [UsageTable])
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$
          [VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Pattern lore -> [Names]
forall dec. AliasesOf dec => PatternT dec -> [Names]
patternAliases Pattern lore
pat)
        usageThroughBindeeAliases :: (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases (VName
name, Names
aliases) = do
          Usages
uses <- VName -> UsageTable -> Maybe Usages
UT.lookup VName
name UsageTable
utable
          UsageTable -> Maybe UsageTable
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable -> Maybe UsageTable) -> UsageTable -> Maybe UsageTable
forall a b. (a -> b) -> a -> b
$ [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Usages -> UsageTable
`UT.usage` Usages
uses) ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases

type BlockPred lore = ST.SymbolTable lore -> UT.UsageTable -> Stm lore -> Bool

neverBlocks :: BlockPred lore
neverBlocks :: BlockPred lore
neverBlocks SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False

alwaysBlocks :: BlockPred lore
alwaysBlocks :: BlockPred lore
alwaysBlocks SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
True

isFalse :: Bool -> BlockPred lore
isFalse :: Bool -> BlockPred lore
isFalse Bool
b SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool -> Bool
not Bool
b

orIf :: BlockPred lore -> BlockPred lore -> BlockPred lore
orIf :: BlockPred lore -> BlockPred lore -> BlockPred lore
orIf BlockPred lore
p1 BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need = BlockPred lore
p1 SymbolTable lore
body UsageTable
vtable Stm lore
need Bool -> Bool -> Bool
|| BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need

andAlso :: BlockPred lore -> BlockPred lore -> BlockPred lore
andAlso :: BlockPred lore -> BlockPred lore -> BlockPred lore
andAlso BlockPred lore
p1 BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need = BlockPred lore
p1 SymbolTable lore
body UsageTable
vtable Stm lore
need Bool -> Bool -> Bool
&& BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need

isConsumed :: BlockPred lore
isConsumed :: BlockPred lore
isConsumed SymbolTable lore
_ UsageTable
utable = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
utable) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern

isOp :: BlockPred lore
isOp :: BlockPred lore
isOp SymbolTable lore
_ UsageTable
_ (Let Pattern lore
_ StmAux (ExpDec lore)
_ Op{}) = Bool
True
isOp SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False

constructBody :: SimplifiableLore lore => Stms (Wise lore) -> Result
              -> SimpleM lore (Body (Wise lore))
constructBody :: Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
stms Result
res =
  ((Body (Wise lore), Stms (Wise lore)) -> Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body (Wise lore), Stms (Wise lore)) -> Body (Wise lore)
forall a b. (a, b) -> a
fst (SimpleM lore (Body (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) (Body (Wise lore))
 -> SimpleM lore (Body (Wise lore), Stms (Wise lore)))
-> Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ BinderT
  (Wise lore)
  (State VNameSource)
  (Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
     (Wise lore)
     (State VNameSource)
     (Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
   (Wise lore)
   (State VNameSource)
   (Body (Lore (BinderT (Wise lore) (State VNameSource))))
 -> BinderT
      (Wise lore)
      (State VNameSource)
      (Body (Lore (BinderT (Wise lore) (State VNameSource)))))
-> BinderT
     (Wise lore)
     (State VNameSource)
     (Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
     (Wise lore)
     (State VNameSource)
     (Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ do Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
                                          Result
-> BinderT
     (Wise lore)
     (State VNameSource)
     (Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
res

type SimplifiedBody lore a = ((a, UT.UsageTable), Stms (Wise lore))

blockIf :: SimplifiableLore lore =>
           BlockPred (Wise lore)
        -> SimpleM lore (SimplifiedBody lore a)
        -> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf :: BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf BlockPred (Wise lore)
block SimpleM lore (SimplifiedBody lore a)
m = do
  ((a
x, UsageTable
usages), Stms (Wise lore)
stms) <- SimpleM lore (SimplifiedBody lore a)
m
  SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
  RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
  (Stms (Wise lore)
blocked, Stms (Wise lore)
hoisted) <- RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages Stms (Wise lore)
stms
  ((Stms (Wise lore), a), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Stms (Wise lore)
blocked, a
x), Stms (Wise lore)
hoisted)

hasFree :: ASTLore lore => Names -> BlockPred lore
hasFree :: Names -> BlockPred lore
hasFree Names
ks SymbolTable lore
_ UsageTable
_ Stm lore
need = Names
ks Names -> Names -> Bool
`namesIntersect` Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm lore
need

isNotSafe :: ASTLore lore => BlockPred lore
isNotSafe :: BlockPred lore
isNotSafe SymbolTable lore
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm lore -> Bool) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp

isInPlaceBound :: BlockPred m
isInPlaceBound :: BlockPred m
isInPlaceBound SymbolTable m
_ UsageTable
_ = ExpT m -> Bool
forall lore. ExpT lore -> Bool
isUpdate (ExpT m -> Bool) -> (Stm m -> ExpT m) -> Stm m -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm m -> ExpT m
forall lore. Stm lore -> Exp lore
stmExp
  where isUpdate :: ExpT lore -> Bool
isUpdate (BasicOp Update{}) = Bool
True
        isUpdate ExpT lore
_ = Bool
False

isNotCheap :: ASTLore lore => BlockPred lore
isNotCheap :: BlockPred lore
isNotCheap SymbolTable lore
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm lore -> Bool) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Bool
forall lore. ASTLore lore => Stm lore -> Bool
cheapStm

cheapStm :: ASTLore lore => Stm lore -> Bool
cheapStm :: Stm lore -> Bool
cheapStm = Exp lore -> Bool
forall lore. ASTLore lore => Exp lore -> Bool
cheapExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp

cheapExp :: ASTLore lore => Exp lore -> Bool
cheapExp :: Exp lore -> Bool
cheapExp (BasicOp BinOp{})        = Bool
True
cheapExp (BasicOp SubExp{})       = Bool
True
cheapExp (BasicOp UnOp{})         = Bool
True
cheapExp (BasicOp CmpOp{})        = Bool
True
cheapExp (BasicOp ConvOp{})       = Bool
True
cheapExp (BasicOp Copy{})         = Bool
False
cheapExp (BasicOp Manifest{})     = Bool
False
cheapExp DoLoop{}                 = Bool
False
cheapExp (If SubExp
_ BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) = (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm lore -> Bool
forall lore. ASTLore lore => Stm lore -> Bool
cheapStm (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch) Bool -> Bool -> Bool
&&
                                    (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm lore -> Bool
forall lore. ASTLore lore => Stm lore -> Bool
cheapStm (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fbranch)
cheapExp (Op Op lore
op)                  = Op lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp Op lore
op
cheapExp Exp lore
_                        = Bool
True -- Used to be False, but
                                         -- let's try it out.

stmIs :: (Stm lore -> Bool) -> BlockPred lore
stmIs :: (Stm lore -> Bool) -> BlockPred lore
stmIs Stm lore -> Bool
f SymbolTable lore
_ UsageTable
_ = Stm lore -> Bool
f

loopInvariantStm :: ASTLore lore => ST.SymbolTable lore -> Stm lore -> Bool
loopInvariantStm :: SymbolTable lore -> Stm lore -> Bool
loopInvariantStm SymbolTable lore
vtable =
  (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable lore -> Names
forall lore. SymbolTable lore -> Names
ST.availableAtClosestLoop SymbolTable lore
vtable) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (Stm lore -> Names) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn

hoistCommon :: SimplifiableLore lore =>
               SubExp -> IfSort
            -> SimplifiedBody lore Result
            -> SimplifiedBody lore Result
            -> SimpleM lore (Body (Wise lore),
                             Body (Wise lore),
                             Stms (Wise lore))
hoistCommon :: SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
     lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
hoistCommon SubExp
cond IfSort
ifsort ((Result
res1, UsageTable
usages1), Stms (Wise lore)
stms1) ((Result
res2, UsageTable
usages2), Stms (Wise lore)
stms2) = do
  Stm (Wise lore) -> Bool
is_alloc_fun <- (Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool)
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> Stm (Wise lore) -> Bool)
 -> SimpleM lore (Stm (Wise lore) -> Bool))
-> (Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool)
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> Stm (Wise lore) -> Bool
forall lore. HoistBlockers lore -> Stm (Wise lore) -> Bool
isAllocation  (HoistBlockers lore -> Stm (Wise lore) -> Bool)
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> Stm (Wise lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
  BlockPred (Wise lore)
branch_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
 -> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistBranch (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
  SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
  let -- We are unwilling to hoist things that are unsafe or costly,
      -- *except* if they are invariant to the most enclosing loop,
      -- because in that case they will also be hoisted past that
      -- loop.
      --
      -- We also try very hard to hoist allocations or anything that
      -- contributes to memory or array size, because that will allow
      -- allocations to be hoisted.
      cond_loop_invariant :: Bool
cond_loop_invariant =
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable (Wise lore) -> Names
forall lore. SymbolTable lore -> Names
ST.availableAtClosestLoop SymbolTable (Wise lore)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond

      desirableToHoist :: Stm (Wise lore) -> Bool
desirableToHoist Stm (Wise lore)
stm =
          Stm (Wise lore) -> Bool
is_alloc_fun Stm (Wise lore)
stm Bool -> Bool -> Bool
||
          (SymbolTable (Wise lore) -> Int
forall lore. SymbolTable lore -> Int
ST.loopDepth SymbolTable (Wise lore)
vtable Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&&
           Bool
cond_loop_invariant Bool -> Bool -> Bool
&&
           IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
&&
           SymbolTable (Wise lore) -> Stm (Wise lore) -> Bool
forall lore. ASTLore lore => SymbolTable lore -> Stm lore -> Bool
loopInvariantStm SymbolTable (Wise lore)
vtable Stm (Wise lore)
stm)

      -- No matter what, we always want to hoist constants as much as
      -- possible.
      isNotHoistableBnd :: BlockPred (Wise lore)
isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
_ (Let Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ (BasicOp ArrayLit{})) = Bool
False
      isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
_ (Let Pattern (Wise lore)
_ StmAux (ExpDec (Wise lore))
_ (BasicOp SubExp{})) = Bool
False
      isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
usages (Let Pattern (Wise lore)
pat StmAux (ExpDec (Wise lore))
_ ExpT (Wise lore)
_)
        | (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isSize` UsageTable
usages) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatternT (VarWisdom, LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (VarWisdom, LetDec lore)
Pattern (Wise lore)
pat =
            Bool
False
      isNotHoistableBnd SymbolTable (Wise lore)
_ UsageTable
_ Stm (Wise lore)
_ =
        -- Hoist aggressively out of versioning branches.
        IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfEquiv

      block :: BlockPred (Wise lore)
block = BlockPred (Wise lore)
branch_blocker BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf`
              ((BlockPred (Wise lore)
forall lore. ASTLore lore => BlockPred lore
isNotSafe BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. ASTLore lore => BlockPred lore
isNotCheap) BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`andAlso` (Stm (Wise lore) -> Bool) -> BlockPred (Wise lore)
forall lore. (Stm lore -> Bool) -> BlockPred lore
stmIs (Bool -> Bool
not (Bool -> Bool)
-> (Stm (Wise lore) -> Bool) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Bool
desirableToHoist))
              BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isInPlaceBound BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
isNotHoistableBnd

  RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
  (Stms (Wise lore)
body1_bnds', Stms (Wise lore)
safe1) <- SubExp
-> Bool
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
True (SimpleM lore (Stms (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Stms (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
                          RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages1 Stms (Wise lore)
stms1
  (Stms (Wise lore)
body2_bnds', Stms (Wise lore)
safe2) <- SubExp
-> Bool
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
False (SimpleM lore (Stms (Wise lore), Stms (Wise lore))
 -> SimpleM lore (Stms (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
                          RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages2 Stms (Wise lore)
stms2
  let hoistable :: Stms (Wise lore)
hoistable = Stms (Wise lore)
safe1 Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
safe2
  Body (Wise lore)
body1' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
body1_bnds' Result
res1
  Body (Wise lore)
body2' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
body2_bnds' Result
res2
  (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
-> SimpleM
     lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise lore)
body1', Body (Wise lore)
body2', Stms (Wise lore)
hoistable)

-- | Simplify a single body.  The @[Diet]@ only covers the value
-- elements, because the context cannot be consumed.
simplifyBody :: SimplifiableLore lore =>
                [Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody :: [Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds (Body BodyDec lore
_ Stms lore
bnds Result
res) =
  Stms lore
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
bnds (SimpleM lore (SimplifiedBody lore Result)
 -> SimpleM lore (SimplifiedBody lore Result))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall a b. (a -> b) -> a -> b
$ do (Result, UsageTable)
res' <- [Diet] -> Result -> SimpleM lore (Result, UsageTable)
forall lore.
SimplifiableLore lore =>
[Diet] -> Result -> SimpleM lore (Result, UsageTable)
simplifyResult [Diet]
ds Result
res
                         SimplifiedBody lore Result
-> SimpleM lore (SimplifiedBody lore Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result, UsageTable)
res', Stms (Wise lore)
forall a. Monoid a => a
mempty)

-- | Simplify a single 'Result'.  The @[Diet]@ only covers the value
-- elements, because the context cannot be consumed.
simplifyResult :: SimplifiableLore lore =>
                  [Diet] -> Result -> SimpleM lore (Result, UT.UsageTable)
simplifyResult :: [Diet] -> Result -> SimpleM lore (Result, UsageTable)
simplifyResult [Diet]
ds Result
res = do
  let (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([Diet] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Diet]
ds) Result
res
  -- Copy propagation is a little trickier here, because there is no
  -- place to put the certificates when copy-propagating a certified
  -- statement.  However, for results in the *context*, it is OK to
  -- just throw away the certificates, because for the program to be
  -- type-correct, those statements must anyway be used (or
  -- copy-propagated into) the statements producing the value result.
  (Result
ctx_res', Certificates
_ctx_res_cs) <- SimpleM lore Result -> SimpleM lore (Result, Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore Result -> SimpleM lore (Result, Certificates))
-> SimpleM lore Result -> SimpleM lore (Result, Certificates)
forall a b. (a -> b) -> a -> b
$ (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ctx_res
  Result
val_res' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall lore. SubExp -> SimpleM lore SubExp
simplify' Result
val_res

  let consumption :: UsageTable
consumption = [(Diet, SubExp)] -> UsageTable
consumeResult ([(Diet, SubExp)] -> UsageTable) -> [(Diet, SubExp)] -> UsageTable
forall a b. (a -> b) -> a -> b
$ [Diet] -> Result -> [(Diet, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Diet]
ds Result
val_res'
      res' :: Result
res' = Result
ctx_res' Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res'
  (Result, UsageTable) -> SimpleM lore (Result, UsageTable)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
res', Names -> UsageTable
UT.usages (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res') UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
consumption)

  where simplify' :: SubExp -> SimpleM lore SubExp
simplify' (Var VName
name) = do
          Maybe (SubExp, Certificates)
bnd <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
name (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
          case Maybe (SubExp, Certificates)
bnd of
            Just (Constant PrimValue
v, Certificates
cs)
              | Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
            Just (Var VName
id', Certificates
cs)
              | Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
            Maybe (SubExp, Certificates)
_                -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
        simplify' (Constant PrimValue
v) =
          SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v

isDoLoopResult :: Result -> UT.UsageTable
isDoLoopResult :: Result -> UsageTable
isDoLoopResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> (Result -> [UsageTable]) -> Result -> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> UsageTable) -> Result -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> UsageTable
checkForVar
  where checkForVar :: SubExp -> UsageTable
checkForVar (Var VName
ident) = VName -> UsageTable
UT.inResultUsage VName
ident
        checkForVar SubExp
_           = UsageTable
forall a. Monoid a => a
mempty

simplifyStms :: SimplifiableLore lore =>
                Stms lore -> SimpleM lore (a, Stms (Wise lore))
             -> SimpleM lore (a, Stms (Wise lore))
simplifyStms :: Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
stms SimpleM lore (a, Stms (Wise lore))
m =
  case Stms lore -> Maybe (Stm lore, Stms lore)
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms lore
stms of
    Maybe (Stm lore, Stms lore)
Nothing -> Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
forall a. Monoid a => a
mempty SimpleM lore (a, Stms (Wise lore))
m
    Just (Let Pattern lore
pat (StmAux Certificates
stm_cs Attrs
attrs ExpDec lore
dec) Exp lore
e, Stms lore
stms') -> do
      Certificates
stm_cs' <- Certificates -> SimpleM lore Certificates
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Certificates
stm_cs
      ((Exp (Wise lore)
e', Stms (Wise lore)
e_stms), Certificates
e_cs) <- SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore ((Exp (Wise lore), Stms (Wise lore)), Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore (Exp (Wise lore), Stms (Wise lore))
 -> SimpleM
      lore ((Exp (Wise lore), Stms (Wise lore)), Certificates))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore ((Exp (Wise lore), Stms (Wise lore)), Certificates)
forall a b. (a -> b) -> a -> b
$ Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp Exp lore
e
      (Pattern lore
pat', Certificates
pat_cs) <- SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore (Pattern lore)
 -> SimpleM lore (Pattern lore, Certificates))
-> SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> SimpleM lore (Pattern lore)
forall lore dec.
(SimplifiableLore lore, Simplifiable dec) =>
PatternT dec -> SimpleM lore (PatternT dec)
simplifyPattern Pattern lore
pat
      let cs :: Certificates
cs = Certificates
stm_cs'Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
e_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
pat_cs
      Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
e_stms (SimpleM lore (a, Stms (Wise lore))
 -> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
        Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStm (Pattern lore
-> StmAux (ExpDec lore) -> Exp (Wise lore) -> Stm (Wise lore)
forall lore.
(ASTLore lore, CanBeWise (Op lore)) =>
Pattern lore
-> StmAux (ExpDec lore) -> Exp (Wise lore) -> Stm (Wise lore)
mkWiseLetStm Pattern lore
pat' (Certificates -> Attrs -> ExpDec lore -> StmAux (ExpDec lore)
forall dec. Certificates -> Attrs -> dec -> StmAux dec
StmAux Certificates
cs Attrs
attrs ExpDec lore
dec) Exp (Wise lore)
e') (SimpleM lore (a, Stms (Wise lore))
 -> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
        Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
stms' SimpleM lore (a, Stms (Wise lore))
m

inspectStm :: SimplifiableLore lore =>
              Stm (Wise lore) -> SimpleM lore (a, Stms (Wise lore))
           -> SimpleM lore (a, Stms (Wise lore))
inspectStm :: Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStm = Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms (Stms (Wise lore)
 -> SimpleM lore (a, Stms (Wise lore))
 -> SimpleM lore (a, Stms (Wise lore)))
-> (Stm (Wise lore) -> Stms (Wise lore))
-> Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Stms (Wise lore)
forall lore. Stm lore -> Stms lore
oneStm

inspectStms :: SimplifiableLore lore =>
               Stms (Wise lore)
            -> SimpleM lore (a, Stms (Wise lore))
            -> SimpleM lore (a, Stms (Wise lore))
inspectStms :: Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
stms SimpleM lore (a, Stms (Wise lore))
m =
  case Stms (Wise lore) -> Maybe (Stm (Wise lore), Stms (Wise lore))
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms (Wise lore)
stms of
    Maybe (Stm (Wise lore), Stms (Wise lore))
Nothing -> SimpleM lore (a, Stms (Wise lore))
m
    Just (Stm (Wise lore)
stm, Stms (Wise lore)
stms') -> do
      SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
      RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
      Maybe (Stms (Wise lore))
simplified <- RuleBook (Wise lore)
-> SymbolTable (Wise lore)
-> Stm (Wise lore)
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> SymbolTable lore -> Stm lore -> m (Maybe (Stms lore))
topDownSimplifyStm RuleBook (Wise lore)
rules SymbolTable (Wise lore)
vtable Stm (Wise lore)
stm
      case Maybe (Stms (Wise lore))
simplified of
        Just Stms (Wise lore)
newbnds -> SimpleM lore ()
forall lore. SimpleM lore ()
changed SimpleM lore ()
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms (Stms (Wise lore)
newbnds Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
stms') SimpleM lore (a, Stms (Wise lore))
m
        Maybe (Stms (Wise lore))
Nothing      -> do (a
x, Stms (Wise lore)
stms'') <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm Stm (Wise lore)
stm) (SimpleM lore (a, Stms (Wise lore))
 -> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
stms' SimpleM lore (a, Stms (Wise lore))
m
                           (a, Stms (Wise lore)) -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stm (Wise lore) -> Stms (Wise lore)
forall lore. Stm lore -> Stms lore
oneStm Stm (Wise lore)
stm Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
stms'')

simplifyOp :: Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp :: Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp Op lore
op = do Op lore -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
f <- ((SimpleOps lore, Env lore)
 -> Op lore
 -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
     lore
     (Op lore
      -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
  -> Op lore
  -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
 -> SimpleM
      lore
      (Op lore
       -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))))
-> ((SimpleOps lore, Env lore)
    -> Op lore
    -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
     lore
     (Op lore
      -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
forall lore. SimpleOps lore -> SimplifyOp lore (Op lore)
simplifyOpS (SimpleOps lore
 -> Op lore
 -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
                   Op lore -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
f Op lore
op

simplifyExp :: SimplifiableLore lore =>
               Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))

simplifyExp :: Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch (IfDec [BranchType lore]
ts IfSort
ifsort)) = do
  -- Here, we have to check whether 'cond' puts a bound on some free
  -- variable, and if so, chomp it.  We should also try to do CSE
  -- across branches.
  SubExp
cond' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
cond
  [BranchType lore]
ts' <- (BranchType lore -> SimpleM lore (BranchType lore))
-> [BranchType lore] -> SimpleM lore [BranchType lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BranchType lore -> SimpleM lore (BranchType lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [BranchType lore]
ts
  -- FIXME: we have to be conservative about the diet here, because we
  -- lack proper ifnormation.  Something is wrong with the order in
  -- which the simplifier does things - it should be purely bottom-up
  -- (or else, If expressions should indicate explicitly the diet of
  -- their return types).
  let ds :: [Diet]
ds = (BranchType lore -> Diet) -> [BranchType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> BranchType lore -> Diet
forall a b. a -> b -> a
const Diet
Consume) [BranchType lore]
ts
  SimplifiedBody lore Result
tbranch' <- [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
tbranch
  SimplifiedBody lore Result
fbranch' <- [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
fbranch
  (Body (Wise lore)
tbranch'',Body (Wise lore)
fbranch'', Stms (Wise lore)
hoisted) <- SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
     lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
     lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
hoistCommon SubExp
cond' IfSort
ifsort SimplifiedBody lore Result
tbranch' SimplifiedBody lore Result
fbranch'
  (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> Body (Wise lore)
-> Body (Wise lore)
-> IfDec (BranchType (Wise lore))
-> Exp (Wise lore)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond' Body (Wise lore)
tbranch'' Body (Wise lore)
fbranch'' (IfDec (BranchType (Wise lore)) -> Exp (Wise lore))
-> IfDec (BranchType (Wise lore)) -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ts' IfSort
ifsort, Stms (Wise lore)
hoisted)

simplifyExp (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
loopbody) = do
  let ([FParam lore]
ctxparams, Result
ctxinit) = [(FParam lore, SubExp)] -> ([FParam lore], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
ctx
      ([FParam lore]
valparams, Result
valinit) = [(FParam lore, SubExp)] -> ([FParam lore], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
val
  [FParam lore]
ctxparams' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo lore -> SimpleM lore (FParamInfo lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo lore -> SimpleM lore (FParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
ctxparams
  Result
ctxinit' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ctxinit
  [FParam lore]
valparams' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo lore -> SimpleM lore (FParamInfo lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo lore -> SimpleM lore (FParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
valparams
  Result
valinit' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
valinit
  let ctx' :: [(FParam lore, SubExp)]
ctx' = [FParam lore] -> Result -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam lore]
ctxparams' Result
ctxinit'
      val' :: [(FParam lore, SubExp)]
val' = [FParam lore] -> Result -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam lore]
valparams' Result
valinit'
      diets :: [Diet]
diets = (FParam lore -> Diet) -> [FParam lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase Shape Uniqueness -> Diet)
-> (FParam lore -> TypeBase Shape Uniqueness)
-> FParam lore
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType) [FParam lore]
valparams'
  (LoopForm (Wise lore)
form', Names
boundnames, SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
wrapbody) <- case LoopForm lore
form of
    ForLoop VName
loopvar IntType
it SubExp
boundexp [(LParam lore, VName)]
loopvars -> do
      SubExp
boundexp' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
boundexp
      let ([LParam lore]
loop_params, [VName]
loop_arrs) = [(LParam lore, VName)] -> ([LParam lore], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(LParam lore, VName)]
loopvars
      [LParam lore]
loop_params' <- (LParam lore -> SimpleM lore (LParam lore))
-> [LParam lore] -> SimpleM lore [LParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamInfo lore -> SimpleM lore (LParamInfo lore))
-> LParam lore -> SimpleM lore (LParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LParamInfo lore -> SimpleM lore (LParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [LParam lore]
loop_params
      [VName]
loop_arrs' <- (VName -> SimpleM lore VName) -> [VName] -> SimpleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [VName]
loop_arrs
      let form' :: LoopForm (Wise lore)
form' = VName
-> IntType
-> SubExp
-> [(LParam (Wise lore), VName)]
-> LoopForm (Wise lore)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
loopvar IntType
it SubExp
boundexp' ([LParam lore] -> [VName] -> [(LParam lore, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam lore]
loop_params' [VName]
loop_arrs')
      (LoopForm (Wise lore), Names,
 SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM
     lore
     (LoopForm (Wise lore), Names,
      SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
      -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopForm (Wise lore)
form',
              [VName] -> Names
namesFromList (VName
loopvar VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: (LParam lore -> VName) -> [LParam lore] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam lore -> VName
forall dec. Param dec -> VName
paramName [LParam lore]
loop_params') Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
fparamnames,
              VName
-> IntType
-> SubExp
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar VName
loopvar IntType
it SubExp
boundexp' (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
    -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
              [(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' LoopForm (Wise lore)
form' (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
    -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
              [LParam (Wise lore)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindArrayLParams [LParam lore]
[LParam (Wise lore)]
loop_params')
    WhileLoop VName
cond -> do
      VName
cond' <- VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify VName
cond
      (LoopForm (Wise lore), Names,
 SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM
     lore
     (LoopForm (Wise lore), Names,
      SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
      -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> LoopForm (Wise lore)
forall lore. VName -> LoopForm lore
WhileLoop VName
cond',
              Names
fparamnames,
              [(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' (VName -> LoopForm (Wise lore)
forall lore. VName -> LoopForm lore
WhileLoop VName
cond'))
  BlockPred (Wise lore)
seq_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
 -> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistSeq (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
  ((Stms (Wise lore)
loopstms, Result
loopres), Stms (Wise lore)
hoisted) <-
    SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
enterLoop (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
consumeMerge (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    [FParam (Wise lore)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams ([FParam lore]
ctxparams'[FParam lore] -> [FParam lore] -> [FParam lore]
forall a. [a] -> [a] -> [a]
++[FParam lore]
valparams') (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
wrapbody (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf
    (Names -> BlockPred (Wise lore)
forall lore. ASTLore lore => Names -> BlockPred lore
hasFree Names
boundnames BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isConsumed
     BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
seq_blocker BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. ASTLore lore => BlockPred lore
notWorthHoisting) (SimpleM lore (SimplifiedBody lore Result)
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
      ((Result
res, UsageTable
uses), Stms (Wise lore)
stms) <- [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
diets BodyT lore
loopbody
      SimplifiedBody lore Result
-> SimpleM lore (SimplifiedBody lore Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result
res, UsageTable
uses UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> Result -> UsageTable
isDoLoopResult Result
res), Stms (Wise lore)
stms)
  Body (Wise lore)
loopbody' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
loopstms Result
loopres
  (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> Body (Wise lore)
-> Exp (Wise lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' LoopForm (Wise lore)
form' Body (Wise lore)
loopbody', Stms (Wise lore)
hoisted)
  where fparamnames :: Names
fparamnames =
          [VName] -> Names
namesFromList (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val)
        consumeMerge :: SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
consumeMerge =
          (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ (SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. VName -> SymbolTable lore -> SymbolTable lore
ST.consume)) ([VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
consumed_by_merge
        consumed_by_merge :: Names
consumed_by_merge =
          Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(FParam lore, SubExp)] -> Result)
-> [(FParam lore, SubExp)] -> Result
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> ((FParam lore, SubExp) -> TypeBase Shape Uniqueness)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType (FParam lore -> TypeBase Shape Uniqueness)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> TypeBase Shape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val

simplifyExp (Op Op lore
op) = do (OpWithWisdom (Op lore)
op', Stms (Wise lore)
stms) <- Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
forall lore.
Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp Op lore
op
                         (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Op (Wise lore) -> Exp (Wise lore)
forall lore. Op lore -> ExpT lore
Op Op (Wise lore)
OpWithWisdom (Op lore)
op', Stms (Wise lore)
stms)

-- Special case for simplification of commutative BinOps where we
-- arrange the operands in sorted order.  This can make expressions
-- more identical, which helps CSE.
simplifyExp (BasicOp (BinOp BinOp
op SubExp
x SubExp
y))
  | BinOp -> Bool
commutativeBinOp BinOp
op = do
  SubExp
x' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
x
  SubExp
y' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
y
  (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp -> Exp (Wise lore)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Wise lore)) -> BasicOp -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
min SubExp
x' SubExp
y') (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
max SubExp
x' SubExp
y'), Stms (Wise lore)
forall a. Monoid a => a
mempty)

simplifyExp Exp lore
e = do Exp (Wise lore)
e' <- Exp lore -> SimpleM lore (Exp (Wise lore))
forall lore.
SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase Exp lore
e
                   (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Wise lore)
e', Stms (Wise lore)
forall a. Monoid a => a
mempty)

simplifyExpBase :: SimplifiableLore lore =>
                   Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase :: Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase = Mapper lore (Wise lore) (SimpleM lore)
-> Exp lore -> SimpleM lore (Exp (Wise lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper lore (Wise lore) (SimpleM lore)
hoist
  where hoist :: Mapper lore (Wise lore) (SimpleM lore)
hoist = Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper {
                -- Bodies are handled explicitly because we need to
                -- provide their result diet.
                  mapOnBody :: Scope (Wise lore) -> Body lore -> SimpleM lore (Body (Wise lore))
mapOnBody =
                  [Char]
-> Scope (Wise lore)
-> Body lore
-> SimpleM lore (Body (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled body in simplification engine."
                , mapOnSubExp :: SubExp -> SimpleM lore SubExp
mapOnSubExp = SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
                -- Lambdas are handled explicitly because we need to
                -- bind their parameters.
                , mapOnVName :: VName -> SimpleM lore VName
mapOnVName = VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
                , mapOnRetType :: RetType lore -> SimpleM lore (RetType (Wise lore))
mapOnRetType = RetType lore -> SimpleM lore (RetType (Wise lore))
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
                , mapOnBranchType :: BranchType lore -> SimpleM lore (BranchType (Wise lore))
mapOnBranchType = BranchType lore -> SimpleM lore (BranchType (Wise lore))
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
                , mapOnFParam :: FParam lore -> SimpleM lore (FParam (Wise lore))
mapOnFParam =
                  [Char] -> FParam lore -> SimpleM lore (FParam (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled FParam in simplification engine."
                , mapOnLParam :: LParam lore -> SimpleM lore (LParam (Wise lore))
mapOnLParam =
                  [Char] -> LParam lore -> SimpleM lore (LParam (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled LParam in simplification engine."
                , mapOnOp :: Op lore -> SimpleM lore (Op (Wise lore))
mapOnOp =
                  [Char] -> Op lore -> SimpleM lore (Op (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in simplification engine."
                }

type SimplifiableLore lore = (ASTLore lore,
                              Simplifiable (LetDec lore),
                              Simplifiable (FParamInfo lore),
                              Simplifiable (LParamInfo lore),
                              Simplifiable (RetType lore),
                              Simplifiable (BranchType lore),
                              CanBeWise (Op lore),
                              ST.IndexOp (OpWithWisdom (Op lore)),
                              BinderOps (Wise lore),
                              IsOp (Op lore))

class Simplifiable e where
  simplify :: SimplifiableLore lore => e -> SimpleM lore e

instance (Simplifiable a, Simplifiable b) => Simplifiable (a, b) where
  simplify :: (a, b) -> SimpleM lore (a, b)
simplify (a
x,b
y) = (,) (a -> b -> (a, b)) -> SimpleM lore a -> SimpleM lore (b -> (a, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x SimpleM lore (b -> (a, b)) -> SimpleM lore b -> SimpleM lore (a, b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM lore b
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify b
y

instance (Simplifiable a, Simplifiable b, Simplifiable c) =>
         Simplifiable (a, b, c) where
  simplify :: (a, b, c) -> SimpleM lore (a, b, c)
simplify (a
x,b
y,c
z) = (,,) (a -> b -> c -> (a, b, c))
-> SimpleM lore a -> SimpleM lore (b -> c -> (a, b, c))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x SimpleM lore (b -> c -> (a, b, c))
-> SimpleM lore b -> SimpleM lore (c -> (a, b, c))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM lore b
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify b
y SimpleM lore (c -> (a, b, c))
-> SimpleM lore c -> SimpleM lore (a, b, c)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> c -> SimpleM lore c
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify c
z

-- Convenient for Scatter.
instance Simplifiable Int where
  simplify :: Int -> SimpleM lore Int
simplify = Int -> SimpleM lore Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance Simplifiable a => Simplifiable (Maybe a) where
  simplify :: Maybe a -> SimpleM lore (Maybe a)
simplify Maybe a
Nothing = Maybe a -> SimpleM lore (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
  simplify (Just a
x) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> SimpleM lore a -> SimpleM lore (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x

instance Simplifiable a => Simplifiable [a] where
  simplify :: [a] -> SimpleM lore [a]
simplify = (a -> SimpleM lore a) -> [a] -> SimpleM lore [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify

instance Simplifiable SubExp where
  simplify :: SubExp -> SimpleM lore SubExp
simplify (Var VName
name) = do
    Maybe (SubExp, Certificates)
bnd <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
name (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
    case Maybe (SubExp, Certificates)
bnd of
      Just (Constant PrimValue
v, Certificates
cs) -> do SimpleM lore ()
forall lore. SimpleM lore ()
changed
                                  Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
                                  SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
      Just (Var VName
id', Certificates
cs) -> do SimpleM lore ()
forall lore. SimpleM lore ()
changed
                               Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
                               SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
      Maybe (SubExp, Certificates)
_              -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
  simplify (Constant PrimValue
v) =
    SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v

simplifyPattern :: (SimplifiableLore lore, Simplifiable dec) =>
                   PatternT dec
                -> SimpleM lore (PatternT dec)
simplifyPattern :: PatternT dec -> SimpleM lore (PatternT dec)
simplifyPattern PatternT dec
pat =
  [PatElemT dec] -> [PatElemT dec] -> PatternT dec
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern ([PatElemT dec] -> [PatElemT dec] -> PatternT dec)
-> SimpleM lore [PatElemT dec]
-> SimpleM lore ([PatElemT dec] -> PatternT dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  (PatElemT dec -> SimpleM lore (PatElemT dec))
-> [PatElemT dec] -> SimpleM lore [PatElemT dec]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT dec -> SimpleM lore (PatElemT dec)
forall lore dec.
(ASTLore lore, Simplifiable dec, Simplifiable (LetDec lore),
 Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
 Simplifiable (RetType lore), Simplifiable (BranchType lore),
 CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
 BinderOps (Wise lore)) =>
PatElemT dec -> SimpleM lore (PatElemT dec)
inspect (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements PatternT dec
pat) SimpleM lore ([PatElemT dec] -> PatternT dec)
-> SimpleM lore [PatElemT dec] -> SimpleM lore (PatternT dec)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
  (PatElemT dec -> SimpleM lore (PatElemT dec))
-> [PatElemT dec] -> SimpleM lore [PatElemT dec]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT dec -> SimpleM lore (PatElemT dec)
forall lore dec.
(ASTLore lore, Simplifiable dec, Simplifiable (LetDec lore),
 Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
 Simplifiable (RetType lore), Simplifiable (BranchType lore),
 CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
 BinderOps (Wise lore)) =>
PatElemT dec -> SimpleM lore (PatElemT dec)
inspect (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
pat)
  where inspect :: PatElemT dec -> SimpleM lore (PatElemT dec)
inspect (PatElem VName
name dec
lore) = VName -> dec -> PatElemT dec
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (dec -> PatElemT dec)
-> SimpleM lore dec -> SimpleM lore (PatElemT dec)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> dec -> SimpleM lore dec
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify dec
lore

instance Simplifiable () where
  simplify :: () -> SimpleM lore ()
simplify = () -> SimpleM lore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance Simplifiable VName where
  simplify :: VName -> SimpleM lore VName
simplify VName
v = do
    Maybe (SubExp, Certificates)
se <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
v (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
    case Maybe (SubExp, Certificates)
se of
      Just (Var VName
v', Certificates
cs) -> do SimpleM lore ()
forall lore. SimpleM lore ()
changed
                              Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
                              VName -> SimpleM lore VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v'
      Maybe (SubExp, Certificates)
_             -> VName -> SimpleM lore VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v

instance Simplifiable d => Simplifiable (ShapeBase d) where
  simplify :: ShapeBase d -> SimpleM lore (ShapeBase d)
simplify = ([d] -> ShapeBase d)
-> SimpleM lore [d] -> SimpleM lore (ShapeBase d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape (SimpleM lore [d] -> SimpleM lore (ShapeBase d))
-> (ShapeBase d -> SimpleM lore [d])
-> ShapeBase d
-> SimpleM lore (ShapeBase d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [d] -> SimpleM lore [d]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify ([d] -> SimpleM lore [d])
-> (ShapeBase d -> [d]) -> ShapeBase d -> SimpleM lore [d]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims

instance Simplifiable ExtSize where
  simplify :: ExtSize -> SimpleM lore ExtSize
simplify (Free SubExp
se) = SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize) -> SimpleM lore SubExp -> SimpleM lore ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
se
  simplify (Ext Int
x)   = ExtSize -> SimpleM lore ExtSize
forall (m :: * -> *) a. Monad m => a -> m a
return (ExtSize -> SimpleM lore ExtSize)
-> ExtSize -> SimpleM lore ExtSize
forall a b. (a -> b) -> a -> b
$ Int -> ExtSize
forall a. Int -> Ext a
Ext Int
x

instance Simplifiable Space where
  simplify :: Space -> SimpleM lore Space
simplify (ScalarSpace Result
ds PrimType
t) = Result -> PrimType -> Space
ScalarSpace (Result -> PrimType -> Space)
-> SimpleM lore Result -> SimpleM lore (PrimType -> Space)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> SimpleM lore Result
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ds SimpleM lore (PrimType -> Space)
-> SimpleM lore PrimType -> SimpleM lore Space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM lore PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
  simplify Space
s = Space -> SimpleM lore Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
s

instance Simplifiable shape => Simplifiable (TypeBase shape u) where
  simplify :: TypeBase shape u -> SimpleM lore (TypeBase shape u)
simplify (Array PrimType
et shape
shape u
u) = do
    shape
shape' <- shape -> SimpleM lore shape
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify shape
shape
    TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase shape u -> SimpleM lore (TypeBase shape u))
-> TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape' u
u
  simplify (Mem Space
space) =
    Space -> TypeBase shape u
forall shape u. Space -> TypeBase shape u
Mem (Space -> TypeBase shape u)
-> SimpleM lore Space -> SimpleM lore (TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space -> SimpleM lore Space
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Space
space
  simplify (Prim PrimType
bt) =
    TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase shape u -> SimpleM lore (TypeBase shape u))
-> TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt

instance Simplifiable d => Simplifiable (DimIndex d) where
  simplify :: DimIndex d -> SimpleM lore (DimIndex d)
simplify (DimFix d
i)       = d -> DimIndex d
forall d. d -> DimIndex d
DimFix (d -> DimIndex d) -> SimpleM lore d -> SimpleM lore (DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
i
  simplify (DimSlice d
i d
n d
s) = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice (d -> d -> d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (d -> d -> DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
i SimpleM lore (d -> d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (d -> DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
n SimpleM lore (d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
s

simplifyLambda :: SimplifiableLore lore =>
                  Lambda lore
               -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambda :: Lambda lore -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambda Lambda lore
lam = do
  BlockPred (Wise lore)
par_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
 -> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistPar (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
  BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist BlockPred (Wise lore)
par_blocker Lambda lore
lam

simplifyLambdaNoHoisting :: SimplifiableLore lore =>
                            Lambda lore
                         -> SimpleM lore (Lambda (Wise lore))
simplifyLambdaNoHoisting :: Lambda lore -> SimpleM lore (Lambda (Wise lore))
simplifyLambdaNoHoisting Lambda lore
lam =
  (Lambda (Wise lore), Stms (Wise lore)) -> Lambda (Wise lore)
forall a b. (a, b) -> a
fst ((Lambda (Wise lore), Stms (Wise lore)) -> Lambda (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist (Bool -> BlockPred (Wise lore)
forall lore. Bool -> BlockPred lore
isFalse Bool
False) Lambda lore
lam

simplifyLambdaMaybeHoist :: SimplifiableLore lore =>
                            BlockPred (Wise lore) -> Lambda lore
                         -> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist :: BlockPred (Wise lore)
-> Lambda lore
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist BlockPred (Wise lore)
blocked lam :: Lambda lore
lam@(Lambda [LParam lore]
params BodyT lore
body [Type]
rettype) = do
  [LParam lore]
params' <- (LParam lore -> SimpleM lore (LParam lore))
-> [LParam lore] -> SimpleM lore [LParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamInfo lore -> SimpleM lore (LParamInfo lore))
-> LParam lore -> SimpleM lore (LParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LParamInfo lore -> SimpleM lore (LParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [LParam lore]
params
  let paramnames :: Names
paramnames = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [VName]
forall lore. Lambda lore -> [VName]
boundByLambda Lambda lore
lam
  ((Stms (Wise lore)
lamstms, Result
lamres), Stms (Wise lore)
hoisted) <-
    SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
enterLoop (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    [LParam (Wise lore)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams [LParam lore]
[LParam (Wise lore)]
params' (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
    BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf (BlockPred (Wise lore)
blocked BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` Names -> BlockPred (Wise lore)
forall lore. ASTLore lore => Names -> BlockPred lore
hasFree Names
paramnames BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isConsumed) (SimpleM lore (SimplifiedBody lore Result)
 -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
      [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) [Type]
rettype) BodyT lore
body
  Body (Wise lore)
body' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
lamstms Result
lamres
  [Type]
rettype' <- [Type] -> SimpleM lore [Type]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [Type]
rettype
  (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([LParam (Wise lore)]
-> Body (Wise lore) -> [Type] -> Lambda (Wise lore)
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
[LParam (Wise lore)]
params' Body (Wise lore)
body' [Type]
rettype', Stms (Wise lore)
hoisted)

consumeResult :: [(Diet, SubExp)] -> UT.UsageTable
consumeResult :: [(Diet, SubExp)] -> UsageTable
consumeResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> ([(Diet, SubExp)] -> [UsageTable])
-> [(Diet, SubExp)]
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Diet, SubExp) -> UsageTable) -> [(Diet, SubExp)] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (Diet, SubExp) -> UsageTable
inspect
  where inspect :: (Diet, SubExp) -> UsageTable
inspect (Diet
Consume, SubExp
se) =
          [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map VName -> UsageTable
UT.consumedUsage ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExp -> Names
subExpAliases SubExp
se
        inspect (Diet, SubExp)
_ = UsageTable
forall a. Monoid a => a
mempty

instance Simplifiable Certificates where
  simplify :: Certificates -> SimpleM lore Certificates
simplify (Certificates [VName]
ocs) = [VName] -> Certificates
Certificates ([VName] -> Certificates)
-> ([[VName]] -> [VName]) -> [[VName]] -> Certificates
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [VName]
forall a. Eq a => [a] -> [a]
nub ([VName] -> [VName])
-> ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> Certificates)
-> SimpleM lore [[VName]] -> SimpleM lore Certificates
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> SimpleM lore [VName])
-> [VName] -> SimpleM lore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore [VName]
forall lore. VName -> SimpleM lore [VName]
check [VName]
ocs
    where check :: VName -> SimpleM lore [VName]
check VName
idd = do
            Maybe (SubExp, Certificates)
vv <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
idd (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
            case Maybe (SubExp, Certificates)
vv of
              Just (Constant PrimValue
Checked, Certificates [VName]
cs) -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
cs
              Just (Var VName
idd', Certificates
_) -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd']
              Maybe (SubExp, Certificates)
_ -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd]


insertAllStms :: SimplifiableLore lore =>
                 SimpleM lore (SimplifiedBody lore Result)
              -> SimpleM lore (Body (Wise lore))
insertAllStms :: SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
insertAllStms = (Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> (Stms (Wise lore), Result) -> SimpleM lore (Body (Wise lore))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody ((Stms (Wise lore), Result) -> SimpleM lore (Body (Wise lore)))
-> (((Stms (Wise lore), Result), Stms (Wise lore))
    -> (Stms (Wise lore), Result))
-> ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Stms (Wise lore), Result), Stms (Wise lore))
-> (Stms (Wise lore), Result)
forall a b. (a, b) -> a
fst (((Stms (Wise lore), Result), Stms (Wise lore))
 -> SimpleM lore (Body (Wise lore)))
-> (SimpleM lore (SimplifiedBody lore Result)
    -> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf (Bool -> BlockPred (Wise lore)
forall lore. Bool -> BlockPred lore
isFalse Bool
False)


simplifyFun :: SimplifiableLore lore =>
               FunDef lore -> SimpleM lore (FunDef (Wise lore))
simplifyFun :: FunDef lore -> SimpleM lore (FunDef (Wise lore))
simplifyFun (FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) = do
  [RetType lore]
rettype' <- [RetType lore] -> SimpleM lore [RetType lore]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [RetType lore]
rettype
  [FParam lore]
params' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamInfo lore -> SimpleM lore (FParamInfo lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse FParamInfo lore -> SimpleM lore (FParamInfo lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
params
  let ds :: [Diet]
ds = (RetType lore -> Diet) -> [RetType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase ExtShape Uniqueness -> Diet)
-> (RetType lore -> TypeBase ExtShape Uniqueness)
-> RetType lore
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RetType lore -> TypeBase ExtShape Uniqueness
forall t. DeclExtTyped t => t -> TypeBase ExtShape Uniqueness
declExtTypeOf) [RetType lore]
rettype'
  Body (Wise lore)
body' <- [FParam (Wise lore)]
-> SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall lore a.
SimplifiableLore lore =>
[FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams [FParam lore]
[FParam (Wise lore)]
params (SimpleM lore (Body (Wise lore))
 -> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
insertAllStms (SimpleM lore (SimplifiedBody lore Result)
 -> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
body
  FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore)))
-> FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore))
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType (Wise lore)]
-> [FParam (Wise lore)]
-> Body (Wise lore)
-> FunDef (Wise lore)
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
fname [RetType lore]
[RetType (Wise lore)]
rettype' [FParam lore]
[FParam (Wise lore)]
params' Body (Wise lore)
body'