module CsoundExpr.Translator.ExprTree.TreeSeq 
	(adjList,  reduceAdjList, 
	 adjLists, reduceAdjLists
	)	
where

import Control.Monad.State.Strict
import qualified Data.Map as M

import CsoundExpr.Translator.ExprTree.Tree


numOfChildsList :: Tree a -> [(a, Int)]
numOfChildsList (Node a xs) = (a, length xs) : (numOfChildsList =<< xs)

type TreeLayer a b = (a, (b, [Int]))
mkLayer a b c = (a, (b, c))
           

adjList :: Int -> [Tree a] -> [TreeLayer Int a]
adjList s = fst . foldl f ([], s)
    where f (res, id) t = let xs = fst $ runState (adjList' id t) id
                              ys = reverse xs
                          in  (res ++ ys, nextId ys)
          nextId (x:xs) = 1 + layerOut x


reduceAdjList :: Ord a => [TreeLayer Int a] -> [TreeLayer Int a]
reduceAdjList xs = fst $ runState (foldM reduceAdjList' [] xs) (0, M.empty, M.empty)


adjLists :: Int -> [Tree a] -> [(TreeLayer Int a, [TreeLayer Int a])]
adjLists s = fst . foldl f ([], s)
    where f (res, id) t = let (x:xs) = fst $ runState (adjList' id t) id
                              ys = reverse xs
                          in  (res ++ [(x, ys)], nextId x ys)
          nextId x xs = (1 + ) $ layerOut $ 
                        case xs of
                          [] -> x
                          _  -> head xs


reduceAdjLists :: Ord a => 
                  [(TreeLayer Int a, [TreeLayer Int a])]
               -> [TreeLayer Int a]                  
reduceAdjLists = fst . foldl f ([], (0, M.empty, M.empty))
    where f (res, s) (lastStm, x) = 
              registerLastStm lastStm $ 
              runState (foldM reduceAdjList' res $ x) s




adjList' :: Int -> Tree a -> State Int [TreeLayer Int a]
adjList' root (Node a xs) = 
    state $ \s -> let s'  = nextId s
                      xs' = zip (ids s) xs
                      vs  = fst $ runState (foldM foldAdjList [] xs') s'
                  in  (mkLayer root a (ids s) : vs, s')
    where ids    x = [x + 1 .. x + n]
          nextId x = x + n + 1
          n        = length xs 
          foldAdjList x y = liftM (x ++) $ (uncurry adjList') y

layerOut = fst
layerOp  = fst . snd
layerIns = snd . snd


type LayerNode a  = (a, [Int])
type TableNode a  = M.Map (LayerNode a) Int
type TableId      = M.Map Int Int
type RegisterData a = (Int, TableNode a, TableId)


reduceAdjList' :: 
    Ord a => [TreeLayer Int a]
          -> TreeLayer Int a
          -> State (RegisterData a) [(TreeLayer Int a)]
reduceAdjList' res x = 
    state $ \s@(id, mNodes, mIds) -> 
    let k = (layerOp x, map (mIds M.! ) $ layerIns x)
    in  case (M.lookup k mNodes) of
          Nothing  -> registerNewNode (layerOut x) k (res, s)
          (Just i) -> registerNewId   (layerOut x) i (res, s)


registerNewNode :: Ord a => Int -> LayerNode a 
                -> ([TreeLayer Int a], RegisterData a) 
                -> ([TreeLayer Int a], RegisterData a)
registerNewNode i k (res, (id, mNodes, mIds)) = 
    (res ++ [(id, k)], (id+1, M.insert k id mNodes, M.insert i id mIds))


registerNewId :: Ord a => Int -> Int
                -> ([TreeLayer Int a], RegisterData a) 
                -> ([TreeLayer Int a], RegisterData a)
registerNewId x i (res, (id, mNodes, mIds)) = 
    (res, (id, mNodes, M.insert x i mIds))



registerLastStm :: Ord a => 
		   TreeLayer Int a
                -> ([TreeLayer Int a], RegisterData a) 
                -> ([TreeLayer Int a], RegisterData a)
registerLastStm lastStm (res, (id, mNodes, mIds)) =
    (res ++ [(id, k)],(
                       id + 1,
                       dropLastStm lastStm mNodes, 
                       M.insert (layerOut lastStm) id mIds
                      )
    )
    where k = (layerOp lastStm, map (mIds M.! ) $ layerIns lastStm)


dropLastStm :: Ord a => TreeLayer Int a 
                     -> TableNode a -> TableNode a
dropLastStm lastStm = 
	M.filterWithKey (\k _ -> (/= layerOp lastStm) $ fst k)