module Csound.Dynamic.Tfm.TmpVars
( removeTmpVars
) where
import Control.Monad.Trans.State.Strict
import Data.IntMap.Strict (IntMap)
import Data.IntMap.Strict qualified as IntMap
import Data.Maybe
import Control.Monad
import Csound.Dynamic.Types.Exp
( RatedExp (..),
TmpVar (..),
MainExp (..),
PrimOr (..),
Prim (..),
Rate (..),
getTmpVars,
IfRate (..),
Info (..),
TmpVarRate (..),
getSingleTmpRate,
)
type Node f = (Int, f Int)
type Dag f = [Node f]
type RemoveTmp a = State St a
data St = St
{ St -> IntMap Int
stIds :: IntMap Int
, St -> IntMap (Maybe TmpVarRate, Maybe Info)
stRates :: IntMap (Maybe TmpVarRate, Maybe Info)
}
deriving (Int -> St -> ShowS
[St] -> ShowS
St -> String
(Int -> St -> ShowS)
-> (St -> String) -> ([St] -> ShowS) -> Show St
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> St -> ShowS
showsPrec :: Int -> St -> ShowS
$cshow :: St -> String
show :: St -> String
$cshowList :: [St] -> ShowS
showList :: [St] -> ShowS
Show)
removeTmpVars :: Dag RatedExp -> Dag RatedExp
removeTmpVars :: Dag RatedExp -> Dag RatedExp
removeTmpVars Dag RatedExp
dag = (State St (Dag RatedExp) -> St -> Dag RatedExp)
-> St -> State St (Dag RatedExp) -> Dag RatedExp
forall a b c. (a -> b -> c) -> b -> a -> c
flip State St (Dag RatedExp) -> St -> Dag RatedExp
forall s a. State s a -> s -> a
evalState (IntMap Int -> IntMap (Maybe TmpVarRate, Maybe Info) -> St
St IntMap Int
forall a. IntMap a
IntMap.empty IntMap (Maybe TmpVarRate, Maybe Info)
forall a. IntMap a
IntMap.empty) (State St (Dag RatedExp) -> Dag RatedExp)
-> State St (Dag RatedExp) -> Dag RatedExp
forall a b. (a -> b) -> a -> b
$ do
((Int, RatedExp Int) -> StateT St Identity ())
-> Dag RatedExp -> StateT St Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((TmpVar -> StateT St Identity ())
-> [TmpVar] -> StateT St Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TmpVar -> StateT St Identity ()
saveTmpVarRate ([TmpVar] -> StateT St Identity ())
-> ((Int, RatedExp Int) -> [TmpVar])
-> (Int, RatedExp Int)
-> StateT St Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp Int -> [TmpVar]
forall a. Exp a -> [TmpVar]
getTmpVars (Exp Int -> [TmpVar])
-> ((Int, RatedExp Int) -> Exp Int)
-> (Int, RatedExp Int)
-> [TmpVar]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RatedExp Int -> Exp Int
forall a. RatedExp a -> Exp a
ratedExpExp (RatedExp Int -> Exp Int)
-> ((Int, RatedExp Int) -> RatedExp Int)
-> (Int, RatedExp Int)
-> Exp Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, RatedExp Int) -> RatedExp Int
forall a b. (a, b) -> b
snd) Dag RatedExp
dag
((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> Dag RatedExp -> State St (Dag RatedExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
substArgs ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int)
-> StateT St Identity (Int, RatedExp Int)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
saveTmpVar) Dag RatedExp
dag
where
requestRate :: IfRate -> Maybe Rate -> Maybe Rate
requestRate IfRate
ifRate Maybe Rate
mRate =
case IfRate
ifRate of
IfRate
IfIr -> Rate -> Maybe Rate
forall a. a -> Maybe a
Just Rate
Ir
IfRate
_ -> Maybe Rate
mRate
saveTmpVar :: (Int, RatedExp Int) -> RemoveTmp (Int, RatedExp Int)
saveTmpVar :: (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
saveTmpVar (Int
resId, RatedExp Int
expr) = case RatedExp Int -> Exp Int
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Int
expr of
ReadVarTmp IfRate
ifRate TmpVar
tmp Var
v -> do
Maybe Rate
mRate <- TmpVar -> RemoveTmp (Maybe Rate)
lookupRate TmpVar
tmp
TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
(Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a b. (a -> b) -> a -> b
$
(Int
resId, RatedExp Int
expr { ratedExpExp = ReadVar ifRate v, ratedExpRate = requestRate ifRate mRate })
ReadArrTmp IfRate
ifRate TmpVar
tmp Var
v ArrIndex (PrimOr Int)
index -> do
Maybe Rate
mRate <- TmpVar -> RemoveTmp (Maybe Rate)
lookupRate TmpVar
tmp
TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
(Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a b. (a -> b) -> a -> b
$ (Int
resId, RatedExp Int
expr { ratedExpExp = ReadArr ifRate v index, ratedExpRate = requestRate ifRate mRate })
TfmInit TmpVar
tmp Info
info ArrIndex (PrimOr Int)
args -> do
Maybe TmpVarRate
mTmpRate <- TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate TmpVar
tmp
let
onSingleRate :: Maybe Rate -> StateT St Identity (Int, RatedExp Int)
onSingleRate Maybe Rate
mRate = do
TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
(Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
resId, RatedExp Int
expr { ratedExpExp = Tfm info args, ratedExpRate = mRate })
onMultiRate :: StateT St Identity (Int, RatedExp Int)
onMultiRate = do
TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
(Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
resId, RatedExp Int
expr { ratedExpExp = Tfm info args, ratedExpRate = Nothing })
case Maybe TmpVarRate
mTmpRate of
Just TmpVarRate
tmpRate ->
case TmpVarRate
tmpRate of
SingleTmpRate Rate
rate -> Maybe Rate -> StateT St Identity (Int, RatedExp Int)
onSingleRate (Rate -> Maybe Rate
forall a. a -> Maybe a
Just Rate
rate)
MultiTmpRate [Rate]
_rates -> StateT St Identity (Int, RatedExp Int)
onMultiRate
Maybe TmpVarRate
Nothing -> Maybe Rate -> StateT St Identity (Int, RatedExp Int)
onSingleRate Maybe Rate
forall a. Maybe a
Nothing
Exp Int
_ -> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
resId, RatedExp Int
expr)
substArgs :: (Int, RatedExp Int) -> RemoveTmp (Int, RatedExp Int)
substArgs :: (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
substArgs (Int
resId, RatedExp Int
expr) = do
Exp Int
e <- (PrimOr Int -> StateT St Identity (PrimOr Int))
-> Exp Int -> StateT St Identity (Exp Int)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> MainExp a -> m (MainExp b)
mapM (Int -> PrimOr Int -> StateT St Identity (PrimOr Int)
substTmp Int
resId) (RatedExp Int -> Exp Int
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Int
expr)
(Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a b. (a -> b) -> a -> b
$ (Int
resId, RatedExp Int
expr { ratedExpExp = e })
substTmp :: Int -> PrimOr Int -> RemoveTmp (PrimOr Int)
substTmp :: Int -> PrimOr Int -> StateT St Identity (PrimOr Int)
substTmp Int
resId (PrimOr Either Prim Int
e) = (Either Prim Int -> PrimOr Int)
-> StateT St Identity (Either Prim Int)
-> StateT St Identity (PrimOr Int)
forall a b.
(a -> b) -> StateT St Identity a -> StateT St Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Either Prim Int -> PrimOr Int
forall a. Either Prim a -> PrimOr a
PrimOr (StateT St Identity (Either Prim Int)
-> StateT St Identity (PrimOr Int))
-> StateT St Identity (Either Prim Int)
-> StateT St Identity (PrimOr Int)
forall a b. (a -> b) -> a -> b
$ case Either Prim Int
e of
Right Int
n -> Either Prim Int -> StateT St Identity (Either Prim Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Either Prim Int
forall a b. b -> Either a b
Right Int
n)
Left Prim
p -> case Prim
p of
PrimTmpVar TmpVar
tmp -> Int -> Either Prim Int
forall a b. b -> Either a b
Right (Int -> Either Prim Int)
-> StateT St Identity Int -> StateT St Identity (Either Prim Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> TmpVar -> StateT St Identity Int
lookupTmpVar Int
resId TmpVar
tmp
Prim
_ -> Either Prim Int -> StateT St Identity (Either Prim Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Prim Int -> StateT St Identity (Either Prim Int))
-> Either Prim Int -> StateT St Identity (Either Prim Int)
forall a b. (a -> b) -> a -> b
$ Prim -> Either Prim Int
forall a b. a -> Either a b
Left Prim
p
insertTmpVar :: TmpVar -> Int -> RemoveTmp ()
insertTmpVar :: TmpVar -> Int -> StateT St Identity ()
insertTmpVar (TmpVar Maybe TmpVarRate
_ Maybe Info
_ Int
v) Int
resId =
(St -> St) -> StateT St Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((St -> St) -> StateT St Identity ())
-> (St -> St) -> StateT St Identity ()
forall a b. (a -> b) -> a -> b
$ \St
st -> St
st { stIds = IntMap.insert v resId (stIds st) }
lookupTmpVar :: Int -> TmpVar -> RemoveTmp Int
lookupTmpVar :: Int -> TmpVar -> StateT St Identity Int
lookupTmpVar Int
resId (TmpVar Maybe TmpVarRate
_ Maybe Info
_ Int
n) = (St -> Int) -> StateT St Identity Int
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
forall {a}. a
err (Maybe Int -> Int) -> (St -> Maybe Int) -> St -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntMap Int -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
n (IntMap Int -> Maybe Int) -> (St -> IntMap Int) -> St -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. St -> IntMap Int
stIds)
where
err :: a
err = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"TmpVar not found: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" on result id: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
resId
saveTmpVarRate :: TmpVar -> RemoveTmp ()
saveTmpVarRate :: TmpVar -> StateT St Identity ()
saveTmpVarRate (TmpVar Maybe TmpVarRate
mRate Maybe Info
mInfo Int
n) = do
(St -> St) -> StateT St Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((St -> St) -> StateT St Identity ())
-> (St -> St) -> StateT St Identity ()
forall a b. (a -> b) -> a -> b
$ \St
st -> St
st { stRates = IntMap.insert n (mRate, mInfo) (stRates st)}
lookupRate :: TmpVar -> RemoveTmp (Maybe Rate)
lookupRate :: TmpVar -> RemoveTmp (Maybe Rate)
lookupRate TmpVar
var =
(Maybe TmpVarRate -> Maybe Rate)
-> RemoveTmp (Maybe TmpVarRate) -> RemoveTmp (Maybe Rate)
forall a b.
(a -> b) -> StateT St Identity a -> StateT St Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TmpVarRate -> Maybe Rate
getSingleTmpRate (TmpVarRate -> Maybe Rate) -> Maybe TmpVarRate -> Maybe Rate
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate TmpVar
var)
lookupTmpRate :: TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate :: TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate (TmpVar Maybe TmpVarRate
_ Maybe Info
_ Int
n) = (St -> Maybe TmpVarRate) -> RemoveTmp (Maybe TmpVarRate)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((Maybe TmpVarRate, Maybe Info) -> Maybe TmpVarRate
forall a b. (a, b) -> a
fst ((Maybe TmpVarRate, Maybe Info) -> Maybe TmpVarRate)
-> (St -> Maybe (Maybe TmpVarRate, Maybe Info))
-> St
-> Maybe TmpVarRate
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Int
-> IntMap (Maybe TmpVarRate, Maybe Info)
-> Maybe (Maybe TmpVarRate, Maybe Info)
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
n (IntMap (Maybe TmpVarRate, Maybe Info)
-> Maybe (Maybe TmpVarRate, Maybe Info))
-> (St -> IntMap (Maybe TmpVarRate, Maybe Info))
-> St
-> Maybe (Maybe TmpVarRate, Maybe Info)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. St -> IntMap (Maybe TmpVarRate, Maybe Info)
stRates))