{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TypeFamilies, FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE StrictData #-}
module Futhark.Optimise.Simplify.Engine
(
SimpleM
, runSimpleM
, SimpleOps (..)
, SimplifyOp
, bindableSimpleOps
, Env (envHoistBlockers, envRules)
, emptyEnv
, HoistBlockers(..)
, neverBlocks
, noExtraHoistBlockers
, BlockPred
, orIf
, hasFree
, isConsumed
, isFalse
, isOp
, isNotSafe
, asksEngineEnv
, askVtable
, localVtable
, SimplifiableLore
, Simplifiable (..)
, simplifyStms
, simplifyFun
, simplifyLambda
, simplifyLambdaNoHoisting
, simplifyParam
, bindLParams
, simplifyBody
, SimplifiedBody
, hoistStms
, blockIf
, module Futhark.Optimise.Simplify.Lore
) where
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Either
import Data.List (find, foldl', nub, mapAccumL)
import Data.Maybe
import Futhark.Representation.AST
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Optimise.Simplify.Rule
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Analysis.Usage
import Futhark.Construct
import Futhark.Optimise.Simplify.Lore
import Futhark.Util (splitFromEnd)
data HoistBlockers lore = HoistBlockers
{ HoistBlockers lore -> BlockPred (Wise lore)
blockHoistPar :: BlockPred (Wise lore)
, HoistBlockers lore -> BlockPred (Wise lore)
blockHoistSeq :: BlockPred (Wise lore)
, HoistBlockers lore -> BlockPred (Wise lore)
blockHoistBranch :: BlockPred (Wise lore)
, HoistBlockers lore -> Stm (Wise lore) -> Names
getArraySizes :: Stm (Wise lore) -> Names
, HoistBlockers lore -> Stm (Wise lore) -> Bool
isAllocation :: Stm (Wise lore) -> Bool
}
noExtraHoistBlockers :: HoistBlockers lore
=
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Names)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
forall lore.
BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> BlockPred (Wise lore)
-> (Stm (Wise lore) -> Names)
-> (Stm (Wise lore) -> Bool)
-> HoistBlockers lore
HoistBlockers BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks BlockPred (Wise lore)
forall lore. BlockPred lore
neverBlocks (Names -> Stm (Wise lore) -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) (Bool -> Stm (Wise lore) -> Bool
forall a b. a -> b -> a
const Bool
False)
data Env lore = Env { Env lore -> RuleBook (Wise lore)
envRules :: RuleBook (Wise lore)
, Env lore -> HoistBlockers lore
envHoistBlockers :: HoistBlockers lore
, Env lore -> SymbolTable (Wise lore)
envVtable :: ST.SymbolTable (Wise lore)
}
emptyEnv :: RuleBook (Wise lore) -> HoistBlockers lore -> Env lore
emptyEnv :: RuleBook (Wise lore) -> HoistBlockers lore -> Env lore
emptyEnv RuleBook (Wise lore)
rules HoistBlockers lore
blockers =
Env :: forall lore.
RuleBook (Wise lore)
-> HoistBlockers lore -> SymbolTable (Wise lore) -> Env lore
Env { envRules :: RuleBook (Wise lore)
envRules = RuleBook (Wise lore)
rules
, envHoistBlockers :: HoistBlockers lore
envHoistBlockers = HoistBlockers lore
blockers
, envVtable :: SymbolTable (Wise lore)
envVtable = SymbolTable (Wise lore)
forall a. Monoid a => a
mempty
}
type Protect m = SubExp -> Pattern (Lore m) -> Op (Lore m) -> Maybe (m ())
data SimpleOps lore =
SimpleOps { SimpleOps lore
-> SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpAttr (Wise lore))
mkExpAttrS :: ST.SymbolTable (Wise lore)
-> Pattern (Wise lore) -> Exp (Wise lore)
-> SimpleM lore (ExpAttr (Wise lore))
, SimpleOps lore
-> SymbolTable (Wise lore)
-> Stms (Wise lore)
-> Result
-> SimpleM lore (Body (Wise lore))
mkBodyS :: ST.SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result
-> SimpleM lore (Body (Wise lore))
, SimpleOps lore
-> SymbolTable (Wise lore)
-> [VName]
-> Exp (Wise lore)
-> SimpleM lore (Stm (Wise lore), Stms (Wise lore))
mkLetNamesS :: ST.SymbolTable (Wise lore)
-> [VName] -> Exp (Wise lore)
-> SimpleM lore (Stm (Wise lore), Stms (Wise lore))
, SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS :: Protect (Binder (Wise lore))
, SimpleOps lore -> SimplifyOp lore (Op lore)
simplifyOpS :: SimplifyOp lore (Op lore)
}
type SimplifyOp lore op = op -> SimpleM lore (OpWithWisdom op, Stms (Wise lore))
bindableSimpleOps :: (SimplifiableLore lore, Bindable lore) =>
SimplifyOp lore (Op lore) -> SimpleOps lore
bindableSimpleOps :: SimplifyOp lore (Op lore) -> SimpleOps lore
bindableSimpleOps = (SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpAttr (Wise lore)))
-> (SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> (SymbolTable (Wise lore)
-> [VName]
-> Exp (Wise lore)
-> SimpleM lore (Stm (Wise lore), Stms (Wise lore)))
-> Protect (Binder (Wise lore))
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
forall lore.
(SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpAttr (Wise lore)))
-> (SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> (SymbolTable (Wise lore)
-> [VName]
-> Exp (Wise lore)
-> SimpleM lore (Stm (Wise lore), Stms (Wise lore)))
-> Protect (Binder (Wise lore))
-> SimplifyOp lore (Op lore)
-> SimpleOps lore
SimpleOps SymbolTable (Wise lore)
-> Pattern (Wise lore)
-> Exp (Wise lore)
-> SimpleM lore (ExpAttr (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Bindable lore) =>
p -> PatternT (LetAttr lore) -> Exp lore -> m (ExpAttr lore)
mkExpAttrS' SymbolTable (Wise lore)
-> Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall (m :: * -> *) lore p.
(Monad m, Bindable lore) =>
p -> Stms lore -> Result -> m (Body lore)
mkBodyS' SymbolTable (Wise lore)
-> [VName]
-> Exp (Wise lore)
-> SimpleM lore (Stm (Wise lore), Stms (Wise lore))
forall (f :: * -> *) lore a p.
(Bindable lore, MonadFreshNames f, HasScope lore f, Monoid a) =>
p -> [VName] -> Exp lore -> f (Stm lore, a)
mkLetNamesS' Protect (Binder (Wise lore))
forall p p p a. p -> p -> p -> Maybe a
protectHoistedOpS'
where mkExpAttrS' :: p -> PatternT (LetAttr lore) -> Exp lore -> m (ExpAttr lore)
mkExpAttrS' p
_ PatternT (LetAttr lore)
pat Exp lore
e = ExpAttr lore -> m (ExpAttr lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpAttr lore -> m (ExpAttr lore))
-> ExpAttr lore -> m (ExpAttr lore)
forall a b. (a -> b) -> a -> b
$ PatternT (LetAttr lore) -> Exp lore -> ExpAttr lore
forall lore.
Bindable lore =>
Pattern lore -> Exp lore -> ExpAttr lore
mkExpAttr PatternT (LetAttr lore)
pat Exp lore
e
mkBodyS' :: p -> Stms lore -> Result -> m (Body lore)
mkBodyS' p
_ Stms lore
bnds Result
res = Body lore -> m (Body lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body lore -> m (Body lore)) -> Body lore -> m (Body lore)
forall a b. (a -> b) -> a -> b
$ Stms lore -> Result -> Body lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms lore
bnds Result
res
mkLetNamesS' :: p -> [VName] -> Exp lore -> f (Stm lore, a)
mkLetNamesS' p
_ [VName]
name Exp lore
e = (,) (Stm lore -> a -> (Stm lore, a))
-> f (Stm lore) -> f (a -> (Stm lore, a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Exp lore -> f (Stm lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, HasScope lore m) =>
[VName] -> Exp lore -> m (Stm lore)
mkLetNames [VName]
name Exp lore
e f (a -> (Stm lore, a)) -> f a -> f (Stm lore, a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Monoid a => a
mempty
protectHoistedOpS' :: p -> p -> p -> Maybe a
protectHoistedOpS' p
_ p
_ p
_ = Maybe a
forall a. Maybe a
Nothing
newtype SimpleM lore a =
SimpleM (ReaderT (SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates)) a)
deriving (Functor (SimpleM lore)
a -> SimpleM lore a
Functor (SimpleM lore)
-> (forall a. a -> SimpleM lore a)
-> (forall a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b)
-> (forall a b c.
(a -> b -> c)
-> SimpleM lore a -> SimpleM lore b -> SimpleM lore c)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a)
-> Applicative (SimpleM lore)
SimpleM lore a -> SimpleM lore b -> SimpleM lore b
SimpleM lore a -> SimpleM lore b -> SimpleM lore a
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall lore. Functor (SimpleM lore)
forall a. a -> SimpleM lore a
forall lore a. a -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
forall a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall lore a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SimpleM lore a -> SimpleM lore b -> SimpleM lore a
$c<* :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore a
*> :: SimpleM lore a -> SimpleM lore b -> SimpleM lore b
$c*> :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
liftA2 :: (a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
$cliftA2 :: forall lore a b c.
(a -> b -> c) -> SimpleM lore a -> SimpleM lore b -> SimpleM lore c
<*> :: SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
$c<*> :: forall lore a b.
SimpleM lore (a -> b) -> SimpleM lore a -> SimpleM lore b
pure :: a -> SimpleM lore a
$cpure :: forall lore a. a -> SimpleM lore a
$cp1Applicative :: forall lore. Functor (SimpleM lore)
Applicative, a -> SimpleM lore b -> SimpleM lore a
(a -> b) -> SimpleM lore a -> SimpleM lore b
(forall a b. (a -> b) -> SimpleM lore a -> SimpleM lore b)
-> (forall a b. a -> SimpleM lore b -> SimpleM lore a)
-> Functor (SimpleM lore)
forall a b. a -> SimpleM lore b -> SimpleM lore a
forall a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
forall lore a b. a -> SimpleM lore b -> SimpleM lore a
forall lore a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SimpleM lore b -> SimpleM lore a
$c<$ :: forall lore a b. a -> SimpleM lore b -> SimpleM lore a
fmap :: (a -> b) -> SimpleM lore a -> SimpleM lore b
$cfmap :: forall lore a b. (a -> b) -> SimpleM lore a -> SimpleM lore b
Functor, Applicative (SimpleM lore)
a -> SimpleM lore a
Applicative (SimpleM lore)
-> (forall a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b)
-> (forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b)
-> (forall a. a -> SimpleM lore a)
-> Monad (SimpleM lore)
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore. Applicative (SimpleM lore)
forall a. a -> SimpleM lore a
forall lore a. a -> SimpleM lore a
forall a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
forall lore a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SimpleM lore a
$creturn :: forall lore a. a -> SimpleM lore a
>> :: SimpleM lore a -> SimpleM lore b -> SimpleM lore b
$c>> :: forall lore a b. SimpleM lore a -> SimpleM lore b -> SimpleM lore b
>>= :: SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
$c>>= :: forall lore a b.
SimpleM lore a -> (a -> SimpleM lore b) -> SimpleM lore b
$cp1Monad :: forall lore. Applicative (SimpleM lore)
Monad,
MonadReader (SimpleOps lore, Env lore),
MonadState (VNameSource, Bool, Certificates))
instance MonadFreshNames (SimpleM lore) where
putNameSource :: VNameSource -> SimpleM lore ()
putNameSource VNameSource
src = ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
_, Bool
b, Certificates
c) -> (VNameSource
src, Bool
b, Certificates
c)
getNameSource :: SimpleM lore VNameSource
getNameSource = ((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource)
-> ((VNameSource, Bool, Certificates) -> VNameSource)
-> SimpleM lore VNameSource
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
_, Certificates
_) -> VNameSource
a
instance SimplifiableLore lore => HasScope (Wise lore) (SimpleM lore) where
askScope :: SimpleM lore (Scope (Wise lore))
askScope = SymbolTable (Wise lore) -> Scope (Wise lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope (SymbolTable (Wise lore) -> Scope (Wise lore))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Scope (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
lookupType :: VName -> SimpleM lore Type
lookupType VName
name = do
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case VName -> SymbolTable (Wise lore) -> Maybe Type
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
name SymbolTable (Wise lore)
vtable of
Just Type
t -> Type -> SimpleM lore Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
t
Maybe Type
Nothing -> [Char] -> SimpleM lore Type
forall a. HasCallStack => [Char] -> a
error ([Char] -> SimpleM lore Type) -> [Char] -> SimpleM lore Type
forall a b. (a -> b) -> a -> b
$
[Char]
"SimpleM.lookupType: cannot find variable " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" in symbol table."
instance SimplifiableLore lore =>
LocalScope (Wise lore) (SimpleM lore) where
localScope :: Scope (Wise lore) -> SimpleM lore a -> SimpleM lore a
localScope Scope (Wise lore)
types = (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. Semigroup a => a -> a -> a
<>Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. Attributes lore => Scope lore -> SymbolTable lore
ST.fromScope Scope (Wise lore)
types)
runSimpleM :: SimpleM lore a
-> SimpleOps lore
-> Env lore
-> VNameSource
-> ((a, Bool), VNameSource)
runSimpleM :: SimpleM lore a
-> SimpleOps lore
-> Env lore
-> VNameSource
-> ((a, Bool), VNameSource)
runSimpleM (SimpleM ReaderT
(SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates))
a
m) SimpleOps lore
simpl Env lore
env VNameSource
src =
let (a
x, (VNameSource
src', Bool
b, Certificates
_)) = State (VNameSource, Bool, Certificates) a
-> (VNameSource, Bool, Certificates)
-> (a, (VNameSource, Bool, Certificates))
forall s a. State s a -> s -> (a, s)
runState (ReaderT
(SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates))
a
-> (SimpleOps lore, Env lore)
-> State (VNameSource, Bool, Certificates) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
(SimpleOps lore, Env lore)
(State (VNameSource, Bool, Certificates))
a
m (SimpleOps lore
simpl, Env lore
env)) (VNameSource
src, Bool
False, Certificates
forall a. Monoid a => a
mempty)
in ((a
x, Bool
b), VNameSource
src')
askEngineEnv :: SimpleM lore (Env lore)
askEngineEnv :: SimpleM lore (Env lore)
askEngineEnv = ((SimpleOps lore, Env lore) -> Env lore) -> SimpleM lore (Env lore)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (SimpleOps lore, Env lore) -> Env lore
forall a b. (a, b) -> b
snd
asksEngineEnv :: (Env lore -> a) -> SimpleM lore a
asksEngineEnv :: (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> a
f = Env lore -> a
f (Env lore -> a) -> SimpleM lore (Env lore) -> SimpleM lore a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (Env lore)
forall lore. SimpleM lore (Env lore)
askEngineEnv
askVtable :: SimpleM lore (ST.SymbolTable (Wise lore))
askVtable :: SimpleM lore (SymbolTable (Wise lore))
askVtable = (Env lore -> SymbolTable (Wise lore))
-> SimpleM lore (SymbolTable (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> SymbolTable (Wise lore)
forall lore. Env lore -> SymbolTable (Wise lore)
envVtable
localVtable :: (ST.SymbolTable (Wise lore) -> ST.SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable :: (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable SymbolTable (Wise lore) -> SymbolTable (Wise lore)
f = ((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a -> SimpleM lore a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a -> SimpleM lore a)
-> ((SimpleOps lore, Env lore) -> (SimpleOps lore, Env lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \(SimpleOps lore
ops, Env lore
env) -> (SimpleOps lore
ops, Env lore
env { envVtable :: SymbolTable (Wise lore)
envVtable = SymbolTable (Wise lore) -> SymbolTable (Wise lore)
f (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ Env lore -> SymbolTable (Wise lore)
forall lore. Env lore -> SymbolTable (Wise lore)
envVtable Env lore
env })
collectCerts :: SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts :: SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts SimpleM lore a
m = do a
x <- SimpleM lore a
m
(VNameSource
a, Bool
b, Certificates
cs) <- SimpleM lore (VNameSource, Bool, Certificates)
forall s (m :: * -> *). MonadState s m => m s
get
(VNameSource, Bool, Certificates) -> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (VNameSource
a, Bool
b, Certificates
forall a. Monoid a => a
mempty)
(a, Certificates) -> SimpleM lore (a, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Certificates
cs)
changed :: SimpleM lore ()
changed :: SimpleM lore ()
changed = ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
src, Bool
_, Certificates
cs) -> (VNameSource
src, Bool
True, Certificates
cs)
usedCerts :: Certificates -> SimpleM lore ()
usedCerts :: Certificates -> SimpleM lore ()
usedCerts Certificates
cs = ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ())
-> ((VNameSource, Bool, Certificates)
-> (VNameSource, Bool, Certificates))
-> SimpleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VNameSource
a, Bool
b, Certificates
c) -> (VNameSource
a, Bool
b, Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
c)
enterLoop :: SimpleM lore a -> SimpleM lore a
enterLoop :: SimpleM lore a -> SimpleM lore a
enterLoop = (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. SymbolTable lore -> SymbolTable lore
ST.deepen
bindFParams :: SimplifiableLore lore =>
[FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams :: [FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams [FParam (Wise lore)]
params =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ [FParam (Wise lore)]
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
Attributes lore =>
[FParam lore] -> SymbolTable lore -> SymbolTable lore
ST.insertFParams [FParam (Wise lore)]
params
bindLParams :: SimplifiableLore lore =>
[LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams :: [LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams [LParam (Wise lore)]
params =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise lore)
vtable ->
(Param (LParamAttr lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Param (LParamAttr lore)]
-> SymbolTable (Wise lore)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Param (LParamAttr lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
Attributes lore =>
LParam lore -> SymbolTable lore -> SymbolTable lore
ST.insertLParam SymbolTable (Wise lore)
vtable [Param (LParamAttr lore)]
[LParam (Wise lore)]
params
bindArrayLParams :: SimplifiableLore lore =>
[(LParam (Wise lore),Maybe VName)] -> SimpleM lore a
-> SimpleM lore a
bindArrayLParams :: [(LParam (Wise lore), Maybe VName)]
-> SimpleM lore a -> SimpleM lore a
bindArrayLParams [(LParam (Wise lore), Maybe VName)]
params =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ \SymbolTable (Wise lore)
vtable ->
((Param (LParamAttr lore), Maybe VName)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [(Param (LParamAttr lore), Maybe VName)]
-> SymbolTable (Wise lore)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Param (LParamAttr lore)
-> Maybe VName
-> SymbolTable (Wise lore)
-> SymbolTable (Wise lore))
-> (Param (LParamAttr lore), Maybe VName)
-> SymbolTable (Wise lore)
-> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Param (LParamAttr lore)
-> Maybe VName
-> SymbolTable (Wise lore)
-> SymbolTable (Wise lore)
forall lore.
Attributes lore =>
LParam lore -> Maybe VName -> SymbolTable lore -> SymbolTable lore
ST.insertArrayLParam) SymbolTable (Wise lore)
vtable [(Param (LParamAttr lore), Maybe VName)]
[(LParam (Wise lore), Maybe VName)]
params
bindLoopVar :: SimplifiableLore lore =>
VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar :: VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar VName
var IntType
it SubExp
bound =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a)
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a
-> SimpleM lore a
forall a b. (a -> b) -> a -> b
$ SymbolTable (Wise lore) -> SymbolTable (Wise lore)
clampUpper (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> SymbolTable (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolTable (Wise lore) -> SymbolTable (Wise lore)
clampVar
where clampVar :: SymbolTable (Wise lore) -> SymbolTable (Wise lore)
clampVar = VName
-> IntType
-> SubExp
-> SymbolTable (Wise lore)
-> SymbolTable (Wise lore)
forall lore.
Attributes lore =>
VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.insertLoopVar VName
var IntType
it SubExp
bound
clampUpper :: SymbolTable (Wise lore) -> SymbolTable (Wise lore)
clampUpper = case SubExp
bound of Var VName
v -> VName -> Int -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore. VName -> Int -> SymbolTable lore -> SymbolTable lore
ST.isAtLeast VName
v Int
1
SubExp
_ -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a. a -> a
id
protectIfHoisted :: SimplifiableLore lore =>
SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted :: SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
side SimpleM lore (a, Stms (Wise lore))
m = do
(a
x, Stms (Wise lore)
stms) <- SimpleM lore (a, Stms (Wise lore))
m
SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
ops <- ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())))
-> ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall lore. SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS (SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore)))
-> Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise lore) -> Bool) -> Stms (Wise lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp (Wise lore) -> Bool)
-> (Stm (Wise lore) -> Exp (Wise lore)) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Wise lore)
stms
then do SubExp
cond' <- if Bool
side then SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond
else [Char]
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"cond_neg" (Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Wise lore) -> Exp (Wise lore)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Wise lore) -> Exp (Wise lore))
-> BasicOp (Wise lore) -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp (Wise lore)
forall lore. UnOp -> SubExp -> BasicOp lore
UnOp UnOp
Not SubExp
cond
(Stm (Wise lore) -> Binder (Wise lore) ())
-> Stms (Wise lore) -> Binder (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BinderT (Wise lore) (State VNameSource))
-> (Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
Protect (BinderT (Wise lore) (State VNameSource))
ops Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool
forall lore. Attributes lore => Exp lore -> Bool
unsafeOrCostly SubExp
cond') Stms (Wise lore)
stms
else Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
a -> Binder (Wise lore) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
where unsafeOrCostly :: Exp lore -> Bool
unsafeOrCostly Exp lore
e = Bool -> Bool
not (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp Exp lore
e) Bool -> Bool -> Bool
|| Bool -> Bool
not (Exp lore -> Bool
forall lore. Attributes lore => Exp lore -> Bool
cheapExp Exp lore
e)
protectLoopHoisted :: SimplifiableLore lore =>
[(FParam (Wise lore),SubExp)]
-> [(FParam (Wise lore),SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted :: [(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam (Wise lore), SubExp)]
ctx [(FParam (Wise lore), SubExp)]
val LoopForm (Wise lore)
form SimpleM lore (a, Stms (Wise lore))
m = do
(a
x, Stms (Wise lore)
stms) <- SimpleM lore (a, Stms (Wise lore))
m
SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
ops <- ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())))
-> ((SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> SimpleM
lore
(SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall lore. SimpleOps lore -> Protect (Binder (Wise lore))
protectHoistedOpS (SimpleOps lore
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ()))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore)))
-> Binder (Wise lore) a -> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Wise lore) -> Bool) -> Stms (Wise lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp (Wise lore) -> Bool)
-> (Stm (Wise lore) -> Exp (Wise lore)) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Exp (Wise lore)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Wise lore)
stms
then do SubExp
is_nonempty <- BinderT (Wise lore) (State VNameSource) SubExp
checkIfNonEmpty
(Stm (Wise lore) -> Binder (Wise lore) ())
-> Stms (Wise lore) -> Binder (Wise lore) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Protect (BinderT (Wise lore) (State VNameSource))
-> (Exp (Lore (BinderT (Wise lore) (State VNameSource))) -> Bool)
-> SubExp
-> Stm (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *).
MonadBinder m =>
Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf SubExp
-> PatternT (VarWisdom, LetAttr lore)
-> OpWithWisdom (Op lore)
-> Maybe (Binder (Wise lore) ())
Protect (BinderT (Wise lore) (State VNameSource))
ops (Bool -> Bool
not (Bool -> Bool)
-> (Exp (Wise lore) -> Bool) -> Exp (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Wise lore) -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp) SubExp
is_nonempty) Stms (Wise lore)
stms
else Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> Binder (Wise lore) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
a -> Binder (Wise lore) a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
where checkIfNonEmpty :: BinderT (Wise lore) (State VNameSource) SubExp
checkIfNonEmpty =
case LoopForm (Wise lore)
form of
WhileLoop VName
cond
| Just (Param (FParamAttr lore)
_, SubExp
cond_init) <-
((Param (FParamAttr lore), SubExp) -> Bool)
-> [(Param (FParamAttr lore), SubExp)]
-> Maybe (Param (FParamAttr lore), SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
cond) (VName -> Bool)
-> ((Param (FParamAttr lore), SubExp) -> VName)
-> (Param (FParamAttr lore), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (FParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName (Param (FParamAttr lore) -> VName)
-> ((Param (FParamAttr lore), SubExp) -> Param (FParamAttr lore))
-> (Param (FParamAttr lore), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamAttr lore), SubExp) -> Param (FParamAttr lore)
forall a b. (a, b) -> a
fst) ([(Param (FParamAttr lore), SubExp)]
-> Maybe (Param (FParamAttr lore), SubExp))
-> [(Param (FParamAttr lore), SubExp)]
-> Maybe (Param (FParamAttr lore), SubExp)
forall a b. (a -> b) -> a -> b
$ [(Param (FParamAttr lore), SubExp)]
[(FParam (Wise lore), SubExp)]
ctx [(Param (FParamAttr lore), SubExp)]
-> [(Param (FParamAttr lore), SubExp)]
-> [(Param (FParamAttr lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamAttr lore), SubExp)]
[(FParam (Wise lore), SubExp)]
val ->
SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
cond_init
| Bool
otherwise -> SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> BinderT (Wise lore) (State VNameSource) SubExp)
-> SubExp -> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
ForLoop VName
_ IntType
it SubExp
bound [(LParam (Wise lore), VName)]
_ ->
[Char]
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"loop_nonempty" (Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp)
-> Exp (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp (Wise lore) -> Exp (Wise lore)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Wise lore) -> Exp (Wise lore))
-> BasicOp (Wise lore) -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp (Wise lore)
forall lore. CmpOp -> SubExp -> SubExp -> BasicOp lore
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) (IntType -> Integer -> SubExp
intConst IntType
it Integer
0) SubExp
bound
protectIf :: MonadBinder m =>
Protect m
-> (Exp (Lore m) -> Bool)
-> SubExp -> Stm (Lore m) -> m ()
protectIf :: Protect m
-> (Exp (Lore m) -> Bool) -> SubExp -> Stm (Lore m) -> m ()
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat (StmAux Certificates
cs ExpAttr (Lore m)
_)
(If SubExp
cond BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfAttr [BranchType (Lore m)]
if_ts IfSort
IfFallback))) = do
SubExp
cond' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"protect_cond_conj" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogAnd SubExp
taken SubExp
cond
Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfAttr (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond' BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfAttr (BranchType (Lore m)) -> Exp (Lore m))
-> IfAttr (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[BranchType (Lore m)] -> IfSort -> IfAttr (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [BranchType (Lore m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat (StmAux Certificates
cs ExpAttr (Lore m)
_) (BasicOp (Assert SubExp
cond ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc))) = do
SubExp
not_taken <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"loop_not_taken" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp (Lore m)
forall lore. UnOp -> SubExp -> BasicOp lore
UnOp UnOp
Not SubExp
taken
SubExp
cond' <- [Char] -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"protect_assert_disj" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Lore m)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
LogOr SubExp
not_taken SubExp
cond
Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp (Lore m)
forall lore.
SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp lore
Assert SubExp
cond' ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc
protectIf Protect m
protect Exp (Lore m) -> Bool
_ SubExp
taken (Let Pattern (Lore m)
pat (StmAux Certificates
cs ExpAttr (Lore m)
_) (Op Op (Lore m)
op))
| Just m ()
m <- Protect m
protect SubExp
taken Pattern (Lore m)
pat Op (Lore m)
op =
Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs m ()
m
protectIf Protect m
_ Exp (Lore m) -> Bool
f SubExp
taken (Let Pattern (Lore m)
pat (StmAux Certificates
cs ExpAttr (Lore m)
_) Exp (Lore m)
e)
| Exp (Lore m) -> Bool
f Exp (Lore m)
e = do
BodyT (Lore m)
taken_body <- [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp (Lore m)
e]
BodyT (Lore m)
untaken_body <- [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody ([m (Exp (Lore m))] -> m (BodyT (Lore m)))
-> [m (Exp (Lore m))] -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ (Type -> m (Exp (Lore m))) -> [Type] -> [m (Exp (Lore m))]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Type -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Type -> m (Exp (Lore m))
emptyOfType ([VName] -> Type -> m (Exp (Lore m)))
-> [VName] -> Type -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m) -> [VName]
forall attr. PatternT attr -> [VName]
patternContextNames Pattern (Lore m)
pat)
(Pattern (Lore m) -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternValueTypes Pattern (Lore m)
pat)
[BranchType (Lore m)]
if_ts <- Pattern (Lore m) -> m [BranchType (Lore m)]
forall lore (m :: * -> *).
(Attributes lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern Pattern (Lore m)
pat
Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore m) -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind_ Pattern (Lore m)
pat (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfAttr (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
taken BodyT (Lore m)
taken_body BodyT (Lore m)
untaken_body (IfAttr (BranchType (Lore m)) -> Exp (Lore m))
-> IfAttr (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[BranchType (Lore m)] -> IfSort -> IfAttr (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [BranchType (Lore m)]
if_ts IfSort
IfFallback
protectIf Protect m
_ Exp (Lore m) -> Bool
_ SubExp
_ Stm (Lore m)
stm =
Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore m)
stm
emptyOfType :: MonadBinder m => [VName] -> Type -> m (Exp (Lore m))
emptyOfType :: [VName] -> Type -> m (Exp (Lore m))
emptyOfType [VName]
_ Mem{} =
[Char] -> m (Exp (Lore m))
forall a. HasCallStack => [Char] -> a
error [Char]
"emptyOfType: Cannot hoist non-existential memory."
emptyOfType [VName]
_ (Prim PrimType
pt) =
Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore m)
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp (Lore m)) -> SubExp -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
emptyOfType [VName]
ctx_names (Array PrimType
pt Shape
shape NoUniqueness
_) = do
let dims :: Result
dims = (SubExp -> SubExp) -> Result -> Result
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
zeroIfContext (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp (Lore m)
forall lore. PrimType -> Result -> BasicOp lore
Scratch PrimType
pt Result
dims
where zeroIfContext :: SubExp -> SubExp
zeroIfContext (Var VName
v) | VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ctx_names = IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0
zeroIfContext SubExp
se = SubExp
se
notWorthHoisting :: Attributes lore => BlockPred lore
notWorthHoisting :: BlockPred lore
notWorthHoisting SymbolTable lore
_ UsageTable
_ (Let Pattern lore
pat StmAux (ExpAttr lore)
_ Exp lore
e) =
Bool -> Bool
not (Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp Exp lore
e) Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) (Int -> Bool) -> (Type -> Int) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank) (Pattern lore -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes Pattern lore
pat)
hoistStms :: SimplifiableLore lore =>
RuleBook (Wise lore) -> BlockPred (Wise lore)
-> ST.SymbolTable (Wise lore) -> UT.UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore),
Stms (Wise lore))
hoistStms :: RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
uses Stms (Wise lore)
orig_stms = do
([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted) <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
simplifyStmsBottomUp SymbolTable (Wise lore)
vtable UsageTable
uses Stms (Wise lore)
orig_stms
Bool -> SimpleM lore () -> SimpleM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Stm (Wise lore)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Stm (Wise lore)]
hoisted) SimpleM lore ()
forall lore. SimpleM lore ()
changed
(Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise lore)] -> Stms (Wise lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Wise lore)]
blocked, [Stm (Wise lore)] -> Stms (Wise lore)
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm (Wise lore)]
hoisted)
where simplifyStmsBottomUp :: SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
simplifyStmsBottomUp SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms = do
(UsageTable
_, [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms') <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms
let ([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted) = [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)]))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> ([Stm (Wise lore)], [Stm (Wise lore)])
forall a b. (a -> b) -> a -> b
$ [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall lore.
Attributes lore =>
[Either (Stm lore) (Stm lore)] -> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms'
([Stm (Wise lore)], [Stm (Wise lore)])
-> SimpleM lore ([Stm (Wise lore)], [Stm (Wise lore)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm (Wise lore)]
blocked, [Stm (Wise lore)]
hoisted)
simplifyStmsBottomUp' :: SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
stms =
((UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))]))
-> (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
hoistable (UsageTable
uses',[]) ([(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))]))
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall a b. (a -> b) -> a -> b
$ [(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a. [a] -> [a]
reverse ([(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))])
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a b. (a -> b) -> a -> b
$ [Stm (Wise lore)]
-> [SymbolTable (Wise lore)]
-> [(Stm (Wise lore), SymbolTable (Wise lore))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Stms (Wise lore) -> [Stm (Wise lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Wise lore)
stms) [SymbolTable (Wise lore)]
vtables
where vtables :: [SymbolTable (Wise lore)]
vtables = (SymbolTable (Wise lore)
-> Stm (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> [Stm (Wise lore)]
-> [SymbolTable (Wise lore)]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl ((Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore)
-> Stm (Wise lore)
-> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
(IndexOp (Op lore), Ranged lore, Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm) SymbolTable (Wise lore)
vtable' ([Stm (Wise lore)] -> [SymbolTable (Wise lore)])
-> [Stm (Wise lore)] -> [SymbolTable (Wise lore)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise lore) -> [Stm (Wise lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms (Wise lore)
stms
hoistable :: (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> (Stm (Wise lore), SymbolTable (Wise lore))
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
hoistable (UsageTable
uses',[Either (Stm (Wise lore)) (Stm (Wise lore))]
stms) (Stm (Wise lore)
stm, SymbolTable (Wise lore)
vtable')
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
uses') ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Stm (Wise lore) -> [VName]
forall lore. Stm lore -> [VName]
provides Stm (Wise lore)
stm =
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
| Bool
otherwise = do
Maybe (Stms (Wise lore))
res <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (SymbolTable (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. a -> b -> a
const SymbolTable (Wise lore)
vtable') (SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore))))
-> SimpleM lore (Maybe (Stms (Wise lore)))
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall a b. (a -> b) -> a -> b
$
RuleBook (Wise lore)
-> (SymbolTable (Wise lore), UsageTable)
-> Stm (Wise lore)
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> (SymbolTable lore, UsageTable)
-> Stm lore
-> m (Maybe (Stms lore))
bottomUpSimplifyStm RuleBook (Wise lore)
rules (SymbolTable (Wise lore)
vtable', UsageTable
uses') Stm (Wise lore)
stm
case Maybe (Stms (Wise lore))
res of
Maybe (Stms (Wise lore))
Nothing
| BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm ->
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (SymbolTable (Wise lore)
-> UsageTable -> Stm (Wise lore) -> UsageTable
forall lore.
(Attributes lore, Aliased lore) =>
SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm
UsageTable -> [VName] -> UsageTable
`UT.without` Stm (Wise lore) -> [VName]
forall lore. Stm lore -> [VName]
provides Stm (Wise lore)
stm,
Stm (Wise lore) -> Either (Stm (Wise lore)) (Stm (Wise lore))
forall a b. a -> Either a b
Left Stm (Wise lore)
stm Either (Stm (Wise lore)) (Stm (Wise lore))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
| Bool
otherwise ->
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (SymbolTable (Wise lore)
-> UsageTable -> Stm (Wise lore) -> UsageTable
forall lore.
(Attributes lore, Aliased lore) =>
SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage SymbolTable (Wise lore)
vtable' UsageTable
uses' Stm (Wise lore)
stm, Stm (Wise lore) -> Either (Stm (Wise lore)) (Stm (Wise lore))
forall a b. b -> Either a b
Right Stm (Wise lore)
stm Either (Stm (Wise lore)) (Stm (Wise lore))
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. a -> [a] -> [a]
: [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
Just Stms (Wise lore)
optimstms -> do
SimpleM lore ()
forall lore. SimpleM lore ()
changed
(UsageTable
uses'',[Either (Stm (Wise lore)) (Stm (Wise lore))]
stms') <- SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
simplifyStmsBottomUp' SymbolTable (Wise lore)
vtable' UsageTable
uses' Stms (Wise lore)
optimstms
(UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
-> SimpleM
lore (UsageTable, [Either (Stm (Wise lore)) (Stm (Wise lore))])
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable
uses'', [Either (Stm (Wise lore)) (Stm (Wise lore))]
stms'[Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
-> [Either (Stm (Wise lore)) (Stm (Wise lore))]
forall a. [a] -> [a] -> [a]
++[Either (Stm (Wise lore)) (Stm (Wise lore))]
stms)
blockUnhoistedDeps :: Attributes lore =>
[Either (Stm lore) (Stm lore)]
-> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps :: [Either (Stm lore) (Stm lore)] -> [Either (Stm lore) (Stm lore)]
blockUnhoistedDeps = (Names, [Either (Stm lore) (Stm lore)])
-> [Either (Stm lore) (Stm lore)]
forall a b. (a, b) -> b
snd ((Names, [Either (Stm lore) (Stm lore)])
-> [Either (Stm lore) (Stm lore)])
-> ([Either (Stm lore) (Stm lore)]
-> (Names, [Either (Stm lore) (Stm lore)]))
-> [Either (Stm lore) (Stm lore)]
-> [Either (Stm lore) (Stm lore)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore)))
-> Names
-> [Either (Stm lore) (Stm lore)]
-> (Names, [Either (Stm lore) (Stm lore)])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
forall lore.
(FreeAttr (ExpAttr lore), FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore), FreeIn (LParamAttr lore),
FreeIn (LetAttr lore), FreeIn (Op lore)) =>
Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
block Names
forall a. Monoid a => a
mempty
where block :: Names
-> Either (Stm lore) (Stm lore)
-> (Names, Either (Stm lore) (Stm lore))
block Names
blocked (Left Stm lore
need) =
(Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm lore -> [VName]
forall lore. Stm lore -> [VName]
provides Stm lore
need), Stm lore -> Either (Stm lore) (Stm lore)
forall a b. a -> Either a b
Left Stm lore
need)
block Names
blocked (Right Stm lore
need)
| Names
blocked Names -> Names -> Bool
`namesIntersect` Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm lore
need =
(Names
blocked Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList (Stm lore -> [VName]
forall lore. Stm lore -> [VName]
provides Stm lore
need), Stm lore -> Either (Stm lore) (Stm lore)
forall a b. a -> Either a b
Left Stm lore
need)
| Bool
otherwise =
(Names
blocked, Stm lore -> Either (Stm lore) (Stm lore)
forall a b. b -> Either a b
Right Stm lore
need)
provides :: Stm lore -> [VName]
provides :: Stm lore -> [VName]
provides = PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> (Stm lore -> PatternT (LetAttr lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern
expandUsage :: (Attributes lore, Aliased lore) =>
ST.SymbolTable lore -> UT.UsageTable -> Stm lore -> UT.UsageTable
expandUsage :: SymbolTable lore -> UsageTable -> Stm lore -> UsageTable
expandUsage SymbolTable lore
vtable UsageTable
utable Stm lore
bnd =
(VName -> Names) -> UsageTable -> UsageTable
UT.expand (VName -> SymbolTable lore -> Names
forall lore. VName -> SymbolTable lore -> Names
`ST.lookupAliases` SymbolTable lore
vtable) (Stm lore -> UsageTable
forall lore.
(Attributes lore, Aliased lore) =>
Stm lore -> UsageTable
usageInStm Stm lore
bnd UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
usageThroughAliases) UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<>
UsageTable
utable
where pat :: Pattern lore
pat = Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd
usageThroughAliases :: UsageTable
usageThroughAliases =
[UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ ((VName, Names) -> Maybe UsageTable)
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases ([(VName, Names)] -> [UsageTable])
-> [(VName, Names)] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$
[VName] -> [Names] -> [(VName, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
pat) (Pattern lore -> [Names]
forall attr. AliasesOf attr => PatternT attr -> [Names]
patternAliases Pattern lore
pat)
usageThroughBindeeAliases :: (VName, Names) -> Maybe UsageTable
usageThroughBindeeAliases (VName
name, Names
aliases) = do
Usages
uses <- VName -> UsageTable -> Maybe Usages
UT.lookup VName
name UsageTable
utable
UsageTable -> Maybe UsageTable
forall (m :: * -> *) a. Monad m => a -> m a
return (UsageTable -> Maybe UsageTable) -> UsageTable -> Maybe UsageTable
forall a b. (a -> b) -> a -> b
$ [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Usages -> UsageTable
`UT.usage` Usages
uses) ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases
type BlockPred lore = ST.SymbolTable lore -> UT.UsageTable -> Stm lore -> Bool
neverBlocks :: BlockPred lore
neverBlocks :: BlockPred lore
neverBlocks SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False
isFalse :: Bool -> BlockPred lore
isFalse :: Bool -> BlockPred lore
isFalse Bool
b SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool -> Bool
not Bool
b
orIf :: BlockPred lore -> BlockPred lore -> BlockPred lore
orIf :: BlockPred lore -> BlockPred lore -> BlockPred lore
orIf BlockPred lore
p1 BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need = BlockPred lore
p1 SymbolTable lore
body UsageTable
vtable Stm lore
need Bool -> Bool -> Bool
|| BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need
andAlso :: BlockPred lore -> BlockPred lore -> BlockPred lore
andAlso :: BlockPred lore -> BlockPred lore -> BlockPred lore
andAlso BlockPred lore
p1 BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need = BlockPred lore
p1 SymbolTable lore
body UsageTable
vtable Stm lore
need Bool -> Bool -> Bool
&& BlockPred lore
p2 SymbolTable lore
body UsageTable
vtable Stm lore
need
isConsumed :: BlockPred lore
isConsumed :: BlockPred lore
isConsumed SymbolTable lore
_ UsageTable
utable = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
utable) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> (Stm lore -> PatternT (LetAttr lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern
isOp :: BlockPred lore
isOp :: BlockPred lore
isOp SymbolTable lore
_ UsageTable
_ (Let Pattern lore
_ StmAux (ExpAttr lore)
_ Op{}) = Bool
True
isOp SymbolTable lore
_ UsageTable
_ Stm lore
_ = Bool
False
constructBody :: SimplifiableLore lore => Stms (Wise lore) -> Result
-> SimpleM lore (Body (Wise lore))
constructBody :: Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
stms Result
res =
((Body (Wise lore), Stms (Wise lore)) -> Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Body (Wise lore), Stms (Wise lore)) -> Body (Wise lore)
forall a b. (a, b) -> a
fst (SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore)))
-> Binder (Wise lore) (Body (Wise lore))
-> SimpleM lore (Body (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource)))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ do Stms (Lore (BinderT (Wise lore) (State VNameSource)))
-> BinderT (Wise lore) (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT (Wise lore) (State VNameSource)))
Stms (Wise lore)
stms
Result
-> BinderT
(Wise lore)
(State VNameSource)
(Body (Lore (BinderT (Wise lore) (State VNameSource))))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Lore m))
resultBodyM Result
res
type SimplifiedBody lore a = ((a, UT.UsageTable), Stms (Wise lore))
blockIf :: SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf :: BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf BlockPred (Wise lore)
block SimpleM lore (SimplifiedBody lore a)
m = do
((a
x, UsageTable
usages), Stms (Wise lore)
stms) <- SimpleM lore (SimplifiedBody lore a)
m
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
(Stms (Wise lore)
blocked, Stms (Wise lore)
hoisted) <- RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages Stms (Wise lore)
stms
((Stms (Wise lore), a), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Stms (Wise lore)
blocked, a
x), Stms (Wise lore)
hoisted)
hasFree :: Attributes lore => Names -> BlockPred lore
hasFree :: Names -> BlockPred lore
hasFree Names
ks SymbolTable lore
_ UsageTable
_ Stm lore
need = Names
ks Names -> Names -> Bool
`namesIntersect` Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm lore
need
isNotSafe :: Attributes lore => BlockPred lore
isNotSafe :: BlockPred lore
isNotSafe SymbolTable lore
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm lore -> Bool) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp
isInPlaceBound :: BlockPred m
isInPlaceBound :: BlockPred m
isInPlaceBound SymbolTable m
_ UsageTable
_ = ExpT m -> Bool
forall lore. ExpT lore -> Bool
isUpdate (ExpT m -> Bool) -> (Stm m -> ExpT m) -> Stm m -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm m -> ExpT m
forall lore. Stm lore -> Exp lore
stmExp
where isUpdate :: ExpT lore -> Bool
isUpdate (BasicOp Update{}) = Bool
True
isUpdate ExpT lore
_ = Bool
False
isNotCheap :: Attributes lore => BlockPred lore
isNotCheap :: BlockPred lore
isNotCheap SymbolTable lore
_ UsageTable
_ = Bool -> Bool
not (Bool -> Bool) -> (Stm lore -> Bool) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Bool
forall lore. Attributes lore => Stm lore -> Bool
cheapStm
cheapStm :: Attributes lore => Stm lore -> Bool
cheapStm :: Stm lore -> Bool
cheapStm = Exp lore -> Bool
forall lore. Attributes lore => Exp lore -> Bool
cheapExp (Exp lore -> Bool) -> (Stm lore -> Exp lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp
cheapExp :: Attributes lore => Exp lore -> Bool
cheapExp :: Exp lore -> Bool
cheapExp (BasicOp BinOp{}) = Bool
True
cheapExp (BasicOp SubExp{}) = Bool
True
cheapExp (BasicOp UnOp{}) = Bool
True
cheapExp (BasicOp CmpOp{}) = Bool
True
cheapExp (BasicOp ConvOp{}) = Bool
True
cheapExp (BasicOp Copy{}) = Bool
False
cheapExp DoLoop{} = Bool
False
cheapExp (If SubExp
_ BodyT lore
tbranch BodyT lore
fbranch IfAttr (BranchType lore)
_) = (Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm lore -> Bool
forall lore. Attributes lore => Stm lore -> Bool
cheapStm (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch) Bool -> Bool -> Bool
&&
(Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Stm lore -> Bool
forall lore. Attributes lore => Stm lore -> Bool
cheapStm (BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fbranch)
cheapExp (Op Op lore
op) = Op lore -> Bool
forall op. IsOp op => op -> Bool
cheapOp Op lore
op
cheapExp Exp lore
_ = Bool
True
stmIs :: (Stm lore -> Bool) -> BlockPred lore
stmIs :: (Stm lore -> Bool) -> BlockPred lore
stmIs Stm lore -> Bool
f SymbolTable lore
_ UsageTable
_ = Stm lore -> Bool
f
loopInvariantStm :: Attributes lore => ST.SymbolTable lore -> Stm lore -> Bool
loopInvariantStm :: SymbolTable lore -> Stm lore -> Bool
loopInvariantStm SymbolTable lore
vtable =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable lore -> Names
forall lore. SymbolTable lore -> Names
ST.availableAtClosestLoop SymbolTable lore
vtable) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (Stm lore -> Names) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn
hoistCommon :: SimplifiableLore lore =>
SubExp -> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM lore (Body (Wise lore),
Body (Wise lore),
Stms (Wise lore))
hoistCommon :: SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
hoistCommon SubExp
cond IfSort
ifsort ((Result
res1, UsageTable
usages1), Stms (Wise lore)
stms1) ((Result
res2, UsageTable
usages2), Stms (Wise lore)
stms2) = do
Stm (Wise lore) -> Bool
is_alloc_fun <- (Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool)
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool))
-> (Env lore -> Stm (Wise lore) -> Bool)
-> SimpleM lore (Stm (Wise lore) -> Bool)
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> Stm (Wise lore) -> Bool
forall lore. HoistBlockers lore -> Stm (Wise lore) -> Bool
isAllocation (HoistBlockers lore -> Stm (Wise lore) -> Bool)
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> Stm (Wise lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
Stm (Wise lore) -> Names
getArrSz_fun <- (Env lore -> Stm (Wise lore) -> Names)
-> SimpleM lore (Stm (Wise lore) -> Names)
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> Stm (Wise lore) -> Names)
-> SimpleM lore (Stm (Wise lore) -> Names))
-> (Env lore -> Stm (Wise lore) -> Names)
-> SimpleM lore (Stm (Wise lore) -> Names)
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> Stm (Wise lore) -> Names
forall lore. HoistBlockers lore -> Stm (Wise lore) -> Names
getArraySizes (HoistBlockers lore -> Stm (Wise lore) -> Names)
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> Stm (Wise lore)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
BlockPred (Wise lore)
branch_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistBranch (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
let
cond_loop_invariant :: Bool
cond_loop_invariant =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> Names -> Bool
`nameIn` SymbolTable (Wise lore) -> Names
forall lore. SymbolTable lore -> Names
ST.availableAtClosestLoop SymbolTable (Wise lore)
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
cond
desirableToHoist :: Stm (Wise lore) -> Bool
desirableToHoist Stm (Wise lore)
stm =
Stm (Wise lore) -> Bool
is_alloc_fun Stm (Wise lore)
stm Bool -> Bool -> Bool
||
(SymbolTable (Wise lore) -> Int
forall lore. SymbolTable lore -> Int
ST.loopDepth SymbolTable (Wise lore)
vtable Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&&
Bool
cond_loop_invariant Bool -> Bool -> Bool
&&
IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
&&
SymbolTable (Wise lore) -> Stm (Wise lore) -> Bool
forall lore.
Attributes lore =>
SymbolTable lore -> Stm lore -> Bool
loopInvariantStm SymbolTable (Wise lore)
vtable Stm (Wise lore)
stm)
hoistbl_nms :: Names
hoistbl_nms = (Stm (Wise lore) -> Bool)
-> (Stm (Wise lore) -> Names) -> [Stm (Wise lore)] -> Names
forall lore.
(FreeAttr (ExpAttr lore), FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore), FreeIn (LParamAttr lore),
FreeIn (LetAttr lore), FreeIn (Op lore)) =>
(Stm lore -> Bool) -> (Stm lore -> Names) -> [Stm lore] -> Names
filterBnds Stm (Wise lore) -> Bool
desirableToHoist Stm (Wise lore) -> Names
getArrSz_fun ([Stm (Wise lore)] -> Names) -> [Stm (Wise lore)] -> Names
forall a b. (a -> b) -> a -> b
$
Stms (Wise lore) -> [Stm (Wise lore)]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms (Wise lore) -> [Stm (Wise lore)])
-> Stms (Wise lore) -> [Stm (Wise lore)]
forall a b. (a -> b) -> a -> b
$ Stms (Wise lore)
stms1Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<>Stms (Wise lore)
stms2
isNotHoistableBnd :: Names -> p -> p -> Stm lore -> Bool
isNotHoistableBnd Names
_ p
_ p
_ (Let Pattern lore
_ StmAux (ExpAttr lore)
_ (BasicOp ArrayLit{})) = Bool
False
isNotHoistableBnd Names
_ p
_ p
_ (Let Pattern lore
_ StmAux (ExpAttr lore)
_ (BasicOp SubExp{})) = Bool
False
isNotHoistableBnd Names
nms p
_ p
_ Stm lore
stm = Bool -> Bool
not (Names -> Stm lore -> Bool
forall lore. Names -> Stm lore -> Bool
hasPatName Names
nms Stm lore
stm)
block :: BlockPred (Wise lore)
block = BlockPred (Wise lore)
branch_blocker BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf`
((BlockPred (Wise lore)
forall lore. Attributes lore => BlockPred lore
isNotSafe BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. Attributes lore => BlockPred lore
isNotCheap) BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`andAlso` (Stm (Wise lore) -> Bool) -> BlockPred (Wise lore)
forall lore. (Stm lore -> Bool) -> BlockPred lore
stmIs (Bool -> Bool
not (Bool -> Bool)
-> (Stm (Wise lore) -> Bool) -> Stm (Wise lore) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Bool
desirableToHoist))
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isInPlaceBound BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` Names -> BlockPred (Wise lore)
forall p p lore. Names -> p -> p -> Stm lore -> Bool
isNotHoistableBnd Names
hoistbl_nms
RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
(Stms (Wise lore)
body1_bnds', Stms (Wise lore)
safe1) <- SubExp
-> Bool
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
True (SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages1 Stms (Wise lore)
stms1
(Stms (Wise lore)
body2_bnds', Stms (Wise lore)
safe2) <- SubExp
-> Bool
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
SubExp
-> Bool
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectIfHoisted SubExp
cond Bool
False (SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
RuleBook (Wise lore)
-> BlockPred (Wise lore)
-> SymbolTable (Wise lore)
-> UsageTable
-> Stms (Wise lore)
-> SimpleM lore (Stms (Wise lore), Stms (Wise lore))
hoistStms RuleBook (Wise lore)
rules BlockPred (Wise lore)
block SymbolTable (Wise lore)
vtable UsageTable
usages2 Stms (Wise lore)
stms2
let hoistable :: Stms (Wise lore)
hoistable = Stms (Wise lore)
safe1 Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
safe2
Body (Wise lore)
body1' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
body1_bnds' Result
res1
Body (Wise lore)
body2' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
body2_bnds' Result
res2
(Body (Wise lore), Body (Wise lore), Stms (Wise lore))
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Wise lore)
body1', Body (Wise lore)
body2', Stms (Wise lore)
hoistable)
where filterBnds :: (Stm lore -> Bool) -> (Stm lore -> Names) -> [Stm lore] -> Names
filterBnds Stm lore -> Bool
interesting Stm lore -> Names
getArrSz_fn [Stm lore]
all_bnds =
let sz_nms :: Names
sz_nms = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Names) -> [Stm lore] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> Names
getArrSz_fn [Stm lore]
all_bnds
sz_needs :: [Stm lore]
sz_needs = [Stm lore] -> Names -> [Stm lore] -> [Stm lore]
forall lore.
(FreeAttr (ExpAttr lore), FreeAttr (BodyAttr lore),
FreeIn (FParamAttr lore), FreeIn (LParamAttr lore),
FreeIn (LetAttr lore), FreeIn (Op lore)) =>
[Stm lore] -> Names -> [Stm lore] -> [Stm lore]
transClosSizes [Stm lore]
all_bnds Names
sz_nms []
alloc_bnds :: [Stm lore]
alloc_bnds = (Stm lore -> Bool) -> [Stm lore] -> [Stm lore]
forall a. (a -> Bool) -> [a] -> [a]
filter Stm lore -> Bool
interesting [Stm lore]
all_bnds
sel_nms :: Names
sel_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
(Stm lore -> [VName]) -> [Stm lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> (Stm lore -> PatternT (LetAttr lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern)
([Stm lore]
sz_needs [Stm lore] -> [Stm lore] -> [Stm lore]
forall a. [a] -> [a] -> [a]
++ [Stm lore]
alloc_bnds)
in Names
sel_nms
transClosSizes :: [Stm lore] -> Names -> [Stm lore] -> [Stm lore]
transClosSizes [Stm lore]
all_bnds Names
scal_nms [Stm lore]
hoist_bnds =
let new_bnds :: [Stm lore]
new_bnds = (Stm lore -> Bool) -> [Stm lore] -> [Stm lore]
forall a. (a -> Bool) -> [a] -> [a]
filter (Names -> Stm lore -> Bool
forall lore. Names -> Stm lore -> Bool
hasPatName Names
scal_nms) [Stm lore]
all_bnds
new_nms :: Names
new_nms = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Names) -> [Stm lore] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Exp lore -> Names
forall a. FreeIn a => a -> Names
freeIn (Exp lore -> Names) -> (Stm lore -> Exp lore) -> Stm lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp) [Stm lore]
new_bnds
in if [Stm lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Stm lore]
new_bnds
then [Stm lore]
hoist_bnds
else [Stm lore] -> Names -> [Stm lore] -> [Stm lore]
transClosSizes [Stm lore]
all_bnds Names
new_nms ([Stm lore]
new_bnds [Stm lore] -> [Stm lore] -> [Stm lore]
forall a. [a] -> [a] -> [a]
++ [Stm lore]
hoist_bnds)
hasPatName :: Names -> Stm lore -> Bool
hasPatName Names
nms Stm lore
bnd = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
nms) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> PatternT (LetAttr lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd
simplifyBody :: SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody :: [Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds (Body BodyAttr lore
_ Stms lore
bnds Result
res) =
Stms lore
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
bnds (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall a b. (a -> b) -> a -> b
$ do (Result, UsageTable)
res' <- [Diet] -> Result -> SimpleM lore (Result, UsageTable)
forall lore.
SimplifiableLore lore =>
[Diet] -> Result -> SimpleM lore (Result, UsageTable)
simplifyResult [Diet]
ds Result
res
SimplifiedBody lore Result
-> SimpleM lore (SimplifiedBody lore Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result, UsageTable)
res', Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyResult :: SimplifiableLore lore =>
[Diet] -> Result -> SimpleM lore (Result, UT.UsageTable)
simplifyResult :: [Diet] -> Result -> SimpleM lore (Result, UsageTable)
simplifyResult [Diet]
ds Result
res = do
let (Result
ctx_res, Result
val_res) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([Diet] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Diet]
ds) Result
res
(Result
ctx_res', Certificates
_ctx_res_cs) <- SimpleM lore Result -> SimpleM lore (Result, Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore Result -> SimpleM lore (Result, Certificates))
-> SimpleM lore Result -> SimpleM lore (Result, Certificates)
forall a b. (a -> b) -> a -> b
$ (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ctx_res
Result
val_res' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall lore. SubExp -> SimpleM lore SubExp
simplify' Result
val_res
let consumption :: UsageTable
consumption = [(Diet, SubExp)] -> UsageTable
consumeResult ([(Diet, SubExp)] -> UsageTable) -> [(Diet, SubExp)] -> UsageTable
forall a b. (a -> b) -> a -> b
$ [Diet] -> Result -> [(Diet, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Diet]
ds Result
val_res'
res' :: Result
res' = Result
ctx_res' Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
val_res'
(Result, UsageTable) -> SimpleM lore (Result, UsageTable)
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
res', Names -> UsageTable
UT.usages (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res') UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> UsageTable
consumption)
where simplify' :: SubExp -> SimpleM lore SubExp
simplify' (Var VName
name) = do
Maybe (SubExp, Certificates)
bnd <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
name (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
bnd of
Just (Constant PrimValue
v, Certificates
cs)
| Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
Just (Var VName
id', Certificates
cs)
| Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
Maybe (SubExp, Certificates)
_ -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
simplify' (Constant PrimValue
v) =
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
isDoLoopResult :: Result -> UT.UsageTable
isDoLoopResult :: Result -> UsageTable
isDoLoopResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> (Result -> [UsageTable]) -> Result -> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> UsageTable) -> Result -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> UsageTable
checkForVar
where checkForVar :: SubExp -> UsageTable
checkForVar (Var VName
ident) = VName -> UsageTable
UT.inResultUsage VName
ident
checkForVar SubExp
_ = UsageTable
forall a. Monoid a => a
mempty
simplifyStms :: SimplifiableLore lore =>
Stms lore -> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms :: Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
stms SimpleM lore (a, Stms (Wise lore))
m =
case Stms lore -> Maybe (Stm lore, Stms lore)
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms lore
stms of
Maybe (Stm lore, Stms lore)
Nothing -> Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
forall a. Monoid a => a
mempty SimpleM lore (a, Stms (Wise lore))
m
Just (Let Pattern lore
pat (StmAux Certificates
stm_cs ExpAttr lore
attr) Exp lore
e, Stms lore
stms') -> do
Certificates
stm_cs' <- Certificates -> SimpleM lore Certificates
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Certificates
stm_cs
((Exp (Wise lore)
e', Stms (Wise lore)
e_stms), Certificates
e_cs) <- SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore ((Exp (Wise lore), Stms (Wise lore)), Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM
lore ((Exp (Wise lore), Stms (Wise lore)), Certificates))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore ((Exp (Wise lore), Stms (Wise lore)), Certificates)
forall a b. (a -> b) -> a -> b
$ Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp Exp lore
e
(Pattern lore
pat', Certificates
pat_cs) <- SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates)
forall lore a. SimpleM lore a -> SimpleM lore (a, Certificates)
collectCerts (SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates))
-> SimpleM lore (Pattern lore)
-> SimpleM lore (Pattern lore, Certificates)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> SimpleM lore (Pattern lore)
forall lore attr.
(SimplifiableLore lore, Simplifiable attr) =>
PatternT attr -> SimpleM lore (PatternT attr)
simplifyPattern Pattern lore
pat
let cs :: Certificates
cs = Certificates
stm_cs'Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
e_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
pat_cs
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
e_stms (SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStm (Pattern lore
-> StmAux (ExpAttr lore) -> Exp (Wise lore) -> Stm (Wise lore)
forall lore.
(Attributes lore, CanBeWise (Op lore)) =>
Pattern lore
-> StmAux (ExpAttr lore) -> Exp (Wise lore) -> Stm (Wise lore)
mkWiseLetStm Pattern lore
pat' (Certificates -> ExpAttr lore -> StmAux (ExpAttr lore)
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ExpAttr lore
attr) Exp (Wise lore)
e') (SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms lore
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
simplifyStms Stms lore
stms' SimpleM lore (a, Stms (Wise lore))
m
inspectStm :: SimplifiableLore lore =>
Stm (Wise lore) -> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStm :: Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStm = Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms (Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> (Stm (Wise lore) -> Stms (Wise lore))
-> Stm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Wise lore) -> Stms (Wise lore)
forall lore. Stm lore -> Stms lore
oneStm
inspectStms :: SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms :: Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
stms SimpleM lore (a, Stms (Wise lore))
m =
case Stms (Wise lore) -> Maybe (Stm (Wise lore), Stms (Wise lore))
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms (Wise lore)
stms of
Maybe (Stm (Wise lore), Stms (Wise lore))
Nothing -> SimpleM lore (a, Stms (Wise lore))
m
Just (Stm (Wise lore)
stm, Stms (Wise lore)
stms') -> do
SymbolTable (Wise lore)
vtable <- SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
RuleBook (Wise lore)
rules <- (Env lore -> RuleBook (Wise lore))
-> SimpleM lore (RuleBook (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv Env lore -> RuleBook (Wise lore)
forall lore. Env lore -> RuleBook (Wise lore)
envRules
Maybe (Stms (Wise lore))
simplified <- RuleBook (Wise lore)
-> SymbolTable (Wise lore)
-> Stm (Wise lore)
-> SimpleM lore (Maybe (Stms (Wise lore)))
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m) =>
RuleBook lore
-> SymbolTable lore -> Stm lore -> m (Maybe (Stms lore))
topDownSimplifyStm RuleBook (Wise lore)
rules SymbolTable (Wise lore)
vtable Stm (Wise lore)
stm
case Maybe (Stms (Wise lore))
simplified of
Just Stms (Wise lore)
newbnds -> SimpleM lore ()
forall lore. SimpleM lore ()
changed SimpleM lore ()
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms (Stms (Wise lore)
newbnds Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
stms') SimpleM lore (a, Stms (Wise lore))
m
Maybe (Stms (Wise lore))
Nothing -> do (a
x, Stms (Wise lore)
stms'') <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (Stm (Wise lore)
-> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
(IndexOp (Op lore), Ranged lore, Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm Stm (Wise lore)
stm) (SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore)))
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
Stms (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
inspectStms Stms (Wise lore)
stms' SimpleM lore (a, Stms (Wise lore))
m
(a, Stms (Wise lore)) -> SimpleM lore (a, Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Stm (Wise lore) -> Stms (Wise lore)
forall lore. Stm lore -> Stms lore
oneStm Stm (Wise lore)
stm Stms (Wise lore) -> Stms (Wise lore) -> Stms (Wise lore)
forall a. Semigroup a => a -> a -> a
<> Stms (Wise lore)
stms'')
simplifyOp :: Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp :: Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp Op lore
op = do Op lore -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
f <- ((SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
lore
(Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (((SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
lore
(Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))))
-> ((SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> SimpleM
lore
(Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
forall a b. (a -> b) -> a -> b
$ SimpleOps lore
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
forall lore. SimpleOps lore -> SimplifyOp lore (Op lore)
simplifyOpS (SimpleOps lore
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore)))
-> ((SimpleOps lore, Env lore) -> SimpleOps lore)
-> (SimpleOps lore, Env lore)
-> Op lore
-> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SimpleOps lore, Env lore) -> SimpleOps lore
forall a b. (a, b) -> a
fst
Op lore -> SimpleM lore (OpWithWisdom (Op lore), Stms (Wise lore))
f Op lore
op
simplifyExp :: SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp :: Exp lore -> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
simplifyExp (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch (IfAttr [BranchType lore]
ts IfSort
ifsort)) = do
SubExp
cond' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
cond
[BranchType lore]
ts' <- (BranchType lore -> SimpleM lore (BranchType lore))
-> [BranchType lore] -> SimpleM lore [BranchType lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BranchType lore -> SimpleM lore (BranchType lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [BranchType lore]
ts
let ds :: [Diet]
ds = (BranchType lore -> Diet) -> [BranchType lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> BranchType lore -> Diet
forall a b. a -> b -> a
const Diet
Consume) [BranchType lore]
ts
SimplifiedBody lore Result
tbranch' <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (Bool
-> SubExp -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
Attributes lore =>
Bool -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.updateBounds Bool
True SubExp
cond) (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall a b. (a -> b) -> a -> b
$ [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
tbranch
SimplifiedBody lore Result
fbranch' <- (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable (Bool
-> SubExp -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
Attributes lore =>
Bool -> SubExp -> SymbolTable lore -> SymbolTable lore
ST.updateBounds Bool
False SubExp
cond) (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (SimplifiedBody lore Result)
forall a b. (a -> b) -> a -> b
$ [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
fbranch
(Body (Wise lore)
tbranch'',Body (Wise lore)
fbranch'', Stms (Wise lore)
hoisted) <- SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
SubExp
-> IfSort
-> SimplifiedBody lore Result
-> SimplifiedBody lore Result
-> SimpleM
lore (Body (Wise lore), Body (Wise lore), Stms (Wise lore))
hoistCommon SubExp
cond' IfSort
ifsort SimplifiedBody lore Result
tbranch' SimplifiedBody lore Result
fbranch'
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
-> Body (Wise lore)
-> Body (Wise lore)
-> IfAttr (BranchType (Wise lore))
-> Exp (Wise lore)
forall lore.
SubExp
-> BodyT lore
-> BodyT lore
-> IfAttr (BranchType lore)
-> ExpT lore
If SubExp
cond' Body (Wise lore)
tbranch'' Body (Wise lore)
fbranch'' (IfAttr (BranchType (Wise lore)) -> Exp (Wise lore))
-> IfAttr (BranchType (Wise lore)) -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfAttr (BranchType lore)
forall rt. [rt] -> IfSort -> IfAttr rt
IfAttr [BranchType lore]
ts' IfSort
ifsort, Stms (Wise lore)
hoisted)
simplifyExp (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
loopbody) = do
let ([FParam lore]
ctxparams, Result
ctxinit) = [(FParam lore, SubExp)] -> ([FParam lore], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
ctx
([FParam lore]
valparams, Result
valinit) = [(FParam lore, SubExp)] -> ([FParam lore], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
val
[FParam lore]
ctxparams' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamAttr lore -> SimpleM lore (FParamAttr lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall attr lore.
(attr -> SimpleM lore attr)
-> Param attr -> SimpleM lore (Param attr)
simplifyParam FParamAttr lore -> SimpleM lore (FParamAttr lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
ctxparams
Result
ctxinit' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ctxinit
[FParam lore]
valparams' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamAttr lore -> SimpleM lore (FParamAttr lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall attr lore.
(attr -> SimpleM lore attr)
-> Param attr -> SimpleM lore (Param attr)
simplifyParam FParamAttr lore -> SimpleM lore (FParamAttr lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
valparams
Result
valinit' <- (SubExp -> SimpleM lore SubExp) -> Result -> SimpleM lore Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
valinit
let ctx' :: [(FParam lore, SubExp)]
ctx' = [FParam lore] -> Result -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam lore]
ctxparams' Result
ctxinit'
val' :: [(FParam lore, SubExp)]
val' = [FParam lore] -> Result -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FParam lore]
valparams' Result
valinit'
diets :: [Diet]
diets = (FParam lore -> Diet) -> [FParam lore] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase Shape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet (TypeBase Shape Uniqueness -> Diet)
-> (FParam lore -> TypeBase Shape Uniqueness)
-> FParam lore
-> Diet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> TypeBase Shape Uniqueness
forall attr.
DeclTyped attr =>
Param attr -> TypeBase Shape Uniqueness
paramDeclType) [FParam lore]
valparams'
(LoopForm (Wise lore)
form', Names
boundnames, SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
wrapbody) <- case LoopForm lore
form of
ForLoop VName
loopvar IntType
it SubExp
boundexp [(LParam lore, VName)]
loopvars -> do
SubExp
boundexp' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
boundexp
let ([LParam lore]
loop_params, [VName]
loop_arrs) = [(LParam lore, VName)] -> ([LParam lore], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(LParam lore, VName)]
loopvars
[LParam lore]
loop_params' <- (LParam lore -> SimpleM lore (LParam lore))
-> [LParam lore] -> SimpleM lore [LParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamAttr lore -> SimpleM lore (LParamAttr lore))
-> LParam lore -> SimpleM lore (LParam lore)
forall attr lore.
(attr -> SimpleM lore attr)
-> Param attr -> SimpleM lore (Param attr)
simplifyParam LParamAttr lore -> SimpleM lore (LParamAttr lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [LParam lore]
loop_params
[VName]
loop_arrs' <- (VName -> SimpleM lore VName) -> [VName] -> SimpleM lore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [VName]
loop_arrs
let form' :: LoopForm (Wise lore)
form' = VName
-> IntType
-> SubExp
-> [(LParam (Wise lore), VName)]
-> LoopForm (Wise lore)
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
loopvar IntType
it SubExp
boundexp' ([LParam lore] -> [VName] -> [(LParam lore, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam lore]
loop_params' [VName]
loop_arrs')
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM
lore
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopForm (Wise lore)
form',
[VName] -> Names
namesFromList (VName
loopvar VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: (LParam lore -> VName) -> [LParam lore] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map LParam lore -> VName
forall attr. Param attr -> VName
paramName [LParam lore]
loop_params') Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
fparamnames,
VName
-> IntType
-> SubExp
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
VName -> IntType -> SubExp -> SimpleM lore a -> SimpleM lore a
bindLoopVar VName
loopvar IntType
it SubExp
boundexp' (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' LoopForm (Wise lore)
form' (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
[(LParam (Wise lore), Maybe VName)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(LParam (Wise lore), Maybe VName)]
-> SimpleM lore a -> SimpleM lore a
bindArrayLParams ([LParam lore] -> [Maybe VName] -> [(LParam lore, Maybe VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam lore]
loop_params' ((VName -> Maybe VName) -> [VName] -> [Maybe VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Maybe VName
forall a. a -> Maybe a
Just [VName]
loop_arrs')))
WhileLoop VName
cond -> do
VName
cond' <- VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify VName
cond
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM
lore
(LoopForm (Wise lore), Names,
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> LoopForm (Wise lore)
forall lore. VName -> LoopForm lore
WhileLoop VName
cond',
Names
fparamnames,
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> SimpleM lore (a, Stms (Wise lore))
-> SimpleM lore (a, Stms (Wise lore))
protectLoopHoisted [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' (VName -> LoopForm (Wise lore)
forall lore. VName -> LoopForm lore
WhileLoop VName
cond'))
BlockPred (Wise lore)
seq_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistSeq (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
((Stms (Wise lore)
loopstms, Result
loopres), Stms (Wise lore)
hoisted) <-
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
enterLoop (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
consumeMerge (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
[FParam (Wise lore)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams ([FParam lore]
ctxparams'[FParam lore] -> [FParam lore] -> [FParam lore]
forall a. [a] -> [a] -> [a]
++[FParam lore]
valparams') (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
wrapbody (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf
(Names -> BlockPred (Wise lore)
forall lore. Attributes lore => Names -> BlockPred lore
hasFree Names
boundnames BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isConsumed
BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
seq_blocker BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. Attributes lore => BlockPred lore
notWorthHoisting) (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ do
((Result
res, UsageTable
uses), Stms (Wise lore)
stms) <- [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
diets BodyT lore
loopbody
SimplifiedBody lore Result
-> SimpleM lore (SimplifiedBody lore Result)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Result
res, UsageTable
uses UsageTable -> UsageTable -> UsageTable
forall a. Semigroup a => a -> a -> a
<> Result -> UsageTable
isDoLoopResult Result
res), Stms (Wise lore)
stms)
Body (Wise lore)
loopbody' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
loopstms Result
loopres
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([(FParam (Wise lore), SubExp)]
-> [(FParam (Wise lore), SubExp)]
-> LoopForm (Wise lore)
-> Body (Wise lore)
-> Exp (Wise lore)
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
ctx' [(FParam lore, SubExp)]
[(FParam (Wise lore), SubExp)]
val' LoopForm (Wise lore)
form' Body (Wise lore)
loopbody', Stms (Wise lore)
hoisted)
where fparamnames :: Names
fparamnames =
[VName] -> Names
namesFromList (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall attr. Param attr -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val)
consumeMerge :: SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
consumeMerge =
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
(SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore a -> SimpleM lore a
localVtable ((SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> (SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$ (SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> [VName] -> SymbolTable (Wise lore)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> SymbolTable (Wise lore) -> VName -> SymbolTable (Wise lore)
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall lore.
Attributes lore =>
VName -> SymbolTable lore -> SymbolTable lore
ST.consume)) ([VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore))
-> [VName] -> SymbolTable (Wise lore) -> SymbolTable (Wise lore)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
consumed_by_merge
consumed_by_merge :: Names
consumed_by_merge =
Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(FParam lore, SubExp)] -> Result)
-> [(FParam lore, SubExp)] -> Result
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> ((FParam lore, SubExp) -> TypeBase Shape Uniqueness)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> TypeBase Shape Uniqueness
forall attr.
DeclTyped attr =>
Param attr -> TypeBase Shape Uniqueness
paramDeclType (FParam lore -> TypeBase Shape Uniqueness)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> TypeBase Shape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val
simplifyExp (Op Op lore
op) = do (OpWithWisdom (Op lore)
op', Stms (Wise lore)
stms) <- Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
forall lore.
Op lore -> SimpleM lore (Op (Wise lore), Stms (Wise lore))
simplifyOp Op lore
op
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Op (Wise lore) -> Exp (Wise lore)
forall lore. Op lore -> ExpT lore
Op Op (Wise lore)
OpWithWisdom (Op lore)
op', Stms (Wise lore)
stms)
simplifyExp (BasicOp (BinOp BinOp
op SubExp
x SubExp
y))
| BinOp -> Bool
commutativeBinOp BinOp
op = do
SubExp
x' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
x
SubExp
y' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
y
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp (Wise lore) -> Exp (Wise lore)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Wise lore) -> Exp (Wise lore))
-> BasicOp (Wise lore) -> Exp (Wise lore)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp (Wise lore)
forall lore. BinOp -> SubExp -> SubExp -> BasicOp lore
BinOp BinOp
op (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
min SubExp
x' SubExp
y') (SubExp -> SubExp -> SubExp
forall a. Ord a => a -> a -> a
max SubExp
x' SubExp
y'), Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyExp Exp lore
e = do Exp (Wise lore)
e' <- Exp lore -> SimpleM lore (Exp (Wise lore))
forall lore.
SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase Exp lore
e
(Exp (Wise lore), Stms (Wise lore))
-> SimpleM lore (Exp (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Wise lore)
e', Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyExpBase :: SimplifiableLore lore =>
Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase :: Exp lore -> SimpleM lore (Exp (Wise lore))
simplifyExpBase = Mapper lore (Wise lore) (SimpleM lore)
-> Exp lore -> SimpleM lore (Exp (Wise lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper lore (Wise lore) (SimpleM lore)
hoist
where hoist :: Mapper lore (Wise lore) (SimpleM lore)
hoist = Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper {
mapOnBody :: Scope (Wise lore) -> Body lore -> SimpleM lore (Body (Wise lore))
mapOnBody =
[Char]
-> Scope (Wise lore)
-> Body lore
-> SimpleM lore (Body (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled body in simplification engine."
, mapOnSubExp :: SubExp -> SimpleM lore SubExp
mapOnSubExp = SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
, mapOnVName :: VName -> SimpleM lore VName
mapOnVName = VName -> SimpleM lore VName
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
, mapOnRetType :: RetType lore -> SimpleM lore (RetType (Wise lore))
mapOnRetType = RetType lore -> SimpleM lore (RetType (Wise lore))
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
, mapOnBranchType :: BranchType lore -> SimpleM lore (BranchType (Wise lore))
mapOnBranchType = BranchType lore -> SimpleM lore (BranchType (Wise lore))
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
, mapOnFParam :: FParam lore -> SimpleM lore (FParam (Wise lore))
mapOnFParam =
[Char] -> FParam lore -> SimpleM lore (FParam (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled FParam in simplification engine."
, mapOnLParam :: LParam lore -> SimpleM lore (LParam (Wise lore))
mapOnLParam =
[Char] -> LParam lore -> SimpleM lore (LParam (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled LParam in simplification engine."
, mapOnOp :: Op lore -> SimpleM lore (Op (Wise lore))
mapOnOp =
[Char] -> Op lore -> SimpleM lore (Op (Wise lore))
forall a. HasCallStack => [Char] -> a
error [Char]
"Unhandled Op in simplification engine."
}
type SimplifiableLore lore = (Attributes lore,
Simplifiable (LetAttr lore),
Simplifiable (FParamAttr lore),
Simplifiable (LParamAttr lore),
Simplifiable (RetType lore),
Simplifiable (BranchType lore),
CanBeWise (Op lore),
ST.IndexOp (OpWithWisdom (Op lore)),
BinderOps (Wise lore),
IsOp (Op lore))
class Simplifiable e where
simplify :: SimplifiableLore lore => e -> SimpleM lore e
instance (Simplifiable a, Simplifiable b) => Simplifiable (a, b) where
simplify :: (a, b) -> SimpleM lore (a, b)
simplify (a
x,b
y) = (,) (a -> b -> (a, b)) -> SimpleM lore a -> SimpleM lore (b -> (a, b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x SimpleM lore (b -> (a, b)) -> SimpleM lore b -> SimpleM lore (a, b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM lore b
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify b
y
instance (Simplifiable a, Simplifiable b, Simplifiable c) =>
Simplifiable (a, b, c) where
simplify :: (a, b, c) -> SimpleM lore (a, b, c)
simplify (a
x,b
y,c
z) = (,,) (a -> b -> c -> (a, b, c))
-> SimpleM lore a -> SimpleM lore (b -> c -> (a, b, c))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x SimpleM lore (b -> c -> (a, b, c))
-> SimpleM lore b -> SimpleM lore (c -> (a, b, c))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> b -> SimpleM lore b
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify b
y SimpleM lore (c -> (a, b, c))
-> SimpleM lore c -> SimpleM lore (a, b, c)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> c -> SimpleM lore c
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify c
z
instance Simplifiable Int where
simplify :: Int -> SimpleM lore Int
simplify = Int -> SimpleM lore Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance Simplifiable a => Simplifiable (Maybe a) where
simplify :: Maybe a -> SimpleM lore (Maybe a)
simplify Maybe a
Nothing = Maybe a -> SimpleM lore (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
simplify (Just a
x) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> SimpleM lore a -> SimpleM lore (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify a
x
instance Simplifiable a => Simplifiable [a] where
simplify :: [a] -> SimpleM lore [a]
simplify = (a -> SimpleM lore a) -> [a] -> SimpleM lore [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> SimpleM lore a
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify
instance Simplifiable SubExp where
simplify :: SubExp -> SimpleM lore SubExp
simplify (Var VName
name) = do
Maybe (SubExp, Certificates)
bnd <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
name (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
bnd of
Just (Constant PrimValue
v, Certificates
cs) -> do SimpleM lore ()
forall lore. SimpleM lore ()
changed
Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
Just (Var VName
id', Certificates
cs) -> do SimpleM lore ()
forall lore. SimpleM lore ()
changed
Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
id'
Maybe (SubExp, Certificates)
_ -> SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
simplify (Constant PrimValue
v) =
SubExp -> SimpleM lore SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
simplifyPattern :: (SimplifiableLore lore, Simplifiable attr) =>
PatternT attr
-> SimpleM lore (PatternT attr)
simplifyPattern :: PatternT attr -> SimpleM lore (PatternT attr)
simplifyPattern PatternT attr
pat =
[PatElemT attr] -> [PatElemT attr] -> PatternT attr
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern ([PatElemT attr] -> [PatElemT attr] -> PatternT attr)
-> SimpleM lore [PatElemT attr]
-> SimpleM lore ([PatElemT attr] -> PatternT attr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
(PatElemT attr -> SimpleM lore (PatElemT attr))
-> [PatElemT attr] -> SimpleM lore [PatElemT attr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT attr -> SimpleM lore (PatElemT attr)
forall lore attr.
(Attributes lore, Simplifiable attr, Simplifiable (LetAttr lore),
Simplifiable (FParamAttr lore), Simplifiable (LParamAttr lore),
Simplifiable (RetType lore), Simplifiable (BranchType lore),
CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
BinderOps (Wise lore)) =>
PatElemT attr -> SimpleM lore (PatElemT attr)
inspect (PatternT attr -> [PatElemT attr]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT attr
pat) SimpleM lore ([PatElemT attr] -> PatternT attr)
-> SimpleM lore [PatElemT attr] -> SimpleM lore (PatternT attr)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
(PatElemT attr -> SimpleM lore (PatElemT attr))
-> [PatElemT attr] -> SimpleM lore [PatElemT attr]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT attr -> SimpleM lore (PatElemT attr)
forall lore attr.
(Attributes lore, Simplifiable attr, Simplifiable (LetAttr lore),
Simplifiable (FParamAttr lore), Simplifiable (LParamAttr lore),
Simplifiable (RetType lore), Simplifiable (BranchType lore),
CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
BinderOps (Wise lore)) =>
PatElemT attr -> SimpleM lore (PatElemT attr)
inspect (PatternT attr -> [PatElemT attr]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT attr
pat)
where inspect :: PatElemT attr -> SimpleM lore (PatElemT attr)
inspect (PatElem VName
name attr
lore) = VName -> attr -> PatElemT attr
forall attr. VName -> attr -> PatElemT attr
PatElem VName
name (attr -> PatElemT attr)
-> SimpleM lore attr -> SimpleM lore (PatElemT attr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> attr -> SimpleM lore attr
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify attr
lore
simplifyParam :: (attr -> SimpleM lore attr) -> Param attr -> SimpleM lore (Param attr)
simplifyParam :: (attr -> SimpleM lore attr)
-> Param attr -> SimpleM lore (Param attr)
simplifyParam attr -> SimpleM lore attr
simplifyAttribute (Param VName
name attr
attr) =
VName -> attr -> Param attr
forall attr. VName -> attr -> Param attr
Param VName
name (attr -> Param attr)
-> SimpleM lore attr -> SimpleM lore (Param attr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> attr -> SimpleM lore attr
simplifyAttribute attr
attr
instance Simplifiable VName where
simplify :: VName -> SimpleM lore VName
simplify VName
v = do
Maybe (SubExp, Certificates)
se <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
v (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
se of
Just (Var VName
v', Certificates
cs) -> do SimpleM lore ()
forall lore. SimpleM lore ()
changed
Certificates -> SimpleM lore ()
forall lore. Certificates -> SimpleM lore ()
usedCerts Certificates
cs
VName -> SimpleM lore VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v'
Maybe (SubExp, Certificates)
_ -> VName -> SimpleM lore VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
instance Simplifiable d => Simplifiable (ShapeBase d) where
simplify :: ShapeBase d -> SimpleM lore (ShapeBase d)
simplify = ([d] -> ShapeBase d)
-> SimpleM lore [d] -> SimpleM lore (ShapeBase d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape (SimpleM lore [d] -> SimpleM lore (ShapeBase d))
-> (ShapeBase d -> SimpleM lore [d])
-> ShapeBase d
-> SimpleM lore (ShapeBase d)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [d] -> SimpleM lore [d]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify ([d] -> SimpleM lore [d])
-> (ShapeBase d -> [d]) -> ShapeBase d -> SimpleM lore [d]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims
instance Simplifiable ExtSize where
simplify :: ExtSize -> SimpleM lore ExtSize
simplify (Free SubExp
se) = SubExp -> ExtSize
forall a. a -> Ext a
Free (SubExp -> ExtSize) -> SimpleM lore SubExp -> SimpleM lore ExtSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify SubExp
se
simplify (Ext Int
x) = ExtSize -> SimpleM lore ExtSize
forall (m :: * -> *) a. Monad m => a -> m a
return (ExtSize -> SimpleM lore ExtSize)
-> ExtSize -> SimpleM lore ExtSize
forall a b. (a -> b) -> a -> b
$ Int -> ExtSize
forall a. Int -> Ext a
Ext Int
x
instance Simplifiable Space where
simplify :: Space -> SimpleM lore Space
simplify (ScalarSpace Result
ds PrimType
t) = Result -> PrimType -> Space
ScalarSpace (Result -> PrimType -> Space)
-> SimpleM lore Result -> SimpleM lore (PrimType -> Space)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> SimpleM lore Result
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Result
ds SimpleM lore (PrimType -> Space)
-> SimpleM lore PrimType -> SimpleM lore Space
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM lore PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
simplify Space
s = Space -> SimpleM lore Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
s
instance Simplifiable shape => Simplifiable (TypeBase shape u) where
simplify :: TypeBase shape u -> SimpleM lore (TypeBase shape u)
simplify (Array PrimType
et shape
shape u
u) = do
shape
shape' <- shape -> SimpleM lore shape
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify shape
shape
TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase shape u -> SimpleM lore (TypeBase shape u))
-> TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape' u
u
simplify (Mem Space
space) =
Space -> TypeBase shape u
forall shape u. Space -> TypeBase shape u
Mem (Space -> TypeBase shape u)
-> SimpleM lore Space -> SimpleM lore (TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Space -> SimpleM lore Space
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify Space
space
simplify (Prim PrimType
bt) =
TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase shape u -> SimpleM lore (TypeBase shape u))
-> TypeBase shape u -> SimpleM lore (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
instance Simplifiable d => Simplifiable (DimIndex d) where
simplify :: DimIndex d -> SimpleM lore (DimIndex d)
simplify (DimFix d
i) = d -> DimIndex d
forall d. d -> DimIndex d
DimFix (d -> DimIndex d) -> SimpleM lore d -> SimpleM lore (DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
i
simplify (DimSlice d
i d
n d
s) = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice (d -> d -> d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (d -> d -> DimIndex d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
i SimpleM lore (d -> d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (d -> DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
n SimpleM lore (d -> DimIndex d)
-> SimpleM lore d -> SimpleM lore (DimIndex d)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> d -> SimpleM lore d
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify d
s
simplifyLambda :: SimplifiableLore lore =>
Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambda :: Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambda Lambda lore
lam [Maybe VName]
arrs = do
BlockPred (Wise lore)
par_blocker <- (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall lore a. (Env lore -> a) -> SimpleM lore a
asksEngineEnv ((Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore)))
-> (Env lore -> BlockPred (Wise lore))
-> SimpleM lore (BlockPred (Wise lore))
forall a b. (a -> b) -> a -> b
$ HoistBlockers lore -> BlockPred (Wise lore)
forall lore. HoistBlockers lore -> BlockPred (Wise lore)
blockHoistPar (HoistBlockers lore -> BlockPred (Wise lore))
-> (Env lore -> HoistBlockers lore)
-> Env lore
-> BlockPred (Wise lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore -> HoistBlockers lore
forall lore. Env lore -> HoistBlockers lore
envHoistBlockers
BlockPred (Wise lore)
-> Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist BlockPred (Wise lore)
par_blocker Lambda lore
lam [Maybe VName]
arrs
simplifyLambdaNoHoisting :: SimplifiableLore lore =>
Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore))
simplifyLambdaNoHoisting :: Lambda lore -> [Maybe VName] -> SimpleM lore (Lambda (Wise lore))
simplifyLambdaNoHoisting Lambda lore
lam [Maybe VName]
arr =
(Lambda (Wise lore), Stms (Wise lore)) -> Lambda (Wise lore)
forall a b. (a, b) -> a
fst ((Lambda (Wise lore), Stms (Wise lore)) -> Lambda (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BlockPred (Wise lore)
-> Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall lore.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist (Bool -> BlockPred (Wise lore)
forall lore. Bool -> BlockPred lore
isFalse Bool
False) Lambda lore
lam [Maybe VName]
arr
simplifyLambdaMaybeHoist :: SimplifiableLore lore =>
BlockPred (Wise lore) -> Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist :: BlockPred (Wise lore)
-> Lambda lore
-> [Maybe VName]
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
simplifyLambdaMaybeHoist BlockPred (Wise lore)
blocked lam :: Lambda lore
lam@(Lambda [LParam lore]
params BodyT lore
body [Type]
rettype) [Maybe VName]
arrs = do
[LParam lore]
params' <- (LParam lore -> SimpleM lore (LParam lore))
-> [LParam lore] -> SimpleM lore [LParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((LParamAttr lore -> SimpleM lore (LParamAttr lore))
-> LParam lore -> SimpleM lore (LParam lore)
forall attr lore.
(attr -> SimpleM lore attr)
-> Param attr -> SimpleM lore (Param attr)
simplifyParam LParamAttr lore -> SimpleM lore (LParamAttr lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [LParam lore]
params
let ([LParam lore]
nonarrayparams, [LParam lore]
arrayparams) =
Int -> [LParam lore] -> ([LParam lore], [LParam lore])
forall a. Int -> [a] -> ([a], [a])
splitAt ([LParam lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam lore]
params' Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Maybe VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Maybe VName]
arrs) [LParam lore]
params'
paramnames :: Names
paramnames = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [VName]
forall lore. Lambda lore -> [VName]
boundByLambda Lambda lore
lam
((Stms (Wise lore)
lamstms, Result
lamres), Stms (Wise lore)
hoisted) <-
SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a. SimpleM lore a -> SimpleM lore a
enterLoop (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
[LParam (Wise lore)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[LParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindLParams [LParam lore]
[LParam (Wise lore)]
nonarrayparams (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
[(LParam (Wise lore), Maybe VName)]
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
[(LParam (Wise lore), Maybe VName)]
-> SimpleM lore a -> SimpleM lore a
bindArrayLParams ([LParam lore] -> [Maybe VName] -> [(LParam lore, Maybe VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LParam lore]
arrayparams [Maybe VName]
arrs) (SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf (BlockPred (Wise lore)
blocked BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` Names -> BlockPred (Wise lore)
forall lore. Attributes lore => Names -> BlockPred lore
hasFree Names
paramnames BlockPred (Wise lore)
-> BlockPred (Wise lore) -> BlockPred (Wise lore)
forall lore. BlockPred lore -> BlockPred lore -> BlockPred lore
`orIf` BlockPred (Wise lore)
forall lore. BlockPred lore
isConsumed) (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall a b. (a -> b) -> a -> b
$
[Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody ((Type -> Diet) -> [Type] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map (Diet -> Type -> Diet
forall a b. a -> b -> a
const Diet
Observe) [Type]
rettype) BodyT lore
body
Body (Wise lore)
body' <- Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody Stms (Wise lore)
lamstms Result
lamres
[Type]
rettype' <- [Type] -> SimpleM lore [Type]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [Type]
rettype
(Lambda (Wise lore), Stms (Wise lore))
-> SimpleM lore (Lambda (Wise lore), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return ([LParam (Wise lore)]
-> Body (Wise lore) -> [Type] -> Lambda (Wise lore)
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
[LParam (Wise lore)]
params' Body (Wise lore)
body' [Type]
rettype', Stms (Wise lore)
hoisted)
consumeResult :: [(Diet, SubExp)] -> UT.UsageTable
consumeResult :: [(Diet, SubExp)] -> UsageTable
consumeResult = [UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable)
-> ([(Diet, SubExp)] -> [UsageTable])
-> [(Diet, SubExp)]
-> UsageTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Diet, SubExp) -> UsageTable) -> [(Diet, SubExp)] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map (Diet, SubExp) -> UsageTable
inspect
where inspect :: (Diet, SubExp) -> UsageTable
inspect (Diet
Consume, SubExp
se) =
[UsageTable] -> UsageTable
forall a. Monoid a => [a] -> a
mconcat ([UsageTable] -> UsageTable) -> [UsageTable] -> UsageTable
forall a b. (a -> b) -> a -> b
$ (VName -> UsageTable) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> [a] -> [b]
map VName -> UsageTable
UT.consumedUsage ([VName] -> [UsageTable]) -> [VName] -> [UsageTable]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExp -> Names
subExpAliases SubExp
se
inspect (Diet, SubExp)
_ = UsageTable
forall a. Monoid a => a
mempty
instance Simplifiable Certificates where
simplify :: Certificates -> SimpleM lore Certificates
simplify (Certificates [VName]
ocs) = [VName] -> Certificates
Certificates ([VName] -> Certificates)
-> ([[VName]] -> [VName]) -> [[VName]] -> Certificates
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [VName]
forall a. Eq a => [a] -> [a]
nub ([VName] -> [VName])
-> ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> Certificates)
-> SimpleM lore [[VName]] -> SimpleM lore Certificates
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> SimpleM lore [VName])
-> [VName] -> SimpleM lore [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> SimpleM lore [VName]
forall lore. VName -> SimpleM lore [VName]
check [VName]
ocs
where check :: VName -> SimpleM lore [VName]
check VName
idd = do
Maybe (SubExp, Certificates)
vv <- VName -> SymbolTable (Wise lore) -> Maybe (SubExp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
ST.lookupSubExp VName
idd (SymbolTable (Wise lore) -> Maybe (SubExp, Certificates))
-> SimpleM lore (SymbolTable (Wise lore))
-> SimpleM lore (Maybe (SubExp, Certificates))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SimpleM lore (SymbolTable (Wise lore))
forall lore. SimpleM lore (SymbolTable (Wise lore))
askVtable
case Maybe (SubExp, Certificates)
vv of
Just (Constant PrimValue
Checked, Certificates [VName]
cs) -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
cs
Just (Var VName
idd', Certificates
_) -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd']
Maybe (SubExp, Certificates)
_ -> [VName] -> SimpleM lore [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
idd]
insertAllStms :: SimplifiableLore lore =>
SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
insertAllStms :: SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
insertAllStms = (Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore)))
-> (Stms (Wise lore), Result) -> SimpleM lore (Body (Wise lore))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
Stms (Wise lore) -> Result -> SimpleM lore (Body (Wise lore))
constructBody ((Stms (Wise lore), Result) -> SimpleM lore (Body (Wise lore)))
-> (((Stms (Wise lore), Result), Stms (Wise lore))
-> (Stms (Wise lore), Result))
-> ((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Stms (Wise lore), Result), Stms (Wise lore))
-> (Stms (Wise lore), Result)
forall a b. (a, b) -> a
fst (((Stms (Wise lore), Result), Stms (Wise lore))
-> SimpleM lore (Body (Wise lore)))
-> (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore ((Stms (Wise lore), Result), Stms (Wise lore))
forall lore a.
SimplifiableLore lore =>
BlockPred (Wise lore)
-> SimpleM lore (SimplifiedBody lore a)
-> SimpleM lore ((Stms (Wise lore), a), Stms (Wise lore))
blockIf (Bool -> BlockPred (Wise lore)
forall lore. Bool -> BlockPred lore
isFalse Bool
False)
simplifyFun :: SimplifiableLore lore =>
FunDef lore -> SimpleM lore (FunDef (Wise lore))
simplifyFun :: FunDef lore -> SimpleM lore (FunDef (Wise lore))
simplifyFun (FunDef Maybe EntryPoint
entry Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) = do
[RetType lore]
rettype' <- [RetType lore] -> SimpleM lore [RetType lore]
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify [RetType lore]
rettype
[FParam lore]
params' <- (FParam lore -> SimpleM lore (FParam lore))
-> [FParam lore] -> SimpleM lore [FParam lore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((FParamAttr lore -> SimpleM lore (FParamAttr lore))
-> FParam lore -> SimpleM lore (FParam lore)
forall attr lore.
(attr -> SimpleM lore attr)
-> Param attr -> SimpleM lore (Param attr)
simplifyParam FParamAttr lore -> SimpleM lore (FParamAttr lore)
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify) [FParam lore]
params
let ds :: [Diet]
ds = (TypeBase ExtShape Uniqueness -> Diet)
-> [TypeBase ExtShape Uniqueness] -> [Diet]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase ExtShape Uniqueness -> Diet
forall shape. TypeBase shape Uniqueness -> Diet
diet ([RetType lore] -> [TypeBase ExtShape Uniqueness]
forall rt. IsRetType rt => [rt] -> [TypeBase ExtShape Uniqueness]
retTypeValues [RetType lore]
rettype')
Body (Wise lore)
body' <- [FParam (Wise lore)]
-> SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall lore a.
SimplifiableLore lore =>
[FParam (Wise lore)] -> SimpleM lore a -> SimpleM lore a
bindFParams [FParam lore]
[FParam (Wise lore)]
params (SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (Body (Wise lore))
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall lore.
SimplifiableLore lore =>
SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
insertAllStms (SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore)))
-> SimpleM lore (SimplifiedBody lore Result)
-> SimpleM lore (Body (Wise lore))
forall a b. (a -> b) -> a -> b
$ [Diet] -> BodyT lore -> SimpleM lore (SimplifiedBody lore Result)
forall lore.
SimplifiableLore lore =>
[Diet] -> Body lore -> SimpleM lore (SimplifiedBody lore Result)
simplifyBody [Diet]
ds BodyT lore
body
FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore)))
-> FunDef (Wise lore) -> SimpleM lore (FunDef (Wise lore))
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Name
-> [RetType (Wise lore)]
-> [FParam (Wise lore)]
-> Body (Wise lore)
-> FunDef (Wise lore)
forall lore.
Maybe EntryPoint
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Name
fname [RetType lore]
[RetType (Wise lore)]
rettype' [FParam lore]
[FParam (Wise lore)]
params' Body (Wise lore)
body'