module Csound.Dynamic.Tfm.Liveness (
    liveness
) where

import Prelude hiding (mapM, mapM_)

import Control.Monad.Trans.State.Strict
import Data.Traversable
import Data.Foldable
import qualified Data.Map as M

import Control.Monad.Trans.Class
import Control.Monad hiding (mapM, mapM_)
import Control.Monad.ST
import qualified Data.Array.Unboxed as A
import qualified Data.Array.MArray as A
import qualified Data.Array.ST as A

import qualified Csound.Dynamic.Tfm.DeduceTypes as D
import Csound.Dynamic.Tfm.DeduceTypes(varType, varId)
import Csound.Dynamic.Types.Exp(Rate(..))

liveness :: Traversable f => Int -> Dag f -> Dag f
liveness :: Int -> Dag f -> Dag f
liveness Int
lastFreshId Dag f
as = (forall s. ST s (Dag f)) -> Dag f
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Dag f)) -> Dag f)
-> (forall s. ST s (Dag f)) -> Dag f
forall a b. (a -> b) -> a -> b
$ do
  Registers s
st <- Int -> LivenessTable -> ST s (Registers s)
forall s. Int -> LivenessTable -> ST s (Registers s)
initSt Int
lastFreshId (LivenessTable -> ST s (Registers s))
-> LivenessTable -> ST s (Registers s)
forall a b. (a -> b) -> a -> b
$ Int -> Dag f -> LivenessTable
forall (f :: * -> *).
Traversable f =>
Int -> Dag f -> LivenessTable
analyse Int
lastFreshId Dag f
as
  StateT (Registers s) (ST s) (Dag f) -> Registers s -> ST s (Dag f)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (((Int, Exp f) -> StateT (Registers s) (ST s) (Exp f))
-> [(Int, Exp f)] -> StateT (Registers s) (ST s) (Dag f)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int, Exp f) -> StateT (Registers s) (ST s) (Exp f)
forall (f :: * -> *) s.
Traversable f =>
(Int, Exp f) -> Memory s (Exp f)
substExp ([(Int, Exp f)] -> StateT (Registers s) (ST s) (Dag f))
-> [(Int, Exp f)] -> StateT (Registers s) (ST s) (Dag f)
forall a b. (a -> b) -> a -> b
$ Dag f -> [(Int, Exp f)]
forall a. [a] -> [(Int, a)]
countLines (Dag f -> [(Int, Exp f)]) -> Dag f -> [(Int, Exp f)]
forall a b. (a -> b) -> a -> b
$ Dag f
as) Registers s
st

type LineNumber = Int

countLines :: [a] -> [(LineNumber, a)]
countLines :: [a] -> [(Int, a)]
countLines = [Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]

type Var  = D.Var Rate

type Lhs   = [Var]
type Rhs f = f Var
type Exp f = (Lhs, Rhs f)

type Dag f = [Exp f]

-----------------------------------------------

data IdList = IdList
    [Int] -- fresh ids
    Int   -- the biggest used id

allocId :: IdList -> (Int, IdList)
allocId :: IdList -> (Int, IdList)
allocId (IdList [Int]
is Int
lastId) = ([Int] -> Int
forall a. [a] -> a
head [Int]
is, [Int] -> Int -> IdList
IdList ([Int] -> [Int]
forall a. [a] -> [a]
tail [Int]
is) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max ([Int] -> Int
forall a. [a] -> a
head [Int]
is) Int
lastId))

freeId :: Int -> IdList -> IdList
freeId :: Int -> IdList -> IdList
freeId  Int
n (IdList [Int]
is Int
lastId) = [Int] -> Int -> IdList
IdList (Int -> [Int] -> [Int]
insertSorted Int
n [Int]
is) Int
lastId1
  where lastId1 :: Int
lastId1 = if (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lastId) then (Int
lastId Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) else Int
lastId

