module Csound.Dynamic.Tfm.Liveness ( liveness ) where import Prelude hiding (mapM, mapM_) import Control.Monad.Trans.State.Strict import Data.Traversable import Data.Foldable import qualified Data.Map as M import Control.Monad.Trans.Class import Control.Monad hiding (mapM, mapM_) import Control.Monad.ST import qualified Data.Array.Unboxed as A import qualified Data.Array.MArray as A import qualified Data.Array.ST as A import qualified Csound.Dynamic.Tfm.DeduceTypes as D import Csound.Dynamic.Tfm.DeduceTypes(varType, varId) import Csound.Dynamic.Types.Exp(Rate(..)) liveness :: Traversable f => Int -> Dag f -> Dag f liveness lastFreshId as = runST $ do st <- initSt lastFreshId $ analyse lastFreshId as evalStateT (mapM substExp $ countLines $ as) st type LineNumber = Int countLines :: [a] -> [(LineNumber, a)] countLines = zip [0 ..] type Var = D.Var Rate type Lhs = [Var] type Rhs f = f Var type Exp f = (Lhs, Rhs f) type Dag f = [Exp f] ----------------------------------------------- data IdList = IdList [Int] -- fresh ids Int -- the biggest used id allocId :: IdList -> (Int, IdList) allocId (IdList is lastId) = (head is, IdList (tail is) (max (head is) lastId)) freeId :: Int -> IdList -> IdList freeId n (IdList is lastId) = IdList (insertSorted n is) lastId1 where lastId1 = if (n == lastId) then (lastId - 1) else lastId insertSorted :: Int -> [Int] -> [Int] insertSorted n (a:as) | n < a = n : a : as | n == a = a : as | otherwise = a : insertSorted n as insertSorted n [] = [n] ----------------------------------------------- type StArr s = A.STUArray s Int Int type LivenessTable = A.UArray Int Int type SubstTable s = StArr s data Registers s = Registers { registers :: M.Map Rate IdList , livenessTable :: LivenessTable , substTable :: SubstTable s } type Memory s a = StateT (Registers s) (ST s) a onRegs :: (M.Map Rate IdList -> M.Map Rate IdList) -> (Registers s -> Registers s) onRegs f rs = rs { registers = f $ registers rs } initRegs :: M.Map Rate IdList initRegs = M.fromList $ fmap (\x -> (x, initIdList)) [(minBound :: Rate) .. maxBound] where initIdList = IdList [0..] 0 isAlive :: LineNumber -> Var -> Memory s Bool isAlive lineNum v = do tab <- fmap livenessTable get return $ lineNum < tab A.! (varId v) lookUpSubst :: Int -> Memory s Int lookUpSubst i = do tab <- fmap substTable get lift $ A.readArray tab i saveSubst :: Int -> Int -> Memory s () saveSubst from to = do tab <- fmap substTable get lift $ A.writeArray tab from to substLhs :: Var -> Memory s Var substLhs v = do v1 <- allocAndSkipInits v saveSubst (varId v) (varId v1) return v1 substRhs :: LineNumber -> Var -> Memory s Var substRhs lineNum v = do i1 <- lookUpSubst (varId v) let v1 = D.Var i1 (varType v) b <- isAlive lineNum v unless b $ free v1 return v1 allocAndSkipInits :: Var -> Memory s Var allocAndSkipInits v | isInit r = return v | otherwise = alloc r where r = varType v isInit x = x == Ir || x == Sr alloc :: Rate -> Memory s Var alloc rate = state $ \mem -> let (i, mem1) = allocRegister rate mem in (D.Var i rate, mem1) where allocRegister :: Rate -> Registers s -> (Int, Registers s) allocRegister r mem = (i, onRegs (M.update (const $ Just is) r) mem) where (i, is) = allocId $ registers mem M.! r free :: Var -> Memory s () free v = state $ \mem -> let mem1 = freeRegister (varType v) (varId v) mem in ((), mem1) where freeRegister :: Rate -> Int -> Registers s -> Registers s freeRegister rate i = onRegs $ M.update (Just . freeId i) rate -------------------------------------------------------------------------- analyse :: Traversable f => Int -> Dag f -> LivenessTable analyse lastFreshId as = A.runSTUArray $ do arr <- A.newArray (0, lastFreshId) 0 mapM_ (go arr) $ countLines as return arr where go :: Traversable f => StArr s -> (LineNumber, Exp f) -> ST s () go arr (lineNum, (_, rhs)) = mapM (countVar arr lineNum) rhs >> return () countVar :: StArr s -> LineNumber -> Var -> ST s () countVar arr lineNum v = do val <- A.readArray arr i A.writeArray arr i (val `max` lineNum) where i = varId v substExp :: Traversable f => (LineNumber, Exp f) -> Memory s (Exp f) substExp (lineNum, (lhs, rhs)) = do freshLhs <- traverse substLhs lhs freshRhs <- traverse (substRhs lineNum) rhs return (freshLhs, freshRhs) initSt :: Int -> LivenessTable -> ST s (Registers s) initSt lastFreshId livenessTab = fmap (Registers initRegs livenessTab) (initSubstTable lastFreshId) initSubstTable :: Int -> ST s (SubstTable s) initSubstTable n = A.newListArray (0, n+1) [0 .. n + 1]