-- | We collect all if-blocks under the if-the-else expressions and statements.
--
-- For a given if-block of code the taks is to agregate all expressions
-- that can be used inside that block and don't affect external expressions
-- relative to that block
--
-- For exampe consider expression:
--
-- > k3 opcA k2 k1
-- > k4 opcB 1, 120
-- >
-- > if cond then
-- >   k5 = k3
-- > else
-- >   k5 = k4
-- > endif
-- >
-- It can be transformed to:
--
-- > if cond then
-- >  k3 opcA k2 k1
-- >  k5 = k3
-- > else
-- >   k4 opcB 1, 120
-- >   k5 = k4
-- > endif
--
-- We bring relevant to if-blocks expressions inside the block.
-- But we should be careful not to touch the expressions that are dependencies
-- to expressions outside of the block.
--
-- The algorithm to find groups of such expressions proceeds as follows:
--
-- * count how many times given expression is used in RHS of the expression.
--    Create a table for fast access (O (expr-size)). Let's call it global count.
--
-- * for a given expression definition start to follow it's dependencies recursively
--    and count for all siblings how many times they are used in RHS of the expression.--
--    Let's call it local count
--
-- * The rule: for a given integer label/name
--      * if the global count equals to the local count
--          it can be brought inside if-block. Because all it's usages are inside the sub-expressions
--          of that block and does not leak to the outer scope.
--      * if name is not a sibling of the node for which the rule does not hold true
--
--  There are cases when node is inside if sub-graph but the problem is that one of it's
--    parents may be not fit to the graph. To solve this problem we go over the sub-graph 2 times:
--
--    1) to collect local counts we create IntMap of Usage counts local to the if-block
--    2) to mark as False all nodes that are not local to if and also (IMPORTANT) mark as False all it's children.
--        As we traverse the graph in breadth first we will recursively mark all non fit siblings.
--        I hope that it works :)
--        On this stage we create a set of nodes which are truly local
--    this is a set of local variables
--
--    One buggy solution was to traverse the sub graph and put inside the set the
--     nodes which are local regarding the ussage count. But this does not work as
--     valid node can have invalid parent. And algorithm will exclude parent but
--     keep the child which will lead to the broken code.
--
-- This rule works for generic expressions defined on traversable functor F.
--
-- But there are some Csound peculiriaties:
--
-- * reminder:
--      * if-blocks can work on Ir and on Kr rates.
--      * Kr if-blocks are ignored on initialization Ir stage.
--
-- * this leads to csound syntax specific rules:
--
--    * init expressions can not be brought inside Kr if-block (they will be ignored)
--       also Opcodes that run at I-rate.
--
--    * variable / array initialisation can not be brought inside Kr if-block
--
--    * all constants inside the block should have the same rate as the block itself.
--       i.e. ir constants inside Ir block and kr constants inside kr block
--
--  So we should recursively follow the depndencies of the if-block root variable definition.
--  But we also exclude nodes early if they can not be present inside the block by rate.
module Csound.Dynamic.Tfm.IfBlocks
  ( collectIfBlocks
  ) where

import Csound.Dynamic.Types.Exp hiding (Var(..))
import Csound.Dynamic.Types.Exp qualified as Exp
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict
import Data.Maybe (fromMaybe)
import Data.Vector.Mutable qualified as Vector
import Data.Vector.Unboxed.Mutable qualified as UnboxedVector
import Data.List qualified as List
import Data.IntMap.Strict (IntMap)
import Data.IntMap.Strict qualified as IntMap
import Data.IntSet (IntSet)
import Data.IntSet qualified as IntSet
import Data.STRef
import Data.Bifunctor (first)
import Csound.Dynamic.Tfm.InferTypes (InferenceResult (..), Stmt(..), Var(..))
import Data.Text qualified as Text
-- import Debug.Trace

type Expr  = Stmt Var

collectIfBlocks :: InferenceResult -> InferenceResult
collectIfBlocks :: InferenceResult -> InferenceResult
collectIfBlocks infRes :: InferenceResult
infRes@InferenceResult{Bool
Int
[Stmt Var]
typedProgram :: [Stmt Var]
programLastFreshId :: Int
programHasIfs :: Bool
typedProgram :: InferenceResult -> [Stmt Var]
programLastFreshId :: InferenceResult -> Int
programHasIfs :: InferenceResult -> Bool
..}
  | Bool
programHasIfs = (forall s. ST s InferenceResult) -> InferenceResult
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s InferenceResult) -> InferenceResult)
-> (forall s. ST s InferenceResult) -> InferenceResult
forall a b. (a -> b) -> a -> b
$ do
      Env s
env <- Int -> [Stmt Var] -> ST s (Env s)
forall s. Int -> [Stmt Var] -> ST s (Env s)
newEnv Int
programLastFreshId [Stmt Var]
typedProgram
      ([Stmt Var] -> Env s -> ST s InferenceResult)
-> ([Stmt Var], Env s) -> ST s InferenceResult
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Stmt Var] -> Env s -> ST s InferenceResult
forall s. [Stmt Var] -> Env s -> ST s InferenceResult
toResult (([Stmt Var], Env s) -> ST s InferenceResult)
-> ST s ([Stmt Var], Env s) -> ST s InferenceResult
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StateT (Env s) (ST s) [Stmt Var]
-> Env s -> ST s ([Stmt Var], Env s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ([Stmt Var] -> [Stmt Var] -> StateT (Env s) (ST s) [Stmt Var]
forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
collectIter [] ([Stmt Var] -> StateT (Env s) (ST s) [Stmt Var])
-> [Stmt Var] -> StateT (Env s) (ST s) [Stmt Var]
forall a b. (a -> b) -> a -> b
$ [Stmt Var] -> [Stmt Var]
forall a. [a] -> [a]
List.reverse [Stmt Var]
typedProgram) Env s
env
  | Bool
otherwise = InferenceResult
infRes
  where
    toResult :: [Stmt Var] -> Env s -> ST s InferenceResult
    toResult :: forall s. [Stmt Var] -> Env s -> ST s InferenceResult
toResult [Stmt Var]
prog Env{Int
STRef s Int
DagGraph s
IsInits s
UsageCounts s
envUsageCount :: UsageCounts s
envDag :: DagGraph s
envIsInit :: IsInits s
envLastFreshId :: STRef s Int
envDagSize :: Int
envUsageCount :: forall s. Env s -> UsageCounts s
envDag :: forall s. Env s -> DagGraph s
envIsInit :: forall s. Env s -> IsInits s
envLastFreshId :: forall s. Env s -> STRef s Int
envDagSize :: forall s. Env s -> Int
..} = do
      Int
lastId <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
envLastFreshId
      InferenceResult -> ST s InferenceResult
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InferenceResult -> ST s InferenceResult)
-> InferenceResult -> ST s InferenceResult
forall a b. (a -> b) -> a -> b
$ InferenceResult
infRes { typedProgram = prog, programLastFreshId = lastId }

-- | Monad of the algorithm
type Collect s a = StateT (Env s) (ST s) a

type UsageCounts s = UnboxedVector.STVector s Int
type DagGraph s = Vector.STVector s (RatedExp Var)
type IsInits s =  UnboxedVector.STVector s Bool

-- | Internal mutable state of the algorithm
data Env s = Env
  { forall s. Env s -> UsageCounts s
envUsageCount  :: UsageCounts s
  , forall s. Env s -> DagGraph s
envDag         :: DagGraph s
  , forall s. Env s -> IsInits s
envIsInit      :: IsInits s
  , forall s. Env s -> STRef s Int
envLastFreshId :: STRef s Int
  , forall s. Env s -> Int
envDagSize     :: Int
  }

---------------------------------------------------
-- collect interface

getDagSize :: Collect s Int
getDagSize :: forall s. Collect s Int
getDagSize = (Env s -> Int) -> StateT (Env s) (ST s) Int
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Env s -> Int
forall s. Env s -> Int
envDagSize

readGlobalUsages :: Int -> Collect s Int
readGlobalUsages :: forall s. Int -> Collect s Int
readGlobalUsages Int
n = do
  Int
dagSize <- Collect s Int
forall s. Collect s Int
getDagSize
  if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
dagSize
    then do
      UsageCounts s
usages <- (Env s -> UsageCounts s) -> StateT (Env s) (ST s) (UsageCounts s)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Env s -> UsageCounts s
forall s. Env s -> UsageCounts s
envUsageCount
      ST s Int -> Collect s Int
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s Int -> Collect s Int) -> ST s Int -> Collect s Int
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UnboxedVector.read UsageCounts s
MVector (PrimState (ST s)) Int
usages Int
n
    else Int -> Collect s Int
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0

readIsInit :: Int -> Collect s Bool
readIsInit :: forall s. Int -> Collect s Bool
readIsInit Int
n = do
  Int
dagSize <- Collect s Int
forall s. Collect s Int
getDagSize
  if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
dagSize
    then do
      IsInits s
