{-# LANGUAGE IncoherentInstances, PatternGuards, MultiParamTypeClasses, UndecidableInstances #-}

module TrieMap.RadixTrie where

import Control.Applicative

import Data.Maybe
import Data.Monoid
import Data.Foldable
import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as Seq
import Data.Traversable

import TrieMap.Algebraic
import TrieMap.Applicative
import TrieMap.MapTypes
import TrieMap.TrieAlgebraic

import Prelude hiding (foldr)

instance Sized (Edge k m a) where
	getSize (Edge s _ _ _) = s

instance TrieKeyT [] RadixTrie where
	compareKeyT (a:as) (b:bs) = compareKey a b `mappend` compareKeyT as bs
	compareKeyT [] (_:_) = LT
	compareKeyT (_:_) [] = GT
	compareKeyT [] [] = EQ
	emptyT = Radix Nothing
	nullT (Radix m) = isNothing m
	sizeT (Radix m) = getSize m
	getSingleT (Radix m) = m >>= getSingleEdge
	guardNullT (Radix m) = m >>= guardNullEdge >>= return . Radix . Just
	alterLookupT f ks (Radix Nothing) = (Radix . single ks) <$> f Nothing
	alterLookupT f ks (Radix (Just e)) = Radix <$> alterLookupEdge f ks e
	lookupT ks (Radix m) = m >>= lookupEdge ks
	foldWithKeyT f z (Radix m) = foldr (foldEdge f) z m
	mapAppT f (Radix m) = Radix <$> traverse (mapAppEdge f) m
	mapMaybeT f (Radix m) = Radix (m >>= mapMaybeEdge f)
	mapEitherT f (Radix m) = radBoth (maybe (Nothing, Nothing) (mapEitherEdge f) m)
		where	 radBoth (e1, e2) = (Radix e1, Radix e2)
	fromDistAscListT = fromAscListT (\ _ x _ -> x)
	fromAscListT _ [] = Radix Nothing
	fromAscListT f (x:xs) = Radix (Just (groupAscHeads' f x xs))
	fromListT f xs = Radix (groupHeads f xs)
	splitLookupT _ _ (Radix Nothing) = (emptyT, Nothing, emptyT)
	splitLookupT f k (Radix (Just e)) = case splitLookupEdge f k e of
		(eL, ans, eR)	-> (Radix eL, ans, Radix eR)
	isSubmapT (<=) (Radix m1) (Radix m2) = isSubmapAlg (isSubEdge (<=)) m1 m2
	getMinT (Radix m) = fmap (Radix <$>) (m >>= getMinEdge)
	getMaxT (Radix m) = fmap (Radix <$>) (m >>= getMaxEdge)
	updateMinT _ (Radix Nothing) = (False, Radix Nothing)
	updateMinT f (Radix (Just e)) = Radix <$> updateMinEdge f e
	updateMaxT _ (Radix Nothing) = (False, Radix Nothing)
	updateMaxT f (Radix (Just e)) = Radix <$> updateMaxEdge f e
	unionT f (Radix m1) (Radix m2) = Radix (unionMaybe (unionEdge f) m1 m2)
	intersectT f (Radix m1) (Radix m2) = Radix (intersectMaybe (intersectEdge f) m1 m2)
	differenceT f (Radix m1) (Radix m2) = Radix (differenceMaybe (differenceEdge f) m1 m2)

instance TrieKey k m => TrieKey [k] (RadixTrie k m) where
	compareKey = compareKeyT
	emptyAlg = emptyT
	nullAlg = nullT
	getSingleAlg = getSingleT
	guardNullAlg = guardNullT
	sizeAlg = sizeT
	lookupAlg = lookupT
	alterLookupAlg = alterLookupT
	mapAppAlg = mapAppT
	mapMaybeAlg = mapMaybeT
	mapEitherAlg = mapEitherT
	foldWithKeyAlg = foldWithKeyT
	unionMaybeAlg = unionT
	intersectAlg = intersectT
	differenceAlg = differenceT
	getMinAlg = getMinT
	getMaxAlg = getMaxT
	updateMinAlg = updateMinT
	updateMaxAlg = updateMaxT
	isSubmapAlg = isSubmapT
	splitLookupAlg = splitLookupT

single :: (Sized a, TrieKey k m) => [k] -> Maybe a -> MEdge k m a
single ks = fmap (\ v -> Edge (getSize v) ks (Just v) emptyAlg)

edge :: (Sized a, TrieKey k m) => [k] -> Maybe a -> m (Edge k m a) -> Edge k m a
edge ks v ts = Edge (getSize v + getSize ts) ks v ts

getSingleEdge :: TrieKey k m => Edge k m a -> Maybe ([k], a)
getSingleEdge (Edge _ ks (Just v) ts)
	| nullAlg ts	= Just (ks, v)
getSingleEdge (Edge _ ks Nothing ts) = do
	(l, e') <- getSingleAlg ts
	(ls, v) <- getSingleEdge e'
	return (ks ++ l:ls, v)
getSingleEdge _ = Nothing

guardNullEdge :: TrieKey k m => Edge k m a -> MEdge k m a
guardNullEdge (Edge s ks Nothing ts)
	| nullAlg ts	= Nothing
	| Just (l, Edge _ ls v ts') <- getSingleAlg ts
			= Just (Edge s (ks ++ l:ls) v ts')
guardNullEdge e = Just e

alterLookupEdge :: (Eq k, TrieKey k m, Sized a) => (Maybe a -> (b, Maybe a)) -> [k] -> Edge k m a -> (b, MEdge k m a)
alterLookupEdge f ks0 e@(Edge s ls0 v0 ts) = procEdge 0 ks0 ls0 where
	procEdge i _ _ | i `seq` False = undefined
	procEdge i (k:ks) (l:ls)
		| k == l	= procEdge (i+1) ks ls
		| otherwise	= breakEdge <$> f Nothing where
			breakEdge Nothing	= Just e
			breakEdge (Just v)	= let sV = getSize v in
				Just (Edge (sV + s) (take i ls0) Nothing 
					(fromListAlg (\ _ v _ -> v) [(k, Edge sV ks (Just v) emptyAlg), (l, Edge s ls v0 ts)]))
	procEdge _ [] (l:ls) = splitEdge <$> f Nothing where
		splitEdge Nothing = Just e
		splitEdge (Just v) = let sV = getSize v in
			Just (Edge (sV + s) ks0 (Just v) (singletonAlg l (Edge s ls v0 ts)))
	procEdge _(k:ks) [] = (guardNullEdge . edge ls0 v0) <$> alterLookupAlg g k ts where
		g Nothing = fmap (\ v -> Edge (getSize v) ks (Just v) emptyAlg) <$> f Nothing 
		g (Just e) = alterLookupEdge f ks e
	procEdge _ [] [] = fmap (\ v -> guardNullEdge $ edge ls0 v ts) (f v0)

lookupEdge :: (Eq k, TrieKey k m) => [k] -> Edge k m a -> Maybe a
lookupEdge ks (Edge _ ls v ts) = procEdge ks ls where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
	procEdge (k:ks) [] = lookupAlg k ts >>= lookupEdge ks
	procEdge [] [] = v
	procEdge _ _ = Nothing

foldEdge :: TrieKey k m => ([k] -> a -> b -> b) -> Edge k m a -> b -> b
foldEdge f (Edge _ ks v ts) z = foldr (f ks) (foldWithKeyAlg (\ l -> foldEdge (\ ls -> f (ks ++ l:ls))) z ts) v

mapAppEdge :: (TrieKey k m, Applicative f, Sized b) => ([k] -> a -> f b) -> Edge k m a -> f (Edge k m b)
mapAppEdge f (Edge _ ks v ts) = edge ks <$> traverse (f ks) v <*> mapAppAlg (\ l -> mapAppEdge (\ ls -> f (ks ++ l:ls))) ts

mapMaybeEdge :: (TrieKey k m, Sized b) => ([k] -> a -> Maybe b) -> Edge k m a -> MEdge k m b
mapMaybeEdge f (Edge _ ks v ts) = 
	guardNullEdge (edge ks (v >>= f ks) (mapMaybeAlg (\ l -> mapMaybeEdge (\ ls -> f (ks ++ l:ls))) ts))

mapEitherEdge :: (TrieKey k m, Sized b, Sized c) => ([k] -> a -> (Maybe b, Maybe c)) -> Edge k m a -> 
	(MEdge k m b, MEdge k m c)
mapEitherEdge f (Edge _ ks v ts) = guardBoth (edge ks vL tsL, edge ks vR tsR)
	where	(vL, vR) = maybe (Nothing, Nothing) (f ks) v
		ts' = mapEitherAlg (\ l -> mapEitherEdge (\ ls -> f (ks ++ l:ls))) ts
		(tsL, tsR) = mapEitherAlg (\ l -> mapEitherEdge (\ ls -> f (ks ++ l:ls))) ts
		guardBoth (e1, e2) = (guardNullEdge e1, guardNullEdge e2)

groupAscHeads' :: (Eq k, TrieKey k m, Sized a) => ([k] -> a -> a -> a) -> ([k], a) -> [([k], a)] -> Edge k m a
groupAscHeads' f (ks, v) [] = Edge (getSize v) ks (Just v) emptyAlg
groupAscHeads' f x xs = group0 Nothing (x:xs) where
	group0 v0 (([], v):xs) = group0 (Just (maybe v (f [] v) v0)) xs
	group0 (Just v0) [] = Edge (getSize v0) [] (Just v0) emptyAlg
	group0 v0 ((k:ks, v):xs) = group1 Seq.empty k (ks, v) Seq.empty xs where
		group1 ts k vk vs ((l:ls, v):xs)
			| k == l	= group1 ts k vk (vs |> (ls, v)) xs
			| otherwise	= group1 (ts |> (k, groupAscHeads' (f . (k:)) vk (toList vs))) l (ls, v) Seq.empty xs
		group1 ts k v vs []
			| Nothing <- v0, Seq.null ts, Edge s xs vX tsX <- groupAscHeads' (f . (k:)) v (toList vs)
				= Edge s (k:xs) vX tsX
			| otherwise
				= edge [] v0 (fromDistAscListAlg (toList ts ++ [(k, groupAscHeads' (f . (k:)) v (toList vs))]))

groupHeads :: (Eq k, TrieKey k m, Sized a) => ([k] -> a -> a -> a) -> [([k], a)] -> MEdge k m a
groupHeads _ [] = Nothing
groupHeads f xs = guardNullEdge $ edge [] v0 (mapMaybeAlg (\ k (Elem xs) -> groupHeads (f . (k:)) xs) $
		fromListAlg (\ _ (Elem x) (Elem y) -> Elem (x ++ y)) [(k, Elem [(ks, v)]) | (k, ks, v) <- ts])
	where	(v0, ts) = let	proc ([], v) (v0, ts) = (Just (maybe v (f [] v) v0), ts)
				proc (k:ks, v) (v0, ts) = (v0, (k, ks, v):ts)
				in foldr proc (Nothing, []) xs

mapEdge :: (Sized b, TrieKey k m) => ([k] -> a -> b) -> Edge k m a -> Edge k m b
mapEdge f (Edge _ ks v ts) = edge ks (fmap (f ks) v) (mapWithKeyAlg (\ l -> mapEdge (\ ls -> f (ks ++ l:ls))) ts)

splitLookupEdge :: (Sized a, TrieKey k m) => (a -> (Maybe a, Maybe b, Maybe a)) -> [k] -> Edge k m a -> 
	(MEdge k m a, Maybe b, MEdge k m a)
splitLookupEdge f ks e@(Edge s ls v ts) = procEdge ks ls where
	procEdge (k:ks) (l:ls) = case compareKey k l of
		LT	-> (Nothing, Nothing, Just e)
		GT	-> (Just e, Nothing, Nothing)
		EQ	-> procEdge ks ls
	procEdge (k:ks) [] = case splitLookupAlg g k ts of
		(tsL, ans, tsR)	-> (guardNullEdge (edge ls v tsL), ans, guardNullEdge (edge ls Nothing tsR))
		where	g = splitLookupEdge f ks 
	procEdge [] (l:ls) = (Nothing, Nothing, Just e)
	procEdge [] [] = case v of
		Nothing	-> (Nothing, Nothing, Just e)
		Just v	-> case f v of
			(vL, ans, vR)	-> (single ls vL, ans, guardNullEdge (edge ls vR ts))

isSubEdge :: (TrieKey k m, Sized a, Sized b) => (a -> b -> Bool) -> Edge k m a -> Edge k m b -> Bool
isSubEdge (<=) (Edge sK ks vK tsK) (Edge _ ls vL tsL) = procEdge ks ls where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
	procEdge (k:ks) []
		| Just e' <- lookupAlg k tsL
			= isSubEdge (<=) (Edge sK ks vK tsK) e'
	procEdge [] [] = isSubmapAlg (<=) vK vL && isSubmapAlg (isSubEdge (<=)) tsK tsL

getMinEdge :: (TrieKey k m, Sized a) => Edge k m a -> Maybe (([k], a), MEdge k m a)
getMinEdge (Edge s ks (Just v) ts) = Just ((ks, v), guardNullEdge (Edge (s - getSize v) ks Nothing ts))
getMinEdge (Edge _ ks Nothing ts) = do
	((l, e'), ts') <- getMinAlg ts
	((ls, v), e'') <- getMinEdge e'
	return ((ks ++ l:ls, v), fmap (edge ks Nothing) (maybe (guardNullAlg ts') 
		(\ e'' -> Just $ snd $ updateMinAlg (\ _ _ -> (False, Just e'')) ts) e''))

getMaxEdge :: (TrieKey k m, Sized a) => Edge k m a -> Maybe (([k], a), MEdge k m a)
getMaxEdge (Edge _ ks v0 ts)
	| nullAlg ts = maybe Nothing (\ v -> Just ((ks, v), Nothing)) v0
	| otherwise	= do
		((l, e'), ts') <- getMaxAlg ts
		((ls, v), e'') <- getMaxEdge e'
		return ((ks ++ l:ls, v), fmap (edge ks Nothing) (maybe (guardNullAlg ts') 
			(\ e'' -> Just $ snd $ updateMaxAlg (\ _ _ -> (False, Just e'')) ts) e''))

updateMinEdge :: (TrieKey k m, Sized a) => ([k] -> a -> (Bool, Maybe a)) -> Edge k m a -> (Bool, MEdge k m a)
updateMinEdge f (Edge _ ks (Just v) ts)
	= fmap (\ v -> guardNullEdge (edge ks v ts)) (f ks v)
updateMinEdge f (Edge _ ks Nothing ts) = fmap (guardNullEdge . edge ks Nothing) (updateMinAlg g ts) where
	g l = updateMinEdge (\ ls -> f (ks ++ l:ls))

updateMaxEdge :: (TrieKey k m, Sized a) => ([k] -> a -> (Bool, Maybe a)) -> Edge k m a -> (Bool, MEdge k m a)
updateMaxEdge f (Edge _ ks (Just v) ts)
	| nullAlg ts = fmap (\ v -> guardNullEdge (edge ks v ts)) (f ks v)
updateMaxEdge f (Edge _ ks v ts) = fmap (guardNullEdge . edge ks v) (updateMinAlg g ts) where
	g l = updateMinEdge (\ ls -> f (ks ++ l:ls))

unionEdge :: (TrieKey k m, Sized a) => ([k] -> a -> a -> Maybe a) -> Edge k m a -> Edge k m a -> MEdge k m a
unionEdge f (Edge sK ks0 vK tsK) (Edge sL ls0 vL tsL) = procEdge 0 ks0 ls0 where
	procEdge i _ _ | i `seq` False = undefined
	procEdge i (k:ks) (l:ls)
		| k == l	= procEdge (i+1) ks ls
		| otherwise	= Just (Edge (sK + sL) (take i ks0) Nothing 
					(insertAlg k (Edge sK ks vK tsK) $ singletonAlg l (Edge sL ls vL tsL)))
	procEdge _ (k:ks) [] = guardNullEdge $ edge ls0 vL $ alterAlg g k tsL where
		g Nothing = Just (Edge sK ks vK tsK)
		g (Just e) = unionEdge (\ ks' -> f (ls0 ++ k:ks')) (Edge sK ks vK tsK) e
	procEdge _ [] (l:ls) = guardNullEdge $ edge ks0 vK $ alterAlg g l tsK where
		g Nothing = Just (Edge sL ls vL tsL)
		g (Just e) = unionEdge (\ ls' -> f (ks0 ++ l:ls')) e (Edge sL ls vL tsL)
	procEdge _ [] [] = guardNullEdge $ edge ks0 (unionMaybe	(f ks0) vK vL) $
		unionMaybeAlg (\ x -> unionEdge (\ xs -> f (ks0 ++ x:xs))) tsK tsL

intersectEdge :: (TrieKey k m, Sized c) => ([k] -> a -> b -> Maybe c) -> Edge k m a -> Edge k m b -> MEdge k m c
intersectEdge f (Edge sK ks0 vK tsK) (Edge sL ls0 vL tsL) = procEdge ks0 ls0 where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
		| otherwise	= Nothing
	procEdge (k:ks) [] = do
		e' <- lookupAlg k tsL
		Edge sX xs vX tsX <- intersectEdge (\ ks' -> f (ls0 ++ k:ks')) (Edge sK ks vK tsK) e'
		return (Edge sX (ls0 ++ k:xs) vX tsX)
	procEdge [] (l:ls) = do
		e' <- lookupAlg l tsK
		Edge sX xs vX tsX <- intersectEdge (\ ls' -> f (ks0 ++ l:ls')) e' (Edge sL ls vL tsL)
		return (Edge sX (ks0 ++ l:xs) vX tsX)
	procEdge [] [] = guardNullEdge $ edge ks0 (intersectMaybe (f ks0) vK vL) 
		(intersectAlg (\ x -> intersectEdge (\ xs -> f (ks0 ++ x:xs))) tsK tsL)

differenceEdge :: (TrieKey k m, Sized a) => ([k] -> a -> b -> Maybe a) -> Edge k m a -> Edge k m b -> MEdge k m a
differenceEdge f e@(Edge sK ks0 vK tsK) (Edge sL ls0 vL tsL) = procEdge ks0 ls0 where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
	procEdge (k:ks) []
		| Just e' <- lookupAlg k tsL
			= do	Edge sX xs vX tsX <- differenceEdge (\ ks' -> f (ls0 ++ k:ks')) (Edge sK ks vK tsK) e'
				return (Edge sX (ls0 ++ k:xs) vX tsX)
	procEdge [] (l:ls) = guardNullEdge $ edge ks0 vK (alterAlg (>>= g) l tsK) where
		g e = differenceEdge (\ ls' -> f (ks0 ++ l:ls')) e (Edge sL ls vL tsL)
	procEdge [] [] = guardNullEdge $ edge ks0 (intersectMaybe (f ks0) vK vL) $ 
		intersectAlg (\ x -> intersectEdge (\ xs -> f (ks0 ++ x:xs))) tsK tsL
	procEdge _ _ = Just e