-- | Generalized tries. \"Normal\" tries encode finite maps from lists to arbitrary values, where the
-- common prefixes are shared. Here we do the same for trees, generically.
--
-- See also
--
-- * Connelly, Morris: A generalization of the trie data structure
--
-- * Ralf Hinze: Generalizing Generalized Tries
--
-- This module should be imported qualified.
--

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Trie 
  ( Trie 
    -- * Construction \/ deconstruction
  , empty , singleton
  , fromList , toList
  , bag , universeBag
    -- * Lookup
  , lookup 
    -- * Insertion \/ deletion
  , insert , insertWith
  , delete , update
    -- * Set operations
  , intersection , intersectionWith
  , union        , unionWith
  , difference   , differenceWith
#ifdef WITH_QUICKCHECK
    -- * Tests
  , runtests_Trie
#endif
  ) 
  where

---------------------------------------------------------------------------------

import Prelude hiding ( lookup )

import Data.Generics.Fixplate.Base
import Data.Generics.Fixplate.Open hiding ( toList )
import Data.Generics.Fixplate.Traversals ( universe )

import qualified Data.Foldable as Foldable
import Data.Foldable hiding ( toList )

import Data.Traversable as Traversable

import qualified Data.Map as Map ; import Data.Map (Map)

#ifdef WITH_QUICKCHECK
import Test.QuickCheck
import Data.Generics.Fixplate.Test.Tools
import Data.Generics.Fixplate.Misc
import Data.List ( sort , group , nubBy , nub , (\\) , foldl' )
import Control.Applicative ( (<$>) )
import Debug.Trace
#endif

---------------------------------------------------------------------------------

-- | Creates a trie-multiset from a list of trees.
bag :: (Functor f, Foldable f, OrdF f) => [Mu f] -> Trie f Int
bag ts = Prelude.foldl worker emptyTrie ts where
  worker trie tree = trieInsertWith id (+) tree 1 trie

-- | This is equivalent to
--
-- > universeBag == bag . universe
--
-- TODO: more efficient implementation?
universeBag :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f Int
universeBag = bag . universe

---------------------------------------------------------------------------------

empty :: (Functor f, Foldable f, OrdF f) => Trie f a 
empty = emptyTrie 

singleton :: (Functor f, Foldable f, OrdF f) => Mu f -> a -> Trie f a 
singleton = trieSingleton

lookup :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f a -> Maybe a
lookup = trieLookup 

insert :: (Functor f, Foldable f, OrdF f) => Mu f -> a -> Trie f a -> Trie f a
insert = trieInsertWith id const

insertWith :: (Functor f, Foldable f, OrdF f) => (a -> b) -> (a -> b -> b) -> Mu f -> a -> Trie f b -> Trie f b
insertWith = trieInsertWith

update :: (Functor f, Foldable f, OrdF f) => (a -> Maybe a) -> Mu f -> Trie f a -> Trie f a
update = trieUpdate

delete :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f a -> Trie f a
delete = trieUpdate (const Nothing)

-- | TODO: more efficient implementation?
fromList :: (Traversable f, OrdF f) => [(Mu f, a)] -> Trie f a
fromList ts = Prelude.foldl worker emptyTrie ts where
  worker trie (tree,value) = trieInsertWith id const tree value trie 

toList :: (Traversable f, OrdF f) => Trie f a -> [(Mu f, a)] 
toList = trieToList

intersection :: (Functor f, Foldable f, OrdF f) => Trie f a -> Trie f b -> Trie f a
intersection = trieIntersectionWith const

intersectionWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> c) -> Trie f a -> Trie f b -> Trie f c
intersectionWith = trieIntersectionWith

-- | Union is left-biased:
--
-- > union == unionWith const
--
union :: (Functor f, Foldable f, OrdF f) => Trie f a -> Trie f a -> Trie f a
union = trieUnionWith const

unionWith :: (Functor f, Foldable f, OrdF f) => (a -> a -> a) -> Trie f a -> Trie f a -> Trie f a
unionWith = trieUnionWith