inits <- (Env s -> IsInits s) -> StateT (Env s) (ST s) (IsInits s)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Env s -> IsInits s
forall s. Env s -> IsInits s
envIsInit
      ST s Bool -> Collect s Bool
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s Bool -> Collect s Bool) -> ST s Bool -> Collect s Bool
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UnboxedVector.read IsInits s
MVector (PrimState (ST s)) Bool
inits Int
n
    else Bool -> Collect s Bool
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

readDag :: Var -> Collect s (Maybe Expr)
readDag :: forall s. Var -> Collect s (Maybe (Stmt Var))
readDag Var
lhs = do
  Int
dagSize <- Collect s Int
forall s. Collect s Int
getDagSize
  if Var -> Int
varId Var
lhs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
dagSize
    then do
      DagGraph s
dag <- (Env s -> DagGraph s) -> StateT (Env s) (ST s) (DagGraph s)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Env s -> DagGraph s
forall s. Env s -> DagGraph s
envDag
      (RatedExp Var -> Maybe (Stmt Var))
-> StateT (Env s) (ST s) (RatedExp Var)
-> Collect s (Maybe (Stmt Var))
forall a b.
(a -> b) -> StateT (Env s) (ST s) a -> StateT (Env s) (ST s) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Stmt Var -> Maybe (Stmt Var)
forall a. a -> Maybe a
Just (Stmt Var -> Maybe (Stmt Var))
-> (RatedExp Var -> Stmt Var) -> RatedExp Var -> Maybe (Stmt Var)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt Var
lhs )) (StateT (Env s) (ST s) (RatedExp Var)
 -> Collect s (Maybe (Stmt Var)))
-> StateT (Env s) (ST s) (RatedExp Var)
-> Collect s (Maybe (Stmt Var))
forall a b. (a -> b) -> a -> b
$ ST s (RatedExp Var) -> StateT (Env s) (ST s) (RatedExp Var)
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (RatedExp Var) -> StateT (Env s) (ST s) (RatedExp Var))
-> ST s (RatedExp Var) -> StateT (Env s) (ST s) (RatedExp Var)
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST s)) (RatedExp Var)
-> Int -> ST s (RatedExp Var)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
Vector.read DagGraph s
MVector (PrimState (ST s)) (RatedExp Var)
dag (Var -> Int
varId Var
lhs)
    else Maybe (Stmt Var) -> Collect s (Maybe (Stmt Var))
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stmt Var)
forall a. Maybe a
Nothing

withDag :: Var -> (Expr -> Collect s ()) -> Collect s ()
withDag :: forall s. Var -> (Stmt Var -> Collect s ()) -> Collect s ()
withDag Var
n Stmt Var -> Collect s ()
cont = do
  Maybe (Stmt Var)
mExpr <- Var -> Collect s (Maybe (Stmt Var))
forall s. Var -> Collect s (Maybe (Stmt Var))
readDag Var
n
  Maybe (Stmt Var) -> (Stmt Var -> Collect s ()) -> Collect s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe (Stmt Var)
mExpr Stmt Var -> Collect s ()
cont

freshId :: Collect s Int
freshId :: forall s. Collect s Int
freshId = do
  STRef s Int
ref <- (Env s -> STRef s Int) -> StateT (Env s) (ST s) (STRef s Int)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets Env s -> STRef s Int
forall s. Env s -> STRef s Int
envLastFreshId
  ST s Int -> Collect s Int
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s Int -> Collect s Int) -> ST s Int -> Collect s Int
forall a b. (a -> b) -> a -> b
$ do
    Int
newId <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
ref
    STRef s Int -> (Int -> Int) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' STRef s Int
ref Int -> Int
forall a. Enum a => a -> a
succ
    Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
newId

---------------------------------------------------------------------------
-- working with DAG-graph

traverseAccumDag :: forall s a . Show a => (Expr -> a -> Collect s a) -> a -> (Expr -> Collect s Bool) -> PrimOr Var -> Collect s a
traverseAccumDag :: forall s a.
Show a =>
(Stmt Var -> a -> Collect s a)
-> a -> (Stmt Var -> Collect s Bool) -> PrimOr Var -> Collect s a
traverseAccumDag Stmt Var -> a -> Collect s a
update a
initSt Stmt Var -> Collect s Bool
getIsEnd (PrimOr Either Prim Var
root) = do
  case Either Prim Var
root of
    Left Prim
_    -> a -> Collect s a
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
initSt
    Right Var
var -> do
      STRef s a
ref <- ST s (STRef s a) -> StateT (Env s) (ST s) (STRef s a)
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (STRef s a) -> StateT (Env s) (ST s) (STRef s a))
-> ST s (STRef s a) -> StateT (Env s) (ST s) (STRef s a)
forall a b. (a -> b) -> a -> b
$ a -> ST s (STRef s a)
forall a s. a -> ST s (STRef s a)
newSTRef a
initSt
      Var
-> (Stmt Var -> Collect s Bool)
-> (Stmt Var -> Collect s ())
-> Collect s ()
forall s.
Var
-> (Stmt Var -> Collect s Bool)
-> (Stmt Var -> Collect s ())
-> Collect s ()
traverseDag Var
var Stmt Var -> Collect s Bool
getIsEnd (STRef s a -> Stmt Var -> Collect s ()
go STRef s a
ref)
      ST s a -> Collect s a
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s a -> Collect s a) -> ST s a -> Collect s a
forall a b. (a -> b) -> a -> b
$ STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
ref
  where
    go :: STRef s a -> Expr -> Collect s ()
    go :: STRef s a -> Stmt Var -> Collect s ()
go STRef s a
ref Stmt Var
expr = do
      a
val <- ST s a -> Collect s a
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s a -> Collect s a) -> ST s a -> Collect s a
forall a b. (a -> b) -> a -> b
$ STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
ref
      a
newVal <- Stmt Var -> a -> Collect s a
update Stmt Var
expr a
val
      ST s () -> Collect s ()
forall (m :: * -> *) a. Monad m => m a -> StateT (Env s) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> Collect s ()) -> ST s () -> Collect s ()
forall a b. (a -> b) -> a -> b
$ STRef s a -> a -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s a
ref (a -> ST s ()) -> a -> ST s ()
forall a b. (a -> b) -> a -> b
$
 --       trace (unlines ["GO", show $ stmtLhs expr, show $ ratedExpExp $ stmtRhs expr, show newVal]) $
        a
newVal

-- | Breadth first traversal
traverseDag :: Var -> (Expr -> Collect s Bool) -> (Expr -> Collect s ()) -> Collect s ()
traverseDag :: forall s.
Var
-> (Stmt Var -> Collect s Bool)
-> (Stmt Var -> Collect s ())
-> Collect s ()
traverseDag Var
root Stmt Var -> Collect s Bool
getIsEnd Stmt Var -> Collect s ()
go =
  Var -> (Stmt Var -> Collect s ()) -> Collect s ()
forall s. Var -> (Stmt Var -> Collect s ()) -> Collect s ()
withDag Var
root ((Stmt Var -> Collect s ()) -> Collect s ())
-> (Stmt Var -> Collect s ()) -> Collect s ()
forall a b. (a -> b) -> a -> b
$ \Stmt Var
expr -> do
    Bool
isTerminal <- Stmt Var -> Collect s Bool
getIsEnd Stmt Var
expr
    Bool -> Collect s () -> Collect s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
isTerminal (Collect s () -> Collect s ()) -> Collect s () -> Collect s ()
forall a b. (a -> b) -> a -> b
$ do
      Stmt Var -> Collect s ()
go Stmt Var
expr
      (Var -> Collect s ()) -> RatedExp Var -> Collect s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\Var
var -> Var
-> (Stmt Var -> Collect s Bool)
-> (Stmt Var -> Collect s ())
-> Collect s ()
forall s.
Var
-> (Stmt Var -> Collect s Bool)
-> (Stmt Var -> Collect s ())
-> Collect s ()
traverseDag Var
var Stmt Var -> Collect s Bool
getIsEnd Stmt Var -> Collect s ()
go) (Stmt Var -> RatedExp Var
forall a. Stmt a -> RatedExp a
stmtRhs Stmt Var
expr)


-----------------------------------------------------------

newEnv :: forall s . Int -> [Expr] -> ST s (Env s)
newEnv :: forall s. Int -> [Stmt Var] -> ST s (Env s)
newEnv Int
exprSize [Stmt Var]
exprs = do
  UsageCounts s
usageCount <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UnboxedVector.replicate Int
exprSize Int
0
  DagGraph s
dag <- Int -> ST s (MVector (PrimState (ST s)) (RatedExp Var))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
Vector.new Int
exprSize
  IsInits s
isInit <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UnboxedVector.replicate Int
exprSize Bool
False
  STRef s Int
exprSizeRef <- Int -> ST s (STRef s Int)
forall a s. a -> ST s (STRef s a)
newSTRef Int
exprSize
  let env :: Env s