insertSorted :: Int -> [Int] -> [Int]
insertSorted :: Int -> [Int] -> [Int]
insertSorted Int
n (Int
a:[Int]
as)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
a  = Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
a Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
as
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
a = Int
a Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
as
  | Bool
otherwise = Int
a Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int] -> [Int]
insertSorted Int
n [Int]
as
insertSorted Int
n [] = [Int
n]

-----------------------------------------------

type StArr s = A.STUArray s Int Int

type LivenessTable = A.UArray Int Int
type SubstTable s  = StArr s

data Registers s = Registers
  { Registers s -> Map Rate IdList
registers     :: M.Map Rate IdList
  , Registers s -> LivenessTable
livenessTable :: LivenessTable
  , Registers s -> SubstTable s
substTable    :: SubstTable s
  }

type Memory s a = StateT (Registers s) (ST s) a

onRegs ::  (M.Map Rate IdList -> M.Map Rate IdList) -> (Registers s -> Registers s)
onRegs :: (Map Rate IdList -> Map Rate IdList) -> Registers s -> Registers s
onRegs Map Rate IdList -> Map Rate IdList
f Registers s
rs = Registers s
rs { registers :: Map Rate IdList
registers = Map Rate IdList -> Map Rate IdList
f (Map Rate IdList -> Map Rate IdList)
-> Map Rate IdList -> Map Rate IdList
forall a b. (a -> b) -> a -> b
$ Registers s -> Map Rate IdList
forall s. Registers s -> Map Rate IdList
registers Registers s
rs }

initRegs :: M.Map Rate IdList
initRegs :: Map Rate IdList
initRegs = [(Rate, IdList)] -> Map Rate IdList
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Rate, IdList)] -> Map Rate IdList)
-> [(Rate, IdList)] -> Map Rate IdList
forall a b. (a -> b) -> a -> b
$ (Rate -> (Rate, IdList)) -> [Rate] -> [(Rate, IdList)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Rate
x -> (Rate
x, IdList
initIdList)) [(Rate
forall a. Bounded a => a
minBound :: Rate) .. Rate
forall a. Bounded a => a
maxBound]
  where initIdList :: IdList
initIdList = [Int] -> Int -> IdList
IdList [Int
0..] Int
0

isAlive :: LineNumber -> Var -> Memory s Bool
isAlive :: Int -> Var -> Memory s Bool
isAlive Int
lineNum Var
v = do
  LivenessTable
tab <- (Registers s -> LivenessTable)
-> StateT (Registers s) (ST s) (Registers s)
-> StateT (Registers s) (ST s) LivenessTable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Registers s -> LivenessTable
forall s. Registers s -> LivenessTable
livenessTable StateT (Registers s) (ST s) (Registers s)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  Bool -> Memory s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Memory s Bool) -> Bool -> Memory s Bool
forall a b. (a -> b) -> a -> b
$ Int
lineNum Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< LivenessTable
tab LivenessTable -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> i -> e
A.! (Var -> Int
forall a. Var a -> Int
varId Var
v)

lookUpSubst :: Int -> Memory s Int
lookUpSubst :: Int -> Memory s Int
lookUpSubst Int
i = do
  SubstTable s
tab <- (Registers s -> SubstTable s)
-> StateT (Registers s) (ST s) (Registers s)
-> StateT (Registers s) (ST s) (SubstTable s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Registers s -> SubstTable s
forall s. Registers s -> SubstTable s
substTable StateT (Registers s) (ST s) (Registers s)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  ST s Int -> Memory s Int
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s Int -> Memory s Int) -> ST s Int -> Memory s Int
forall a b. (a -> b) -> a -> b
$ SubstTable s -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
A.readArray SubstTable s
tab Int
i

saveSubst :: Int -> Int -> Memory s ()
saveSubst :: Int -> Int -> Memory s ()
saveSubst Int
from Int
to = do
  SubstTable s
