-- | In this module we substitute temporary vars with graph vars
--
-- After this stage no TmpVars should left in the DAg.
-- typical problems: TmpVar is not inlined by some reason in the next expression down the flow.
--
-- Assumption: every tmpVar is inlined in the next expression
module Csound.Dynamic.Tfm.TmpVars
  ( removeTmpVars
  ) where

import Control.Monad.Trans.State.Strict
import Data.IntMap.Strict (IntMap)
import Data.IntMap.Strict qualified as IntMap
import Data.Maybe
import Control.Monad
-- import Debug.Trace

import Csound.Dynamic.Types.Exp
  ( RatedExp (..),
    TmpVar (..),
    MainExp (..),
    PrimOr (..),
    Prim (..),
    Rate (..),
    getTmpVars,
    IfRate (..),
    Info (..),
    TmpVarRate (..),
    getSingleTmpRate,
  )

type Node f = (Int, f Int)
type Dag f = [Node f]

type RemoveTmp a = State St a

data St = St
  { St -> IntMap Int
stIds :: IntMap Int
    -- ^ ids of tmp vars LHS in equations
  , St -> IntMap (Maybe TmpVarRate, Maybe Info)
stRates :: IntMap (Maybe TmpVarRate, Maybe Info)
    -- ^ rates if requested for TmpVar's
  }
  deriving (Int -> St -> ShowS
[St] -> ShowS
St -> String
(Int -> St -> ShowS)
-> (St -> String) -> ([St] -> ShowS) -> Show St
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> St -> ShowS
showsPrec :: Int -> St -> ShowS
$cshow :: St -> String
show :: St -> String
$cshowList :: [St] -> ShowS
showList :: [St] -> ShowS
Show)

removeTmpVars :: Dag RatedExp -> Dag RatedExp
removeTmpVars :: Dag RatedExp -> Dag RatedExp
removeTmpVars Dag RatedExp
dag = (State St (Dag RatedExp) -> St -> Dag RatedExp)
-> St -> State St (Dag RatedExp) -> Dag RatedExp
forall a b c. (a -> b -> c) -> b -> a -> c
flip State St (Dag RatedExp) -> St -> Dag RatedExp
forall s a. State s a -> s -> a
evalState (IntMap Int -> IntMap (Maybe TmpVarRate, Maybe Info) -> St
St IntMap Int
forall a. IntMap a
IntMap.empty IntMap (Maybe TmpVarRate, Maybe Info)
forall a. IntMap a
IntMap.empty) (State St (Dag RatedExp) -> Dag RatedExp)
-> State St (Dag RatedExp) -> Dag RatedExp
forall a b. (a -> b) -> a -> b
$ do
  ((Int, RatedExp Int) -> StateT St Identity ())
-> Dag RatedExp -> StateT St Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((TmpVar -> StateT St Identity ())
-> [TmpVar] -> StateT St Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TmpVar -> StateT St Identity ()
saveTmpVarRate ([TmpVar] -> StateT St Identity ())
-> ((Int, RatedExp Int) -> [TmpVar])
-> (Int, RatedExp Int)
-> StateT St Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp Int -> [TmpVar]
forall a. Exp a -> [TmpVar]
getTmpVars (Exp Int -> [TmpVar])
-> ((Int, RatedExp Int) -> Exp Int)
-> (Int, RatedExp Int)
-> [TmpVar]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RatedExp Int -> Exp Int
forall a. RatedExp a -> Exp a
ratedExpExp (RatedExp Int -> Exp Int)
-> ((Int, RatedExp Int) -> RatedExp Int)
-> (Int, RatedExp Int)
-> Exp Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  (Int, RatedExp Int) -> RatedExp Int
forall a b. (a, b) -> b
snd) Dag RatedExp
dag
  ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> Dag RatedExp -> State St (Dag RatedExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
substArgs ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int)
-> StateT St Identity (Int, RatedExp Int)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
saveTmpVar) Dag RatedExp
dag

  where
    requestRate :: IfRate -> Maybe Rate -> Maybe Rate
requestRate IfRate
ifRate Maybe Rate
mRate =
      case IfRate
ifRate of
        IfRate
IfIr -> Rate -> Maybe Rate
forall a. a -> Maybe a
Just Rate
Ir
        IfRate