env = UsageCounts s
-> DagGraph s -> IsInits s -> STRef s Int -> Int -> Env s
forall s.
UsageCounts s
-> DagGraph s -> IsInits s -> STRef s Int -> Int -> Env s
Env UsageCounts s
usageCount DagGraph s
dag IsInits s
isInit STRef s Int
exprSizeRef Int
exprSize
  (Stmt Var -> ST s ()) -> [Stmt Var] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Env s -> Stmt Var -> ST s ()
go Env s
env) [Stmt Var]
exprs
  Env s -> ST s (Env s)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Env s
env
  where
    go :: Env s -> Expr -> ST s ()
    go :: Env s -> Stmt Var -> ST s ()
go Env s
env Stmt Var
expr = do
      UsageCounts s -> Stmt Var -> ST s ()
updateUsageCount (Env s -> UsageCounts s
forall s. Env s -> UsageCounts s
envUsageCount Env s
env) Stmt Var
expr
      DagGraph s -> Stmt Var -> ST s ()
updateDag (Env s -> DagGraph s
forall s. Env s -> DagGraph s
envDag Env s
env) Stmt Var
expr
      IsInits s -> Stmt Var -> ST s ()
updateIsInit (Env s -> IsInits s
forall s. Env s -> IsInits s
envIsInit Env s
env) Stmt Var
expr

    updateUsageCount :: UsageCounts s -> Expr -> ST s ()
    updateUsageCount :: UsageCounts s -> Stmt Var -> ST s ()
updateUsageCount UsageCounts s
usageCounts Stmt Var
expr =
      (Var -> ST s ()) -> RatedExp Var -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Var -> ST s ()
forall {m :: * -> *}. (PrimState m ~ s, PrimMonad m) => Var -> m ()
count (Stmt Var -> RatedExp Var
forall a. Stmt a -> RatedExp a
stmtRhs Stmt Var
expr)
      where
        count :: Var -> m ()
count Var
v = MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UnboxedVector.modify UsageCounts s
MVector (PrimState m) Int
usageCounts Int -> Int
forall a. Enum a => a -> a
succ (Var -> Int
varId Var
v)

    updateDag :: DagGraph s -> Expr -> ST s ()
    updateDag :: DagGraph s -> Stmt Var -> ST s ()
updateDag DagGraph s
dag (Stmt Var
lhs RatedExp Var
rhs) =
      MVector (PrimState (ST s)) (RatedExp Var)
-> Int -> RatedExp Var -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
Vector.write DagGraph s
MVector (PrimState (ST s)) (RatedExp Var)
dag (Var -> Int
varId Var
lhs) RatedExp Var
rhs

    updateIsInit :: IsInits s -> Expr -> ST s ()
    updateIsInit :: IsInits s -> Stmt Var -> ST s ()
updateIsInit IsInits s
isInit Stmt Var
expr =
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Stmt Var -> Bool
isInitExpr Stmt Var
expr) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
        MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UnboxedVector.write IsInits s
MVector (PrimState (ST s)) Bool
isInit (Var -> Int
varId (Var -> Int) -> Var -> Int
forall a b. (a -> b) -> a -> b
$ Stmt Var -> Var
forall a. Stmt a -> a
stmtLhs Stmt Var
expr) Bool
True

-- | Be sure not to bring initialization expression inside the if-blocks
isInitExpr :: Stmt Var -> Bool
isInitExpr :: Stmt Var -> Bool
isInitExpr Stmt Var
expr =
  (Var -> Rate
varType (Stmt Var -> Var
forall a. Stmt a -> a
stmtLhs Stmt Var
expr) Rate -> Rate -> Bool
forall a. Eq a => a -> a -> Bool
== Rate
Ir) Bool -> Bool -> Bool
|| MainExp (PrimOr Var) -> Bool
forall {a}. MainExp a -> Bool
checkExpr (RatedExp Var -> MainExp (PrimOr Var)
forall a. RatedExp a -> Exp a
ratedExpExp (RatedExp Var -> MainExp (PrimOr Var))
-> RatedExp Var -> MainExp (PrimOr Var)
forall a b. (a -> b) -> a -> b
$ Stmt Var -> RatedExp Var
forall a. Stmt a -> RatedExp a
stmtRhs Stmt Var
expr)
  where
    checkExpr :: MainExp a -> Bool
checkExpr = \case
      InitVar Var
_ a
_ -> Bool
True
      InitArr Var
_ ArrSize a
_ -> Bool
True
      TfmArr Bool
isInit Var
_ Info
_ ArrSize a
_ -> Bool
isInit
      InitPureArr Rate
_ IfRate
_ ArrSize a
_ -> Bool
True
      InitMacrosInt Text
_ Int
_ -> Bool
True
      InitMacrosDouble Text
_ Double
_ -> Bool
True
      InitMacrosString Text
_ Text
_ -> Bool
True
      ConvertRate Rate
Ir Maybe Rate
_ a
_ -> Bool
True
      Select Rate
Ir Int
_ a
_ -> Bool
True
      MainExp a
_ -> Bool
False

data ExprType a
  = PlainType
  | IfType IfRate (CondInfo a) a (IfCons a)
  | IfElseType IfRate (CondInfo a) a a (IfElseCons a)
  | IfExpType IfRate (CondInfo a) a a

data IfCons a = IfCons
  { forall a. IfCons a -> IfRate -> CondInfo a -> MainExp a
ifBegin :: IfRate -> CondInfo a -> MainExp a
  , forall a. IfCons a -> MainExp a
ifEnd   :: MainExp a
  }

data IfElseCons a = IfElseCons
  { forall a. IfElseCons a -> IfRate -> CondInfo a -> MainExp a
ifElseBegin :: IfRate -> CondInfo a -> MainExp a
  , forall a. IfElseCons a -> MainExp a
elseBegin   :: MainExp a
  , forall a. IfElseCons a -> MainExp a
ifElseEnd   :: MainExp a
  }

type LocalUsageCounts = IntMap Int
type LocalVars = IntSet

-- | We process statements in reverse order
-- and then also accumulation happens in reverse
-- so we don't need to reverse twice
collectIter :: [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
collectIter :: forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
collectIter [Stmt Var]
results = \case
  [] -> [Stmt Var] -> Collect s [Stmt Var]
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Stmt Var]
results
  Stmt Var
expr : [Stmt Var]
exprs ->
    case RatedExp Var -> ExprType (PrimOr Var)
getExprType (Stmt Var -> RatedExp Var
forall a. Stmt a -> RatedExp a
stmtRhs Stmt Var
expr) of
      ExprType (PrimOr Var)
PlainType                         -> Stmt Var -> [Stmt Var] -> Collect s [Stmt Var]
forall {s}. Stmt Var -> [Stmt Var] -> Collect s [Stmt Var]
onPlain Stmt Var
expr [Stmt Var]
exprs
      IfType IfRate
rate CondInfo (PrimOr Var)
check PrimOr Var
th IfCons (PrimOr Var)
cons         -> IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> IfCons (PrimOr Var)
-> Var
-> [Stmt Var]
-> Collect s [Stmt Var]
forall {s}.
IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> IfCons (PrimOr Var)
-> Var
-> [Stmt Var]
-> StateT (Env s) (ST s) [Stmt Var]
onIf IfRate
rate CondInfo (PrimOr Var)
check PrimOr Var
th IfCons (PrimOr Var)
cons (Stmt Var -> Var
forall a. Stmt a -> a
stmtLhs Stmt Var
expr) [Stmt Var]
exprs
      IfElseType IfRate
rate CondInfo (PrimOr Var)
check PrimOr Var
th PrimOr Var
el IfElseCons (PrimOr Var)
cons  -> IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> IfElseCons (PrimOr Var)
-> Var
-> [Stmt Var]
-> Collect s [Stmt Var]
forall {s}.
IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> IfElseCons (PrimOr Var)
-> Var
-> [Stmt Var]
-> StateT (Env s) (ST s) [Stmt Var]
onIfElse IfRate
rate CondInfo (PrimOr Var)
check PrimOr Var
th PrimOr Var
el IfElseCons (PrimOr Var)
cons (Stmt Var -> Var
forall a. Stmt a -> a
stmtLhs Stmt Var
expr) [Stmt Var]
exprs
      IfExpType IfRate
rate CondInfo (PrimOr Var)
check PrimOr Var
th PrimOr Var
el        -> IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> Var
-> [Stmt Var]
-> Collect s [Stmt Var]
forall {s}.
IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> Var
-> [Stmt Var]
-> StateT (Env s) (ST s) [Stmt Var]
onIfExp IfRate
rate CondInfo (PrimOr Var)
check PrimOr Var
th PrimOr Var
el (Stmt Var -> Var
forall a. Stmt a -> a
stmtLhs Stmt Var
expr) [Stmt Var]
exprs
  where
    onPlain :: Stmt Var -> [Stmt Var] -> Collect s [Stmt Var]
