{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
module Data.ListTrie.Base
   ( Trie(..), CMap
   , null, size, size', member, notMember, lookup, lookupWithDefault
   , isSubmapOfBy, isProperSubmapOfBy
   , empty, singleton
   , insert, insert', insertWith, insertWith'
   , delete, adjust, adjust', updateLookup, alter, alter'
   -- This next section maybe should be in Map only
   , union, unions, isSubmapOf, isProperSubmapOf, union', unions', difference, intersection
   , intersection', filter, partition, mapMaybe, mapMaybeWithKey, mapEither
   -- end Map only
   , unionWith, unionWithKey, unionWith', unionWithKey'
   , unionsWith, unionsWithKey, unionsWith', unionsWithKey'
   , differenceWith, differenceWithKey
   , intersectionWith,  intersectionWithKey
   , intersectionWith', intersectionWithKey'
   , filterWithKey, partitionWithKey
   , split, splitLookup
   , mapKeysWith, mapInKeysWith, mapInKeysWith'
   , foldrWithKey,  foldrAscWithKey,  foldrDescWithKey
   , foldlWithKey,  foldlAscWithKey,  foldlDescWithKey
   , foldlWithKey', foldlAscWithKey', foldlDescWithKey'
   , foldTrie
   , toList, toAscList, toDescList
   , fromList, fromListWith, fromListWith', fromListWithKey, fromListWithKey'
   , findMin, findMax, deleteMin, deleteMax, minView, maxView
   , findPredecessor, findSuccessor
   , lookupPrefix, addPrefix, deletePrefix, deleteSuffixes
   , splitPrefix, children, children1
   , showTrieWith
   ) where

import           Control.Applicative (Applicative(..), (<$>))
import           Control.Arrow       ((***), first)
import qualified Data.DList as DL
import           Data.DList          (DList)
import           Data.Either         (partitionEithers)
import           Data.Foldable       (foldr, foldl')
import qualified Data.List as List
import           Data.Maybe          (fromJust)
import qualified Data.Maybe as Maybe
import           Prelude hiding      (lookup, filter, foldr, null)
import qualified Prelude

import qualified Data.ListTrie.Base.Map.Internal as Map
import Data.ListTrie.Base.Classes
   ( Boolable(..)
   , Unwrappable(..)
   , Unionable(..), Differentiable(..), Intersectable(..)
   , Alt(..)
   , fmap', (<$!>)
   )
import Data.ListTrie.Base.Map (Map, OrdMap)
import Data.ListTrie.Util     ((.:), both)

type CMap trie k v = (TMap trie) k (trie k v)

class (Functor (St trie), Unwrappable (St trie), Map (TMap trie) k) => Trie trie k where
   type St trie :: * -> *
   type TMap trie :: * -> * -> *

   mkTrie :: (St trie) a -> CMap trie k a -> trie k a
   tParts :: trie k a -> (St trie a, CMap trie k a)

foldTrie :: (Trie trie k, Map (TMap trie) k)
         => ((St trie) a -> CMap trie k a -> ((St trie) a, CMap trie k a))
         -> trie k a -> trie k a
foldTrie f = (uncurry mkTrie) . (uncurry f) . (fmap (Map.map (foldTrie f))) . tParts

hasValue, noValue :: Boolable b => b -> Bool
hasValue = toBool
noValue  = not . hasValue

tVal :: Trie trie k => trie k a -> St trie a
tVal = fst . tParts

tMap :: Trie trie k => trie k a -> CMap trie k a
tMap = snd . tParts

mapVal :: Trie trie k
       => (forall x y. (x -> y) -> x -> y)
       -> trie k a
       -> ((St trie) a -> (St trie) a)
       -> trie k a
mapVal ($$) tr f = (mkTrie $$ (f . tVal $ tr)) (tMap tr)

mapMap :: (Trie trie k1, Trie trie k2)
       => (forall x y. (x -> y) -> x -> y)
       -> trie k1 a
       -> (CMap trie k1 a -> CMap trie k2 a)
       -> trie k2 a
mapMap ($$) tr f = (mkTrie $$ tVal tr) (f . tMap $ tr)

onVals :: Trie trie k
       => (forall x y. (x -> y) -> x -> y)
       -> ((St trie) a -> (St trie) b -> (St trie) c)
       -> trie k a
       -> trie k b
       -> (St trie) c
onVals ($$) f a b = f $$ tVal a $$ tVal b

onMaps :: Trie trie k
       => (forall x y. (x -> y) -> x -> y)
       -> (CMap trie k a -> CMap trie k b -> CMap trie k c)
       -> trie k a
       -> trie k b
       -> CMap trie k c
onMaps ($$) f a b = f $$ tMap a $$ tMap b

-----------------------

-- * Construction

-- | @O(1)@. The empty map.
empty :: (Alt (St trie) a, Trie trie k) => trie k a
empty = mkTrie altEmpty Map.empty

-- | @O(s)@. The singleton map containing only the given key-value pair.
singleton :: (Alt (St trie) a, Trie trie k) => [k] -> a -> trie k a
singleton xs v = addPrefix xs $ mkTrie (pure v) Map.empty

-- * Modification

-- | @O(min(m,s))@. Inserts the key-value pair into the map. If the key is
-- already a member of the map, the given value replaces the old one.
insert :: (Alt (St trie) a, Trie trie k)
       => [k] -> a -> trie k a -> trie k a
insert = insertWith const

-- | @O(min(m,s))@. Inserts the key-value pair into the map. If the key is
-- already a member of the map, the given value replaces the old one.
insert' :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
        => [k] -> a -> trie k a -> trie k a
insert' = insertWith' const

-- | @O(min(m,s))@. Inserts the key-value pair into the map. If the key is
-- already a member of the map, the old value is replaced by @f givenValue
-- oldValue@ where @f@ is the given function.
insertWith :: (Alt (St trie) a, Trie trie k)
           => (a -> a -> a) -> [k] -> a -> trie k a -> trie k a
insertWith = genericInsertWith ($) (<$>)

-- | @O(min(m,s))@. Like 'insertWith', but the new value is reduced to weak
-- head normal form before being placed into the map, whether it is the given
-- value or a result of the combining function.
insertWith' :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
            => (a -> a -> a) -> [k] -> a -> trie k a -> trie k a
insertWith' = (seq <*>) .: genericInsertWith ($!) (<$!>)

genericInsertWith :: (Alt (St trie) a, Trie trie k)
                  => (forall x y. (x -> y) -> x -> y)
                  -> ((a -> a) -> St trie a -> St trie a)
                  -> (a -> a -> a) -> [k] -> a -> trie k a -> trie k a
genericInsertWith ($$) (<$$>) f = go
 where
   go []     new tr =
      mapVal ($$) tr $ \old -> (f new <$$> old) <|> pure new

   go (x:xs) val tr = mapMap ($$) tr $ \m ->
      Map.insertWith (\_ old -> go xs val old) x (singleton xs val) m

-- | @O(min(m,s))@. Removes the key from the map along with its associated
-- value. If the key is not a member of the map, the map is unchanged.
delete :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
       => [k] -> trie k a -> trie k a
delete = alter (const altEmpty)

-- | @O(min(m,s))@. Adjusts the value at the given key by calling the given
-- function on it. If the key is not a member of the map, the map is unchanged.
adjust :: Trie trie k
       => (a -> a) -> [k] -> trie k a -> trie k a
adjust = genericAdjust ($) fmap

-- | @O(min(m,s))@. Like 'adjust', but the function is applied strictly.
adjust' :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
        => (a -> a) -> [k] -> trie k a -> trie k a
adjust' = genericAdjust ($!) fmap'

genericAdjust :: Trie trie k
              => (forall x y. (x -> y) -> x -> y)
              -> ((a -> a) -> St trie a -> St trie a)
              -> (a -> a) -> [k] -> trie k a -> trie k a
genericAdjust ($$) (<$$>) f = go
 where
   go []     tr = mapVal ($$) tr (f <$$>)
   go (x:xs) tr = mapMap ($$) tr (Map.adjust (go xs) x)

-- | @O(min(m,s))@. Like 'update', but also returns 'Just' the original value,
-- or 'Nothing' if the key is not a member of the map.
updateLookup :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
             => (a -> St trie a) -> [k] -> trie k a -> ((St trie) a, trie k a)
updateLookup f = go
 where
   go [] tr =
      let (v,m) = tParts tr
          v'    = if hasValue v then f (unwrap v) else v
       in (v, mkTrie v' m)

   go (x:xs) orig =
      let m   = tMap orig
       in case Map.lookup x m of
               Nothing -> (altEmpty, orig)
               Just tr ->
                  let (ret, upd) = go xs tr
                   in ( ret
                      , mkTrie (tVal orig) $ if null upd
                                                then Map.delete             x m
                                                else Map.adjust (const upd) x m
                      )

-- | @O(min(m,s))@. The most general modification function, allowing you to
-- modify the value at the given key, whether or not it is a member of the map.
-- In short: the given function is passed 'Just' the value at the key if it is
-- present, or 'Nothing' otherwise; if the function returns 'Just' a value, the
-- new value is inserted into the map, otherwise the old value is removed. More
-- precisely, for @alter f k m@:
--
-- If @k@ is a member of @m@, @f (@'Just'@ oldValue)@ is called. Now:
--
-- - If @f@ returned 'Just'@ newValue@, @oldValue@ is replaced with @newValue@.
--
-- - If @f@ returned 'Nothing', @k@ and @oldValue@ are removed from the map.
--
-- If, instead, @k@ is not a member of @m@, @f @'Nothing' is called, and:
--
-- - If @f@ returned 'Just'@ value@, @value@ is inserted into the map, at @k@.
--
-- - If @f@ returned 'Nothing', the map is unchanged.
--
-- The function is applied lazily only if the given key is a prefix of another
-- key in the map.
alter :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
      => ((St trie) a -> St trie a) -> [k] -> trie k a -> trie k a
alter = genericAlter ($) (flip const)

-- | @O(min(m,s))@. Like 'alter', but the function is always applied strictly.
alter' :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
       => ((St trie) a -> St trie a) -> [k] -> trie k a -> trie k a
alter' = genericAlter ($!) seq

genericAlter :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
             => (forall x y. (x -> y) -> x -> y)
             -> ((St trie) a -> trie k a -> trie k a)
             -> ((St trie) a -> St trie a) -> [k] -> trie k a -> trie k a
genericAlter ($$) seeq f = go
 where
   go []     tr =
      let (v,m) = tParts tr
          v'    = f v
       in v' `seeq` mkTrie v' m

   go (x:xs) tr = mapMap ($$) tr $ \m ->
      Map.alter (\mold -> case mold of
                               Nothing ->
                                  let v = f altEmpty
                                   in if hasValue v
                                         then Just (singleton xs (unwrap v))
                                         else Nothing
                               Just old ->
                                  let new = go xs old
                                   in if null new then Nothing else Just new)
                 x m

-- * Querying

-- | @O(1)@. 'True' iff the map is empty.
null :: (Boolable ((St trie) a), Trie trie k) => trie k a -> Bool
-- Test the strict field last for maximal laziness
null tr = Map.null (tMap tr) && (noValue.tVal $ tr)

-- | @O(n m)@. The number of elements in the map. The value is built up lazily,
-- allowing for delivery of partial results without traversing the whole map.
size :: (Boolable ((St trie) a), Trie trie k, Num n) => trie k a -> n
size  tr = foldr  ((+) . size)  (if hasValue (tVal tr) then 1 else 0) (tMap tr)

-- | @O(n m)@. The number of elements in the map. The value is built strictly:
-- no value is returned until the map has been fully traversed.
size' :: (Boolable ((St trie) a), Trie trie k, Num n) => trie k a -> n
size' tr = foldl' (flip $ (+) . size')
                  (if hasValue (tVal tr) then 1 else 0)
                  (tMap tr)

-- | @O(min(m,s))@. 'True' iff the given key is associated with a value in the
-- map.
member :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
       => [k] -> trie k a -> Bool
member = hasValue .: lookup

-- | @O(min(m,s))@. 'False' iff the given key is associated with a value in the
-- map.
notMember :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
          => [k] -> trie k a -> Bool
notMember = not .: member

-- | @O(min(m,s))@. 'Just' the value in the map associated with the given key,
-- or 'Nothing' if the key is not a member of the map.
lookup :: (Alt (St trie) a, Trie trie k) => [k] -> trie k a -> St trie a
lookup []     tr = tVal tr
lookup (x:xs) tr = maybe altEmpty (lookup xs) (Map.lookup x (tMap tr))

-- | @O(min(m,s))@. Like 'lookup', but returns the given value when the key is
-- not a member of the map.
lookupWithDefault :: (Alt (St trie) a, Trie trie k)
                  => a -> [k] -> trie k a -> a
lookupWithDefault def k tr = unwrap $ lookup k tr <|> pure def

-- | @O(min(n1 m1,n2 m2))@. 'True' iff the first map is a submap of the second,
-- i.e. all keys that are members of the first map are also members of the
-- second map, and their associated values are the same.
--
-- > isSubmapOf = isSubmapOfBy (==)
isSubmapOf :: (Boolable ((St trie) a), Trie trie k, Eq a) => trie k a -> trie k a -> Bool
isSubmapOf = isSubmapOfBy (==)

-- | @O(min(n1 m1,n2 m2))@. Like 'isSubmapOf', but one can specify the equality
-- relation applied to the values.
--
-- 'True' iff all keys that are members of the first map are also members of
-- the second map, and the given function @f@ returns 'True' for all @f
-- firstMapValue secondMapValue@ where @firstMapValue@ and @secondMapValue@ are
-- associated with the same key.
isSubmapOfBy :: (Boolable ((St trie) a), Boolable ((St trie) b), Trie trie k)
             => (a -> b -> Bool)
             -> trie k a
             -> trie k b
             -> Bool
isSubmapOfBy f = go
 where
   go tr1 tr2 =
      let (v1,m1) = tParts tr1
          (v2,m2) = tParts tr2
          hv1     = hasValue v1
          hv2     = hasValue v2
       in and [ not (hv1 && not hv2)
              , (not hv1 && not hv2) || f (unwrap v1) (unwrap v2)
              , Map.isSubmapOfBy go m1 m2
              ]

-- | @O(min(n1 m1,n2 m2))@. 'True' iff the first map is a proper submap of the
-- second, i.e. all keys that are members of the first map are also members of
-- the second map, and their associated values are the same, but the maps are
-- not equal. That is, at least one key was a member of the second map but not
-- the first.
--
-- > isProperSubmapOf = isProperSubmapOfBy (==)
isProperSubmapOf :: (Boolable ((St trie) a), Trie trie k, Eq a)
                 => trie k a -> trie k a -> Bool
isProperSubmapOf = isProperSubmapOfBy (==)

-- | @O(min(n1 m1,n2 m2))@. Like 'isProperSubmapOf', but one can specify the
-- equality relation applied to the values.
--
-- 'True' iff all keys that are members of the first map are also members of
-- the second map, and the given function @f@ returns 'True' for all @f
-- firstMapValue secondMapValue@ where @firstMapValue@ and @secondMapValue@ are
-- associated with the same key, and at least one key in the second map is not
-- a member of the first.
isProperSubmapOfBy :: (Boolable ((St trie) a), Boolable ((St trie) b), Trie trie k)
                   => (a -> b -> Bool)
                   -> trie k a
                   -> trie k b
                   -> Bool
isProperSubmapOfBy f = go False
 where
   go proper tr1 tr2 =
      let (v1,m1) = tParts tr1
          (v2,m2) = tParts tr2
          hv1     = hasValue v1
          hv2     = hasValue v2
          -- This seems suboptimal but I can't think of anything better
          proper' = or [ proper
                       , noValue v1 && hasValue v2
                       , not (Map.null $ Map.difference m2 m1)
                       ]
       in and [ not (hv1 && not hv2)
              , (not hv1 && not hv2) || f (unwrap v1) (unwrap v2)
              , if Map.null m1
                   then proper'
                   else Map.isSubmapOfBy (go proper') m1 m2
              ]


-- * Combination

defaultUnion :: a -> a -> a
defaultUnion = const

-- | @O(min(n1 m1,n2 m2))@. The union of the two maps: the map which contains
-- all keys that are members of either map. This union is left-biased: if a key
-- is a member of both maps, the value from the first map is chosen.
--
-- The worst-case performance occurs when the two maps are identical.
--
-- > union = unionWith const
union :: (Unionable (St trie) a, Trie trie k) => trie k a -> trie k a -> trie k a
union = unionWith defaultUnion

-- | @O(min(n1 m1,n2 m2))@. Like 'union', but the combining function ('const') is
-- applied strictly.
--
-- > union' = unionWith' const
union' :: (Unionable (St trie) a, Trie trie k) => trie k a -> trie k a -> trie k a
union' = unionWith' defaultUnion

-- | @O(min(n1 m1,n2 m2))@. Like 'union', but the given function is used to
-- determine the new value if a key is a member of both given maps. For a
-- function @f@, the new value is @f firstMapValue secondMapValue@.
unionWith :: (Unionable (St trie) a, Trie trie k)
          => (a -> a -> a) -> trie k a -> trie k a -> trie k a
unionWith f = genericUnionWith ($) (unionVals f) (flip const)

-- O(min(n1 m1,n2 m2))
unionWith' :: (Unionable (St trie) a, Trie trie k)
          => (a -> a -> a) -> trie k a -> trie k a -> trie k a
unionWith' f = genericUnionWith ($!) (unionVals' f) seq

genericUnionWith :: Trie trie k
                 => (forall x y. (x -> y) -> x -> y)
                 -> ((St trie) a -> St trie a -> St trie a)
                 -> ((St trie) a -> trie k a -> trie k a)
                 -> trie k a
                 -> trie k a
                 -> trie k a
genericUnionWith ($$) valUnion seeq = go
 where
   go tr1 tr2 =
      let v = onVals ($$) valUnion tr1 tr2
       in v `seeq` (mkTrie v $ onMaps ($$) (Map.unionWith go) tr1 tr2)

-- | @O(min(n1 m1,n2 m2))@. Like 'unionWith', but in addition to the two
-- values, the key is passed to the combining function.
unionWithKey :: (Unionable (St trie) a, Trie trie k) => ([k] -> a -> a -> a)
                                                     -> trie k a
                                                     -> trie k a
                                                     -> trie k a
unionWithKey = genericUnionWithKey ($) unionVals (flip const)

-- | @O(min(n1 m1,n2 m2))@. Like 'unionWithKey', but the combining function is
-- applied strictly.
unionWithKey' :: (Unionable (St trie) a, Trie trie k) => ([k] -> a -> a -> a)
                                                      -> trie k a
                                                      -> trie k a
                                                      -> trie k a
unionWithKey' = genericUnionWithKey ($!) unionVals' seq

genericUnionWithKey :: Trie trie k
                    => (forall x y. (x -> y) -> x -> y)
                    -> ((a -> a -> a) -> St trie a -> St trie a -> St trie a)
                    -> ((St trie) a -> trie k a -> trie k a)
                    -> ([k] -> a -> a -> a)
                    -> trie k a
                    -> trie k a
                    -> trie k a
genericUnionWithKey ($$) valUnion seeq f = go DL.empty
 where
   go k tr1 tr2 =
      let v = onVals ($$) (valUnion (f $ DL.toList k)) tr1 tr2
       in v `seeq` (mkTrie v $
                       onMaps ($$) (Map.unionWithKey $ go . (k `DL.snoc`))
                              tr1 tr2)

-- | @O(sum(n))@. The union of all the maps: the map which contains all keys
-- that are members of any of the maps. If a key is a member of multiple maps,
-- the value that occurs in the earliest of the maps (according to the order of
-- the given list) is chosen.
--
-- The worst-case performance occurs when all the maps are identical.
--
-- > unions = unionsWith const
unions :: (Alt (St trie) a, Unionable (St trie) a, Trie trie k) => [trie k a] -> trie k a
unions = unionsWith defaultUnion

-- | @O(sum(n))@. Like 'unions', but the combining function ('const') is
-- applied strictly.
--
-- > unions' = unionsWith' const
unions' :: (Alt (St trie) a, Unionable (St trie) a, Trie trie k) => [trie k a] -> trie k a
unions' = unionsWith' defaultUnion

-- | @O(sum(n))@. Like 'unions', but the given function determines the final
-- value if a key is a member of more than one map. The function is applied as
-- a left fold over the values in the given list's order. For example:
--
-- > unionsWith (-) [fromList [("a",1)],fromList [("a",2)],fromList [("a",3)]]
-- >    == fromList [("a",(1-2)-3)]
-- >    == fromList [("a",-4)]
unionsWith :: (Alt (St trie) a, Unionable (St trie) a, Trie trie k)
           => (a -> a -> a) -> [trie k a] -> trie k a
unionsWith f = foldl' (unionWith f) empty

-- | @O(sum(n))@. Like 'unionsWith', but the combining function is applied
-- strictly.
unionsWith' :: (Alt (St trie) a, Unionable (St trie) a, Trie trie k)
            => (a -> a -> a) -> [trie k a] -> trie k a
unionsWith' f = foldl' (unionWith' f) empty

-- | @O(sum(n))@. Like 'unionsWith', but in addition to the two values under
-- consideration, the key is passed to the combining function.
unionsWithKey :: (Alt (St trie) a, Unionable (St trie) a, Trie trie k)
              => ([k] -> a -> a -> a) -> [trie k a] -> trie k a
unionsWithKey j = foldl' (unionWithKey j) empty

-- | @O(sum(n))@. Like 'unionsWithKey', but the combining function is applied
-- strictly.
unionsWithKey' :: (Alt (St trie) a, Unionable (St trie) a, Trie trie k)
               => ([k] -> a -> a -> a) -> [trie k a] -> trie k a
unionsWithKey' j = foldl' (unionWithKey' j) empty

-- | @O(min(n1 m1,n2 m2))@. The difference of the two maps: the map which
-- contains all keys that are members of the first map and not of the second.
--
-- The worst-case performance occurs when the two maps are identical.
--
-- > difference = differenceWith (\_ _ -> Nothing)
difference :: (Boolable ((St trie) a), Differentiable (St trie) a b, Trie trie k)
           => trie k a -> trie k b -> trie k a
difference = differenceWith (\_ _ -> Nothing)

-- | @O(min(n1 m1,n2 m2))@. Like 'difference', but the given function
-- determines what to do when a key is a member of both maps. If the function
-- returns 'Nothing', the key is removed; if it returns 'Just' a new value,
-- that value replaces the old one in the first map.
differenceWith :: (Boolable ((St trie) a), Differentiable (St trie) a b, Trie trie k)
               => (a -> b -> Maybe a)
               -> trie k a
               -> trie k b
               -> trie k a
differenceWith f = go
 where
   go tr1 tr2 =
      let v = onVals ($!) (differenceVals f) tr1 tr2

          -- This would be lazy only in the case where the differing keys were at
          -- []. (And even then most operations on the trie would force the
          -- value.) For consistency with other keys and Patricia, just seq it for
          -- that case as well.
       in v `seq` mkTrie v $ onMaps ($!) (Map.differenceWith g) tr1 tr2

   g t1 t2 = let t' = go t1 t2
              in if null t' then Nothing else Just t'

-- | @O(min(n1 m1,n2 m2))@. Like 'differenceWith', but in addition to the two
-- values, the key they are associated with is passed to the combining
-- function.
differenceWithKey :: (Boolable ((St trie) a), Differentiable (St trie) a b, Trie trie k)
                  => ([k] -> a -> b -> Maybe a)
                  -> trie k a
                  -> trie k b
                  -> trie k a
differenceWithKey f = go DL.empty
 where
   go k tr1 tr2 =
      let v = onVals ($!) (differenceVals (f $ DL.toList k)) tr1 tr2

          -- see comment in differenceWith for seq explanation
       in v `seq` mkTrie v $
                     onMaps ($!) (Map.differenceWithKey (g k)) tr1 tr2

   g k x t1 t2 = let t' = go (k `DL.snoc` x) t1 t2
                  in if null t' then Nothing else Just t'

-- | @O(min(n1 m1,n2 m2))@. The intersection of the two maps: the map which
-- contains all keys that are members of both maps.
--
-- The worst-case performance occurs when the two maps are identical.
--
-- > intersection = intersectionWith const
intersection :: (Boolable ((St trie) a), Intersectable (St trie) a b a, Trie trie k)
             => trie k a -> trie k b -> trie k a
intersection = intersectionWith const

-- | @O(min(n1 m1,n2 m2))@. Like 'intersection', but the combining function is
-- applied strictly.
--
-- > intersection' = intersectionWith' const
intersection' :: (Boolable ((St trie) a), Intersectable (St trie) a b a, Trie trie k)
              => trie k a -> trie k b -> trie k a
intersection' = intersectionWith' const

-- | @O(min(n1 m1,n2 m2))@. Like 'intersection', but the given function
-- determines the new values.
intersectionWith :: (Boolable ((St trie) c), Intersectable (St trie) a b c, Trie trie k)
                 => (a -> b -> c)
                 -> trie k a
                 -> trie k b
                 -> trie k c
intersectionWith f = genericIntersectionWith ($) (intersectionVals f) (flip const)

-- | @O(min(n1 m1,n2 m2))@. Like 'intersectionWith', but the combining function
-- is applied strictly.
intersectionWith' :: ( Boolable ((St trie) c), Intersectable (St trie) a b c
                     , Trie trie k
                     )
                  => (a -> b -> c)
                  -> trie k a
                  -> trie k b
                  -> trie k c
intersectionWith' f = genericIntersectionWith ($!) (intersectionVals' f) seq

genericIntersectionWith :: (Boolable ((St trie) c), Trie trie k)
                        => (forall x y. (x -> y) -> x -> y)
                        -> ((St trie) a -> St trie b -> St trie c)
                        -> (St trie c -> trie k c -> trie k c)
                        -> trie k a
                        -> trie k b
                        -> trie k c
genericIntersectionWith ($$) valIntersection seeq = go
 where
   go tr1 tr2 =
      tr seeq
         (onVals ($$) valIntersection tr1 tr2)
         (onMaps ($$) (Map.filter (not.null) .: Map.intersectionWith go)
                 tr1 tr2)

   tr seeq' v m =
      v `seeq'` (mkTrie v $
                    case Map.singletonView m of
                         Just (_, child) | null child -> tMap child
                         _                            -> m)

-- | @O(min(n1 m1,n2 m2))@. Like 'intersectionWith', but in addition to the two
-- values, the key they are associated with is passed to the combining
-- function.
intersectionWithKey :: (Boolable ((St trie) c), Intersectable (St trie) a b c, Trie trie k)
                    => ([k] -> a -> b -> c)
                    -> trie k a
                    -> trie k b
                    -> trie k c
intersectionWithKey =
   genericIntersectionWithKey ($) intersectionVals (flip const)

-- | @O(min(n1 m1,n2 m2))@. Like 'intersectionWithKey', but the combining
-- function is applied strictly.
intersectionWithKey' :: ( Boolable ((St trie) c), Intersectable (St trie) a b c
                        , Trie trie k
                        )
                     => ([k] -> a -> b -> c)
                     -> trie k a
                     -> trie k b
                     -> trie k c
intersectionWithKey' = genericIntersectionWithKey ($!) intersectionVals' seq

genericIntersectionWithKey :: (Boolable ((St trie) c), Trie trie k)
                           => (forall x y. (x -> y) -> x -> y)
                           -> ((a -> b -> c) -> St trie a -> St trie b -> St trie c)
                           -> (St trie c -> trie k c -> trie k c)
                           -> ([k] -> a -> b -> c)
                           -> trie k a
                           -> trie k b
                           -> trie k c
genericIntersectionWithKey ($$) valIntersection seeq f = go DL.empty
 where
   go k tr1 tr2 =
      tr
         (onVals ($$) (valIntersection (f $ DL.toList k)) tr1 tr2)
         (onMaps ($$) (Map.filter (not.null) .:
                          Map.intersectionWithKey (go . (k `DL.snoc`)))
                 tr1 tr2)

   tr v m =
      v `seeq` (mkTrie v $
                   case Map.singletonView m of
                        Just (_, child) | null child -> tMap child
                        _                            -> m)

-- * Filtering

-- | @O(n m)@. Apply the given function to the elements in the map, discarding
-- those for which the function returns 'False'.
filter :: (Alt (St trie) a, Boolable (St trie a), Trie trie k)
       => (a -> Bool) -> trie k a -> trie k a
filter = filterWithKey . const

-- | @O(n m)@. Like 'filter', but the key associated with the element is also
-- passed to the given predicate.
filterWithKey :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
              => ([k] -> a -> Bool) -> trie k a -> trie k a
filterWithKey p = fromList . Prelude.filter (uncurry p) . toList


-- | @O(n m)@. A pair of maps: the first element contains those values for
-- which the given predicate returns 'True', and the second contains those for
-- which it was 'False'.
partition :: (Alt (St trie) a, Boolable (St trie a), Trie trie k)
          => (a -> Bool)
          -> trie k a
          -> (trie k a, trie k a)
partition = partitionWithKey . const

-- | @O(n m)@. Like 'partition', but the key associated with the element is
-- also passed to the given predicate.
partitionWithKey :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
                 => ([k] -> a -> Bool)
                 -> trie k a
                 -> (trie k a, trie k a)
partitionWithKey p = both fromList . List.partition (uncurry p) . toList

-- | @O(n m)@. Apply the given function to the elements in the map, preserving
-- only the 'Just' results.
mapMaybe :: (Alt (St trie) b, Boolable ((St trie) a), Trie trie k)
         => (a -> Maybe b) -> trie k a -> trie k b
mapMaybe = mapMaybeWithKey . const

-- | @O(n m)@. Like 'mapMaybe', but the key associated with the element is also
-- passed to the given function.
mapMaybeWithKey :: (Alt (St trie) b, Boolable ((St trie) a), Trie trie k)
                => ([k] -> a -> Maybe b) -> trie k a -> trie k b
mapMaybeWithKey f =
   fromList . Maybe.mapMaybe (\(k,v) -> fmap ((,) k) (f k v)) . toList

-- | @O(n m)@. Apply the given function to the elements in the map, separating
-- the 'Left' results from the 'Right'. The first element of the pair contains
-- the former results, and the second the latter.
mapEither :: (Alt(St trie) b, Alt (St trie) c, Boolable ((St trie) a), Trie trie k)
          => (a -> Either b c)
          -> trie k a
          -> (trie k b, trie k c)
mapEither = mapEitherWithKey . const

-- | @O(n m)@. Like 'mapEither', but the key associated with the element is
-- also passed to the given function.
mapEitherWithKey :: (Alt(St trie) b, Alt (St trie) c, Boolable ((St trie) a), Trie trie k)
                 => ([k] -> a -> Either b c)
                 -> trie k a
                 -> (trie k b, trie k c)
mapEitherWithKey f =
   (fromList *** fromList) . partitionEithers .
   Prelude.map (\(k,v) -> either (Left . (,) k) (Right . (,) k) (f k v)) .
   toList

-- * Mapping

{-
-- | @O(n m)@. Apply the given function to all the keys in a map.
--
-- > mapKeys = mapKeysWith const
mapKeys :: (Trie trie k1, Trie trie k2)
        => ([k1] -> [k2]) -> trie k1 a -> trie k2 a
mapKeys = mapKeysWith . fromListWith const
-}

-- | @O(n m)@. Like 'mapKeys', but use the first given function to combine
-- elements if the second function gives two keys the same value.
mapKeysWith :: (Boolable ((St trie) a), Trie trie k1, Trie trie k2)
            => ([([k2],a)] -> trie k2 a)
            -> ([k1] -> [k2])
            -> trie k1 a
            -> trie k2 a
mapKeysWith fromlist f = fromlist . map (first f) . toList

-- | @O(n m)@. Like 'mapInKeys', but use the first given function to combine
-- elements if the second function gives two keys the same value.
mapInKeysWith :: (Unionable (St trie) a, Trie trie k1, Trie trie k2)
              => (a -> a -> a)
              -> (k1 -> k2)
              -> trie k1 a
              -> trie k2 a
mapInKeysWith = genericMapInKeysWith ($) unionWith

-- | @O(n m)@. Like 'mapInKeysWith', but apply the combining function strictly.
mapInKeysWith' :: (Unionable (St trie) a, Trie trie k1, Trie trie k2)
               => (a -> a -> a)
               -> (k1 -> k2)
               -> trie k1 a
               -> trie k2 a
mapInKeysWith' = genericMapInKeysWith ($!) unionWith'

genericMapInKeysWith :: ( Unionable (St trie) a
                        , Trie trie k1, Trie trie k2
                        )
                     => (forall x y. (x -> y) -> x -> y)
                     -> (f -> trie k2 a -> trie k2 a -> trie k2 a)
                     -> f
                     -> (k1 -> k2)
                     -> trie k1 a
                     -> trie k2 a
genericMapInKeysWith ($$) unionW j f = go
 where
   go tr = mapMap ($$) tr $
              Map.fromListKVWith (unionW j) . map (f *** go) . Map.toListKV

-- * Folding

-- | @O(n m)@. Equivalent to a list @foldr@ on the 'toList' representation,
-- folding over both the keys and the elements.
foldrWithKey :: (Boolable ((St trie) a), Trie trie k)
             => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldrWithKey f x = foldr (uncurry f) x . toList

-- | @O(n m)@. Equivalent to a list @foldr@ on the 'toAscList' representation,
-- folding over both the keys and the elements.
foldrAscWithKey :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
                => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldrAscWithKey f x = foldr (uncurry f) x . toAscList

-- | @O(n m)@. Equivalent to a list @foldr@ on the 'toDescList' representation,
-- folding over both the keys and the elements.
foldrDescWithKey :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
                 => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldrDescWithKey f x = foldr (uncurry f) x . toDescList

-- | @O(n m)@. Equivalent to a list @foldl@ on the toList representation,
-- folding over both the keys and the elements.
foldlWithKey :: (Boolable ((St trie) a), Trie trie k)
             => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldlWithKey f x = foldl (flip $ uncurry f) x . toList

-- | @O(n m)@. Equivalent to a list @foldl@ on the toAscList representation,
-- folding over both the keys and the elements.
foldlAscWithKey :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
                => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldlAscWithKey f x = foldl (flip $ uncurry f) x . toAscList


-- | @O(n m)@. Equivalent to a list @foldl@ on the toDescList representation,
-- folding over both the keys and the elements.
foldlDescWithKey :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
                 => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldlDescWithKey f x = foldl (flip $ uncurry f) x . toDescList

-- | @O(n m)@. Equivalent to a list @foldl'@ on the 'toList' representation,
-- folding over both the keys and the elements.
foldlWithKey' :: (Boolable ((St trie) a), Trie trie k)
              => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldlWithKey' f x = foldl' (flip $ uncurry f) x . toList

-- | @O(n m)@. Equivalent to a list @foldl'@ on the 'toAscList' representation,
-- folding over both the keys and the elements.
foldlAscWithKey' :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
                 => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldlAscWithKey' f x = foldl' (flip $ uncurry f) x . toAscList

-- | @O(n m)@. Equivalent to a list @foldl'@ on the 'toDescList'
-- representation, folding over both the keys and the elements.
foldlDescWithKey' :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
                  => ([k] -> a -> b -> b) -> b -> trie k a -> b
foldlDescWithKey' f x = foldl' (flip $ uncurry f) x . toDescList

-- * Conversion between lists

-- | @O(n m)@. Converts the map to a list of the key-value pairs contained
-- within, in undefined order.
toList :: (Boolable ((St trie) a), Trie trie k) => trie k a -> [([k],a)]
toList = genericToList Map.toListKV DL.cons

-- | @O(n m)@. Converts the map to a list of the key-value pairs contained
-- within, in ascending order.
toAscList :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
          => trie k a -> [([k],a)]
toAscList = genericToList Map.toAscList DL.cons

-- | @O(n m)@. Converts the map to a list of the key-value pairs contained
-- within, in descending order.
toDescList :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
           => trie k a -> [([k],a)]
toDescList = genericToList (reverse . Map.toAscList) (flip DL.snoc)

genericToList :: (Boolable ((St trie) a), Trie trie k)
              => (CMap trie k a -> [(k, trie k a)])
              -> (([k],a) -> DList ([k],a) -> DList ([k],a))
              -> trie k a
              -> [([k],a)]
genericToList tolist add = DL.toList . go DL.empty
 where
   go xs tr =
      let (v,m) = tParts tr
          xs'   =
             DL.concat .
             map (\(x,t) -> go (xs `DL.snoc` x) t) .
             tolist $ m
       in if hasValue v
             then add (DL.toList xs, unwrap v) xs'
             else                              xs'

-- | @O(n m)@. Creates a map from a list of key-value pairs. If a key occurs
-- more than once, the value from the last pair (according to the list's order)
-- is the one which ends up in the map.
--
-- > fromList = fromListWith const
fromList :: (Alt (St trie) a, Trie trie k) => [([k],a)] -> trie k a
fromList = fromListWith const

-- | @O(n m)@. Like 'fromList', but the given function is used to determine the
-- final value if a key occurs more than once. The function is applied as
-- though it were flipped and then applied as a left fold over the values in
-- the given list's order. Or, equivalently (except as far as performance is
-- concerned), as though the function were applied as a right fold over the
-- values in the reverse of the given list's order. For example:
--
-- > fromListWith (-) [("a",1),("a",2),("a",3),("a",4)]
-- >    == fromList [("a",4-(3-(2-1)))]
-- >    == fromList [("a",2)]
fromListWith :: (Alt (St trie) a, Trie trie k)
             => (a -> a -> a) -> [([k],a)] -> trie k a
fromListWith f = foldl' (flip . uncurry $ insertWith f) empty

-- | @O(n m)@. Like 'fromListWith', but the combining function is applied
-- strictly.
fromListWith' :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
              => (a -> a -> a) -> [([k],a)] -> trie k a
fromListWith' f = foldl' (flip . uncurry $ insertWith' f) empty

-- | @O(n m)@. Like 'fromListWith', but the key, in addition to the values to
-- be combined, is passed to the combining function.
fromListWithKey :: (Alt (St trie) a, Trie trie k)
                => ([k] -> a -> a -> a) -> [([k],a)] -> trie k a
fromListWithKey f = foldl' (\tr (k,v) -> insertWith (f k) k v tr) empty

-- | @O(n m)@. Like 'fromListWithKey', but the combining function is applied
-- strictly.
fromListWithKey' :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
                 => ([k] -> a -> a -> a) -> [([k],a)] -> trie k a
fromListWithKey' f = foldl' (\tr (k,v) -> insertWith' (f k) k v tr) empty

-- * Ordering ops

-- | @O(m)@. Removes and returns the minimal key in the map, along with the
-- value associated with it. If the map is empty, 'Nothing' and the original
-- map are returned.
minView :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
        => trie k a -> (Maybe ([k], a), trie k a)
minView = minMaxView (hasValue . tVal) (fst . Map.minViewWithKey)

-- | @O(m)@. Removes and returns the maximal key in the map, along with the
-- value associated with it. If the map is empty, 'Nothing' and the original
-- map are returned.
maxView :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
        => trie k a -> (Maybe ([k], a), trie k a)
maxView = minMaxView (Map.null . tMap) (fst . Map.maxViewWithKey)

minMaxView :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
           => (trie k a -> Bool)
           -> (CMap trie k a -> Maybe (k, trie k a))
           -> trie k a
           -> (Maybe ([k], a), trie k a)
minMaxView _        _       tr_ | null tr_ = (Nothing, tr_)
minMaxView isWanted mapView tr_ = first Just (go tr_)
 where
   go tr =
      let (v,m) = tParts tr
       in if isWanted tr
             then (([], unwrap v), mkTrie altEmpty m)
             else let (k,      tr')  = fromJust (mapView m)
                      (minMax, tr'') = go tr'
                   in ( first (k:) minMax
                      , mkTrie v $ if null tr''
                                      then Map.delete              k m
                                      else Map.adjust (const tr'') k m
                      )

-- | @O(m)@. Like 'fst' composed with 'minView'. 'Just' the minimal key in the
-- map and its associated value, or 'Nothing' if the map is empty.
findMin :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
        => trie k a -> Maybe ([k], a)
findMin = findMinMax (hasValue . tVal) (fst . Map.minViewWithKey)

-- | @O(m)@. Like 'fst' composed with 'maxView'. 'Just' the minimal key in the
-- map and its associated value, or 'Nothing' if the map is empty.
findMax :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
        => trie k a -> Maybe ([k], a)
findMax = findMinMax (Map.null . tMap) (fst . Map.maxViewWithKey)

findMinMax :: (Boolable ((St trie) a), Trie trie k)
           => (trie k a -> Bool)
           -> (CMap trie k a -> Maybe (k, trie k a))
           -> trie k a
           -> Maybe ([k], a)
findMinMax _        _       tr_ | null tr_ = Nothing
findMinMax isWanted mapView tr_ = Just (go DL.empty tr_)
 where
   go xs tr =
      if isWanted tr
         then (DL.toList xs, unwrap (tVal tr))
         else let (k, tr') = fromJust . mapView . tMap $ tr
               in go (xs `DL.snoc` k) tr'

-- | @O(m)@. Like 'snd' composed with 'minView'. The map without its minimal
-- key, or the unchanged original map if it was empty.
deleteMin :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
          => trie k a -> trie k a
deleteMin = snd . minView

-- | @O(m)@. Like 'snd' composed with 'maxView'. The map without its maximal
-- key, or the unchanged original map if it was empty.
deleteMax :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
          => trie k a -> trie k a
deleteMax = snd . maxView

-- | @O(min(m,s))@. Splits the map in two about the given key. The first
-- element of the resulting pair is a map containing the keys lesser than the
-- given key; the second contains those keys that are greater.
split :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
      => [k] -> trie k a -> (trie k a, trie k a)
split xs tr = let (l,_,g) = splitLookup xs tr in (l,g)

-- | @O(min(m,s))@. Like 'split', but also returns the value associated with
-- the given key, if any.
splitLookup :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
            => [k]
            -> trie k a
            -> (trie k a, St trie a, trie k a)
splitLookup []     tr = (empty, tVal tr, mkTrie altEmpty (tMap tr))
splitLookup (x:xs) tr =
   let (v,m) = tParts tr
       (ml, subTr, mg) = Map.splitLookup x m
    in case subTr of
            Nothing  -> (mkTrie v ml, altEmpty, mkTrie altEmpty mg)
            Just tr' ->
               let (tl, v', tg) = splitLookup xs tr'
                   ml' = if null tl then ml else Map.insert x tl ml
                   mg' = if null tg then mg else Map.insert x tg mg
                in (mkTrie v ml', v', mkTrie altEmpty mg')

-- | @O(m)@. 'Just' the key of the map which precedes the given key in order,
-- along with its associated value, or 'Nothing' if the map is empty.
findPredecessor :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
                => [k] -> trie k a -> Maybe ([k], a)
findPredecessor _   tr | null tr = Nothing
findPredecessor xs_ tr_          = go xs_ tr_
 where
   go [] _ = Nothing

   -- We need to try the trie at x and then the trie at the predecessor of x:
   -- e.g. if looking for "foo", we need to try any 'f' branch to see if it has
   -- "fob" first, before grabbing the next-best option of the maximum of the
   -- 'b' branch, say "bar".
   --
   -- If there's no branch less than 'f' we try the current position as a last
   -- resort.
   go (x:xs) tr =
      let (v,m) = tParts tr
          predecessor = Map.findPredecessor x m
       in fmap (first (x:)) (Map.lookup x m >>= go xs)
          <|>
          case predecessor of
               Nothing         ->
                  if hasValue v
                     then Just ([], unwrap v)
                     else Nothing
               Just (best,btr) -> fmap (first (best:)) (findMax btr)

-- | @O(m)@. 'Just' the key of the map which succeeds the given key in order,
-- along with its associated value, or 'Nothing' if the map is empty.
findSuccessor :: forall trie k a .
                 (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
              => [k] -> trie k a -> Maybe ([k], a)
findSuccessor _   tr | null tr = Nothing
findSuccessor xs_ tr_          = go xs_ tr_
 where
   go :: (Boolable ((St trie) a), Trie trie k, OrdMap (TMap trie) k)
      => [k] -> trie k a -> Maybe ([k], a)
   go [] tr = do (k,t) <- fst . Map.minViewWithKey . tMap $ tr
                 fmap (first (k:)) (findMin t)

   go (x:xs) tr =
      let m = tMap tr
          successor = Map.findSuccessor x m
       in fmap (first (x:)) (Map.lookup x m >>= go xs)
          <|>
          (successor >>= \(best,btr) -> fmap (first (best:)) (findMin btr))

-- * Trie-only operations

-- | @O(s)@. The map which contains all keys of which the given key is a
-- prefix. For example:
--
-- > lookupPrefix "ab" (fromList [("a",1),("ab",2),("ac",3),("abc",4)])
-- >    == fromList [("ab",2),("abc",4)]
lookupPrefix :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
             => [k] -> trie k a -> trie k a
lookupPrefix []     tr = tr
lookupPrefix (x:xs) tr =
   case Map.lookup x (tMap tr) of
        Nothing  -> empty
        Just tr' -> let tr'' = lookupPrefix xs tr'
                     in if null tr''
                           then tr''
                           else mkTrie altEmpty (Map.singleton x tr'')

-- | @O(s)@. Prepends the given key to all the keys of the map. For example:
--
-- > addPrefix "xa" (fromList [("a",1),("b",2)])
-- >    == fromList [("xaa",1),("xab",2)]
addPrefix :: (Alt (St trie) a, Trie trie k)
          => [k] -> trie k a -> trie k a
addPrefix []     = id
addPrefix (x:xs) = mkTrie altEmpty . Map.singleton x . addPrefix xs

-- | @O(s)@. The map which contains all keys of which the given key is a
-- prefix, with the prefix removed from each key. If the given key is not a
-- prefix of any key in the map, an empty map is returned. For example:
--
-- > deletePrefix "a" (fromList [("a",1),("ab",2),("ac",3)])
-- >    == fromList [("",1),("b",2),("c",3)]
--
-- This function can be used, for instance, to reduce potentially expensive I/O
-- operations: if you need to find the value in a map associated with a string,
-- but you only have a prefix of it and retrieving the rest is an expensive
-- operation, calling 'deletePrefix' with what you have might allow you to
-- avoid the operation: if the resulting map is empty, the entire string cannot
-- be a member of the map.
deletePrefix :: (Alt (St trie) a, Trie trie k)
             => [k] -> trie k a -> trie k a
deletePrefix []     tr = tr
deletePrefix (x:xs) tr =
   case Map.lookup x (tMap tr) of
        Nothing  -> empty
        Just tr' -> deletePrefix xs tr'

-- | @O(s)@. Deletes all keys which are suffixes of the given key. For example:
--
-- > deleteSuffixes "ab" (fromList $ zip ["a","ab","ac","b","abc"] [1..])
-- >    == fromList [("a",1),("ac",3),("b",4)]
deleteSuffixes :: (Alt (St trie) a, Boolable ((St trie) a), Trie trie k)
               => [k] -> trie k a -> trie k a
deleteSuffixes []     _  = empty
deleteSuffixes (x:xs) tr =
   let (v,m) = tParts tr
    in case Map.lookup x m of
            Nothing  -> tr
            Just tr' -> let tr'' = deleteSuffixes xs tr'
                         in if null tr''
                               then mkTrie v (Map.delete x      m)
                               else mkTrie v (Map.insert x tr'' m)

-- | @O(m)@. A triple containing the longest common prefix of all keys in the
-- map, the value associated with that prefix, if any, and the map with that
-- prefix removed from all the keys as well as the map itself. Examples:
--
-- > splitPrefix (fromList [("a",1),("b",2)])
-- >    == ("", Nothing, fromList [("a",1),("b",2)])
-- > splitPrefix (fromList [("a",1),("ab",2),("ac",3)])
-- >    == ("a", Just 1, fromList [("b",2),("c",3)])
splitPrefix :: forall trie k a .
               (Alt (St trie) a, Trie trie k)
            => trie k a -> ([k], St trie a, trie k a)
splitPrefix = go DL.empty
 where
   go :: (Alt (St trie) a, Trie trie k)
      => DL.DList k -> trie k a -> ([k], St trie a, trie k a)
   go xs tr =
      case Map.singletonView (tMap tr) of
           Just (x,tr') -> go (xs `DL.snoc` x) tr'
           Nothing      -> let (v,m) = tParts tr
                            in (DL.toList xs, v, mkTrie altEmpty m)

-- | @O(m)@. The children of the longest common prefix in the trie as maps,
-- associated with their distinguishing key value. If the map contains less
-- than two keys, this function will return an empty map. Examples;
--
-- > children (fromList [("a",1),("abc",2),("abcd",3)])
-- >    == Map.fromList [('b',fromList [("c",2),("cd",3)])]
-- > children (fromList [("b",1),("c",2)])
-- >    == Map.fromList [('b',fromList [("",1)]),('c',fromList [("",2)])]
children :: (Boolable ((St trie) a), Trie trie k)
         => trie k a -> CMap trie k a
children tr = let (v,m) = tParts tr
               in if hasValue v
                     then m
                     else case Map.singletonView m of
                               Just (_, tr') -> children tr'
                               Nothing       -> m

-- | @O(1)@. The children of the first element of the longest common prefix in
-- the trie as maps, associated with their distinguishing key value. If the map
-- contains less than two keys, this function will return an empty map.
--
-- If the longest common prefix of all keys in the trie is the empty list, this
-- function is equivalent to 'children'.
--
-- Examples:
--
-- > children1 (fromList [("abc",1),("abcd",2)])
-- >    == Map.fromList [('a',fromList [("bc",1),("bcd",2)])]
-- > children1 (fromList [("b",1),("c",2)])
-- >    == Map.fromList [('b',fromList [("",1)]),('c',fromList [("",2)])]
children1 :: (Alt (St trie) a, Trie trie k)
          => trie k a -> CMap trie k a
children1 = tMap

-- * Visualization

-- | @O(n m)@. Like 'showTrie', but uses the given function to display the
-- elements of the map. Still undefined.
showTrieWith :: (Show k, Trie trie k)
             => ((St trie) a -> ShowS) -> trie k a -> ShowS
showTrieWith = go 0
 where
   go indent f tr =
      let (v,m) = tParts tr
          sv    = f v
          lv    = length (sv [])
       in sv . showChar ' '
        . (foldr (.) id . zipWith (flip ($)) (False : repeat True) $
              map (\(k,t) -> \b -> let sk = shows k
                                       lk = length (sk [])
                                       i  = indent + lv + 1
                                    in (if b
                                           then showChar '\n'
                                              . showString (replicate i ' ')
                                           else id)
                                     . showString "-> "
                                     . sk . showChar ' '
                                     . go (i + lk + 4) f t)
                  (Map.toListKV m))
