{-# OPTIONS -fglasgow-exts -fallow-undecidable-instances #-} 

-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Trie
-- Copyright   :  (c) Keith Wansbrough 2005, Christian Maeder 2006, Jean-Philippe Bernardy 2006
-- License     :  BSD-style
-- 
-- Maintainer  :  jeanphilippe.bernardy; google mail.
-- Stability   :  volatile
-- Portability :  unknown
--
--  This module provides a basic implementation of the Trie data type.
--
-- Note: performance is currently rather bad. See the benchmark directory. Please contribute :)
--
-----------------------------------------------------------------------------

module Data.Trie
    (
    -- * Data type
    Trie(..)
    -- * Operators
    , (!)
    -- , (\\)
    -- * Query
    , null
    -- , size
    , member
    , lookup
    , prefixLookup
    -- * Construction
    , empty
    , singleton
    -- ** Insertion
    , insert
    , insertWith
    -- ** Delete\/Update
    -- , delete
    -- , adjust
    -- , update
    , alter
    -- * Combine
    -- ** Union
    , union         
    , unionWith          
    -- , unions
    -- , unionsWith
    -- ** Difference
    , difference
    , differenceWith
    -- ** Intersection
    , intersection           
    , intersectionWith
    -- * Traversal
    -- ** Map

    -- , map
    -- ** Fold

    -- , foldr
    -- * Conversion
    , retypeKeys

    -- , elems
    -- , keys
    , fromAscList
    , fromList
    , fromListWith
    , toList
    -- * Filter 

    , filter
    -- , partition
    --, split         
    --, splitLookup   
    -- * Submap

    -- , isSubmapOf
    , isSubmapOfBy
    -- * Primitive accessors
    , upwards, downwards
    -- * Derived operations
    , takeWhile, takeWhile', fringe
    -- * Debugging
    , toTree
    ) where
           
import Control.Monad
import Data.Collections (Sequence, (|>), (><))
import Data.Maybe
import Data.Monoid
import Data.Tree
import Data.Typeable
import qualified Data.List as List
import Prelude hiding (takeWhile, null, lookup, map, foldr, filter)
import qualified Data.Collections as C
import qualified Data.Foldable as F
import qualified Data.Map.AVL as M

-- | A Trie with key elements of type @k@ (keys of type @[k]@) and values of type @v@.
-- Note that the type is not opaque: user can pattern match on it and construct and Trie value.
-- This is because there is no non-trivial invariant to preserve.
data Trie s k v = Trie { value :: !(Maybe v),
                         children :: !(M.Map k (Trie s k v))
                       } 
-- FIXME: Strictness annotations should NOT be needed.
-- The s type parameter is there to satisfy FDs, maybe it could be removed if this is ported to ATS.

#include "Typeable.h"
INSTANCE_TYPEABLE3(Trie,theTc,"Data.Trie.Trie")

retypeKeys :: Trie s1 k v -> Trie s2 k v
retypeKeys (Trie v cs) = Trie v (fmap retypeKeys cs)

toMaybe :: (a -> Bool) -> a -> Maybe a
toMaybe f b = if f b then Nothing else Just b