onPlain Stmt Var
expr [Stmt Var]
rest = [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
collectIter (Stmt Var
expr Stmt Var -> [Stmt Var] -> [Stmt Var]
forall a. a -> [a] -> [a]
: [Stmt Var]
results) [Stmt Var]
rest

    onIf :: IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> IfCons (PrimOr Var)
-> Var
-> [Stmt Var]
-> StateT (Env s) (ST s) [Stmt Var]
onIf IfRate
ifRate CondInfo (PrimOr Var)
check PrimOr Var
th IfCons (PrimOr Var)
cons Var
lhs [Stmt Var]
exprs = do
      LocalVars
vs <- IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
forall {s}. IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
blockLocalVars IfRate
ifRate PrimOr Var
th
      ([Stmt Var]
newIfBlock, [Stmt Var]
rest) <- LocalVars
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
LocalVars
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
redefineIf LocalVars
vs Var
lhs IfRate
ifRate CondInfo (PrimOr Var)
check IfCons (PrimOr Var)
cons [Stmt Var]
exprs
      [Stmt Var] -> [Stmt Var] -> StateT (Env s) (ST s) [Stmt Var]
forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
toResult [Stmt Var]
newIfBlock [Stmt Var]
rest

    toResult :: [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
toResult [Stmt Var]
newIfBlock [Stmt Var]
rest = do
      [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
collectIter ([Stmt Var] -> [Stmt Var] -> [Stmt Var]
forall a. [a] -> [a] -> [a]
copyToResult [Stmt Var]
newIfBlock [Stmt Var]
results) [Stmt Var]
rest

    copyToResult :: [a] -> [a] -> [a]
    copyToResult :: forall a. [a] -> [a] -> [a]
copyToResult [a]
items [a]
result = ([a] -> a -> [a]) -> [a] -> [a] -> [a]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' ((a -> [a] -> [a]) -> [a] -> a -> [a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (:)) [a]
result [a]
items

    blockLocalVars :: IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
blockLocalVars IfRate
ifRate PrimOr Var
root = do
      LocalUsageCounts
localUsage <- IfRate -> PrimOr Var -> Collect s LocalUsageCounts
forall s. IfRate -> PrimOr Var -> Collect s LocalUsageCounts
getLocalUsage IfRate
ifRate PrimOr Var
root
      -- globals <- mapM (\v -> (\g -> (v, (g, localUsage IntMap.! v))) <$> readGlobalUsages v) $ IntMap.keys localUsage
      -- trace (unlines $ show <$> globals) $
      LocalUsageCounts
-> IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
forall s.
LocalUsageCounts -> IfRate -> PrimOr Var -> Collect s LocalVars
getLocalVars LocalUsageCounts
localUsage IfRate
ifRate PrimOr Var
root

    onIfElse :: IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> IfElseCons (PrimOr Var)
-> Var
-> [Stmt Var]
-> StateT (Env s) (ST s) [Stmt Var]
onIfElse IfRate
ifRate CondInfo (PrimOr Var)
check PrimOr Var
th PrimOr Var
el IfElseCons (PrimOr Var)
cons Var
lhs [Stmt Var]
exprs = do
      LocalVars
thVars <- IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
forall {s}. IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
blockLocalVars IfRate
ifRate PrimOr Var
th
      LocalVars
elVars <- IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
forall {s}. IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
blockLocalVars IfRate
ifRate PrimOr Var
el
      ([Stmt Var]
newIfBlock, [Stmt Var]
rest) <- LocalVars
-> LocalVars
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfElseCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
LocalVars
-> LocalVars
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfElseCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
redefineIfElse LocalVars
thVars LocalVars
elVars Var
lhs IfRate
ifRate CondInfo (PrimOr Var)
check IfElseCons (PrimOr Var)
cons [Stmt Var]
exprs
      [Stmt Var] -> [Stmt Var] -> StateT (Env s) (ST s) [Stmt Var]
forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
toResult [Stmt Var]
newIfBlock [Stmt Var]
rest

    onIfExp :: IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> Var
-> [Stmt Var]
-> StateT (Env s) (ST s) [Stmt Var]
onIfExp IfRate
ifRate CondInfo (PrimOr Var)
check PrimOr Var
th PrimOr Var
el Var
lhs [Stmt Var]
exprs = do
      LocalVars
thVars <- IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
forall {s}. IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
blockLocalVars IfRate
ifRate PrimOr Var
th
      LocalVars
elVars <- IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
forall {s}. IfRate -> PrimOr Var -> StateT (Env s) (ST s) LocalVars
blockLocalVars IfRate
ifRate PrimOr Var
el
      ([Stmt Var]
newIfBlock, [Stmt Var]
rest) <- LocalVars
-> LocalVars
-> PrimOr Var
-> PrimOr Var
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfElseCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
LocalVars
-> LocalVars
-> PrimOr Var
-> PrimOr Var
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfElseCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
redefineIfElseExp LocalVars
thVars LocalVars
elVars PrimOr Var
th PrimOr Var
el Var
lhs IfRate
ifRate CondInfo (PrimOr Var)
check IfElseCons (PrimOr Var)
forall {a}. IfElseCons a
cons [Stmt Var]
exprs
      [Stmt Var] -> [Stmt Var] -> StateT (Env s) (ST s) [Stmt Var]
forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
toResult [Stmt Var]
newIfBlock [Stmt Var]
rest
      where
        cons :: IfElseCons a
cons = IfElseCons { ifElseBegin :: IfRate -> CondInfo a -> MainExp a
ifElseBegin = IfRate -> CondInfo a -> MainExp a
forall a. IfRate -> CondInfo a -> MainExp a
IfBegin, elseBegin :: MainExp a
elseBegin = MainExp a
forall a. MainExp a
ElseBegin, ifElseEnd :: MainExp a
ifElseEnd = MainExp a
forall a. MainExp a
IfEnd }

collectSubs :: Bool -> [Expr] -> Collect s [Expr]
collectSubs :: forall s. Bool -> [Stmt Var] -> Collect s [Stmt Var]
collectSubs Bool
hasIfs [Stmt Var]
newIfBlock
  | Bool
hasIfs    = [Stmt Var] -> [Stmt Var]
forall a. [a] -> [a]
List.reverse ([Stmt Var] -> [Stmt Var])
-> StateT (Env s) (ST s) [Stmt Var]
-> StateT (Env s) (ST s) [Stmt Var]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Stmt Var] -> [Stmt Var] -> StateT (Env s) (ST s) [Stmt Var]
forall s. [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
collectIter [] [Stmt Var]
newIfBlock
  | Bool
otherwise = [Stmt Var] -> StateT (Env s) (ST s) [Stmt Var]
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Stmt Var]
newIfBlock

redefineIf ::
     LocalVars
  -> Var
  -> IfRate
  -> CondInfo (PrimOr Var)
  -> IfCons (PrimOr Var)
  -> [Expr]
  -> Collect s ([Expr], [Expr])
redefineIf :: forall s.
LocalVars
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
redefineIf LocalVars
localVars Var
ifBeginId IfRate
ifRate CondInfo (PrimOr Var)
condInfo IfCons{MainExp (PrimOr Var)
IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifBegin :: forall a. IfCons a -> IfRate -> CondInfo a -> MainExp a
ifEnd :: forall a. IfCons a -> MainExp a
ifBegin :: IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifEnd :: MainExp (PrimOr Var)
..} [Stmt Var]
exprs = do
  (Stmt Var, Stmt Var)
ifStmts <- StateT (Env s) (ST s) (Stmt Var, Stmt Var)
forall {s}. StateT (Env s) (ST s) (Stmt Var, Stmt Var)
getIfStmts
  ([Stmt Var] -> [Stmt Var])
-> ([Stmt Var], [Stmt Var]) -> ([Stmt Var], [Stmt Var])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Stmt Var, Stmt Var) -> [Stmt Var] -> [Stmt Var]
forall {a}. (a, a) -> [a] -> [a]
toResult (Stmt Var, Stmt Var)
ifStmts) (([Stmt Var], [Stmt Var]) -> ([Stmt Var], [Stmt Var]))
-> Collect s ([Stmt Var], [Stmt Var])
-> Collect s ([Stmt Var], [Stmt Var])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
iterRedefine IfRate
ifRate LocalVars
localVars Int
blockSize [] Bool
False [] [Stmt Var]
exprs
  where
    blockSize :: Int
blockSize = LocalVars -> Int
IntSet.size LocalVars
localVars

    -- | we expect if-block expressions to be reversed
    toResult :: (a, a) -> [a] -> [a]
toResult (a
ifBeginStmt, a
ifEndStmt) [a]
blockExprs =
      a
ifEndStmt a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
blockExprs [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a
ifBeginStmt]

    getIfStmts :: StateT (Env s) (ST s) (Stmt Var, Stmt Var)
getIfStmts = do
      Int
ifEndId <- Collect s Int
forall s. Collect s Int
freshId
      let ifEndStmt :: Stmt Var
ifEndStmt = Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt (Rate -> Int -> Var
Var Rate
Xr Int
ifEndId) (MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp MainExp (PrimOr Var)
ifEnd)
          ifBeginStmt :: Stmt Var