difference :: (Functor f, Foldable f, OrdF f) => Trie f a -> Trie f b -> Trie f a
difference = trieDifferenceWith (\_ _ -> Nothing)

differenceWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> Maybe a) -> Trie f a -> Trie f b -> Trie f a
differenceWith = trieDifferenceWith

---------------------------------------------------------------------------------

-- | 'Trie' is an efficient(?) implementation of finite maps from @(Mu f)@ to an arbitrary type @v@.
newtype Trie f v = Trie { unTrie :: Map (HoleF f) (Chain f v) }

data Chain f v
  = Value v
  | Chain (Trie f (Chain f v))

-- this is only to be able to define an Ord instance
newtype HoleF f = HoleF { unHoleF :: f Hole }

instance EqF  f => Eq  (HoleF f) where (==)    (HoleF x) (HoleF y) = equalF   x y
instance OrdF f => Ord (HoleF f) where compare (HoleF x) (HoleF y) = compareF x y

emptyTrie :: (Functor f, Foldable f, OrdF f) => Trie f v 
emptyTrie = Trie (Map.empty)

---------------------------------------------------------------------------------

trieLookup :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f v -> Maybe v
trieLookup (Fix t) (Trie trie) = 
  case Map.lookup (HoleF s) trie of
    Nothing    -> Nothing
    Just chain -> chainLookup (Foldable.toList t) chain
  where
    s = fmap (const Hole) t

chainLookup :: (Functor f, Foldable f, OrdF f) => [Mu f] -> Chain f v -> Maybe v
chainLookup [] chain = case chain of { Value x -> Just x ; _ -> error "chainLookup: shouldn't happen #1" }
chainLookup (k:ks) chain = case chain of
  Chain sub -> case trieLookup k sub of
    Just chain -> chainLookup ks chain
    Nothing    -> Nothing
  Value  _  -> error "chainLookup: shouldn't happen #2"

---------------------------------------------------------------------------------

chainSingleton :: (Functor f, Foldable f, OrdF f) => [Mu f] -> a -> Chain f a
chainSingleton trees x = go trees where
  go [] = Value x
  go (t:ts) = Chain (trieSingleton t (go ts))

trieSingleton :: (Functor f, Foldable f, OrdF f) => Mu f -> a -> Trie f a 
trieSingleton (Fix t) x = Trie $ Map.singleton (HoleF s) (chainSingleton (Foldable.toList t) x) where
  s = fmap (const Hole) t

---------------------------------------------------------------------------------

mapInsertWith :: Ord k => (a -> v) -> (a -> v -> v) -> k -> a -> Map k v ->  Map k v
mapInsertWith f g k x = x `seq` Map.alter worker k where
  worker Nothing   =          Just $! (f x)
  worker (Just y)  = y `seq` (Just $! (g x y))

trieInsertWith :: (Functor f, Foldable f, OrdF f) => (a -> b) -> (a -> b -> b) -> Mu f -> a -> Trie f b -> Trie f b
trieInsertWith uf ug (Fix t) value (Trie trie) = Trie $ mapInsertWith wf wg (HoleF s) value trie where
  wf z       = chainSingleton (Foldable.toList t) (uf z)
  wg z chain = chainInsertWith uf ug (Foldable.toList t) z chain
  s = fmap (const Hole) t

chainInsertWith :: (Functor f, Foldable f, OrdF f) => (a -> b) -> (a -> b -> b) -> [Mu f] -> a -> Chain f b -> Chain f b
chainInsertWith uf ug trees x chain = go trees chain where
  go ts chn = case ts of
    [] -> case chn of
      Value y -> Value (ug x y)
      Chain _ -> error "chainInsertWith: shouldn't happen #1" 
    (t:ts) -> case chn of
      Chain trie -> Chain $ trieInsertWith wf wg t x trie where
        wf z   = chainSingleton ts (uf z)
        wg z c = chainInsertWith uf ug ts z c
      Value _    -> error "chainInsertWith: shouldn't happen #2" 

---------------------------------------------------------------------------------

