{-# LANGUAGE Rank2Types, PatternGuards, FlexibleContexts, TypeFamilies, UndecidableInstances, MultiParamTypeClasses #-}

module Data.TrieMap.Regular.RadixTrie where

import Data.TrieMap.Regular.Class
import Data.TrieMap.Regular.Base
import Data.TrieMap.Regular.Ord
import Data.TrieMap.Regular.Eq
import Data.TrieMap.Sized
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative

import Control.Arrow
import Control.Applicative
import Control.Monad

import Data.Maybe
import Data.Monoid
import Data.Foldable
import Data.Traversable

import Prelude hiding (foldr, foldl)

data Edge f (m :: * -> (* -> *) -> * -> *) k (a :: * -> *) ix = Edge {-# UNPACK #-} !Int [f k] (Maybe (a ix)) (m k (Edge f m k a) ix)
type Edge' f k a ix = Edge f (TrieMapT f) k a ix
type MEdge f k m a ix = Maybe (Edge f m k a ix)
type MEdge' f k a ix = Maybe (Edge' f k a ix)

-- type instance PF (Edge f m k a ix) = (K0 (L f k) :*: K0 (Maybe (a ix)) :*: L (K0 k :*: I0) :*: K0 Int)
-- type instance (RadixTrie f k a ix) = U0 :+: PF (Edge f m k a ix)

-- instance (TrieKeyT f m, m ~ TrieMapT f, TrieKey k (TrieMap k)) => Regular (Edge f m k a ix) where
-- 	from (Edge n ks v ts) = K0 (List ks) :*: K0 v :*: 

newtype RadixTrie f k a ix = Radix (MEdge' f k a ix)
-- newtype K0 a b = K0 a

type instance TrieMapT (L f) = RadixTrie f
type instance TrieMap (L f r) = RadixTrie f r
-- type instance TrieMap [k] = RadixTrie k (TrieMap k)

edgeSize :: Sized (Edge f m k a)
edgeSize (Edge s _ _ _) = s

edge :: (TrieKeyT f m, m ~ TrieMapT f, TrieKey k (TrieMap k)) => Sized a -> [f k] -> Maybe (a ix) -> m k (Edge f m k a) ix -> Edge f m k a ix
edge s ks v ts = Edge (maybe 0 s v + sizeT edgeSize ts) ks v ts

instance (OrdT f, TrieKeyT f m, m ~ TrieMapT f) => TrieKeyT (L f) (RadixTrie f) where
	emptyT = Radix Nothing
	nullT (Radix m) = isNothing m
	sizeT _ (Radix m) = maybe 0 edgeSize m
	lookupT (List ks) (Radix m) = m >>= lookupE ks
	lookupIxT s (List ks) (Radix m) = m >>= lookupIxE s 0 ks
	assocAtT s i (Radix m) = fromJust (do	(i', ks, v) <- m >>= assocAtE s i
						return (i', List ks, v))
	updateAtT s f i (Radix m) = Radix (m >>= updateAtE s (\ i' -> f i' . List) i)
	alterT s f (List ks) (Radix m) = Radix (maybe (singletonME s ks (f Nothing)) (alterE s f ks) m)
	traverseWithKeyT s f (Radix m) = Radix <$> traverse (traverseE s (f . List)) m
	foldWithKeyT f (Radix m) z = foldr (foldE (f . List)) z m
	foldlWithKeyT f (Radix m) z = foldr (foldlE (f . List)) z m
	mapEitherT s1 s2 f (Radix m) = (Radix *** Radix) (maybe (Nothing, Nothing) (mapEitherE s1 s2 (f . List)) m)
	splitLookupT s f (List ks) (Radix m) = Radix `sides` maybe (Nothing, Nothing, Nothing) (splitLookupE s f ks) m
	unionT s f (Radix m1) (Radix m2) = Radix (unionMaybe (unionE s (f . List)) m1 m2)
	isectT s f (Radix m1) (Radix m2) = Radix (isectMaybe (isectE s (f . List)) m1 m2)
	diffT s f (Radix m1) (Radix m2) = Radix (diffMaybe (diffE s (f . List)) m1 m2)
	extractMinT s (Radix m) = First m >>= liftM (first List *** Radix) . extractMinE s
	extractMaxT s (Radix m) = Last m >>= liftM (first List *** Radix) . extractMaxE s
	alterMinT s f (Radix m) = Radix (m >>= alterMinE s (f . List))
	alterMaxT s f (Radix m) = Radix (m >>= alterMaxE s (f . List))
	isSubmapT (<=) (Radix m1) (Radix m2) = subMaybe (isSubEdge (<=)) m1 m2
	fromListT s f xs = Radix (fromListE s (f . List) [(ks, a) | (List ks, a) <- xs])
	fromAscListT s f xs = Radix (fromAscListE s (f . List) [(ks, a) | (List ks, a) <- xs])

instance (OrdT f, TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => TrieKey (L f k) (RadixTrie f k) where
	emptyM = emptyT
	nullM = nullT
	sizeM = sizeT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	updateAtM = updateAtT
	alterM = alterT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractMinM = extractMinT
	extractMaxM = extractMaxT
	alterMinM = alterMinT
	alterMaxM = alterMaxT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT

-- instance (Ord k, TrieKey k m) => TrieKey [k] (RadixTrie k m) where
-- 	emptyM = Radix Nothing
-- 	nullM (Radix m) = isNothing m
-- 	lookupM ks (Radix m) = m >>= lookupE ks
-- 	alterM f ks (Radix m) = Radix (maybe (singletonME ks (f Nothing)) (alterE f ks) m)
-- 	traverseWithKeyM f (Radix m) = Radix <$> traverse (traverseE f) m
-- 	foldWithKeyM f (Radix m) z = foldr (foldE f) z m
-- 	mapEitherM f (Radix m) = (Radix *** Radix) (maybe (Nothing, Nothing) (mapEitherE f) m)
-- 	splitLookupM f ks (Radix m) = Radix `sides` maybe (Nothing, Nothing, Nothing) (splitLookupE f ks) m
-- 	unionM f (Radix m1) (Radix m2) = Radix (unionMaybe (unionE f) m1 m2)
-- 	isectM f (Radix m1) (Radix m2) = Radix (isectMaybe (isectE f) m1 m2)
-- 	diffM f (Radix m1) (Radix m2) = Radix (diffMaybe (diffE f) m1 m2)
-- 	extractMinM (Radix m) = First m >>= fmap (fmap Radix) . extractMinE
-- 	extractMaxM (Radix m) = Last m >>= fmap (fmap Radix) . extractMaxE
-- 	alterMinM f (Radix m) = Radix (m >>= alterMinE f)
-- 	alterMaxM f (Radix m) = Radix (m >>= alterMaxE f)
-- 	isSubmapM (<=) (Radix m1) (Radix m2) = subMaybe (isSubEdge (<=)) m1 m2
-- 	fromListM = Radix .: fromListE
-- 	fromAscListM = Radix .: fromAscListE

compact :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Edge' f k a ix -> MEdge' f k a ix
compact e@(Edge s ks Nothing ts) = case assocsT ts of
	[]	-> Nothing
	[~(k, e'@(Edge s' ls v ts'))]
		-> e' `seq` compact (Edge s' (ks ++ k:ls) v ts')
	_	-> Just e
compact e = Just e

cons :: f k -> Edge' f k a ix -> Edge' f k a ix
l `cons` Edge s ls v ts = Edge s (l:ls) v ts

cat :: [f k] -> Edge' f k a ix -> Edge' f k a ix
ks `cat` Edge s ls v ts = Edge s (ks ++ ls) v ts

singletonME :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a -> [f k] -> Maybe (a ix) -> MEdge' f k a ix
singletonME s ks = fmap (\ v -> Edge (s v) ks (Just v) emptyT)

lookupE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => [f k] -> Edge' f k a ix -> Maybe (a ix)
lookupE ks (Edge _ ls v ts) = match ks ls where
	match (k:ks) (l:ls)
		| k `eqT` l	= match ks ls
	match (k:ks) [] = do	e' <- lookupT k ts
				lookupE ks e'
	match [] [] = v
	match _ _ = Nothing

alterE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => 
	Sized a -> (Maybe (a ix) -> Maybe (a ix)) -> [f k] -> Edge' f k a ix -> MEdge' f k a ix
alterE s f ks0 e@(Edge sz ls0 v0 ts0) = match 0 ks0 ls0 where
	match i _ _ | i `seq` False = undefined
	match i (k:ks) (l:ls)
		| k `eqT` l	= match (i+1) ks ls
		| Just v <- f Nothing
				= Just (Edge (sz + s v) (take i ls0) Nothing 
					(fromListT edgeSize (const const) [(k, Edge (s v) ks (Just v) emptyT), 
						(l, Edge sz ls v0 ts0)]))
	match _ (k:ks) [] = compact $ edge s ls0 v0 $ alterT edgeSize g k ts0 where
		g = maybe (singletonME s ks (f Nothing)) (alterE s f ks)
	match _ [] (l:ls)
		| Just v <- f Nothing
			= Just (Edge (sz + s v) ks0 (Just v) (singletonT edgeSize l (Edge sz ls v0 ts0)))
	match _ [] [] = compact (edge s ls0 (f v0) ts0)
	match _ _ _ = Just e

traverseE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k), Applicative t) => 
	Sized b -> ([f k] -> a ix -> t (b ix)) -> Edge' f k a ix -> t (Edge' f k b ix)
traverseE s f (Edge _ ks v ts) =
	edge s ks <$> traverse (f ks) v <*> traverseWithKeyT edgeSize (\ l -> traverseE s (\ ls -> f (ks ++ l:ls))) ts

foldE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => ([f k] -> a ix -> b -> b) -> Edge' f k a ix -> b -> b
foldE f (Edge _ ks v ts) z = foldr (f ks) (foldWithKeyT (\ l -> foldE (\ ls -> f (ks ++ l:ls))) ts z) v

foldlE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => ([f k] -> b -> a ix -> b) -> Edge' f k a ix -> b -> b
foldlE f (Edge _ ks v ts) z = foldlWithKeyT (\ l z m -> foldlE (\ ls -> f (ks ++ l:ls)) m z) ts (foldl (f ks) z v)

mapEitherE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized b -> Sized c -> 
	EitherMap (EitherMap [f k] (a ix) (b ix) (c ix)) (Edge' f k a ix) (Edge' f k b ix) (Edge' f k c ix)
mapEitherE s1 s2 f (Edge _ ks v ts) = case (maybe (Nothing, Nothing) (f ks) v, mapEitherT edgeSize edgeSize 
					(\ l -> mapEitherE s1 s2 (\ ls -> f (ks ++ l:ls))) ts) of 
	((vL, vR), (tsL, tsR)) -> (compact (edge s1 ks vL tsL), compact (edge s2 ks vR tsR))

splitLookupE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a -> SplitMap (a ix) x -> [f k] -> SplitMap (Edge' f k a ix) x
splitLookupE s f ks e@(Edge _ ls v ts) = match ks ls where
	match (k:ks) (l:ls) = case compareT k l of
		LT	-> (Nothing, Nothing, Just e)
		EQ	-> match ks ls
		GT	-> (Just e, Nothing, Nothing)
	match [] [] = case v of
		Nothing	-> (Nothing, Nothing, Just e)
		Just v	-> compact `sides` case f v of
			(vL, x, vR) -> (edge s ls vL emptyT, x, edge s ls vR ts)
	match [] (l:ls) = (Just e, Nothing, Nothing)
	match (k:ks) [] = compact `sides` case splitLookupT edgeSize g k ts of
		(tsL, x, tsR)	-> (edge s ls v tsL, x, edge s ls Nothing tsR)
		where	g = splitLookupE s f ks

unionE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a -> UnionFunc (UnionFunc [f k] (a ix)) (Edge' f k a ix)
unionE s f (Edge szK ks0 vK tsK) (Edge szL ls0 vL tsL) = match 0 ks0 ls0 where
	match i _ _ | i `seq` False = undefined
	match i (k:ks) (l:ls)
		| k `eqT` l	= match (i+1) ks ls
		| otherwise	= Just (Edge (szK + szL) (take i ks0) Nothing 
					(fromListT edgeSize (const const) [(k, Edge szK ks vK tsK), (l, Edge szL ls vL tsL)]))
	match _ (k:ks) [] = compact (edge s ls0 vL $ alterT edgeSize g k tsL) where
		g Nothing = Just (Edge szK ks vK tsK)
		g (Just e) = unionE s (\ ks' -> f (ls0 ++ k:ks')) (Edge szK ks vK tsK) e
	match _ [] (l:ls) = compact (edge s ks0 vK $ alterT edgeSize g l tsK) where
		g Nothing = Just (Edge szL ls vL tsL)
		g (Just e) = unionE s (\ ls' -> f (ks0 ++ l:ls')) e (Edge szL ls vL tsL)
	match _ [] [] = compact (edge s ks0 (unionMaybe (f ks0) vK vL) (unionT edgeSize g tsK tsL)) where
		g x = unionE s (\ xs -> f (ks0 ++ x:xs))

extractMinE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a -> Edge' f k a ix -> First (([f k], a ix), MEdge' f k a ix)
extractMinE s (Edge _ ks v ts) = (do
	v <- First v
	return ((ks, v), compact (edge s ks Nothing ts))) `mplus` 
  (do	((x, e'), ts') <- extractMinT edgeSize ts
	((xs, v), e'') <- extractMinE s e'
	return ((ks ++ x:xs, v), compact (edge s ks Nothing (maybe ts' (\ e'' -> alterMinT edgeSize (\ _ _ -> Just e'') ts) e''))))

extractMaxE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a -> Edge' f k a ix -> Last (([f k], a ix), MEdge' f k a ix)
extractMaxE s (Edge _ ks v ts) = (do
	v <- Last v
	return ((ks, v), Nothing)) `mplus`
  (do	((x, e'), ts') <- extractMaxT edgeSize ts
	((xs, v), e'') <- extractMaxE s e'
	return ((ks ++ x:xs, v), compact (edge s ks Nothing (maybe ts' (\ e'' -> alterMaxT edgeSize (\ _ _ -> Just e'') ts) e''))))

alterMinE, alterMaxE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a ->
	([f k] -> a ix -> Maybe (a ix)) -> Edge' f k a ix -> MEdge' f k a ix
alterMinE s f (Edge _ ks (Just v) ts) = compact (edge s ks (f ks v) ts)
alterMinE s f (Edge _ ks Nothing ts) = compact (edge s ks Nothing (alterMinT edgeSize (\ x -> alterMinE s (\ xs -> f (ks ++ x:xs))) ts))

alterMaxE s f (Edge _ ks v ts)
	| nullT ts	= do	v' <- v >>= f ks
				return (Edge (s v') ks (Just v') ts)
	| otherwise	= compact (edge s ks v (alterMaxT edgeSize (\ x -> alterMaxE s (\ xs -> f (ks ++ x:xs))) ts))

isectE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized c ->
	IsectFunc (IsectFunc [f k] (a ix) (b ix) (c ix)) (Edge' f k a ix) (Edge' f k b ix) (Edge' f k c ix)
isectE s f (Edge szK ks vK tsK) (Edge szL ls vL tsL) = match ks ls where
	match (k:ks) (l:ls)
		| k `eqT` l	= match ks ls
	match (k:ks) [] = do	e' <- lookupT k tsL
				liftM (cat ls . cons k) (isectE s (\ ks' -> f (ls ++ k:ks')) (Edge szK ks vK tsK) e')
	match [] (l:ls) = do	e' <- lookupT l tsK
				liftM (cat ks . cons l) (isectE s (\ ls' -> f (ks ++ l:ls')) e' (Edge szL ls vL tsL))
	match [] [] = compact (edge s ks (isectMaybe (f ks) vK vL) (isectT edgeSize g tsK tsL)) where
		g x = isectE s (\ xs -> f (ks ++ x:xs))

diffE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a ->
	DiffFunc (DiffFunc [f k] (a ix) (b ix)) (Edge' f k a ix) (Edge' f k b ix)
diffE s f e@(Edge szK ks vK tsK) (Edge szL ls vL tsL) = match ks ls where
	match (k:ks) (l:ls)
		| k `eqT` l	= match ks ls
	match (k:ks) []
		| Just e' <- lookupT k tsL
			= fmap (cat ls . cons k) (diffE s (\ ks' -> f (ls ++ k:ks')) (Edge szK ks vK tsK) e')
	match [] (l:ls) = compact (edge s ks vK (alterT edgeSize (>>= g) l tsK)) where
		g e' = diffE s (\ ls' -> f (ks ++ l:ls')) e' (Edge szL ls vL tsL)
	match [] [] = compact (edge s ks (diffMaybe (f ks) vK vL) (diffT edgeSize g tsK tsL)) where
		g x = diffE s (\ xs -> f (ks ++ x:xs))

isSubEdge :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => LEq (a ix) (b ix) -> LEq (Edge' f k a ix) (Edge' f k b ix)
isSubEdge (<=) (Edge szK ks vK tsK) (Edge szL ls vL tsL) = match ks ls where
	match (k:ks) (l:ls)
		| k `eqT` l	= match ks ls
	match (k:ks) []
		| Just e' <- lookupT k tsL
			= isSubEdge (<=) (Edge szK ks vK tsK) e'
	match [] []
		= subMaybe (<=) vK vL && isSubmapT (isSubEdge (<=)) tsK tsL
	match _ _ = False

filterer :: (k -> k -> Bool) -> (a -> a -> a) -> [([k], a)] -> (Maybe a, [(k, [([k], a)])])
filterer (==) f = filterer' where
	filterer' (([], a):xs) = first (Just . maybe a (flip f a)) (filterer' xs)
	filterer' ((k:ks, a):xs) = second (cons k ks a) (filterer' xs)
	filterer' [] = (Nothing, [])
	cons k ks a [] = [(k, [(ks, a)])]
	cons k ks a ys0@((k', xs):ys)
		| k == k'	= (k', (ks,a):xs):ys
		| otherwise	= (k, [(ks, a)]):ys0

fromListE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => Sized a -> ([f k] -> a ix -> a ix -> a ix) -> [([f k], a ix)] -> MEdge' f k a ix
fromListE _ _ [] = Nothing
fromListE s f xs = case filterer eqT (f []) xs of
	(Nothing, [(k, xs)]) -> cons k <$> fromListE s (f . (k:)) xs
	(v, xss) -> Just (edge s [] v (mapWithKeyT edgeSize (\ k (K0 xs) -> fromJust (fromListE s (f . (k:)) xs))
				(fromListT (const 1) (\ _ (K0 xs) (K0 ys) -> K0 (ys ++ xs)) [(k, K0 xs) | (k, xs) <- xss])))

fromAscListE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) => 
	Sized a -> ([f k] -> a ix -> a ix -> a ix) -> [([f k], a ix)] -> MEdge' f k a ix
fromAscListE _ _ [] = Nothing
fromAscListE s f xs = case filterer eqT (f []) xs of
	(Nothing, [(k, xs)]) -> cons k <$> fromAscListE s (f . (k:)) xs
	(v, xss) -> Just (edge s [] v (fromDistAscListT edgeSize [(k, fromJust (fromAscListE s (f . (k:)) xs)) | (k, xs) <- xss]))

lookupIxE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) =>
	Sized a -> Int -> [f k] -> Edge' f k a ix -> Maybe (Int, a ix)
lookupIxE _ i _ _ | i `seq` False = undefined
lookupIxE s i ks (Edge _ ls v ts) = match ks ls where
	match (k:ks) (l:ls)
		| k `eqT` l	= match ks ls
	match (k:ks) [] = do
		(iT, e') <- lookupIxT edgeSize k ts
		lookupIxE s (i + maybe 0 s v + iT) ks e'
	match [] [] = (,) i <$> v
	match _ _ = Nothing

assocAtE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) =>
	Sized a -> Int -> Edge' f k a ix -> Maybe (Int, [f k], a ix)
assocAtE s i (Edge _ ks Nothing ts) = case assocAtT edgeSize i ts of
	(iT, l, e') -> do	(i', ls, v) <- assocAtE s (i - iT) e'
				return (iT + i', ks ++ l:ls, v)
assocAtE s i (Edge _ ks (Just v) ts)
	| i < sv	= return (0, ks, v)
	| (iT, l, e') <- assocAtT edgeSize (i - sv) ts
		= do	(i', ls, v') <- assocAtE s ((i - sv) - iT) e'
			return (i' + iT + sv, ks ++ l:ls, v')
	where sv = s v

updateAtE :: (TrieKeyT f (TrieMapT f), TrieKey k (TrieMap k)) =>
	Sized a -> (Int -> [f k] -> a ix -> Maybe (a ix)) -> Int -> Edge' f k a ix -> MEdge' f k a ix
updateAtE s f i (Edge sz ks Nothing ts) = compact (edge s ks Nothing (updateAtT edgeSize g i ts)) where
	g iT l = updateAtE s (\ i' ls -> f (iT + i') (ks ++ l:ls)) (i - iT)
updateAtE s f i (Edge sz ks (Just v) ts)
	| i < sv	= compact (edge s ks (f 0 ks v) ts)
	| otherwise	= compact (edge s ks (Just v) (updateAtT edgeSize g (i - sv) ts))
	where	sv = s v
		g iT l = updateAtE s (\ i' ls -> f (sv + iT + i') (ks ++ l:ls)) (i - sv - iT)