ifBeginStmt = Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt Var
ifBeginId (MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp (MainExp (PrimOr Var) -> RatedExp Var)
-> MainExp (PrimOr Var) -> RatedExp Var
forall a b. (a -> b) -> a -> b
$ IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifBegin IfRate
ifRate CondInfo (PrimOr Var)
condInfo)
      (Stmt Var, Stmt Var) -> StateT (Env s) (ST s) (Stmt Var, Stmt Var)
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stmt Var
ifBeginStmt, Stmt Var
ifEndStmt)

iterRedefine :: IfRate -> LocalVars -> Int -> [Expr] -> Bool -> [Expr]-> [Expr] -> Collect s ([Expr], [Expr])
iterRedefine :: forall s.
IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
iterRedefine IfRate
ifRate LocalVars
localVars Int
currentBlockSize [Stmt Var]
resultIfExprs Bool
hasIfs [Stmt Var]
resultRest [Stmt Var]
nextExprs
  | Int
currentBlockSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
forall {s}. StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
result
  | Bool
otherwise      =
      case [Stmt Var]
nextExprs of
        []              -> StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
forall {s}. StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
result
        e :: Stmt Var
e@(Stmt Var
lhs RatedExp Var
_) : [Stmt Var]
es ->
          if Var -> Bool
isLocal Var
lhs
            then Stmt Var