tab <- (Registers s -> SubstTable s)
-> StateT (Registers s) (ST s) (Registers s)
-> StateT (Registers s) (ST s) (SubstTable s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Registers s -> SubstTable s
forall s. Registers s -> SubstTable s
substTable StateT (Registers s) (ST s) (Registers s)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  ST s () -> Memory s ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> Memory s ()) -> ST s () -> Memory s ()
forall a b. (a -> b) -> a -> b
$ SubstTable s -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
A.writeArray SubstTable s
tab Int
from Int
to

substLhs :: Var -> Memory s Var
substLhs :: Var -> Memory s Var
substLhs Var
v = do
  Var
v1 <- Var -> Memory s Var
forall s. Var -> Memory s Var
allocAndSkipInits Var
v
  Int -> Int -> Memory s ()
forall s. Int -> Int -> Memory s ()
saveSubst (Var -> Int
forall a. Var a -> Int
varId Var
v) (Var -> Int
forall a. Var a -> Int
varId Var
v1)
  Var -> Memory s Var
forall (m :: * -> *) a. Monad m => a -> m a
return Var
v1

substRhs :: LineNumber -> Var -> Memory s Var
substRhs :: Int -> Var -> Memory s Var
substRhs Int
lineNum Var
v = do
  Int
i1 <- Int -> Memory s Int
forall s. Int -> Memory s Int
lookUpSubst (Var -> Int
forall a. Var a -> Int
varId Var
v)
  let v1 :: Var
v1 = Int -> Rate -> Var
forall a. Int -> a -> Var a
D.Var Int
i1 (Var -> Rate
forall a. Var a -> a
varType Var
v)
  Bool
b <- Int -> Var -> Memory s Bool
forall s. Int -> Var -> Memory s Bool
isAlive Int
lineNum Var
v
  Bool
-> StateT (Registers s) (ST s) () -> StateT (Registers s) (ST s) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
b (StateT (Registers s) (ST s) () -> StateT (Registers s) (ST s) ())
-> StateT (Registers s) (ST s) () -> StateT (Registers s) (ST s) ()
forall a b. (a -> b) -> a -> b
$ Var -> StateT (Registers s) (ST s) ()
forall s. Var -> Memory s ()
free Var
v1
  Var -> Memory s Var
forall (m :: * -> *) a. Monad m => a -> m a
return Var
v1

allocAndSkipInits :: Var -> Memory s Var
allocAndSkipInits :: Var -> Memory s Var
allocAndSkipInits Var
v
    | Rate -> Bool
isInit Rate
r  = Var -> Memory s Var
forall (m :: * -> *) a. Monad m => a -> m a
return Var
v
    | Bool
otherwise = Rate -> Memory s Var
forall s. Rate -> Memory s Var
alloc Rate
r
    where
        r :: Rate
r = Var -> Rate
forall a. Var a -> a
varType Var
v
        isInit :: Rate -> Bool
isInit Rate
x = Rate
x Rate -> Rate -> Bool
forall a. Eq a => a -> a -> Bool
== Rate
Ir Bool -> Bool -> Bool
|| Rate
x Rate -> Rate -> Bool
forall a. Eq a => a -> a -> Bool
== Rate
Sr

alloc :: Rate -> Memory s Var
alloc :: Rate -> Memory s Var
alloc Rate
rate = (Registers s -> (Var, Registers s)) -> Memory s Var
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Registers s -> (Var, Registers s)) -> Memory s Var)
-> (Registers s -> (Var, Registers s)) -> Memory s Var
forall a b. (a -> b) -> a -> b
$ \Registers s
mem ->
  let (Int
i, Registers s
mem1) = Rate -> Registers s -> (Int, Registers s)
forall s. Rate -> Registers s -> (Int, Registers s)
allocRegister Rate
rate Registers s
mem
  in  (Int -> Rate -> Var
forall a. Int -> a -> Var a
D.Var Int
i Rate
rate, Registers s
mem1)
  where
    allocRegister :: Rate -> Registers s -> (Int, Registers s)
    allocRegister :: Rate -> Registers s -> (Int, Registers s)