alter :: forall s k v. (C.Foldable s k, Ord k) => (Maybe v -> Maybe v) -> s -> Trie s k v -> Trie s k v
alter f s t = C.foldr rek zero s t
    where zero (Trie v cs) = (Trie (f v) cs) 
          rek k sub (Trie v cs) = Trie v (C.alter (f' sub) k cs)
          f' sub t = toMaybe null (sub (fromMaybe empty t))
          -- recursive application: need to "create" empty nodes in case f creates a leaf node.

-- alternate version faster for insertion/not touching nodes, but requires sequence.
-- alter f s (Trie v cs) = case C.front s of
--                           Nothing -> Trie (f v) cs
--                           Just (k,ks) -> Trie v (M.alter (f' ks) k cs)
--     where f' ks Nothing = fmap (singleton ks) (f Nothing)
--           f' ks (Just t) = toMaybe (alter f ks t)

adjust :: forall s k v. (C.Foldable s k, Ord k) => (v -> v) -> s -> Trie s k v -> Trie s k v
adjust f s t = C.foldr rek zero s t
    where zero t@(Trie Nothing _) = t
          zero (Trie (Just v) cs) = (Trie (Just (f v)) cs) 
          rek k sub (Trie v cs) = Trie v (C.adjust sub k cs)
        
-- | Modify the 'children' field of a trie.
value_u :: (Maybe v -> Maybe v) -> Trie s k v -> Trie s k v
value_u f p = p { value = f (value p) }

-- | Modify the 'children' field of a trie.
children_u :: (M.Map k (Trie s k v) -> M.Map k (Trie s k v)) -> Trie s k v -> Trie s k v
children_u f p = p { children = f (children p) }

-- | The empty trie.
empty :: Ord k => Trie s k v
empty = Trie { value = Nothing, children = C.empty }

-- | Is the trie empty ?
null :: Trie s k v -> Bool
null (Trie Nothing cs) = C.null cs
null _ = False

-- | The singleton trie.
singleton :: (Ord k, C.Foldable s k) => s -> v -> Trie s k v
singleton k x = C.foldr singleton_ (Trie (Just x) C.empty) k
    where singleton_ k sub = Trie {value = Nothing, children = C.singleton (k,sub)}

-- | Combining two tries.  The first shadows the second.
union :: Ord k => Trie s k v -> Trie s k v -> Trie s k v
union p1 p2 =
    Trie {
          value = mplus (value p1) (value p2),
          children = C.unionWith union (children p1) (children p2)
         }

-- | Combining two tries.  If the two define the same key, the
-- specified combining function is used.
unionWith :: Ord k => (v -> v -> v) -> Trie s k v -> Trie s k v -> Trie s k v
unionWith f p1 p2 =
    Trie {
          value = lift (value p1) (value p2),
          children = C.unionWith (unionWith f) (children p1) (children p2)
         }
    where lift Nothing y = y
          lift x Nothing = x
          lift (Just x) (Just y) = Just (f x y)
  
-- | Combining two tries.  If the two tries define the same key, the
-- specified combining function is used.
intersectionWith :: Ord k => (v -> v -> v) -> Trie s k v -> Trie s k v -> Trie s k v
intersectionWith f p1 p2 =
    Trie {
          value = lift (value p1) (value p2),
          children = C.filter (not . null . snd) $ C.intersectionWith (intersectionWith f) (children p1) (children p2)
         }
    where lift (Just x) (Just y) = Just (f x y)
          lift _ _ = Nothing

intersection :: Ord k => Trie s k v -> Trie s k v -> Trie s k v
intersection = intersectionWith const

differenceWith :: Ord k => (v -> v -> Maybe v) -> Trie s k v -> Trie s k v -> Trie s k v
differenceWith f p1 p2 =
    Trie {
          value = lift (value p1) (value p2),
          children = C.differenceWith combine (children p1) (children p2)
         }
    where lift Nothing _ = Nothing
          lift (Just x) Nothing = Just x
          lift (Just x) (Just y) = f x y
          combine x y = let i = differenceWith f x y in if null i then Nothing else Just i

difference :: Ord k => Trie s k v -> Trie s k v -> Trie s k v
difference = differenceWith (\_ _->Nothing)

isSubmapOfBy :: Ord k => (v -> v -> Bool) -> Trie s k v -> Trie s k v -> Bool
isSubmapOfBy f p1 p2 = ok (value p1) (value p2) && 
                       C.isSubmapBy (isSubmapOfBy f) (children p1) (children p2)
    where ok Nothing _ = True
          ok _ Nothing = False
          ok (Just x) (Just y) = f x y

lookup :: forall s m k v. (C.Foldable s k, Monad m, Ord k) => s -> Trie s k v -> m v
lookup s t = maybe (fail "key not found in Trie") return 
             (C.foldl' lookup_ (Just t) s >>= value) 
    where --lookup_ :: k -> Maybe (Trie s k v) -> Maybe (Trie s k v)
          lookup_ t k = t >>= C.lookup k . children

(!) :: forall s k v. (C.Foldable s k, Ord k) => Trie s k v -> s -> v
(!) = (C.!)

member :: forall s k v. (C.Foldable s k, Ord k) => s -> Trie s k v -> Bool
member k = isJust . lookup k

insert :: forall s k v. (C.Foldable s k, Ord k) => s -> v -> Trie s k v -> Trie s k v
insert = insertWith const

insertWith :: forall s k v. (C.Foldable s k, Ord k) => (v -> v -> v) -> s -> v -> Trie s k v -> Trie s k v
insertWith f k a c = alter (\x -> Just $ case x of {Nothing->a;Just a' -> f a a'}) k c

-- | @prefixLookup k p@ returns a sequence of all @(k',v)@ pairs, such that @k@ is a prefix of @k'@. 
-- The sequence is sorted by lexicographic order of keys.
prefixLookup :: forall s k v result. (Ord k, Sequence s k, Sequence result (s,v)) => s -> Trie s k v -> result
prefixLookup ks p = getNode p >< C.concatMap (\(k,p') -> prefixLookup (ks |> k) p') (C.toList (children p))
    where getNode :: Trie s k v -> result
          getNode p = maybe C.empty (\v -> C.singleton (ks,v)) (value p)

-- | An upwards accumulation on the trie.
upwards :: Ord k => (Trie s k v -> Trie s k v) -> Trie s k v -> Trie s k v
upwards f = f . children_u (fmap (upwards f))

-- | A downwards accumulation on the trie.
downwards :: Ord k => (Trie s k v -> Trie s k v) -> Trie s k v -> Trie s k v
downwards f = children_u (fmap (downwards f)) . f

-- | Return the prefix of the trie satisfying @f@.
takeWhile :: Ord k => (Trie s k v -> Bool) -> Trie s k v -> Trie s k v
takeWhile f = downwards (children_u (C.filter (f . snd)))

-- | Return the prefix of the trie satisfying @f@ on all values present.
takeWhile' :: Ord k => (v -> Bool) -> Trie s k v -> Trie s k v
takeWhile' f = takeWhile (maybe True f . value)

-- | Return the fringe of the trie (the trie composed of only the leaf nodes).
fringe :: Ord k => Trie s k v -> Trie s k v
fringe = upwards (\p -> if C.null (children p) then p else value_u (const Nothing) p)


toList :: (Sequence s k, Ord k) => Trie s k v -> [(s,v)]
toList = C.toList

-- TODO: put those in the class instances.

fromAscList :: forall s k v. (Sequence s k, Ord k) => [(s,v)] -> Trie s k v
fromAscList l = Trie (fmap snd . listToMaybe $ values)
                     (M.fromAscList $ List.map mkVal $ List.groupBy (testing (C.head . fst)) l')
    where (values, l') = span (C.null . fst) l
          mkVal grp = (C.head . fst . head $ grp, fromAscList $ fmap dropHead grp) 
          dropHead (k, val) = (C.tail k, val)

testing :: Eq b => (a -> b) -> (a -> a -> Bool)
testing f x y = f x == f y

fromList :: forall s k v. (Sequence s k, Ord k) => [(s,v)] -> Trie s k v
fromList = fromListWith (\x _ -> x)

fromListWith :: forall s k v. (Sequence s k, Ord k) => (v -> v -> v) -> [(s,v)] -> Trie s k v
fromListWith f l = Trie (reduce values) (fmap (fromListWith f) subMap)
    where (values,l') = List.partition (C.null . fst) l
          mkVal (k, val) = (C.head k, [(C.tail k, val)]) 
          subMap = M.fromListWith (flip (++)) $ fmap mkVal l'
          reduce [] = Nothing
          reduce l = Just (List.foldr1 f . fmap snd $ l)



filterWithKey :: forall k v s. (Ord k, Sequence s k) => (s -> v -> Bool) -> Trie s k v -> Trie s k v
filterWithKey f t = f' C.empty t
    where f' :: s -> Trie s k v -> Trie s k v
          f' ks t = Trie (do {x <- value t;
                              if f ks x then return x else Nothing}) 
                         (C.filter (not . null . snd) $ C.mapWithKey (\k -> f' (ks |> k)) (children t))

filter :: forall k v s. (Ord k, Sequence s k) => (v -> Bool) -> Trie s k v -> Trie s k v
filter f (Trie v cs) = Trie (f' v) (C.filter (not . null . snd) $ fmap (filter f) cs)
    where f' v@(Just x) | f x = v
          f' _ = Nothing

mapWithKey :: forall k v v' s. (Ord k, Sequence s k) => (s -> v -> v') -> Trie s k v -> Trie s k v'
mapWithKey f t = f' C.empty t
    where f' :: s -> Trie s k v -> Trie s k v'
          f' ks t = Trie (fmap (f ks) (value t))
                         (C.mapWithKey' (\k -> f' (ks |> k)) (children t))

instance F.Foldable (Trie s k) where
    foldMap f t = F.foldMap f (value t) `mappend` F.foldMap (F.foldMap f) (children t)

instance Sequence s k => C.Foldable (Trie s k v) (s,v) where
    null = null
    foldMap f t = fm C.empty f t
        where fm ks f t = C.foldMap f (fmap (\v->(ks,v)) (value t))
                          `mappend` 
                          C.foldMap (\(k,t) -> fm (ks |> k) f t) (children t)

instance (Ord k, Sequence s k) => C.Unfoldable (Trie s k v) (s,v) where
    insert = uncurry (C.insertWith (\x _ -> x))
    empty = empty
    insertMany l c | null c    = fromList (C.toList l)
                   | otherwise = C.foldr C.insert c l
    insertManySorted l c | null c    = fromAscList (C.toList l)
                         | otherwise = C.foldr C.insert c l
    {-# SPECIALIZE instance C.Unfoldable (Trie String Char v) (String,v) #-}    

instance (Ord k, Sequence s k) => C.Collection (Trie s k v) (s,v) where
    filter f = filterWithKey (curry f)

instance (Ord k,Sequence s k) => C.Map (Trie s k v) s v where
    alter = alter
    lookup = lookup
    intersectionWith = intersectionWith
    fromFoldableWith f = fromListWith f . C.toList
    unionWith = unionWith
    isSubmapBy = isSubmapOfBy
    differenceWith = differenceWith
    mapWithKey = mapWithKey
    {-# SPECIALIZE instance C.Map (Trie String Char v) String v #-}

instance (Ord k,C.Foldable s k) => C.Indexed (Trie s k v) s v where
    index k = fromJust . lookup k
    adjust = adjust
    inDomain = member

instance (Show k, Show v) => Show (Trie [k] k v) where
    show t = "fromList " ++ show (C.toList t :: [([k],v)])

instance Ord k => Monoid (Trie s k v) where
    mempty = empty
    mappend = union

instance (Eq k, Eq v) => Eq (Trie s k v) where
    (Trie v cs) == (Trie v' cs') = v == v' && cs == cs' 

toTree :: k -> Trie s k v -> Tree (k,Maybe v)
toTree k (Trie v cs) = Node (k,v) $ C.foldr f [] cs
    where f (k,t) = (toTree k t :)


-- foldWithKey :: Ord k => ([k] -> a -> b -> b) -> b -> Map k a -> b
-- foldWithKey f k t = f' [] k t
--     where f' :: [k] -> b -> Map k a -> b
--           f' ks t = Trie (do {x <- value t;
--                               if f ks x then return x else Nothing}) 
--                     (C.mapWithKey (\k -> f' (ks++[k])) (children t))
--                     -- C.filter (not . null) . 
--