-> [Stmt Var] -> StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
forall {s}.
Stmt Var -> [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
appendLocal Stmt Var
e [Stmt Var]
es
            else Stmt Var
-> [Stmt Var] -> StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
forall {s}.
Stmt Var -> [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
appendRest Stmt Var
e [Stmt Var]
es
  where
    rec :: (Int -> Int)
-> Stmt Var
-> ([Stmt Var] -> [Stmt Var])
-> ([Stmt Var] -> [Stmt Var])
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
rec Int -> Int
onBlockSize Stmt Var
expr [Stmt Var] -> [Stmt Var]
onIfExprs [Stmt Var] -> [Stmt Var]
onRestExprs [Stmt Var]
newNextExprs =
      IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
iterRedefine
        IfRate
ifRate
        LocalVars
localVars
        (Int -> Int
onBlockSize Int
currentBlockSize)
        ([Stmt Var] -> [Stmt Var]
onIfExprs [Stmt Var]
resultIfExprs)
        (Bool
hasIfs Bool -> Bool -> Bool
|| (RatedExp Var -> Bool
isIfExpr (RatedExp Var -> Bool) -> RatedExp Var -> Bool
forall a b. (a -> b) -> a -> b
$ Stmt Var -> RatedExp Var
forall a. Stmt a -> RatedExp a
stmtRhs Stmt Var
expr))
        ([Stmt Var] -> [Stmt Var]
onRestExprs [Stmt Var]
resultRest)
        [Stmt Var]
newNextExprs

    result :: StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
result = ([Stmt Var], Bool, [Stmt Var])
-> StateT (Env s) (ST s) ([Stmt Var], [Stmt Var])
forall {b} {s}.
([Stmt Var], Bool, b) -> StateT (Env s) (ST s) ([Stmt Var], b)
recollect
      ( [Stmt Var] -> [Stmt Var]
forall a. [a] -> [a]
List.reverse ([Stmt Var] -> [Stmt Var]) -> [Stmt Var] -> [Stmt Var]
forall a b. (a -> b) -> a -> b
$ [Stmt Var]
resultIfExprs
      , Bool
hasIfs
      , [Stmt Var] -> [Stmt Var]
forall a. [a] -> [a]
List.reverse [Stmt Var]
resultRest [Stmt Var] -> [Stmt Var] -> [Stmt Var]
forall a. Semigroup a => a -> a -> a
<> [Stmt Var]
nextExprs
      )

    recollect :: ([Stmt Var], Bool, b) -> StateT (Env s) (ST s) ([Stmt Var], b)
recollect ([Stmt Var]
newIfBlock, Bool
finalHasIfs, b
rest) = do
      [Stmt Var]
newIfBlockCollected <- Bool -> [Stmt Var] -> Collect s [Stmt Var]
forall s. Bool -> [Stmt Var] -> Collect s [Stmt Var]
collectSubs Bool
finalHasIfs [Stmt Var]
newIfBlock
      ([Stmt Var], b) -> StateT (Env s) (ST s) ([Stmt Var], b)
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stmt Var]
newIfBlockCollected, b
rest)

    appendLocal :: Stmt Var -> [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
appendLocal Stmt Var
e [Stmt Var]
es = (Int -> Int)
-> Stmt Var
-> ([Stmt Var] -> [Stmt Var])
-> ([Stmt Var] -> [Stmt Var])
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall {s}.
(Int -> Int)
-> Stmt Var
-> ([Stmt Var] -> [Stmt Var])
-> ([Stmt Var] -> [Stmt Var])
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
rec Int -> Int
forall a. Enum a => a -> a
pred Stmt Var
e (Stmt Var
e Stmt Var -> [Stmt Var] -> [Stmt Var]
forall a. a -> [a] -> [a]
: ) [Stmt Var] -> [Stmt Var]
forall a. a -> a
id     [Stmt Var]
es
    appendRest :: Stmt Var -> [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
appendRest Stmt Var
e [Stmt Var]
es  = (Int -> Int)
-> Stmt Var
-> ([Stmt Var] -> [Stmt Var])
-> ([Stmt Var] -> [Stmt Var])
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall {s}.
(Int -> Int)
-> Stmt Var
-> ([Stmt Var] -> [Stmt Var])
-> ([Stmt Var] -> [Stmt Var])
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
rec Int -> Int
forall a. a -> a
id   Stmt Var
e [Stmt Var] -> [Stmt Var]
forall a. a -> a
id     (Stmt Var
e Stmt Var -> [Stmt Var] -> [Stmt Var]
forall a. a -> [a] -> [a]
: ) [Stmt Var]
es

    isLocal :: Var -> Bool
    isLocal :: Var -> Bool
isLocal Var
var = Int -> LocalVars -> Bool
IntSet.member (Var -> Int
varId Var
var) LocalVars
localVars

redefineIfElse ::
     LocalVars
  -> LocalVars
  -> Var
  -> IfRate
  -> CondInfo (PrimOr Var)
  -> IfElseCons (PrimOr Var)
  -> [Expr]
  -> Collect s ([Expr], [Expr])
redefineIfElse :: forall s.
LocalVars
-> LocalVars
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfElseCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
redefineIfElse LocalVars
thLocalVars LocalVars
elLocalVars Var
ifBeginId IfRate
ifRate CondInfo (PrimOr Var)
condInfo IfElseCons{MainExp (PrimOr Var)
IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifElseBegin :: forall a. IfElseCons a -> IfRate -> CondInfo a -> MainExp a
elseBegin :: forall a. IfElseCons a -> MainExp a
ifElseEnd :: forall a. IfElseCons a -> MainExp a
ifElseBegin :: IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
elseBegin :: MainExp (PrimOr Var)
ifElseEnd :: MainExp (PrimOr Var)
..} [Stmt Var]
exprs = do
  (Stmt Var, Stmt Var, Stmt Var)
ifStmts <- StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
forall {s}. StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
getIfElseStmts
  ([Stmt Var]
ifBlockExprs, [Stmt Var]
rest1) <- [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
forall {s}. [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getIfPart [Stmt Var]
exprs
  ([Stmt Var]
elseBlockExprs, [Stmt Var]
rest2) <- [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
forall {s}. [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getElsePart [Stmt Var]
rest1
  ([Stmt Var], [Stmt Var]) -> Collect s ([Stmt Var], [Stmt Var])
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Stmt Var, Stmt Var, Stmt Var)
-> [Stmt Var] -> [Stmt Var] -> [Stmt Var]
forall {a}. (a, a, a) -> [a] -> [a] -> [a]
toResult (Stmt Var, Stmt Var, Stmt Var)
ifStmts [Stmt Var]
ifBlockExprs [Stmt Var]
elseBlockExprs, [Stmt Var]
rest2)
  where
    -- note that block epxressions are reversed
    toResult :: (a, a, a) -> [a] -> [a] -> [a]
toResult (a
ifBeginStmt, a
elseBeginStmt, a
ifEndStmt) [a]
ifBlockExprs [a]
elseBlockExprs =
      a
ifEndStmt a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [[a]] -> [a]
forall a. Monoid a => [a] -> a
mconcat
      [   [a]
elseBlockExprs
      ,   [a
elseBeginStmt]
      ,   [a]
ifBlockExprs
      , [a
ifBeginStmt]
      ]

    getIfElseStmts :: StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
getIfElseStmts = do
      let ifBeginStmt :: Stmt Var
ifBeginStmt = Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt Var
ifBeginId (MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp (MainExp (PrimOr Var) -> RatedExp Var)
-> MainExp (PrimOr Var) -> RatedExp Var
forall a b. (a -> b) -> a -> b
$ IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifElseBegin IfRate
ifRate CondInfo (PrimOr Var)
condInfo)
      Stmt Var
elseBeginStmt <- (\Int
elId -> Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt (Rate -> Int -> Var
Var Rate
Xr Int
elId) (MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp MainExp (PrimOr Var)
elseBegin)) (Int -> Stmt Var)
-> StateT (Env s) (ST s) Int -> StateT (Env s) (ST s) (Stmt Var)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Env s) (ST s) Int
forall s. Collect s Int
freshId
      Stmt Var
ifEndStmt <- (\Int
endId -> Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt (Rate -> Int -> Var
Var Rate
Xr Int
endId) (MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp MainExp (PrimOr Var)
ifElseEnd)) (Int -> Stmt Var)
-> StateT (Env s) (ST s) Int -> StateT (Env s) (ST s) (Stmt Var)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Env s) (ST s) Int
forall s. Collect s Int
freshId
      (Stmt Var, Stmt Var, Stmt Var)
-> StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stmt Var
ifBeginStmt, Stmt Var
elseBeginStmt, Stmt Var
ifEndStmt)

    getIfPart :: [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getIfPart [Stmt Var]
es = IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
iterRedefine IfRate
ifRate LocalVars
thLocalVars Int
ifBlockSize [] Bool
False [] [Stmt Var]
es
    getElsePart :: [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getElsePart [Stmt Var]
es = IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
iterRedefine IfRate
ifRate LocalVars
elLocalVars Int
elseBlockSize [] Bool
False [] [Stmt Var]
es

    ifBlockSize :: Int
ifBlockSize = LocalVars -> Int
IntSet.size LocalVars
thLocalVars
    elseBlockSize :: Int
elseBlockSize = LocalVars -> Int
IntSet.size LocalVars
elLocalVars

redefineIfElseExp ::
     forall s
   . LocalVars
  -> LocalVars
  -> PrimOr Var
  -> PrimOr Var
  -> Var
  -> IfRate
  -> CondInfo (PrimOr Var)
  -> IfElseCons (PrimOr Var)
  -> [Expr]
  -> Collect s ([Expr], [Expr])
redefineIfElseExp :: forall s.
LocalVars
-> LocalVars
-> PrimOr Var
-> PrimOr Var
-> Var
-> IfRate
-> CondInfo (PrimOr Var)
-> IfElseCons (PrimOr Var)
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
redefineIfElseExp LocalVars
thLocalVars LocalVars
elLocalVars PrimOr Var
th PrimOr Var
el Var
ifResultId IfRate
ifRate CondInfo (PrimOr Var)
condInfo IfElseCons{MainExp (PrimOr Var)
IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifElseBegin :: forall a. IfElseCons a -> IfRate -> CondInfo a -> MainExp a
elseBegin :: forall a. IfElseCons a -> MainExp a
ifElseEnd :: forall a. IfElseCons a -> MainExp a
ifElseBegin :: IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
elseBegin :: MainExp (PrimOr Var)
ifElseEnd :: MainExp (PrimOr Var)
..} [Stmt Var]
exprs = do
  (Stmt Var, Stmt Var, Stmt Var)
ifStmts <- StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
forall {s}. StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
getIfElseStmts
  -- note that blocks are returned in reversed order
  ([Stmt Var]
ifBlockExprs, [Stmt Var]
rest1) <- [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
forall {s}. [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getIfPart [Stmt Var]
exprs
  ([Stmt Var]
elseBlockExprs, [Stmt Var]
rest2) <- [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
forall {s}. [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getElsePart [Stmt Var]
rest1
  [Stmt Var]
ifResult <- (Stmt Var, Stmt Var, Stmt Var)
-> [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
toResult (Stmt Var, Stmt Var, Stmt Var)
ifStmts [Stmt Var]
ifBlockExprs [Stmt Var]
elseBlockExprs
  ([Stmt Var], [Stmt Var]) -> Collect s ([Stmt Var], [Stmt Var])
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stmt Var]
ifResult, [Stmt Var]
rest2)
  where
     -- note that expressions in the blocks are returned in reversed order
    toResult :: (Expr, Expr, Expr) -> [Expr] -> [Expr] -> Collect s [Expr]
    toResult :: (Stmt Var, Stmt Var, Stmt Var)
-> [Stmt Var] -> [Stmt Var] -> Collect s [Stmt Var]
toResult (Stmt Var
ifBeginStmt, Stmt Var
elseBeginStmt, Stmt Var
ifEndStmt) [Stmt Var]
ifBlockExprs [Stmt Var]
elseBlockExprs = do
      Stmt Var
thAssign <- Var -> PrimOr Var -> Collect s (Stmt Var)
writeRes Var
ifResultId PrimOr Var
th
      Stmt Var
elAssign <- Var -> PrimOr Var -> Collect s (Stmt Var)
writeRes Var
ifResultId PrimOr Var
el
      [Stmt Var] -> Collect s [Stmt Var]
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stmt Var] -> Collect s [Stmt Var])
-> [Stmt Var] -> Collect s [Stmt Var]
forall a b. (a -> b) -> a -> b
$
            Stmt Var
ifEndStmt Stmt Var -> [Stmt Var] -> [Stmt Var]
forall a. a -> [a] -> [a]
: Stmt Var
elAssign Stmt Var -> [Stmt Var] -> [Stmt Var]
forall a. a -> [a] -> [a]
: [[Stmt Var]] -> [Stmt Var]
forall a. Monoid a => [a] -> a
mconcat
            [   [Stmt Var]
elseBlockExprs
            ,   [Stmt Var
elseBeginStmt, Stmt Var
thAssign]
            ,   [Stmt Var]
ifBlockExprs
            , [Stmt Var
ifBeginStmt]
            ]

    getIfElseStmts :: StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
getIfElseStmts = do
      Stmt Var
ifBeginStmt <- (\Int
ifBeginId -> (Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt (Rate -> Int -> Var
Var Rate
Xr Int
ifBeginId) (RatedExp Var -> Stmt Var) -> RatedExp Var -> Stmt Var
forall a b. (a -> b) -> a -> b
$ MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp (MainExp (PrimOr Var) -> RatedExp Var)
-> MainExp (PrimOr Var) -> RatedExp Var
forall a b. (a -> b) -> a -> b
$ IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifElseBegin IfRate
ifRate CondInfo (PrimOr Var)
condInfo)) (Int -> Stmt Var)
-> StateT (Env s) (ST s) Int -> StateT (Env s) (ST s) (Stmt Var)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Env s) (ST s) Int
forall s. Collect s Int
freshId
      Stmt Var
elseBeginStmt <- (\Int
elId -> (Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt (Rate -> Int -> Var
Var Rate
Xr Int
elId) (RatedExp Var -> Stmt Var) -> RatedExp Var -> Stmt Var
forall a b. (a -> b) -> a -> b
$ MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp MainExp (PrimOr Var)
elseBegin)) (Int -> Stmt Var)
-> StateT (Env s) (ST s) Int -> StateT (Env s) (ST s) (Stmt Var)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Env s) (ST s) Int
forall s. Collect s Int
freshId
      Stmt Var
ifEndStmt <- (\Int
endId -> (Var -> RatedExp Var -> Stmt Var
forall a. a -> RatedExp a -> Stmt a
Stmt (Rate -> Int -> Var
Var Rate
Xr Int
endId) (RatedExp Var -> Stmt Var) -> RatedExp Var -> Stmt Var
forall a b. (a -> b) -> a -> b
$ MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp MainExp (PrimOr Var)
ifElseEnd)) (Int -> Stmt Var)
-> StateT (Env s) (ST s) Int -> StateT (Env s) (ST s) (Stmt Var)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Env s) (ST s) Int
forall s. Collect s Int
freshId
      (Stmt Var, Stmt Var, Stmt Var)
-> StateT (Env s) (ST s) (Stmt Var, Stmt Var, Stmt Var)
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stmt Var
ifBeginStmt, Stmt Var
elseBeginStmt, Stmt Var
ifEndStmt)

    getIfPart :: [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getIfPart [Stmt Var]
es = IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
iterRedefine IfRate
ifRate LocalVars
thLocalVars Int
ifBlockSize [] Bool
False [] [Stmt Var]
es
    getElsePart :: [Stmt Var] -> Collect s ([Stmt Var], [Stmt Var])
getElsePart [Stmt Var]
es = IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
forall s.
IfRate
-> LocalVars
-> Int
-> [Stmt Var]
-> Bool
-> [Stmt Var]
-> [Stmt Var]
-> Collect s ([Stmt Var], [Stmt Var])
iterRedefine IfRate
ifRate LocalVars
elLocalVars Int
elseBlockSize [] Bool
False [] [Stmt Var]
es

    ifBlockSize :: Int
ifBlockSize = LocalVars -> Int
IntSet.size LocalVars
thLocalVars
    elseBlockSize :: Int
elseBlockSize = LocalVars -> Int
IntSet.size LocalVars
elLocalVars

    writeRes :: Var -> PrimOr Var -> Collect s Expr
    writeRes :: Var -> PrimOr Var -> Collect s (Stmt Var)
writeRes Var
resId PrimOr Var
expr = do
      Int
varWriteId <- Collect s Int
forall s. Collect s Int
freshId
      Stmt Var -> Collect s (Stmt Var)
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stmt Var -> Collect s (Stmt Var))
-> Stmt Var -> Collect s (Stmt Var)
forall a b. (a -> b) -> a -> b
$ Stmt
        { stmtLhs :: Var
stmtLhs = Rate -> Int -> Var
Var Rate
Xr Int
varWriteId
        , stmtRhs :: RatedExp Var
stmtRhs = MainExp (PrimOr Var) -> RatedExp Var
forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp (MainExp (PrimOr Var) -> RatedExp Var)
-> MainExp (PrimOr Var) -> RatedExp Var
forall a b. (a -> b) -> a -> b
$ Var -> PrimOr Var -> MainExp (PrimOr Var)
forall a. Var -> a -> MainExp a
WriteVar (Var -> Var
toVar Var
resId) PrimOr Var
expr
        }

    toVar :: Var -> Var