allocRegister Rate
r Registers s
mem = (Int
i, (Map Rate IdList -> Map Rate IdList) -> Registers s -> Registers s
forall s.
(Map Rate IdList -> Map Rate IdList) -> Registers s -> Registers s
onRegs ((IdList -> Maybe IdList)
-> Rate -> Map Rate IdList -> Map Rate IdList
forall k a. Ord k => (a -> Maybe a) -> k -> Map k a -> Map k a
M.update (Maybe IdList -> IdList -> Maybe IdList
forall a b. a -> b -> a
const (Maybe IdList -> IdList -> Maybe IdList)
-> Maybe IdList -> IdList -> Maybe IdList
forall a b. (a -> b) -> a -> b
$ IdList -> Maybe IdList
forall a. a -> Maybe a
Just IdList
is) Rate
r) Registers s
mem)
      where (Int
i, IdList
is) = IdList -> (Int, IdList)
allocId (IdList -> (Int, IdList)) -> IdList -> (Int, IdList)
forall a b. (a -> b) -> a -> b
$ Registers s -> Map Rate IdList
forall s. Registers s -> Map Rate IdList
registers Registers s
mem Map Rate IdList -> Rate -> IdList
forall k a. Ord k => Map k a -> k -> a
M.! Rate
r

free :: Var -> Memory s ()
free :: Var -> Memory s ()
free Var
v = (Registers s -> ((), Registers s)) -> Memory s ()
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((Registers s -> ((), Registers s)) -> Memory s ())
-> (Registers s -> ((), Registers s)) -> Memory s ()
forall a b. (a -> b) -> a -> b
$ \Registers s
mem ->
  let mem1 :: Registers s
mem1 = Rate -> Int -> Registers s -> Registers s
forall s. Rate -> Int -> Registers s -> Registers s
freeRegister (Var -> Rate
forall a. Var a -> a
varType Var
v) (Var -> Int
forall a. Var a -> Int
varId Var
v) Registers s
mem
  in  ((), Registers s
mem1)
  where
    freeRegister :: Rate -> Int -> Registers s -> Registers s
    freeRegister :: Rate -> Int -> Registers s -> Registers s
freeRegister Rate
rate Int
i = (Map Rate IdList -> Map Rate IdList) -> Registers s -> Registers s
forall s.
(Map Rate IdList -> Map Rate IdList) -> Registers s -> Registers s
onRegs ((Map Rate IdList -> Map Rate IdList)
 -> Registers s -> Registers s)
-> (Map Rate IdList -> Map Rate IdList)
-> Registers s
-> Registers s
forall a b. (a -> b) -> a -> b
$ (IdList -> Maybe IdList)
-> Rate -> Map Rate IdList -> Map Rate IdList
forall k a. Ord k => (a -> Maybe a) -> k -> Map k a -> Map k a
M.update (IdList -> Maybe IdList
forall a. a -> Maybe a
Just (IdList -> Maybe IdList)
-> (IdList -> IdList) -> IdList -> Maybe IdList
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IdList -> IdList
freeId Int
i) Rate
rate

--------------------------------------------------------------------------

analyse :: Traversable f => Int -> Dag f -> LivenessTable
analyse :: Int -> Dag f -> LivenessTable
analyse Int
lastFreshId Dag f
as = (forall s. ST s (STUArray s Int Int)) -> LivenessTable
forall i e. (forall s. ST s (STUArray s i e)) -> UArray i e
A.runSTUArray ((forall s. ST s (STUArray s Int Int)) -> LivenessTable)
-> (forall s. ST s (STUArray s Int Int)) -> LivenessTable
forall a b. (a -> b) -> a -> b
$ do
  STUArray s Int Int
