{-# LANGUAGE
  GADTs
  #-}

module Data.Trie.Pred.Unified.Tail
  ( NUPTrie (..)
  , lookup
  , merge
  , areDisjoint
  ) where

import Prelude hiding (lookup)
import Data.List.NonEmpty hiding (map)
import Data.List.NonEmpty as NE hiding (map)



data NUPTrie t x where
  NUMore :: t
         -> Maybe x
         -> [NUPTrie t x]
         -> NUPTrie t x
  NUPred :: t
         -> (t -> Maybe r)
         -> Maybe (r -> x)
         -> [NUPTrie t (r -> x)]
         -> NUPTrie t x


-- | Overwrites when similar, leaves untouched when not
merge :: (Eq t) => NUPTrie t x -> NUPTrie t x -> NUPTrie t x
merge xx@(NUMore t mx xs) yy@(NUMore p my ys)
  | t == p = NUMore p my $ foldr go [] $ xs ++ ys
  | otherwise = xx
  where
    go :: (Eq t) => NUPTrie t x -> [NUPTrie t x] -> [NUPTrie t x]
    go a [] = [a]
    go a (b:bs) | areDisjoint a b =       a : b : bs
                | otherwise       = (merge a b) : bs
merge xx@(NUPred t q mrx xrs) yy@(NUPred p w mry yrs)
  | t == p = yy
  | otherwise = xx
merge xx@(NUMore t mx xs) yy@(NUPred p w mrx xrs)
  | t == p = yy -- predicate children are incompatible
  | otherwise = xx
merge xx@(NUPred t q mrx xrs) yy@(NUMore p my ys)
  | t == p = yy
  | otherwise = xx


areDisjoint :: (Eq t) => NUPTrie t x -> NUPTrie t x -> Bool
areDisjoint (NUMore t _ _)    (NUMore p _ _)    = t == p
areDisjoint (NUPred t _ _ _)  (NUPred p _ _ _)  = t == p
areDisjoint (NUPred t _ _ _)  (NUMore p _ _)    = t == p
areDisjoint (NUMore t _ _)    (NUPred p _ _ _)  = t == p


lookup :: Eq t => NonEmpty t -> NUPTrie t x -> Maybe x
lookup (t:|ts) (NUMore t' mx xs)
  | t == t' = case ts of
    [] -> mx
    _  -> getFirst $ map (lookup $ NE.fromList ts) xs
  | otherwise = Nothing
lookup (t:|ts) (NUPred _ p mrx xrs) =
  p t >>=
    \r -> case ts of
      [] -> ($ r) <$> mrx
      _  -> ($ r) <$> (getFirst $ map (lookup $ NE.fromList ts) xrs)


getFirst :: [Maybe a] -> Maybe a
getFirst [] = Nothing
getFirst (Nothing:xs) = getFirst xs
getFirst (Just x :xs) = Just x