toVar Var
v = Rate -> Text -> Var
Exp.VarVerbatim (Var -> Rate
varType Var
v) Text
name
      where
        name :: Text
name = Text -> Text
Text.toLower (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Rate -> String
forall a. Show a => a -> String
show (Var -> Rate
varType Var
v) String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Var -> Int
varId Var
v)


toRatedExp :: MainExp (PrimOr a) -> RatedExp a
toRatedExp :: forall a. MainExp (PrimOr a) -> RatedExp a
toRatedExp MainExp (PrimOr a)
expr =
  RatedExp
    { ratedExpHash :: ByteString
ratedExpHash = ByteString
""
    , ratedExpDepends :: Maybe Int
ratedExpDepends = Maybe Int
forall a. Maybe a
Nothing
    , ratedExpRate :: Maybe Rate
ratedExpRate = Maybe Rate
forall a. Maybe a
Nothing
    , ratedExpExp :: MainExp (PrimOr a)
ratedExpExp = MainExp (PrimOr a)
expr
    }

type LocalMarks = IntMap Bool

getLocalVars :: forall s . LocalUsageCounts -> IfRate -> PrimOr Var -> Collect s LocalVars
getLocalVars :: forall s.
LocalUsageCounts -> IfRate -> PrimOr Var -> Collect s LocalVars
getLocalVars LocalUsageCounts
localUsages IfRate
ifRate PrimOr Var
root = LocalMarks -> LocalVars
toSet (LocalMarks -> LocalVars)
-> StateT (Env s) (ST s) LocalMarks
-> StateT (Env s) (ST s) LocalVars
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  (Stmt Var -> LocalMarks -> StateT (Env s) (ST s) LocalMarks)
-> LocalMarks
-> (Stmt Var -> Collect s Bool)
-> PrimOr Var
-> StateT (Env s) (ST s) LocalMarks
forall s a.
Show a =>
(Stmt Var -> a -> Collect s a)
-> a -> (Stmt Var -> Collect s Bool) -> PrimOr Var -> Collect s a
traverseAccumDag Stmt Var -> LocalMarks -> StateT (Env s) (ST s) LocalMarks
update LocalMarks
initMarks (IfRate -> Stmt Var -> Collect s Bool
forall s. IfRate -> Stmt Var -> Collect s Bool
isEnd IfRate
ifRate) PrimOr Var
root
  where
    initMarks :: LocalMarks
initMarks = (Prim -> LocalMarks)
-> (Var -> LocalMarks) -> Either Prim Var -> LocalMarks
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (LocalMarks -> Prim -> LocalMarks
forall a b. a -> b -> a
const LocalMarks
forall a. IntMap a
IntMap.empty) (\Var
var -> Int -> Bool -> LocalMarks
forall a. Int -> a -> IntMap a
IntMap.singleton (Var -> Int
varId Var
var) Bool
True) (Either Prim Var -> LocalMarks) -> Either Prim Var -> LocalMarks
forall a b. (a -> b) -> a -> b
$ PrimOr Var -> Either Prim Var
forall a. PrimOr a -> Either Prim a
unPrimOr PrimOr Var
root

    update :: Expr -> LocalMarks -> Collect s LocalMarks
    update :: Stmt Var -> LocalMarks -> StateT (Env s) (ST s) LocalMarks
update (Stmt Var
lhs RatedExp Var
rhs) LocalMarks
localMarks
      | Bool
isParentLocal = do
          Bool
isLocal <- Var -> Collect s Bool
fullyInsideLocal Var
lhs
          let tfm :: LocalMarks -> LocalMarks
tfm = if Bool
isLocal then LocalMarks -> LocalMarks
forall a. a -> a
id else LocalMarks -> LocalMarks
onFalseLocal
          LocalMarks -> StateT (Env s) (ST s) LocalMarks
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LocalMarks -> StateT (Env s) (ST s) LocalMarks)
-> LocalMarks -> StateT (Env s) (ST s) LocalMarks
forall a b. (a -> b) -> a -> b
$ LocalMarks -> LocalMarks
tfm (LocalMarks -> LocalMarks) -> LocalMarks -> LocalMarks
forall a b. (a -> b) -> a -> b
$ (Maybe Bool -> Maybe Bool) -> Int -> LocalMarks -> LocalMarks
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IntMap.alter (Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool)
-> (Maybe Bool -> Bool) -> Maybe Bool -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> (Bool -> Bool) -> Maybe Bool -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
isLocal (Bool
isLocal Bool -> Bool -> Bool
&&)) (Var -> Int
varId Var
lhs) LocalMarks
localMarks
      | Bool
otherwise = LocalMarks -> StateT (Env s) (ST s) LocalMarks
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LocalMarks -> StateT (Env s) (ST s) LocalMarks)
-> LocalMarks -> StateT (Env s) (ST s) LocalMarks
forall a b. (a -> b) -> a -> b
$ LocalMarks -> LocalMarks
onFalseLocal LocalMarks
localMarks
      where
        isParentLocal :: Bool
isParentLocal = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
True (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> LocalMarks -> Maybe Bool
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (Var -> Int
varId Var
lhs) LocalMarks
localMarks

        onFalseLocal :: LocalMarks -> LocalMarks
onFalseLocal =
          State LocalMarks () -> LocalMarks -> LocalMarks
forall s a. State s a -> s -> s
execState ((Var -> State LocalMarks ()) -> RatedExp Var -> State LocalMarks ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\Var
v -> (LocalMarks -> LocalMarks) -> State LocalMarks ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((LocalMarks -> LocalMarks) -> State LocalMarks ())
-> (LocalMarks -> LocalMarks) -> State LocalMarks ()
forall a b. (a -> b) -> a -> b
$ Int -> Bool -> LocalMarks -> LocalMarks
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert (Var -> Int
varId Var
v) Bool
False) RatedExp Var
rhs)

    fullyInsideLocal :: Var -> Collect s Bool
    fullyInsideLocal :: Var -> Collect s Bool
fullyInsideLocal Var
lhs = do
      Int
globalCount <- Int -> Collect s Int
forall s. Int -> Collect s Int
readGlobalUsages (Var -> Int
varId Var
lhs)
      let localCount :: Maybe Int
localCount = Int -> LocalUsageCounts -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (Var -> Int
varId Var
lhs) LocalUsageCounts
localUsages
      Bool -> Collect s Bool
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Collect s Bool) -> Bool -> Collect s Bool
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
globalCount Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Int
localCount

    toSet :: LocalMarks -> LocalVars
    toSet :: LocalMarks -> LocalVars
toSet = LocalMarks -> LocalVars
forall a. IntMap a -> LocalVars
IntMap.keysSet (LocalMarks -> LocalVars)
-> (LocalMarks -> LocalMarks) -> LocalMarks -> LocalVars
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Bool) -> LocalMarks -> LocalMarks
forall a. (a -> Bool) -> IntMap a -> IntMap a
IntMap.filter Bool -> Bool
forall a. a -> a
id

getLocalUsage :: forall s . IfRate -> PrimOr Var -> Collect s LocalUsageCounts
getLocalUsage :: forall s. IfRate -> PrimOr Var -> Collect s LocalUsageCounts
getLocalUsage IfRate
ifRate PrimOr Var
root =
  (Stmt Var -> LocalUsageCounts -> Collect s LocalUsageCounts)
-> LocalUsageCounts
-> (Stmt Var -> Collect s Bool)
-> PrimOr Var
-> Collect s LocalUsageCounts
forall s a.
Show a =>
(Stmt Var -> a -> Collect s a)
-> a -> (Stmt Var -> Collect s Bool) -> PrimOr Var -> Collect s a
traverseAccumDag Stmt Var -> LocalUsageCounts -> Collect s LocalUsageCounts
update LocalUsageCounts
initCount (IfRate -> Stmt Var -> Collect s Bool
forall s. IfRate -> Stmt Var -> Collect s Bool
isEnd IfRate
ifRate) PrimOr Var
root
  where
    initCount :: LocalUsageCounts
