{-# LANGUAGE
  GADTs
  #-}

module Data.Trie.Pred.Unified.Fast
  ( FUPTrie (..)
  , lookup
  , merge
  , areDisjoint
  ) where

import Prelude hiding (lookup)
import Data.List.NonEmpty hiding (map)
import Data.List.NonEmpty as NE hiding (map)



-- | A fast, unified, predicative trie. Singleton, leaf-ending
-- branches have their own data constructor, predicate labels and literals
-- have a unified type.
data FUPTrie t x where
  FURest :: NonEmpty t
         -> x
         -> FUPTrie t x
  FUMore :: t
         -> Maybe x
         -> NonEmpty (FUPTrie t x)
         -> FUPTrie t x
  FUPred :: t
         -> (t -> Maybe r)
         -> Maybe (r -> x)
         -> [FUPTrie t (r -> x)]
         -> FUPTrie t x


-- | Overwrites when similar, leaves untouched when not
merge :: (Eq t) => FUPTrie t x -> FUPTrie t x -> FUPTrie t x
merge xx@(FURest tss@(t:|ts) x) yy@(FURest pss@(p:|ps) y)
  | tss == pss = yy
  | t == p     = let
                   xx' = FURest (NE.fromList ts) x
                   yy' = FURest (NE.fromList ps) y
                 in
                 FUMore p Nothing $
                   if areDisjoint xx' yy'
                     then NE.fromList [xx', yy']
                     else NE.fromList
                            [merge (FURest (NE.fromList ts) x) (FURest (NE.fromList ps) y)]
  | otherwise = xx
merge xx@(FUMore t mx xs) yy@(FUMore p my ys)
  | t == p = FUMore p my $ NE.fromList $ foldr go [] $ (NE.toList xs) ++ (NE.toList ys)
  | otherwise = xx
  where
    go :: (Eq t) => FUPTrie t x -> [FUPTrie t x] -> [FUPTrie t x]
    go a [] = [a]
    go a (b:bs) | areDisjoint a b =       a : b : bs
                | otherwise       = (merge a b) : bs
merge xx@(FUPred t q mrx xrs) yy@(FUPred p w mry yrs)
  | t == p = yy
  | otherwise = xx
merge xx@(FURest (t:|ts) x) yy@(FUMore p my ys)
  | t == p = case ts of
               [] -> FUMore p (Just x) ys
               _  -> FUMore p my $ fmap (merge $ FURest (NE.fromList ts) x) ys
  | otherwise = xx
merge xx@(FUMore t mx xs) yy@(FURest (p:|ps) y)
  | t == p = case ps of
               [] -> FUMore t (Just y) xs
               _  -> FUMore t mx $ fmap (flip merge $ FURest (NE.fromList ps) y) xs
  | otherwise = yy
merge xx@(FUMore t mx xs) yy@(FUPred p w mrx xrs)
  | t == p = yy -- predicate children are incompatible
  | otherwise = xx
merge xx@(FURest (t:|ts) x) yy@(FUPred p w mry yrs)
  | t == p = yy
  | otherwise = xx
merge xx@(FUPred t q mrx xrs) yy@(FUMore p my ys)
  | t == p = yy
  | otherwise = xx
merge xx@(FUPred t q mrx xrs) yy@(FURest (p:|ps) y)
  | t == p = yy
  | otherwise = xx


areDisjoint :: (Eq t) => FUPTrie t x -> FUPTrie t x -> Bool
areDisjoint (FURest (t:|_) _) (FURest (p:|_) _) = t == p
areDisjoint (FUMore t _ _)    (FUMore p _ _)    = t == p
areDisjoint (FURest (t:|_) _) (FUMore p _ _)    = t == p
areDisjoint (FUMore t _ _)    (FURest (p:|_) _) = t == p
areDisjoint (FUPred t _ _ _)  (FUPred p _ _ _)  = t == p
areDisjoint (FUPred t _ _ _)  (FUMore p _ _)    = t == p
areDisjoint (FUPred t _ _ _)  (FURest (p:|_) _) = t == p
areDisjoint (FUMore t _ _)    (FUPred p _ _ _)  = t == p
areDisjoint (FURest (t:|_) _) (FUPred p _ _ _)  = t == p


lookup :: Eq t => NonEmpty t -> FUPTrie t x -> Maybe x
lookup tss@(t:|ts) (FURest ps x) | tss == ps = Just x
                                 | otherwise = Nothing
lookup     (t:|ts) (FUMore t' mx xs) | t == t' =
  case ts of
    [] -> mx
    _  -> getFirst $ NE.toList $ fmap (lookup $ NE.fromList ts) xs
                                   | otherwise = Nothing
lookup     (t:|ts) (FUPred _ p mrx xrs) =
  p t >>=
    \r -> case ts of
      [] -> ($ r) <$> mrx
      _  -> ($ r) <$> (getFirst $ map (lookup $ NE.fromList ts) xrs)


getFirst :: [Maybe a] -> Maybe a
getFirst [] = Nothing
getFirst (Nothing:xs) = getFirst xs
getFirst (Just x :xs) = Just x