trieUpdate :: (Functor f, Foldable f, OrdF f) => (a -> Maybe a) -> Mu f -> Trie f a -> Trie f a
trieUpdate user (Fix t) (Trie trie) = Trie $ Map.update worker (HoleF s) trie where  
  worker chain = chainUpdate user (Foldable.toList t) chain    
  s = fmap (const Hole) t

chainUpdate :: (Functor f, Foldable f, OrdF f) => (a -> Maybe a) -> [Mu f] -> Chain f a -> Maybe (Chain f a)
chainUpdate user = go where
  go trees chain = case trees of
    [] -> case chain of 
      Value x -> case user x of
        Just y  -> Just (Value y)
        Nothing -> Nothing
      Chain _ -> error "chainUpdate: shouldn't happen #1" 
    (t:ts) -> case chain of
      Chain trie -> Just $ Chain $ trieUpdate (go ts) t trie
      Value _    -> error "chainInsertWith: shouldn't happen #2" 

---------------------------------------------------------------------------------

trieToList :: (Traversable f, OrdF f) => Trie f a -> [(Mu f, a)] 
trieToList (Trie trie) = 
  [ (Fix (builder key ts), val)  
  | (HoleF key, chain) <- Map.toList trie
  , (ts, val) <- chainToList chain 
  ]

chainToList :: (Traversable f, OrdF f) => Chain f a -> [([Mu f], a)] 
chainToList = go where
  go chain = case chain of
    Value x    -> [([],x)]
    Chain trie -> 
      [ (t:ts, val)
      | (t ,ch ) <- trieToList trie 
      , (ts,val) <- go ch
      ]

---------------------------------------------------------------------------------

chainIntersectionWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> c) -> Chain f a -> Chain f b -> Chain f c
chainIntersectionWith f (Value x ) (Value y ) = Value (f x y)
chainIntersectionWith f (Chain t1) (Chain t2) = Chain (trieIntersectionWith (chainIntersectionWith f) t1 t2)
chainIntersectionWith _ _ _ = error "chainIntersectionWith: shouldn't happen"

trieIntersectionWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> c) -> Trie f a -> Trie f b -> Trie f c
trieIntersectionWith f (Trie trie1) (Trie trie2) = Trie (Map.intersectionWith worker trie1 trie2) where
  worker chain1 chain2 = chainIntersectionWith f chain1 chain2

---------------------------------------------------------------------------------

chainUnionWith :: (Functor f, Foldable f, OrdF f) => (a -> a -> a) -> Chain f a -> Chain f a -> Chain f a
chainUnionWith f (Value x ) (Value y ) = Value (f x y)
chainUnionWith f (Chain t1) (Chain t2) = Chain (trieUnionWith (chainUnionWith f) t1 t2)
chainUnionWith _ _ _ = error "chainUnionWith: shouldn't happen"

trieUnionWith :: (Functor f, Foldable f, OrdF f) => (a -> a -> a) -> Trie f a -> Trie f a -> Trie f a
trieUnionWith f (Trie trie1) (Trie trie2) = Trie (Map.unionWith worker trie1 trie2) where
  worker chain1 chain2 = chainUnionWith f chain1 chain2

---------------------------------------------------------------------------------

chainDifferenceWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> Maybe a) -> Chain f a -> Chain f b -> Maybe (Chain f a)
chainDifferenceWith f (Value x ) (Value y ) = case f x y of 
  Just z  -> Just (Value z)
  Nothing -> Nothing
chainDifferenceWith f (Chain t1) (Chain t2) = Just $ Chain (trieDifferenceWith (chainDifferenceWith f) t1 t2)
chainDifferenceWith _ _ _ = error "chainDifferenceWith: shouldn't happen"

trieDifferenceWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> Maybe a) -> Trie f a -> Trie f b -> Trie f a
trieDifferenceWith f (Trie trie1) (Trie trie2) = Trie (Map.differenceWith worker trie1 trie2) where
  worker chain1 chain2 = chainDifferenceWith f chain1 chain2

