module CsoundExpr.Translator.ExprTree.Tree 
    (Tree(..),
     treeVal, treeChilds, subTree,
     substTreeByFunc, substTreeById,
     foldTree, toIdList, 
     equalTreeStructureBy)
where

import Text.PrettyPrint

data Tree a = Node a [Tree a]
              deriving (Eq, Ord)

ppTree :: Show a => Tree a -> Doc
ppTree (Node a xs) =  text "Node" <+> (text $ show a) $+$ 
                      nest 4 (vcat $ map ppTree xs)

instance Show a => Show (Tree a) where
    show t = show $ ppTree t


treeVal :: Tree a -> a
treeVal (Node x _) = x

treeChilds :: Tree a -> [Tree a]
treeChilds (Node _ xs) = xs


substTreeByFunc :: (Tree a -> Maybe (Tree a)) -> Tree a -> Tree a
substTreeByFunc f t@(Node a ts) = 
    case f t of
      (Just x) -> substTreeByFunc f x
      Nothing  -> (Node a $ map (substTreeByFunc f) ts)



foldTree :: (a -> b -> a) -> a -> Tree b -> a
foldTree f s (Node a ts) = foldl (foldTree f) (f s a) ts


toIdList :: Tree a -> [([Int], a)]
toIdList = f []
    where f ids (Node a as) = ((ids, a) : ) $
              concat [f (ids ++ [i]) x | (i, x) <- zip [0..] as]
 
    
substTreeById :: [Int] -> Tree a -> Tree a -> Tree a
substTreeById id subtree tree = 
    case id of 
      [] -> subtree
      xs -> substTreeById (init id) newSubtree tree
    where newSubtree = 
              let s = subTree (init id) tree
                  x = treeChilds s
                  n = last id
              in  Node (treeVal s) $ 
                      (take n x) ++ [subtree] ++ (drop (n+1) x)


subTree :: [Int] -> Tree a -> Tree a
subTree id tree = 
    case id of
      []     -> tree
      (x:xs) -> subTree xs (treeChilds tree !! x)



equalTreeStructureBy :: (a -> a -> Bool) -> Tree a -> Tree a -> Bool
equalTreeStructureBy pred (Node a as) (Node b bs) 
    | pred a b  = if length as == length bs
                  then all (uncurry $ equalTreeStructureBy pred) $ zip as bs
                  else False
    | otherwise = False