initCount = (Prim -> LocalUsageCounts)
-> (Var -> LocalUsageCounts) -> Either Prim Var -> LocalUsageCounts
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (LocalUsageCounts -> Prim -> LocalUsageCounts
forall a b. a -> b -> a
const LocalUsageCounts
forall a. IntMap a
IntMap.empty) (\Var
var -> Int -> Int -> LocalUsageCounts
forall a. Int -> a -> IntMap a
IntMap.singleton (Var -> Int
varId Var
var) Int
1) (Either Prim Var -> LocalUsageCounts)
-> Either Prim Var -> LocalUsageCounts
forall a b. (a -> b) -> a -> b
$ PrimOr Var -> Either Prim Var
forall a. PrimOr a -> Either Prim a
unPrimOr PrimOr Var
root

    update :: Expr -> LocalUsageCounts -> Collect s LocalUsageCounts
    update :: Stmt Var -> LocalUsageCounts -> Collect s LocalUsageCounts
update (Stmt Var
_lhs RatedExp Var
rhs) LocalUsageCounts
st = LocalUsageCounts -> Collect s LocalUsageCounts
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LocalUsageCounts -> Collect s LocalUsageCounts)
-> LocalUsageCounts -> Collect s LocalUsageCounts
forall a b. (a -> b) -> a -> b
$
      State LocalUsageCounts () -> LocalUsageCounts -> LocalUsageCounts
forall s a. State s a -> s -> s
execState ((Var -> State LocalUsageCounts ())
-> RatedExp Var -> State LocalUsageCounts ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Var -> State LocalUsageCounts ()
forall {m :: * -> *} {a}.
(Monad m, Num a, Enum a) =>
Var -> StateT (IntMap a) m ()
count RatedExp Var
rhs) LocalUsageCounts
st

    count :: Var -> StateT (IntMap a) m ()
count Var
var = (IntMap a -> IntMap a) -> StateT (IntMap a) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((IntMap a -> IntMap a) -> StateT (IntMap a) m ())
-> (IntMap a -> IntMap a) -> StateT (IntMap a) m ()
forall a b. (a -> b) -> a -> b
$ (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IntMap.alter (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> (Maybe a -> a) -> Maybe a -> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> (a -> a) -> Maybe a -> a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe a
1 a -> a
forall a. Enum a => a -> a
succ) (Var -> Int
varId Var
var)

---------------------------------------------------------------------------

-- | Defines rule that if we are inside Kr if-block we can not bring inside
-- Ir-expressions
isEnd :: IfRate -> Expr -> Collect s Bool
isEnd :: forall s. IfRate -> Stmt Var -> Collect s Bool
isEnd IfRate
ifRate (Stmt Var
lhs RatedExp Var
rhs)
  | RatedExp Var -> Bool
isInitVar RatedExp Var
rhs = Bool -> StateT (Env s) (ST s) Bool
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  | Bool
otherwise = case IfRate
ifRate of
      IfRate
IfIr -> Bool -> StateT (Env s) (ST s) Bool
forall a. a -> StateT (Env s) (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
      IfRate
IfKr -> Int -> StateT (Env s) (ST s) Bool
forall s. Int -> Collect s Bool
readIsInit (Var -> Int
varId Var
lhs)

isInitVar :: RatedExp Var -> Bool
isInitVar :: RatedExp Var -> Bool
isInitVar RatedExp Var
expr =
  case RatedExp Var -> MainExp (PrimOr Var)
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Var
expr of
    InitVar Var
_ PrimOr Var
_ -> Bool
True
    InitArr Var
_ ArrSize (PrimOr Var)
_ -> Bool
True
    MainExp (PrimOr Var)
_           -> Bool
False

isIfExpr :: RatedExp Var -> Bool
isIfExpr :: RatedExp Var -> Bool
isIfExpr RatedExp Var
rhs = case RatedExp Var -> ExprType (PrimOr Var)
getExprType RatedExp Var
rhs of
  ExprType (PrimOr Var)
PlainType -> Bool
False
  ExprType (PrimOr Var)
_         -> Bool
True

getExprType :: RatedExp Var -> ExprType (PrimOr Var)
getExprType :: RatedExp Var -> ExprType (PrimOr Var)
getExprType RatedExp Var
expr =
  case RatedExp Var -> MainExp (PrimOr Var)
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Var
expr of
    If IfRate
rate CondInfo (PrimOr Var)
c PrimOr Var
th PrimOr Var
el -> IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> ExprType (PrimOr Var)
forall a. IfRate -> CondInfo a -> a -> a -> ExprType a
IfExpType IfRate
rate CondInfo (PrimOr Var)
c PrimOr Var
th PrimOr Var
el
    IfBlock IfRate
rate CondInfo (PrimOr Var)
c (CodeBlock PrimOr Var
th) -> IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> IfCons (PrimOr Var)
-> ExprType (PrimOr Var)
forall a. IfRate -> CondInfo a -> a -> IfCons a -> ExprType a
IfType IfRate
rate CondInfo (PrimOr Var)
c PrimOr Var
th (IfCons (PrimOr Var) -> ExprType (PrimOr Var))
-> IfCons (PrimOr Var) -> ExprType (PrimOr Var)
forall a b. (a -> b) -> a -> b
$ IfCons { ifBegin :: IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifBegin = IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
forall a. IfRate -> CondInfo a -> MainExp a
IfBegin, ifEnd :: MainExp (PrimOr Var)
ifEnd = MainExp (PrimOr Var)
forall a. MainExp a
IfEnd }
    IfElseBlock IfRate
rate CondInfo (PrimOr Var)
c (CodeBlock PrimOr Var
th) (CodeBlock PrimOr Var
el) -> -- trace (unlines ["TH/EL", show (th, el)])
      IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> PrimOr Var
-> IfElseCons (PrimOr Var)
-> ExprType (PrimOr Var)
forall a.
IfRate -> CondInfo a -> a -> a -> IfElseCons a -> ExprType a
IfElseType IfRate
rate CondInfo (PrimOr Var)
c PrimOr Var
th PrimOr Var
el (IfElseCons (PrimOr Var) -> ExprType (PrimOr Var))
-> IfElseCons (PrimOr Var) -> ExprType (PrimOr Var)
forall a b. (a -> b) -> a -> b
$ IfElseCons { ifElseBegin :: IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifElseBegin = IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
forall a. IfRate -> CondInfo a -> MainExp a
IfBegin, elseBegin :: MainExp (PrimOr Var)
elseBegin = MainExp (PrimOr Var)
forall a. MainExp a
ElseBegin, ifElseEnd :: MainExp (PrimOr Var)
ifElseEnd = MainExp (PrimOr Var)
forall a. MainExp a
IfEnd }
    WhileBlock IfRate
rate CondInfo (PrimOr Var)
c (CodeBlock PrimOr Var
th) -> IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> IfCons (PrimOr Var)
-> ExprType (PrimOr Var)
forall a. IfRate -> CondInfo a -> a -> IfCons a -> ExprType a
IfType IfRate
rate CondInfo (PrimOr Var)
c PrimOr Var
th (IfCons (PrimOr Var) -> ExprType (PrimOr Var))
-> IfCons (PrimOr Var) -> ExprType (PrimOr Var)
forall a b. (a -> b) -> a -> b
$ IfCons { ifBegin :: IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifBegin = IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
forall a. IfRate -> CondInfo a -> MainExp a
WhileBegin, ifEnd :: MainExp (PrimOr Var)
ifEnd = MainExp (PrimOr Var)
forall a. MainExp a
WhileEnd }
    UntilBlock IfRate
rate CondInfo (PrimOr Var)
c (CodeBlock PrimOr Var
th) -> IfRate
-> CondInfo (PrimOr Var)
-> PrimOr Var
-> IfCons (PrimOr Var)
-> ExprType (PrimOr Var)
forall a. IfRate -> CondInfo a -> a -> IfCons a -> ExprType a
IfType IfRate
rate CondInfo (PrimOr Var)
c PrimOr Var
th (IfCons (PrimOr Var) -> ExprType (PrimOr Var))
-> IfCons (PrimOr Var) -> ExprType (PrimOr Var)
forall a b. (a -> b) -> a -> b
$ IfCons { ifBegin :: IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
ifBegin = IfRate -> CondInfo (PrimOr Var) -> MainExp (PrimOr Var)
forall a. IfRate -> CondInfo a -> MainExp a
UntilBegin, ifEnd :: MainExp (PrimOr Var)
ifEnd = MainExp (PrimOr Var)
forall a. MainExp a
UntilEnd }
    -- TODO:
    --     While Ref case
    MainExp (PrimOr Var)
_ -> ExprType (PrimOr Var)
forall a. ExprType a
PlainType