module Csound.Dynamic.Tfm.Liveness (
    liveness
) where

import Prelude hiding (mapM, mapM_)

import Control.Monad.Trans.State.Strict
import Data.Traversable
import Data.Foldable
import qualified Data.Map as M

import Control.Monad.Trans.Class
import Control.Monad hiding (mapM, mapM_)
import Control.Monad.ST
import qualified Data.Array.Unboxed as A
import qualified Data.Array.MArray as A
import qualified Data.Array.ST as A

import qualified Csound.Dynamic.Tfm.DeduceTypes as D
import Csound.Dynamic.Tfm.DeduceTypes(varType, varId)
import Csound.Dynamic.Types.Exp(Rate(..))

liveness :: Traversable f => Int -> Dag f -> Dag f
liveness lastFreshId as = runST $ do
  st <- initSt lastFreshId $ analyse lastFreshId as
  evalStateT (mapM substExp $ countLines $ as) st

type LineNumber = Int

countLines :: [a] -> [(LineNumber, a)]
countLines = zip [0 ..]

type Var  = D.Var Rate

type Lhs   = [Var]
type Rhs f = f Var
type Exp f = (Lhs, Rhs f)

type Dag f = [Exp f]

-----------------------------------------------

data IdList = IdList
    [Int] -- fresh ids
    Int   -- the biggest used id

allocId :: IdList -> (Int, IdList)
allocId (IdList is lastId) = (head is, IdList (tail is) (max (head is) lastId))

freeId :: Int -> IdList -> IdList
freeId  n (IdList is lastId) = IdList (insertSorted n is) lastId1
  where lastId1 = if (n == lastId) then (lastId - 1) else lastId

insertSorted :: Int -> [Int] -> [Int]
insertSorted n (a:as)
  | n < a  = n : a : as
  | n == a = a : as
  | otherwise = a : insertSorted n as
insertSorted n [] = [n]

-----------------------------------------------

type StArr s = A.STUArray s Int Int

type LivenessTable = A.UArray Int Int
type SubstTable s  = StArr s

data Registers s = Registers
        { registers     :: M.Map Rate IdList
        , livenessTable :: LivenessTable
        , substTable    :: SubstTable s
        }

type Memory s a = StateT (Registers s) (ST s) a

onRegs ::  (M.Map Rate IdList -> M.Map Rate IdList) -> (Registers s -> Registers s)
onRegs f rs = rs { registers = f $ registers rs }

initRegs :: M.Map Rate IdList
initRegs = M.fromList $ fmap (\x -> (x, initIdList)) [(minBound :: Rate) .. maxBound]
        where initIdList = IdList [0..] 0

isAlive :: LineNumber -> Var -> Memory s Bool
isAlive lineNum v = do
        tab <- fmap livenessTable get
        return $ lineNum < tab A.! (varId v)

lookUpSubst :: Int -> Memory s Int
lookUpSubst i = do
        tab <- fmap substTable get
        lift $ A.readArray tab i

saveSubst :: Int -> Int -> Memory s ()
saveSubst from to = do
        tab <- fmap substTable get
        lift $ A.writeArray tab from to

substLhs :: Var -> Memory s Var
substLhs v = do
        v1 <- allocAndSkipInits v
        saveSubst (varId v) (varId v1)
        return v1

substRhs :: LineNumber -> Var -> Memory s Var
substRhs lineNum v = do
        i1 <- lookUpSubst (varId v)
        let v1 = D.Var i1 (varType v)
        b <- isAlive lineNum v
        unless b $ free v1
        return v1

allocAndSkipInits :: Var -> Memory s Var
allocAndSkipInits v
    | isInit r  = return v
    | otherwise = alloc r
    where
        r = varType v
        isInit x = x == Ir || x == Sr

alloc :: Rate -> Memory s Var
alloc rate = state $ \mem ->
        let (i, mem1) = allocRegister rate mem
        in  (D.Var i rate, mem1)
        where
                allocRegister :: Rate -> Registers s -> (Int, Registers s)
                allocRegister r mem = (i, onRegs (M.update (const $ Just is) r) mem)
                        where (i, is) = allocId $ registers mem M.! r

free :: Var -> Memory s ()
free v = state $ \mem ->
        let mem1 = freeRegister (varType v) (varId v) mem
        in  ((), mem1)
        where
                freeRegister :: Rate -> Int -> Registers s -> Registers s
                freeRegister rate i = onRegs $ M.update (Just . freeId i) rate

--------------------------------------------------------------------------

analyse :: Traversable f => Int -> Dag f -> LivenessTable
analyse lastFreshId as = A.runSTUArray $ do
        arr <- A.newArray (0, lastFreshId) 0
        mapM_ (go arr) $ countLines as
        return arr
        where
                go :: Traversable f => StArr s -> (LineNumber, Exp f) -> ST s ()
                go arr (lineNum, (_, rhs)) =  mapM (countVar arr lineNum) rhs >> return ()

                countVar :: StArr s  -> LineNumber -> Var -> ST s ()
                countVar arr lineNum v = do
                        val <- A.readArray arr i
                        A.writeArray arr i (val `max` lineNum)
                        where i = varId v

substExp :: Traversable f => (LineNumber, Exp f) -> Memory s (Exp f)
substExp (lineNum, (lhs, rhs)) = do
        freshLhs <- traverse substLhs lhs
        freshRhs <- traverse (substRhs lineNum) rhs
        return (freshLhs, freshRhs)

initSt :: Int -> LivenessTable -> ST s (Registers s)
initSt lastFreshId livenessTab = fmap (Registers initRegs livenessTab) (initSubstTable lastFreshId)

initSubstTable :: Int ->  ST s (SubstTable s)
initSubstTable n = A.newListArray (0, n+1) [0 .. n + 1]