arr <- (Int, Int) -> Int -> ST s (STUArray s Int Int)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
A.newArray (Int
0, Int
lastFreshId) Int
0
  ((Int, Exp f) -> ST s ()) -> [(Int, Exp f)] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (STUArray s Int Int -> (Int, Exp f) -> ST s ()
forall (f :: * -> *) s.
Traversable f =>
StArr s -> (Int, Exp f) -> ST s ()
go STUArray s Int Int
arr) ([(Int, Exp f)] -> ST s ()) -> [(Int, Exp f)] -> ST s ()
forall a b. (a -> b) -> a -> b
$ Dag f -> [(Int, Exp f)]
forall a. [a] -> [(Int, a)]
countLines Dag f
as
  STUArray s Int Int -> ST s (STUArray s Int Int)
forall (m :: * -> *) a. Monad m => a -> m a
return STUArray s Int Int
arr
  where
    go :: Traversable f => StArr s -> (LineNumber, Exp f) -> ST s ()
    go :: StArr s -> (Int, Exp f) -> ST s ()
go StArr s
arr (Int
lineNum, (Lhs
_, Rhs f
rhs)) =  (Var -> ST s ()) -> Rhs f -> ST s (f ())
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (StArr s -> Int -> Var -> ST s ()
forall s. StArr s -> Int -> Var -> ST s ()
countVar StArr s
arr Int
lineNum) Rhs f
rhs ST s (f ()) -> ST s () -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    countVar :: StArr s  -> LineNumber -> Var -> ST s ()
    countVar :: StArr s -> Int -> Var -> ST s ()
countVar StArr s
arr Int
lineNum Var
v = do
      Int
val <- StArr s -> Int -> ST s Int
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
A.readArray StArr s
arr Int
i
      StArr s -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
A.writeArray StArr s
arr Int
i (Int
val Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Int
lineNum)
      where i :: Int
i = Var -> Int
forall a. Var a -> Int
varId Var
v

substExp :: Traversable f => (LineNumber, Exp f) -> Memory s (Exp f)
substExp :: (Int, Exp f) -> Memory s (Exp f)
substExp (Int
lineNum, (Lhs
lhs, Rhs f
rhs)) = do
  Lhs
freshLhs <- (Var -> StateT (Registers s) (ST s) Var)
-> Lhs -> StateT (Registers s) (ST s) Lhs
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Var -> StateT (Registers s) (ST s) Var
forall s. Var -> Memory s Var
substLhs Lhs
lhs
  Rhs f
freshRhs <- (Var -> StateT (Registers s) (ST s) Var)
-> Rhs f -> StateT (Registers s) (ST s) (Rhs f)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Int -> Var -> StateT (Registers s) (ST s) Var
forall s. Int -> Var -> Memory s Var
substRhs Int
lineNum) Rhs f
rhs
  Exp f -> Memory s (Exp f)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lhs
freshLhs, Rhs f
freshRhs)

initSt :: Int -> LivenessTable -> ST s (Registers s)
initSt :: Int -> LivenessTable -> ST s (Registers s)
initSt Int
lastFreshId LivenessTable
livenessTab = (SubstTable s -> Registers s)
-> ST s (SubstTable s) -> ST s (Registers s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Map Rate IdList -> LivenessTable -> SubstTable s -> Registers s
forall s.
Map Rate IdList -> LivenessTable -> SubstTable s -> Registers s
Registers Map Rate IdList
initRegs LivenessTable
livenessTab) (Int -> ST s (SubstTable s)
forall s. Int -> ST s (SubstTable s)
initSubstTable Int
lastFreshId)

initSubstTable :: Int ->  ST s (SubstTable s)
initSubstTable :: Int -> ST s (SubstTable s)
initSubstTable Int
n = (Int, Int) -> [Int] -> ST s (SubstTable s)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> [e] -> m (a i e)
A.newListArray (Int
0, Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1]