{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | This module implements a strict 'TreeMap',
-- which is like a 'Map'
-- but whose key is now a 'NonEmpty' list of 'Map' keys (a 'Path')
-- enabling the possibility to gather mapped values
-- by 'Path' prefixes (inside a 'Node').
module Data.TreeMap.Strict where

import           Control.Applicative (Applicative(..))
import           Control.DeepSeq (NFData(..))
import           Control.Monad (Monad(..))
import           Data.Bool
import           Data.Data (Data)
import           Data.Eq (Eq)
import           Data.Foldable (Foldable, foldMap)
import           Data.Function (($), (.), const, flip, id)
import           Data.Functor (Functor(..), (<$>))
import qualified Data.List
import qualified Data.List.NonEmpty
import           Data.List.NonEmpty (NonEmpty(..))
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import           Data.Maybe (Maybe(..), maybe)
import           Data.Monoid (Monoid(..))
import           Data.Ord (Ord(..))
import qualified Data.Strict.Maybe as Strict
import           Data.Traversable (Traversable(..))
import           Data.Typeable (Typeable)
import           Prelude (Int, Num(..), seq)
import           Text.Show (Show(..))

-- @Data.Strict@ orphan instances
deriving instance Data x => Data (Strict.Maybe x)
deriving instance Typeable Strict.Maybe
instance Monoid x => Monoid (Strict.Maybe x) where
	mempty = Strict.Nothing
	mappend (Strict.Just x) (Strict.Just y) = Strict.Just (x `mappend` y)
	mappend x Strict.Nothing = x
	mappend Strict.Nothing y = y
instance NFData x => NFData (Strict.Maybe x) where
	rnf Strict.Nothing = ()
	rnf (Strict.Just x) = rnf x

-- * Type 'TreeMap'

newtype TreeMap k x
 =      TreeMap (Map k (Node k x))
 deriving (Data, Eq, Show, Typeable)

instance (Ord k, Monoid v) => Monoid (TreeMap k v) where
	mempty = empty
	mappend = union mappend
	-- mconcat = Data.List.foldr mappend mempty
instance Ord k => Functor (TreeMap k) where
	fmap f (TreeMap m) = TreeMap $ fmap (fmap f) m
instance Ord k => Foldable (TreeMap k) where
	foldMap f (TreeMap m) = foldMap (foldMap f) m
instance Ord k => Traversable (TreeMap k) where
	traverse f (TreeMap m) = TreeMap <$> traverse (traverse f) m
instance (Ord k, NFData k, NFData x) => NFData (TreeMap k x) where
	rnf (TreeMap m) = rnf m

-- * Type 'Path'

-- | A 'Path' is a non-empty list of 'Map' keys.
type Path = NonEmpty

path :: k -> [k] -> Path k
path = (:|)

list :: Path k -> [k]
list = Data.List.NonEmpty.toList

reverse :: Path k -> Path k
reverse = Data.List.NonEmpty.reverse

-- * Type 'Node'
data Node k x
 =   Node
 {   node_size        :: !Int -- ^ The number of non-'Strict.Nothing' 'node_value's reachable from this 'Node'.
 ,   node_value       :: !(Strict.Maybe x) -- ^ Some value, or 'Strict.Nothing' if this 'Node' is intermediary.
 ,   node_descendants :: !(TreeMap k x) -- ^ Descendants 'Node's.
 } deriving (Data, Eq, Show, Typeable)

instance (Ord k, Monoid v) => Monoid (Node k v) where
	mempty = node Strict.Nothing (TreeMap mempty)
	mappend
	 Node{node_value=x0, node_descendants=m0}
	 Node{node_value=x1, node_descendants=m1} =
		node (x0 `mappend` x1) (union const m0 m1)
	-- mconcat = Data.List.foldr mappend mempty
instance Ord k => Functor (Node k) where
	fmap f Node{node_value=x, node_descendants=m, node_size} =
		Node
		 { node_value = fmap f x
		 , node_descendants = map f m
		 , node_size
		 }
instance Ord k => Foldable (Node k) where
	foldMap f Node{node_value=Strict.Nothing, node_descendants=TreeMap m} =
		foldMap (foldMap f) m
	foldMap f Node{node_value=Strict.Just x, node_descendants=TreeMap m} =
		f x `mappend` foldMap (foldMap f) m
instance Ord k => Traversable (Node k) where
	traverse f Node{node_value=Strict.Nothing, node_descendants=TreeMap m, node_size} =
		Node node_size <$> pure Strict.Nothing <*> (TreeMap <$> traverse (traverse f) m)
	traverse f Node{node_value=Strict.Just x, node_descendants=TreeMap m, node_size} =
		Node node_size <$> (Strict.Just <$> f x) <*> (TreeMap <$> traverse (traverse f) m)
instance (Ord k, NFData k, NFData x) => NFData (Node k x) where
	rnf (Node s v d) = rnf s `seq` rnf v `seq` rnf d

node :: Strict.Maybe x -> TreeMap k x -> Node k x
node node_value node_descendants =
	Node
	 { node_value
	 , node_size =
		size node_descendants +
		Strict.maybe 0 (const 1) node_value
	 , node_descendants
	 }

node_empty :: Node k x
node_empty = node Strict.Nothing empty

node_find :: Ord k => [k] -> Node k x -> Strict.Maybe (Node k x)
node_find [] n = Strict.Just n
node_find (k:ks) Node{node_descendants=TreeMap m} =
	maybe Strict.Nothing (node_find ks) $
	Map.lookup k m

-- * Construct

-- | Return the empty 'TreeMap'.
empty :: TreeMap k x
empty = TreeMap Map.empty

-- | Return a 'TreeMap' only mapping the given 'Path' to the given value.
singleton :: Ord k => Path k -> x -> TreeMap k x
singleton ks x = insert const ks x empty

-- | Return a 'Node' only containing the given value.
leaf :: x -> Node k x
leaf x = node (Strict.Just x) empty

-- | Return the given 'TreeMap' associating the given 'Path' with the given value,
-- merging values if the given 'TreeMap' already associates the given 'Path'
-- with a non-'Strict.Nothing' 'node_value'.
insert :: Ord k => (x -> x -> x) -> Path k -> x -> TreeMap k x -> TreeMap k x
insert merge (k:|[]) x (TreeMap m) =
	TreeMap $
	Map.insertWith (\_ Node{..} -> node
		 (Strict.maybe (Strict.Just x) (Strict.Just . merge x) node_value)
		 node_descendants)
	 k (leaf x) m
insert merge (k:|k':ks) x (TreeMap m) =
	TreeMap $
	Map.insertWith (\_ Node{..} -> node node_value $
		insert merge (path k' ks) x node_descendants)
	 k
	 (node Strict.Nothing (insert merge (path k' ks) x empty))
	 m

-- | Return a 'TreeMap' associating for each tuple of the given list
-- the 'Path' to the value,
-- merging values of identical 'Path's (in respective order).
from_List :: Ord k => (x -> x -> x) -> [(Path k, x)] -> TreeMap k x
from_List merge = Data.List.foldl (\acc (p, x) -> insert merge p x acc) empty

-- | Return a 'TreeMap' associating for each key and value of the given 'Map'
-- the 'Path' to the value,
-- merging values of identical 'Path's (in respective order).
from_Map :: Ord k => (x -> x -> x) -> Map (Path k) x -> TreeMap k x
from_Map merge = Map.foldlWithKey (\acc p x -> insert merge p x acc) empty

-- * Size

-- | Return the 'Map' in the given 'TreeMap'.
nodes :: TreeMap k x -> Map k (Node k x)
nodes (TreeMap m) = m

-- | Return 'True' iif. the given 'TreeMap' is 'empty'.
null :: TreeMap k x -> Bool
null (TreeMap m) = Map.null m

-- | Return the number of non-'Strict.Nothing' 'node_value's in the given 'TreeMap'.
--
--   * Complexity: O(r) where r is the size of the root 'Map'.
size :: TreeMap k x -> Int
size = Map.foldr ((+) . node_size) 0 . nodes

-- * Find

-- | Return the value (if any) associated with the given 'Path'.
find :: Ord k => Path k -> TreeMap k x -> Strict.Maybe x
find (k:|[]) (TreeMap m) = maybe Strict.Nothing node_value $ Map.lookup k m
find (k:|k':ks) (TreeMap m) =
	maybe Strict.Nothing (find (path k' ks) . node_descendants) $
	Map.lookup k m

-- | Return the values (if any) associated with the prefixes of the given 'Path' (included).
find_along :: Ord k => Path k -> TreeMap k x -> [x]
find_along p (TreeMap tm) =
	go (list p) tm
	where
		go :: Ord k => [k] -> Map k (Node k x) -> [x]
		go [] _m = []
		go (k:ks) m =
			case Map.lookup k m of
			 Nothing -> []
			 Just nod ->
				Strict.maybe id (:) (node_value nod) $
				go ks $ nodes (node_descendants nod)

-- | Return the 'Node' (if any) associated with the given 'Path'.
find_node :: Ord k => Path k -> TreeMap k x -> Maybe (Node k x)
find_node (k:|[]) (TreeMap m) = Map.lookup k m
find_node (k:|k':ks) (TreeMap m) =
	Map.lookup k m >>=
	find_node (path k' ks) . node_descendants

-- * Union

-- | Return a 'TreeMap' associating the same 'Path's as both given 'TreeMap's,
-- merging values (in respective order) when a 'Path' leads
-- to a non-'Strict.Nothing' 'node_value' in both given 'TreeMap's.
union :: Ord k => (x -> x -> x) -> TreeMap k x -> TreeMap k x -> TreeMap k x
union merge (TreeMap tm0) (TreeMap tm1) =
	TreeMap $
	Map.unionWith
	 (\Node{node_value=x0, node_descendants=m0}
	   Node{node_value=x1, node_descendants=m1} ->
		node (Strict.maybe x1 (\x0' -> Strict.maybe (Strict.Just x0') (Strict.Just . merge x0') x1) x0)
		 (union merge m0 m1))
	 tm0 tm1

-- | Return the 'union' of the given 'TreeMap's.
--
-- NOTE: use 'Data.List.foldl'' to reduce demand on the control-stack.
unions :: Ord k => (x -> x -> x) -> [TreeMap k x] -> TreeMap k x
unions merge = Data.List.foldl' (union merge) empty

-- foldl' :: (a -> b -> a) -> a -> [b] -> a
-- foldl' f = go
-- 	where
-- 		go z []     = z
-- 		go z (x:xs) = z `seq` go (f z x) xs

-- * Map

-- | Return the given 'TreeMap' with each non-'Strict.Nothing' 'node_value'
-- mapped by the given function.
map :: Ord k => (x -> y) -> TreeMap k x -> TreeMap k y
map f =
	TreeMap .
	Map.map
	 (\n@Node{node_value=x, node_descendants=m} ->
		n{ node_value       = fmap f x
		 , node_descendants = map f m
		 }) .
	nodes

-- | Return the given 'TreeMap' with each 'Path' section
-- and each non-'Strict.Nothing' 'node_value'
-- mapped by the given functions.
--
-- WARNING: the function mapping 'Path' sections must be monotonic,
-- like in 'Map.mapKeysMonotonic'.
map_monotonic :: (Ord k, Ord l) => (k -> l) -> (x -> y) -> TreeMap k x -> TreeMap l y
map_monotonic fk fx =
	TreeMap .
	Map.mapKeysMonotonic fk .
	Map.map
	 (\n@Node{node_value=x, node_descendants=m} ->
		n{ node_value       = fmap fx x
		 , node_descendants = map_monotonic fk fx m
		 }) .
	nodes

-- | Return the given 'TreeMap' with each 'node_value'
-- mapped by the given function supplied with
-- the already mapped 'node_descendants' of the current 'Node'.
map_by_depth_first :: Ord k => (TreeMap k y -> Strict.Maybe x -> y) -> TreeMap k x -> TreeMap k y
map_by_depth_first f =
	TreeMap .
	Map.map
	 (\Node{node_value, node_descendants} ->
		let m = map_by_depth_first f node_descendants in
		node (Strict.Just $ f m node_value) m) .
	nodes

-- * Alter

alterl_path :: Ord k => (Strict.Maybe x -> Strict.Maybe x) -> Path k -> TreeMap k x -> TreeMap k x
alterl_path fct =
	go fct . list
	where
		go :: Ord k
		 => (Strict.Maybe x -> Strict.Maybe x) -> [k]
		 -> TreeMap k x -> TreeMap k x
		go _f [] m = m
		go f (k:p) (TreeMap m) =
			TreeMap $
			Map.alter
			 (\c ->
				let (cv, cm) =
					case c of
					 Just Node{node_value=v, node_descendants=d} -> (v, d)
					 Nothing -> (Strict.Nothing, empty) in
				let fx = f cv in
				let gm = go f p cm in
				case (fx, size gm) of
				 (Strict.Nothing, 0) -> Nothing
				 (_, s) -> Just
					Node
					 { node_value = fx
					 , node_descendants = gm
					 , node_size = s + 1
					 }
			 ) k m

-- * Fold

-- | Return the given accumulator folded by the given function
-- applied on non-'Strict.Nothing' 'node_value's
-- from left to right through the given 'TreeMap'.
foldl_with_Path :: Ord k => (a -> Path k -> x -> a) -> a -> TreeMap k x -> a
foldl_with_Path =
	foldp []
	where
		foldp :: Ord k
		 => [k] -> (a -> Path k -> x -> a)
		 -> a -> TreeMap k x -> a
		foldp p fct a (TreeMap m) =
			Map.foldlWithKey
			 (\acc k Node{..} ->
				let acc' = Strict.maybe acc (fct acc (reverse $ path k p)) node_value in
				foldp (k:p) fct acc' node_descendants) a m

-- | Return the given accumulator folded by the given function
-- applied on non-'Strict.Nothing' 'Node's and 'node_value's
-- from left to right through the given 'TreeMap'.
foldl_with_Path_and_Node :: Ord k => (a -> Node k x -> Path k -> x -> a) -> a -> TreeMap k x -> a
foldl_with_Path_and_Node =
	foldp []
	where
		foldp :: Ord k
		 => [k] -> (a -> Node k x -> Path k -> x -> a)
		 -> a -> TreeMap k x -> a
		foldp p fct a (TreeMap m) =
			Map.foldlWithKey
			 (\acc k n@Node{..} ->
				let acc' = Strict.maybe acc (fct acc n (reverse $ path k p)) node_value in
				foldp (k:p) fct acc' node_descendants) a m

-- | Return the given accumulator folded by the given function
-- applied on non-'Strict.Nothing' 'node_value's
-- from right to left through the given 'TreeMap'.
foldr_with_Path :: Ord k => (Path k -> x -> a -> a) -> a -> TreeMap k x -> a
foldr_with_Path =
	foldp []
	where
		foldp :: Ord k
		 => [k] -> (Path k -> x -> a -> a)
		 -> a -> TreeMap k x -> a
		foldp p fct a (TreeMap m) =
			Map.foldrWithKey
			 (\k Node{..} acc ->
				let acc' = foldp (k:p) fct acc node_descendants in
				Strict.maybe acc' (\x -> fct (reverse $ path k p) x acc') node_value) a m

-- | Return the given accumulator folded by the given function
-- applied on non-'Strict.Nothing' 'Node's and 'node_value's
-- from right to left through the given 'TreeMap'.
foldr_with_Path_and_Node :: Ord k => (Node k x -> Path k -> x -> a -> a) -> a -> TreeMap k x -> a
foldr_with_Path_and_Node =
	foldp []
	where
		foldp :: Ord k
		 => [k] -> (Node k x -> Path k -> x -> a -> a)
		 -> a -> TreeMap k x -> a
		foldp p fct a (TreeMap m) =
			Map.foldrWithKey
			 (\k n@Node{..} acc ->
				let acc' = foldp (k:p) fct acc node_descendants in
				Strict.maybe acc' (\x -> fct n (reverse $ path k p) x acc') node_value) a m

-- | Return the given accumulator folded by the given function
-- applied on non-'Strict.Nothing' 'node_value's
-- from left to right along the given 'Path'.
foldl_path :: Ord k => (Path k -> x -> a -> a) -> Path k -> TreeMap k x -> a -> a
foldl_path fct =
	go fct [] . list
	where
		go :: Ord k
		 => (Path k -> x -> a -> a) -> [k] -> [k]
		 -> TreeMap k x -> a -> a
		go _f _ [] _t a = a
		go f p (k:n) (TreeMap t) a =
			case Map.lookup k t of
			 Nothing -> a
			 Just Node{..} ->
				case node_value of
				 Strict.Nothing -> go f (k:p) n node_descendants a
				 Strict.Just x  -> go f (k:p) n node_descendants (f (reverse $ path k p) x a)

-- | Return the given accumulator folded by the given function
-- applied on non-'Strict.Nothing' 'node_value's
-- from right to left along the given 'Path'.
foldr_path :: Ord k => (Path k -> x -> a -> a) -> Path k -> TreeMap k x -> a -> a
foldr_path fct =
	go fct [] . list
	where
		go :: Ord k
		 => (Path k -> x -> a -> a) -> [k] -> [k]
		 -> TreeMap k x -> a -> a
		go _f _ [] _t a = a
		go f p (k:n) (TreeMap t) a =
			case Map.lookup k t of
			 Nothing -> a
			 Just Node{..} ->
				case node_value of
				 Strict.Nothing -> go f (k:p) n node_descendants a
				 Strict.Just x  -> f (reverse $ path k p) x $ go f (k:p) n node_descendants a

-- * Flatten

-- | Return a 'Map' associating each 'Path'
-- leading to a non-'Strict.Nothing' 'node_value' in the given 'TreeMap',
-- with its value mapped by the given function.
flatten :: Ord k => (x -> y) -> TreeMap k x -> Map (Path k) y
flatten = flatten_with_Path . const

-- | Like 'flatten' but with also the current 'Path' given to the mapping function.
flatten_with_Path :: Ord k => (Path k -> x -> y) -> TreeMap k x -> Map (Path k) y
flatten_with_Path =
	flat_map []
	where
		flat_map :: Ord k
		 => [k] -> (Path k -> x -> y)
		 -> TreeMap k x
		 -> Map (Path k) y
		flat_map p f (TreeMap m) =
			Map.unions $
			Map.mapKeysMonotonic (reverse . flip path p) (
			Map.mapMaybeWithKey (\k Node{node_value} ->
				case node_value of
				 Strict.Nothing -> Nothing
				 Strict.Just x  -> Just $ f (reverse $ path k p) x) m
			) :
			Map.foldrWithKey
			 (\k -> (:) . flat_map (k:p) f . node_descendants)
			 [] m

-- * Filter

-- | Return the given 'TreeMap'
--   keeping only its non-'Strict.Nothing' 'node_value's
--   passing the given predicate.
filter :: Ord k => (x -> Bool) -> TreeMap k x -> TreeMap k x
filter f =
	map_Maybe_with_Path
	 (\_p x -> if f x then Strict.Just x else Strict.Nothing)

-- | Like 'filter' but with also the current 'Path' given to the predicate.
filter_with_Path :: Ord k => (Path k -> x -> Bool) -> TreeMap k x -> TreeMap k x
filter_with_Path f =
	map_Maybe_with_Path
	 (\p x -> if f p x then Strict.Just x else Strict.Nothing)

-- | Like 'filter_with_Path' but with also the current 'Node' given to the predicate.
filter_with_Path_and_Node :: Ord k => (Node k x -> Path k -> x -> Bool) -> TreeMap k x -> TreeMap k x
filter_with_Path_and_Node f =
	map_Maybe_with_Path_and_Node
	 (\n p x -> if f n p x then Strict.Just x else Strict.Nothing)

-- | Return the given 'TreeMap'
--   mapping its non-'Strict.Nothing' 'node_value's
--   and keeping only the non-'Strict.Nothing' results.
map_Maybe :: Ord k => (x -> Strict.Maybe y) -> TreeMap k x -> TreeMap k y
map_Maybe = map_Maybe_with_Path . const

-- | Like 'map_Maybe' but with also the current 'Path' given to the predicate.
map_Maybe_with_Path :: Ord k => (Path k -> x -> Strict.Maybe y) -> TreeMap k x -> TreeMap k y
map_Maybe_with_Path = map_Maybe_with_Path_and_Node . const

-- | Like 'map_Maybe_with_Path' but with also the current 'Node' given to the predicate.
map_Maybe_with_Path_and_Node :: Ord k => (Node k x -> Path k -> x -> Strict.Maybe y) -> TreeMap k x -> TreeMap k y
map_Maybe_with_Path_and_Node =
	go []
	where
		go :: Ord k
		 => [k] -> (Node k x -> Path k -> x -> Strict.Maybe y)
		 -> TreeMap k x
		 -> TreeMap k y
		go p test (TreeMap m) =
			TreeMap $
			Map.mapMaybeWithKey
			 (\k nod@Node{node_value=v, node_descendants=ns} ->
				let node_descendants = go (k:p) test ns in
				let node_size = size node_descendants in
				case v of
				 Strict.Just x ->
					let node_value = test nod (reverse $ path k p) x in
					case node_value of
					 Strict.Nothing | null node_descendants -> Nothing
					 Strict.Nothing -> Just Node{node_value, node_descendants, node_size=1 + node_size}
					 Strict.Just _  -> Just Node{node_value, node_descendants, node_size}
				 _ ->
					if null node_descendants
					then Nothing
					else Just Node{node_value=Strict.Nothing, node_descendants, node_size}
			 ) m