{-# LANGUAGE BangPatterns #-}
module Agda.Utils.Trie
  ( Trie(..)
  , empty, singleton, everyPrefix, insert, insertWith, union, unionWith
  , adjust, delete
  , toList, toAscList, toListOrderedBy
  , lookup, member, lookupPath, lookupTrie
  , mapSubTries, filter
  , valueAt
  ) where
import Prelude hiding (null, lookup, filter)
import Data.Function
import Data.Foldable (Foldable)
import qualified Data.Maybe as Lazy
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.List as List
import qualified Agda.Utils.Maybe.Strict as Strict
import Agda.Utils.Null
import Agda.Utils.Lens
data Trie k v = Trie !(Strict.Maybe v) !(Map k (Trie k v))
  deriving ( Show
           , Eq
           , Functor
           , Foldable
           )
instance Null (Trie k v) where
  empty = Trie Strict.Nothing Map.empty
  null (Trie v t) = null v && null t
singletonOrEveryPrefix :: Bool -> [k] -> v -> Trie k v
singletonOrEveryPrefix _           []       !v =
  Trie (Strict.Just v) Map.empty
singletonOrEveryPrefix everyPrefix (x : xs) !v =
  Trie (if everyPrefix then Strict.Just v else Strict.Nothing)
       (Map.singleton x (singletonOrEveryPrefix everyPrefix xs v))
singleton :: [k] -> v -> Trie k v
singleton = singletonOrEveryPrefix False
everyPrefix :: [k] -> v -> Trie k v
everyPrefix = singletonOrEveryPrefix True
union :: (Ord k) => Trie k v -> Trie k v -> Trie k v
union = unionWith const
unionWith :: (Ord k) => (v -> v -> v) -> Trie k v -> Trie k v -> Trie k v
unionWith f (Trie v ss) (Trie w ts) =
  Trie (Strict.unionMaybeWith f v w) (Map.unionWith (unionWith f) ss ts)
insert :: (Ord k) => [k] -> v -> Trie k v -> Trie k v
insert k v t = union (singleton k v) t
insertWith :: (Ord k) => (v -> v -> v) -> [k] -> v -> Trie k v -> Trie k v
insertWith f k v t = unionWith f (singleton k v) t
delete :: Ord k => [k] -> Trie k v -> Trie k v
delete path = adjust path (const Strict.Nothing)
adjust ::
  Ord k =>
  [k] -> (Strict.Maybe v -> Strict.Maybe v) -> Trie k v -> Trie k v
adjust path f t@(Trie v ts) =
  case path of
    
    []                                 -> Trie (f v) ts
    
    k : ks | Just s <- Map.lookup k ts -> Trie v $ Map.insert k (adjust ks f s) ts
    
    _ -> t
toList :: Ord k => Trie k v -> [([k],v)]
toList = toAscList
toAscList :: Ord k => Trie k v -> [([k],v)]
toAscList (Trie mv ts) = Strict.maybeToList (([],) <$> mv) ++
  [ (k:ks, v)
  | (k,  t) <- Map.toAscList ts
  , (ks, v) <- toAscList t
  ]
toListOrderedBy :: Ord k => (v -> v -> Ordering) -> Trie k v -> [([k], v)]
toListOrderedBy cmp (Trie mv ts) =
  Strict.maybeToList (([],) <$> mv) ++
  [ (k : ks, v) | (k, t)  <- List.sortBy (cmp' `on` val . snd) $ Map.toAscList ts,
                  (ks, v) <- toListOrderedBy cmp t ]
  where
    cmp' Strict.Nothing  Strict.Just{}   = LT
    cmp' Strict.Just{}   Strict.Nothing  = GT
    cmp' Strict.Nothing  Strict.Nothing  = EQ
    cmp' (Strict.Just x) (Strict.Just y) = cmp x y
    val (Trie mv _) = mv
mapSubTries :: Ord k => (Trie k u -> Maybe v) -> Trie k u -> Trie k v
mapSubTries f t@(Trie mv ts) = Trie (Strict.toStrict (f t)) (fmap (mapSubTries f) ts)
lookup :: Ord k => [k] -> Trie k v -> Maybe v
lookup []       (Trie v _)  = Strict.toLazy v
lookup (k : ks) (Trie _ ts) = case Map.lookup k ts of
  Nothing -> Nothing
  Just t  -> lookup ks t
member :: Ord k => [k] -> Trie k v -> Bool
member ks t = Lazy.isJust (lookup ks t)
lookupPath :: Ord k => [k] -> Trie k v -> [v]
lookupPath xs (Trie v cs) = case xs of
    []     -> Strict.maybeToList v
    x : xs -> Strict.maybeToList v ++
              maybe [] (lookupPath xs) (Map.lookup x cs)
lookupTrie :: Ord k => [k] -> Trie k v -> Trie k v
lookupTrie []       t           = t
lookupTrie (k : ks) (Trie _ cs) = maybe empty (lookupTrie ks) (Map.lookup k cs)
filter :: Ord k => (v -> Bool) -> Trie k v -> Trie k v
filter p (Trie mv ts) = Trie mv' (Map.filter (not . null) $ filter p <$> ts)
  where
    mv' =
      case mv of
        Strict.Just v | p v -> mv
        _                   -> Strict.Nothing
valueAt :: Ord k => [k] -> Lens' (Maybe v) (Trie k v)
valueAt path f t = f (lookup path t) <&> \ case
  Nothing -> delete path t
  Just v  -> insert path v t