-- Tree module
-- By Gregory W. Schwartz

-- | Collects all functions pertaining to trees

{-# LANGUAGE BangPatterns #-}

module Math.TreeFun.Tree where

-- Built-in
import Data.List
import Data.Tree
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Sequence as S
import Control.Applicative
import qualified Data.Foldable as F
import Control.Monad.State

-- Local
import Math.TreeFun.Types

-- | Convert a bool to an integer
boolToInt :: Bool -> Int
boolToInt True  = 1
boolToInt False = 0

-- | Find out if a node is a leaf or not
isLeaf :: Tree a -> Bool
isLeaf (Node { subForest = [] }) = True
isLeaf _                         = False

-- | Return the labels of the leaves of the tree
leaves :: Tree a -> [a]
leaves (Node { rootLabel = x, subForest = [] }) = [x]
leaves (Node { rootLabel = _, subForest = xs }) = concatMap leaves xs

-- | Return the labels of the leaves of the tree with their relative heights
-- from the root (the input number you give determines how many steps away the
-- leaves are, should almost always start at 0)
leavesHeight :: (Ord a) => Int -> Tree a -> M.Map a Int
leavesHeight !h (Node { rootLabel = x, subForest = [] }) = M.singleton x h
leavesHeight !h (Node { rootLabel = _, subForest = xs }) =
    M.unions . map (leavesHeight (h + 1)) $ xs

-- | Return the labels of the leaves of the tree with their relative heights
-- from the root (the input number you give determines how many steps away the
-- leaves are, should almost always start at 0). Also, here we give leaves that
-- share a parent a separate label.
leavesCommonHeight :: (Ord a) => Int -> Tree a -> M.Map a (Int, Int)
leavesCommonHeight startHeight tree = evalState (iter startHeight tree) 0
  where
    iter !h (Node { rootLabel = x, subForest = [] }) = do
        label <- get
        return $ M.singleton x (h, label)
    iter !h (Node { rootLabel = _, subForest = xs }) = do
        -- Get leaves and assign them the label
        ls    <- mapM (iter (h + 1)) . filter isLeaf $ xs

        -- Increment label
        label <- get
        put $ label + 1

        -- Get rest of the trees
        ts    <- mapM (iter (h + 1)) . filter (not . isLeaf) $ xs
        -- Combine the results
        return . M.unions . (++) ts $ ls

-- | Return the labels of the leaves of the tree with their weights
-- determined by the product of the number of children of their parents all
-- the way up to the root, along with their distance. Returns Double for
-- more precision.
leavesParentMult :: (Ord a) => Double
                            -> Double
                            -> Tree a
                            -> M.Map a (Double, Double)
leavesParentMult !w !d (Node { rootLabel = x, subForest = [] }) =
    M.singleton x (w, d)
leavesParentMult !w !d (Node { rootLabel = _, subForest = xs }) =
    M.unions . map (leavesParentMult (w * genericLength xs) (d + 1)) $ xs

-- | Return the labels of the leaves of the tree with their weights
-- determined by the product of the number of children of their parents all
-- the way up to the root. Also, here we give leaves that share a parent
-- a separate label.
leavesCommonParentMult :: (Ord a) => Int -> Tree a -> M.Map a (Int, Int)
leavesCommonParentMult numChildren tree = evalState (iter numChildren tree) 0
  where
    iter multChildren (Node { rootLabel = x, subForest = [] }) = do
        label <- get
        return $ M.singleton x (multChildren, label)
    iter multChildren (Node { rootLabel = _, subForest = xs }) = do
        -- Get leaves and assign them the label
        ls    <- mapM (iter (multChildren * length xs)) . filter isLeaf $ xs

        -- Increment label
        label <- get
        put $ label + 1

        -- Get rest of the trees
        ts    <- mapM (iter (multChildren * length xs))
               . filter (not . isLeaf)
               $ xs
        -- Combine the results
        return . M.unions . (++) ts $ ls

-- | Return the labels of the leaves of the tree with their relative heights
-- from the root (the input number you give determines how many steps away the
-- leaves are, should almost always start at 0), slower version not requiring
-- Ord but no Maps
leavesHeightList :: Int -> Tree a -> [(a, Int)]
leavesHeightList h (Node { rootLabel = x, subForest = [] }) = [(x, h)]
leavesHeightList h (Node { rootLabel = _, subForest = xs }) =
    concatMap (leavesHeightList (h + 1)) xs

-- | Return the inner nodes of the tree
innerNodes :: Tree a -> [a]
innerNodes (Node { rootLabel = _, subForest = [] }) = []
innerNodes (Node { rootLabel = x, subForest = xs }) = x
                                                    : concatMap innerNodes xs

-- | Return the number of leaves in a tree
numLeaves :: (Num b) => Tree a -> b
numLeaves = genericLength . leaves

-- | Return the number of inner nodes of a tree
numInner :: (Num b) => Tree a -> b
numInner = genericLength . innerNodes

-- | Return True if a tree has a leaf connected to the root of the given
-- tree
hasRootLeaf :: Tree a -> Bool
hasRootLeaf (Node { subForest = ts }) = not . null . filter isLeaf $ ts

-- | Return the list of root leaves
getRootLeaves :: Tree a -> [a]
getRootLeaves (Node { subForest = ts }) = map rootLabel . filter isLeaf $ ts

-- | Return the list of properties in a property map for a tree
getProperties :: (Eq b) => PropertyMap a b -> [b]
getProperties = nub . F.toList . F.foldl' (S.><) S.empty . M.elems

-- | Remove leaves from a tree
filterLeaves :: Tree a -> Tree a
filterLeaves tree = tree {subForest = filter (not . isLeaf) . subForest $ tree}

-- | Remove leaves attached to the root of the tree
filterRootLeaves :: Tree a -> Tree a
filterRootLeaves root@(Node { subForest = ts }) =
    root { subForest = filter (not . isLeaf) ts }

-- | Return the map of distances from each leaf to another leaf
getDistanceMap :: (Eq a, Ord a) => Tree a -> DistanceMap a
getDistanceMap tree = M.fromListWith (M.unionWith (S.><))
                    $ (\x y -> if x == y
                                   then (x, M.singleton 0 (S.singleton y))
                                   else ( x
                                        , M.singleton
                                          (getDistance tree x y)
                                          (S.singleton y) ) )
                  <$> leaves tree
                  <*> leaves tree

-- | Find the distance between two leaves in a tree.
getDistance :: (Eq a) => Tree a -> a -> a -> Int
getDistance (Node { rootLabel = l, subForest = [] }) x y = boolToInt
                                                         $ l `elem` [x, y]
getDistance n@(Node { rootLabel = _, subForest = xs }) x y
    | none      = 0
    | otherwise = sum
                . (:) (boolToInt notShared)
                . map (\t -> getDistance t x y)
                $ xs
  where
    -- Only count nodes that have one or the other, not shared or empty
    notShared = (elem x ls) || (elem y ls) && not (elem x ls && elem y ls)
      where
        ls = leaves n
    none = not (elem x ls || elem y ls)
      where
        ls = leaves n

-- | Return the map of distances from each leaf to another leaf
getDistanceMapSuperNode :: (Eq a, Ord a) => Tree (SuperNode a) -> DistanceMap a
getDistanceMapSuperNode tree = M.fromListWith (M.unionWith (S.><))
                             $ (\x y -> if x == y
                                            then
                                                (x , M.singleton 0 (S.singleton y))
                                            else ( x
                                                 , M.singleton
                                                   (getDistanceSuperNode tree x y)
                                                   (S.singleton y) ) )
                           <$> allLeaves
                           <*> allLeaves
  where
    allLeaves = M.keys . myLeaves . rootLabel $ tree

-- | Find the distance between two leaves in a leafNode tree. Begin recording
-- distances when record is True (should have height starting at 0)
getDistanceSuperNode :: (Eq a, Ord a) => Tree (SuperNode a) -> a -> a -> Int
getDistanceSuperNode (Node { rootLabel = SuperNode { myLeaves = ls
                                                   , myParent = p }
                           , subForest = ts } ) x y
    | shared ls    = head
                   . filter (/= 1)
                   . map (\a -> getDistanceSuperNode a x y)
                   $ ts
    | notShared ls = getParentLeafDist x p + getParentLeafDist y p
    | otherwise    = 0
  where
    -- Only count nodes that have one or the other, not shared or empty
    notShared xs = (M.member x xs || M.member y xs)
                && not (M.member x xs && M.member y xs)
    shared xs    = M.member x xs && M.member y xs
    getParentLeafDist a b = fst . fromJust . M.lookup a . myLeaves $ b

-- | Get the sum of a tree for a tree with numbered labels
sumTree :: (Num a) => Tree a -> a
sumTree = F.foldl' (+) 0

-- | Convert a tree to the LeafNode tree data structure (the leaves are in the
-- nodes)
toSuperNodeTree :: (Ord a) => SuperNode a -> Tree a -> Tree (SuperNode a)
toSuperNodeTree p n@(Node { rootLabel = x, subForest = xs }) =
    Node { rootLabel = newNode
         , subForest = map (toSuperNodeTree newNode) xs }
  where
    newNode = SuperNode { myRootLabel = x
                        , myLeaves = leavesCommonHeight 0 n
                        , myParent = p }