_ -> Maybe Rate
mRate

    saveTmpVar :: (Int, RatedExp Int) -> RemoveTmp (Int, RatedExp Int)
    saveTmpVar :: (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
saveTmpVar (Int
resId, RatedExp Int
expr) = case RatedExp Int -> Exp Int
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Int
expr of
      ReadVarTmp IfRate
ifRate TmpVar
tmp Var
v -> do
        Maybe Rate
mRate <- TmpVar -> RemoveTmp (Maybe Rate)
lookupRate TmpVar
tmp
        TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
        (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a b. (a -> b) -> a -> b
$
          -- (\x -> trace (unwords ["TMP VAR:", show $ratedExpRate $ snd x]) x) $
          (Int
resId, RatedExp Int
expr { ratedExpExp = ReadVar ifRate v, ratedExpRate = requestRate ifRate mRate })

      ReadArrTmp IfRate
ifRate TmpVar
tmp Var
v ArrIndex (PrimOr Int)
index -> do
        Maybe Rate
mRate <- TmpVar -> RemoveTmp (Maybe Rate)
lookupRate TmpVar
tmp
        TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
        (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a b. (a -> b) -> a -> b
$ (Int
resId, RatedExp Int
expr { ratedExpExp = ReadArr ifRate v index, ratedExpRate = requestRate ifRate mRate })

      TfmInit TmpVar
tmp Info
info ArrIndex (PrimOr Int)
args -> do
        Maybe TmpVarRate
mTmpRate <- TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate TmpVar
tmp
        let
          onSingleRate :: Maybe Rate -> StateT St Identity (Int, RatedExp Int)
onSingleRate Maybe Rate
mRate = do
            TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
            (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
resId, RatedExp Int
expr { ratedExpExp = Tfm info args, ratedExpRate = mRate })

          onMultiRate :: StateT St Identity (Int, RatedExp Int)
onMultiRate = do
            TmpVar -> Int -> StateT St Identity ()
insertTmpVar TmpVar
tmp Int
resId
            (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
resId, RatedExp Int
expr { ratedExpExp = Tfm info args, ratedExpRate = Nothing })
        case Maybe TmpVarRate
mTmpRate of
          Just TmpVarRate
tmpRate ->
            case TmpVarRate
tmpRate of
              SingleTmpRate Rate
rate -> Maybe Rate -> StateT St Identity (Int, RatedExp Int)
onSingleRate (Rate -> Maybe Rate
forall a. a -> Maybe a
Just Rate
rate)
              MultiTmpRate [Rate]
_rates -> StateT St Identity (Int, RatedExp Int)
onMultiRate
          Maybe TmpVarRate
Nothing -> Maybe Rate -> StateT St Identity (Int, RatedExp Int)
onSingleRate Maybe Rate
forall a. Maybe a
Nothing

      Exp Int
_ -> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
resId, RatedExp Int
expr)

    substArgs :: (Int, RatedExp Int) -> RemoveTmp (Int, RatedExp Int)
    substArgs :: (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
substArgs (Int
resId, RatedExp Int
expr) = do
      Exp Int
e <- (PrimOr Int -> StateT St Identity (PrimOr Int))
-> Exp Int -> StateT St Identity (Exp Int)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> MainExp a -> m (MainExp b)
mapM (Int -> PrimOr Int -> StateT St Identity (PrimOr Int)
substTmp Int
resId) (RatedExp Int -> Exp Int
forall a. RatedExp a -> Exp a
ratedExpExp RatedExp Int
expr)
      (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int))
-> (Int, RatedExp Int) -> StateT St Identity (Int, RatedExp Int)
forall a b. (a -> b) -> a -> b
$ (Int
resId, RatedExp Int
expr { ratedExpExp = e })

    substTmp :: Int -> PrimOr Int -> RemoveTmp (PrimOr Int)
    substTmp :: Int -> PrimOr Int -> StateT St Identity (PrimOr Int)
substTmp Int
resId (PrimOr Either Prim Int
e) = (Either Prim Int -> PrimOr Int)
-> StateT St Identity (Either Prim Int)
-> StateT St Identity (PrimOr Int)
forall a b.
(a -> b) -> StateT St Identity a -> StateT St Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Either Prim Int -> PrimOr Int
forall a. Either Prim a -> PrimOr a
PrimOr (StateT St Identity (Either Prim Int)
 -> StateT St Identity (PrimOr Int))
-> StateT St Identity (Either Prim Int)
-> StateT St Identity (PrimOr Int)
forall a b. (a -> b) -> a -> b
$ case Either Prim Int
e of
      Right Int
n -> Either Prim Int -> StateT St Identity (Either Prim Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> Either Prim Int
forall a b. b -> Either a b
Right Int
n)
      Left Prim
p -> case Prim
p of
        PrimTmpVar TmpVar
tmp -> Int -> Either Prim Int
forall a b. b -> Either a b
Right (Int -> Either Prim Int)
-> StateT St Identity Int -> StateT St Identity (Either Prim Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> TmpVar -> StateT St Identity Int
lookupTmpVar Int
resId TmpVar
tmp
        Prim
_              -> Either Prim Int -> StateT St Identity (Either Prim Int)
forall a. a -> StateT St Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Prim Int -> StateT St Identity (Either Prim Int))
-> Either Prim Int -> StateT St Identity (Either Prim Int)
forall a b. (a -> b) -> a -> b
$ Prim -> Either Prim Int
forall a b. a -> Either a b
Left Prim
p

insertTmpVar :: TmpVar -> Int -> RemoveTmp ()
insertTmpVar :: TmpVar -> Int -> StateT St Identity ()
insertTmpVar (TmpVar Maybe TmpVarRate
_ Maybe Info
_ Int
v) Int
resId =
  (St -> St) -> StateT St Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((St -> St) -> StateT St Identity ())
-> (St -> St) -> StateT St Identity ()
forall a b. (a -> b) -> a -> b
$ \St
st -> St
st { stIds = IntMap.insert v resId (stIds st) }

lookupTmpVar :: Int -> TmpVar -> RemoveTmp Int
lookupTmpVar :: Int -> TmpVar -> StateT St Identity Int
lookupTmpVar Int
resId (TmpVar Maybe TmpVarRate
_ Maybe Info
_ Int
n) = (St -> Int) -> StateT St Identity Int
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
forall {a}. a
err (Maybe Int -> Int) -> (St -> Maybe Int) -> St -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntMap Int -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
n (IntMap Int -> Maybe Int) -> (St -> IntMap Int) -> St -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. St -> IntMap Int
stIds)
  where
    err :: a
err = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"TmpVar not found: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" on result id: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
resId

saveTmpVarRate :: TmpVar -> RemoveTmp ()
saveTmpVarRate :: TmpVar -> StateT St Identity ()
saveTmpVarRate (TmpVar Maybe TmpVarRate
mRate Maybe Info
mInfo Int
n) = do
  (St -> St) -> StateT St Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' ((St -> St) -> StateT St Identity ())
-> (St -> St) -> StateT St Identity ()
forall a b. (a -> b) -> a -> b
$ \St
st -> St
st { stRates = IntMap.insert n (mRate, mInfo) (stRates st)}

lookupRate :: TmpVar -> RemoveTmp (Maybe Rate)
lookupRate :: TmpVar -> RemoveTmp (Maybe Rate)
lookupRate TmpVar
var =
  (Maybe TmpVarRate -> Maybe Rate)
-> RemoveTmp (Maybe TmpVarRate) -> RemoveTmp (Maybe Rate)
forall a b.
(a -> b) -> StateT St Identity a -> StateT St Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TmpVarRate -> Maybe Rate
getSingleTmpRate (TmpVarRate -> Maybe Rate) -> Maybe TmpVarRate -> Maybe Rate
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate TmpVar
var)

lookupTmpRate :: TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate :: TmpVar -> RemoveTmp (Maybe TmpVarRate)
lookupTmpRate (TmpVar Maybe TmpVarRate
_ Maybe Info
_ Int
n) = (St -> Maybe TmpVarRate) -> RemoveTmp (Maybe TmpVarRate)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets ((Maybe TmpVarRate, Maybe Info) -> Maybe TmpVarRate
forall a b. (a, b) -> a
fst ((Maybe TmpVarRate, Maybe Info) -> Maybe TmpVarRate)
-> (St -> Maybe (Maybe TmpVarRate, Maybe Info))
-> St
-> Maybe TmpVarRate
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Int
-> IntMap (Maybe TmpVarRate, Maybe Info)
-> Maybe (Maybe TmpVarRate, Maybe Info)
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
n (IntMap (Maybe TmpVarRate, Maybe Info)
 -> Maybe (Maybe TmpVarRate, Maybe Info))
-> (St -> IntMap (Maybe TmpVarRate, Maybe Info))
-> St
-> Maybe (Maybe TmpVarRate, Maybe Info)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. St -> IntMap (Maybe TmpVarRate, Maybe Info)
stRates))