{-# LANGUAGE DeriveFunctor       #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.Trie.Pseudo where

import           Control.Applicative
import           Control.Arrow             (second)
import Control.Monad (replicateM)
import           Data.Foldable             hiding (all)
import           Data.List                 (intercalate)
import           Data.List.NonEmpty        (NonEmpty (..), fromList, toList)
import qualified Data.List.NonEmpty        as NE
import           Data.Maybe                (fromMaybe)
import           Data.Monoid
import qualified Data.Semigroup            as S
import           Prelude                   hiding (foldl, foldr, foldr1, lookup,
                                            map)



-- TODO: difference
-- | Non-Empty Rose Tree with explicit emptyness
data PseudoTrie t a = More (t, Maybe a) (NonEmpty (PseudoTrie t a))
                    | Rest (NonEmpty t) a
                    | Nil
  deriving (Show, Eq, Functor)

-- | Overwriting instance
instance (Eq t) => Monoid (PseudoTrie t a) where
  mempty = Nil
  mappend = merge

-- | Depth first
instance Foldable (PseudoTrie t) where
  foldr _ acc Nil = acc
  foldr f acc (Rest _ x) = f x acc
  foldr f acc (More (t, Nothing) xs) = foldr go acc xs
    where
      go z bcc = foldr f bcc z
  foldr f acc (More (t, Just x) xs) = foldr go (f x acc) xs
    where
      go z bcc = foldr f bcc z

beginsWith :: (Eq t) => PseudoTrie t a -> t -> Bool
beginsWith Nil _ = False
beginsWith (Rest (t:|_) _) p = t == p
beginsWith (More (t,_) _) p  = t == p

-- | Provides a form of deletion by setting a path to @Nothing@, but doesn't
-- cleanup like @prune@
assign :: (Eq t) => NonEmpty t -> Maybe a -> PseudoTrie t a -> PseudoTrie t a
assign ts (Just x) Nil = Rest ts x
assign _  Nothing  Nil = Nil
assign tss@(t:|ts) mx ys@(Rest pss@(p:|ps) y)
  | tss == pss = case mx of
                   (Just x) -> Rest pss x
                   Nothing  -> Nil
  | t == p = case (ts,ps) of
               ([],  p':_) -> More (t,mx) $ Rest (NE.fromList ps) y :| []
               (t':_,  []) -> case mx of
                                Just x  -> More (p,Just y) $ Rest (NE.fromList ts) x :| []
                                Nothing -> ys
               (t':_,p':_) -> if t' == p'
                                then More (t,Nothing) $
                                       assign (NE.fromList ts) mx (Rest (NE.fromList ps) y) :| []
                                else case mx of -- disjoint
                                       Nothing  -> ys
                                       Just x   -> More (t,Nothing) $ NE.fromList $
                                                     [ Rest (NE.fromList ps) y
                                                     , Rest (NE.fromList ts) x
                                                     ]
  | otherwise = ys
assign (t:|ts) mx y@(More (p,my) ys)
  | t == p = case ts of
               [] -> More (p,mx) ys
               _  -> More (p,my) $ fmap (assign (NE.fromList ts) mx) ys
  | otherwise = y

-- | Overwrite the LHS point-wise with the RHS's contents
merge :: (Eq t) => PseudoTrie t a -> PseudoTrie t a -> PseudoTrie t a
merge Nil y = y
merge x Nil = x
merge xx@(Rest tss@(t:|ts) x) (Rest pss@(p:|ps) y)
  | tss == pss = Rest pss y
  | t == p = case (ts,ps) of
               ([],p':ps') -> More (t,Just x) $ Rest (NE.fromList ps) y :| []
               (t':ts',[]) -> More (t,Just y) $ Rest (NE.fromList ts) x :| []
               (_,_)       -> More (t,Nothing) $
                                merge (Rest (NE.fromList ts) x)
                                      (Rest (NE.fromList ps) y) :| []
  | otherwise = xx
merge xx@(More (t,mx) xs) (More (p,my) ys)
  | t == p = More (p,my) $ NE.fromList $
               foldr go [] $ NE.toList xs ++ NE.toList ys
  | otherwise = xx
  where
    go q [] = [q]
    go q (z:zs) | areDisjoint q z = q : z : zs
                | otherwise = merge q z : zs
merge xx@(More (t,mx) xs) (Rest pss@(p:|ps) y)
  | t == p = case ps of
               [] -> More (t,Just y) xs
               _  -> More (t,mx) $
                       fmap (flip merge $ Rest (NE.fromList ps) y) xs
  | otherwise = xx
merge xx@(Rest tss@(t:|ts) x) (More (p,my) ys)
  | t == p = case ts of
               [] -> More (p,Just x) ys
               _  -> More (p,my) $
                       fmap (merge $ Rest (NE.fromList ts) x) ys
  | otherwise = xx


add :: (Eq t) => NonEmpty t -> PseudoTrie t a -> PseudoTrie t a -> PseudoTrie t a
add ts input container =
  let ts' = NE.toList ts in
  merge container $ mkMores ts' input
  where
    mkMores :: (Eq t) => [t] -> PseudoTrie t a -> PseudoTrie t a
    mkMores [] trie = trie
    mkMores (t:ts) trie = More (t,Nothing) $
      mkMores ts trie :| []


toAssocs :: PseudoTrie t a -> [(NonEmpty t, a)]
toAssocs = go [] []
  where
    go :: [t] -> [(NonEmpty t, a)] -> PseudoTrie t a -> [(NonEmpty t, a)]
    go depth acc Nil = acc
    go depth acc (Rest ts x) = (NE.fromList $ depth ++ NE.toList ts, x) : acc
    go depth acc (More (t, Nothing) xs) =
      foldr (flip $ go $ depth ++ [t]) acc $ NE.toList xs
    go depth acc (More (t, Just x) xs) =
      (NE.fromList $ depth ++ [t], x) :
        (foldr $ flip $ go $ depth ++ [t]) acc (NE.toList xs)

fromAssocs :: (Eq t) => [(NonEmpty t, a)] -> PseudoTrie t a
fromAssocs = foldr (uncurry assign) Nil . fmap (second Just)

lookup :: (Eq t) => NonEmpty t -> PseudoTrie t a -> Maybe a
lookup _   Nil = Nothing
lookup tss (Rest pss a)
  | tss == pss = Just a
  | otherwise = Nothing
lookup tss@(t:|ts) (More (p,mx) xs)
  | t == p = case ts of
               [] -> mx
               (t':ts') -> find (hasNextTag t') xs >>= lookup (fromList ts)
  | otherwise = Nothing

  where
    hasNextTag :: (Eq t) => t -> PseudoTrie t a -> Bool
    hasNextTag t Nil = False
    hasNextTag t (More (p,_) _) = t == p
    hasNextTag t (Rest (p:|_) _) = t == p

-- | Simple test on the heads of two tries
areDisjoint :: (Eq t) => PseudoTrie t a -> PseudoTrie t a -> Bool
areDisjoint (More (t,_) _) (More (p,_) _)
  | t == p = False
  | otherwise = True
areDisjoint (Rest (t:|_) _) (Rest (p:|_) _)
  | t == p = False
  | otherwise = True
areDisjoint _ _ = True

-- | The meet of two @PseudoTrie@s
intersectionWith :: (Eq t) =>
                    (a -> b -> c)
                 -> PseudoTrie t a
                 -> PseudoTrie t b
                 -> PseudoTrie t c
intersectionWith _ _ Nil = Nil
intersectionWith _ Nil _ = Nil
intersectionWith f (Rest tss@(t:|ts) x) (Rest pss@(p:|ps) y)
  | tss == pss = Rest pss $ f x y
  | otherwise = Nil
intersectionWith f (More (t,mx) xs) (More (p,my) ys)
  | t == p = case [intersectionWith f x' y' | x' <- NE.toList xs, y' <- NE.toList ys] of
               [] -> case f <$> mx <*> my of
                       Nothing -> Nil
                       Just c  -> Rest (p :| []) c
               zs -> More (p,f <$> mx <*> my) $ NE.fromList zs
  -- implicit root
  | otherwise = Nil
intersectionWith f (More (t,mx) xs) (Rest pss@(p:|ps) y)
  | t == p = case ps of
               [] -> case f <$> mx <*> Just y of
                     Nothing -> Nil
                     Just c  -> Rest (p :| []) c
               _  -> More (p,Nothing) $ fmap (flip (intersectionWith f) $ Rest (fromList ps) y) xs
  | otherwise = Nil
intersectionWith f (Rest tss@(t:|ts) x) (More (p,my) ys)
  | t == p = case ts of
               [] -> case f <$> Just x <*> my of
                     Nothing -> Nil
                     Just c  -> Rest (t :| []) c
               _  -> More (t,Nothing) $ fmap (intersectionWith f $ Rest (fromList ts) x) ys
  | otherwise = Nil

-- difference :: Eq t =>
--               PseudoTrie t a
--            -> PseudoTrie t a
--            -> PseudoTrie t a


-- | Needless intermediary elements are turned into shortcuts, @Nil@'s in
-- subtrees are also removed.
prune :: PseudoTrie t a -> PseudoTrie t a
prune = go
  where
    go Nil = Nil
    go xx@(Rest ts x) = xx
    go (More (t,Nothing) xs) =
      case cleaned xs of
        [Nil]       -> Nil
        [Rest ts x] -> Rest (t:|NE.toList ts) x
        xs'         -> More (t,Nothing) $ NE.fromList xs'
    go (More (t,Just x) xs) =
      case cleaned xs of
        [Nil] -> Rest (t:|[]) x
        xs'   -> More (t,Just x) $ NE.fromList xs'

    cleaned xs = removeNils (NE.toList $ fmap go xs)

    removeNils xs = case removeNils' xs of
                      [] -> [Nil]
                      ys -> ys
      where
        removeNils' []       =     []
        removeNils' (Nil:xs) =     removeNils' xs
        removeNils' (x:xs)   = x : removeNils' xs