---------------------------------------------------------------------------------
-- Tests

#ifdef WITH_QUICKCHECK

runtests_Trie :: IO ()
runtests_Trie = do
  quickCheck prop_difference
  quickCheck prop_differenceWith
  quickCheck prop_union
  quickCheck prop_intersection

  quickCheck prop_unibag_naive
  quickCheck prop_fromList_naive
  quickCheck prop_bag
  quickCheck prop_bag_b
  quickCheck prop_fromList_toList
  quickCheck prop_multiSetToList_b
  quickCheck prop_insert
  quickCheck prop_delete
  quickCheck prop_update
  quickCheck prop_insert_delete
  quickCheck prop_delete_insert
  quickCheck prop_lookup
  quickCheck prop_lookup_notfound
  quickCheck prop_singleton

--------------------

newtype Multiplicity = Multiplicity { unMultiplicity :: Int } deriving (Eq,Ord,Show)

instance Arbitrary Multiplicity where
  arbitrary = do
    n <- choose (1, 7)
    return (Multiplicity n)

newtype MultiSet = MultiSet { unMultiSet :: [(Multiplicity, FixT Label)] } deriving (Eq,Ord,Show)

instance Arbitrary MultiSet where arbitrary = MultiSet <$> arbitrary

multiSetToList :: MultiSet -> [FixT Label]
multiSetToList (MultiSet mxs) = go mxs where
  go [] = []
  go ((Multiplicity n, x):rest) = replicate n x ++ go rest

multiSetToList_b :: MultiSet -> [FixT Label]
multiSetToList_b (MultiSet mxs) = go mxs [] where
  go [] [] = []
  go [] ys = go ys []
  go ((Multiplicity n, x):rest) ys = if n>0 
    then x : go rest ( (Multiplicity (n-1), x) : ys )
    else go rest ys

newtype FiniteMap = FiniteMap { unFiniteMap :: [(FixT Label,Char)] } deriving (Eq,Ord,Show)

instance Arbitrary FiniteMap where arbitrary = (FiniteMap . nubBy (equating fst)) <$> arbitrary

type TrieT = Trie (TreeF Label) Char

finiteMap :: FiniteMap -> TrieT
finiteMap (FiniteMap fmap) = fromList fmap

--------------------

fromListNaive :: (Traversable f, OrdF f) => [(Mu f, a)] -> Trie f a
fromListNaive ts = Prelude.foldl worker emptyTrie ts where
  worker trie (tree,value) = trieInsertWith id const tree value trie 

universeBagNaive :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f Int
universeBagNaive = bag . universe

mapBag :: Ord a => [a] -> Map a Int
mapBag xs = Data.List.foldl' f Map.empty xs where
  f old x = Map.insertWith (+) x 1 old

--------------------
    
prop_unibag_naive :: FixT Label -> Bool
prop_unibag_naive tree = toList (universeBag tree) == toList (universeBagNaive tree)

prop_fromList_naive :: FiniteMap -> Bool
prop_fromList_naive (FiniteMap list) = toList (fromList list) == toList (fromListNaive list)

prop_bag :: MultiSet -> Bool
prop_bag mset = (sort $ toList $ bag $ multiSetToList mset) == sort (map f $ unMultiSet mset) where
  f (Multiplicity k, x) = (x,k)

prop_bag_b :: MultiSet -> Bool
prop_bag_b mset = (sort $ toList $ bag $ multiSetToList_b mset) == sort (map f $ unMultiSet mset) where
  f (Multiplicity k, x) = (x,k)

prop_fromList_toList :: FiniteMap -> Bool
prop_fromList_toList (FiniteMap list) = sort (toList (fromList list)) == sort list

prop_multiSetToList_b :: MultiSet -> Bool
prop_multiSetToList_b mset = toList (bag (multiSetToList mset)) == toList (bag (multiSetToList_b mset)) 

prop_insert :: FixT Label -> Char -> FiniteMap -> Bool
prop_insert key ch (FiniteMap list) = sort (toList (insert key ch trie)) == sort ((key,ch) : toList trie) where
  trie = fromList list

prop_delete :: Int -> FiniteMap -> Bool
prop_delete i (FiniteMap list) = (n==0) || (toList (delete key trie) == toList trie \\ [(key,value)]) where
  trie = fromList list
  n = length list
  k = mod i n
  (key,value) = list!!k

prop_update :: Char -> Int -> FiniteMap -> Bool
prop_update new i (FiniteMap list) = (n==0) || (toList (update f key trie) == replace (toList trie)) where
  trie = fromList list
  n = length list
  k = mod i n
  (key,value) = list!!k
  replace [] = []
  replace (this@(k,x):rest) = if k==key 
    then case f x of 
      Nothing -> rest
      Just y  -> (k,y) : replace rest
    else this : replace rest    
  f old = if old < 'A' then Nothing else Just new

prop_insert_delete :: FixT Label -> Char -> FiniteMap -> Bool
prop_insert_delete key ch (FiniteMap list) = toList (delete key (insert key ch trie)) == toList trie where
  trie = delete key (fromList list)   -- !

prop_delete_insert :: Int -> FiniteMap -> Bool
prop_delete_insert i (FiniteMap list) = (n==0) || (toList (insert key value (delete key trie)) == toList trie) where
  trie = fromList list
  n = length list
  k = mod i n
  (key,value) = list!!k

prop_lookup :: Int -> FiniteMap -> Bool
prop_lookup i (FiniteMap list) = (n==0) || (Just value == lookup key trie) where
  trie = fromList list
  n = length list
  k = mod i n
  (key,value) = list!!k

prop_lookup_notfound :: FixT Label -> FiniteMap -> Bool
prop_lookup_notfound key (FiniteMap list) = lookup key trie == Nothing where
  trie = delete key (fromList list)   -- !#endif

prop_singleton :: FixT Label -> Char -> Bool
prop_singleton tree ch = toList (singleton tree ch) == [(tree,ch)]

prop_intersection :: MultiSet -> Bool
prop_intersection mset = {- trace ("--"++show n++"--") -} (itrie == imap) where

  list = multiSetToList_b mset
  n = length list
  k = div n 3
  l = div (2*n) 3
  xs = take l list
  ys = drop k list

  itrie = sort $     toList $     intersectionWith (+) (   bag xs) (   bag ys)
  imap  = sort $ Map.toList $ Map.intersectionWith (+) (mapBag xs) (mapBag ys) 

prop_union :: MultiSet -> Bool
prop_union mset = {- trace ("--"++show n++"--") -} (utrie == umap) where

  list = multiSetToList_b mset
  n = length list
  k = div n 3
  l = div (2*n) 3
  xs = take l list
  ys = drop k list

  utrie = sort $     toList $     unionWith (+) (   bag xs) (   bag ys)
  umap  = sort $ Map.toList $ Map.unionWith (+) (mapBag xs) (mapBag ys) 

prop_difference :: MultiSet -> Bool
prop_difference mset = {- trace ("--"++show [length xs , length ys, length dtrie]++"--") -} (dtrie == dmap) where

  list = multiSetToList_b mset
  n = length list
  k = div n 3
  l = div (2*n) 3
  xs = take l list
  ys = drop k list

  dtrie = sort $     toList $     difference (   bag xs) (   bag ys)
  dmap  = sort $ Map.toList $ Map.difference (mapBag xs) (mapBag ys) 

prop_differenceWith :: MultiSet -> Bool
prop_differenceWith mset = {- trace ("--"++show [length xs , length ys, length dtrie]++"--") -} (dtrie == dmap) where

  list = multiSetToList_b mset
  n = length list
  k = div n 3
  l = div (2*n) 3
  xs = take l list
  ys = drop k list

  f x y = if y<=2 then Just (x+1) else Nothing

  dtrie = sort $     toList $     differenceWith f (   bag xs) (   bag ys)
  dmap  = sort $ Map.toList $ Map.differenceWith f (mapBag xs) (mapBag ys) 

#endif

---------------------------------------------------------------------------------