{-# LANGUAGE MultiParamTypeClasses, UndecidableInstances, FlexibleContexts, StandaloneDeriving, PatternGuards #-}

module TrieMap.RadixTrie (RadixTrie) where

import Control.Applicative hiding (Alternative(..))
import Control.Monad
import Data.Foldable
import Data.Traversable
import Data.Monoid
import Data.Maybe
import Data.Ord
import Data.Sequence (Seq, (|>))
import qualified Data.Sequence as Seq

import TrieMap.MapTypes
import TrieMap.TrieAlgebraic
import TrieMap.Applicative

import Prelude hiding (null, foldr, all)

instance (Eq k, Eq v, TrieKey k m) => Eq (Edge k m v) where
	Edge ks1 v1 ts1 == Edge ks2 v2 ts2 = ks1 == ks2 && v1 == v2 && assocsAlg ts1 == assocsAlg ts2

instance (Ord k, Ord v, TrieKey k m) => Ord (Edge k m v) where
	Edge ks1 v1 ts1 `compare` Edge  ks2 v2 ts2 = 
		compare ks1 ks2 `mappend` compare v1 v2 `mappend` comparing assocsAlg ts1 ts2

deriving instance (Eq k, Eq v, TrieKey k m) => Eq (RadixTrie k m v)
deriving instance (Ord k, Ord v, TrieKey k m) => Ord (RadixTrie k m v)
deriving instance (Show k, Show v, Functor m, Show (m String)) => Show (RadixTrie k m v)

instance (Show k, Show v, Functor m, Show (m String)) => Show (Edge k m v) where
	show (Edge k v ts) = "Edge " ++ show k ++ " " ++ show v ++ " " ++ show (fmap show ts)

instance (Ord k, TrieKey k m) => TrieKey [k] (RadixTrie k m) where
	emptyAlg = Radix Nothing
	nullAlg = isNothing . unRad
	getSingleAlg (Radix e) = e >>= getSingleEdge
	guardNullAlg (Radix e) = do	e <- guardNullEdge =<< e
					return (Radix (Just e))
	lookupAlg ks = unRad >=> lookupEdge (==) ks
-- 	sizeAlg (Radix e) = maybe 0 sizeEdge e
	alterLookupAlg f k = fmap Radix .
		maybe (fmap (maybeSingleEdge k) $ f Nothing)
			(alterLookupEdge (==) f k) . unRad
	foldWithKeyAlg f z = foldr (flip (foldWithKeyEdge f)) z . unRad
	mapMaybeAlg f (Radix e) = Radix (e >>= mapMaybeEdge f)
	mapEitherAlg f (Radix Nothing) = (emptyAlg, emptyAlg)
	mapEitherAlg f (Radix (Just e)) = (Radix e1, Radix e2)
		where	(e1, e2) = mapEitherEdge f e
-- 	mapMaybeAlg f (Radix e) = (Radix . join) <$> traverse (mapAppMaybeEdge f) e
	mapAppAlg f = fmap Radix . traverse (mapAppEdge f) . unRad
	unionMaybeAlg f (Radix e1) (Radix e2) = Radix (unionMaybe (unionMaybeEdge f) e1 e2)
	intersectAlg f (Radix e1) (Radix e2) = Radix (intersectMaybe (intersectEdge f) e1 e2)
	differenceAlg f (Radix e1) (Radix e2) = Radix (differenceMaybe (differenceEdge f) e1 e2)

	getMinAlg (Radix e) = fmap (fmap Radix . getMinEdge) e
	getMaxAlg (Radix e) = fmap (fmap Radix . getMaxEdge) e
-- 	updateMinAlg f (Radix e) = Radix $ e >>= updateMinEdge f
-- 	updateMaxAlg f (Radix e) = Radix $ e >>= updateMaxEdge f

	fromListAlg f xs = Radix (edgeFromList f xs)
	fromAscListAlg f xs = Radix (edgeFromAscList f xs)
	fromDistAscListAlg = fromAscListAlg (\ _ v _ -> v)

	isSubmapAlg (<=) (Radix e1) (Radix e2) = isSubmapAlg (isSubmapEdge (<=)) e1 e2

	valid (Radix e) = maybe True validEdge e

	splitLookupAlg _ _ (Radix Nothing) = (Radix Nothing, Nothing, Radix Nothing)
	splitLookupAlg f k (Radix (Just e)) = case splitEdge f k e of
		(eL, ans, eR)	-> (Radix eL, ans, Radix eR)

-- sizeEdge :: Edge k m v -> Int
-- sizeEdge (Edge n _ _ _) = n

-- edge :: TrieKey k m => [k] -> Maybe v -> m (Edge k m v) -> Edge k m v
-- edge ks v ts = Edge (maybe id (const (+1)) v $ foldl' (\ n e -> n + sizeEdge e) 0 ts) ks v ts

lookupEdge :: TrieKey k m => (k -> k -> Bool) -> [k] -> Edge k m v -> Maybe v
lookupEdge (==) ks (Edge ls v ts) = procEdge ks ls where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
	procEdge (k:ks) [] = lookupAlg k ts >>= lookupEdge (==) ks
	procEdge [] [] = v
	procEdge _ _ = Nothing

edgeFromList :: (Eq k, TrieKey k m) => ([k] -> v -> v -> v) -> [([k], v)] -> MEdge k m v
edgeFromList f xs = guardNullEdge $ Edge [] v0 $ mapMaybeAlg (\ k -> edgeFromList (f . (k:))) $ fromListAlg (const (flip (++))) ys
	where	part ([], v) (v0, ys) = (Just $ maybe v (flip (f []) v) v0, ys)
		part (k:ks, v) (v0, ys) = (v0, (k, [(ks, v)]):ys)
		(v0, ys) = foldr part (Nothing, []) xs

edgeFromAscList :: (Eq k, TrieKey k m) => ([k] -> v -> v -> v) -> [([k], v)] -> MEdge k m v
edgeFromAscList _ [] = Nothing
edgeFromAscList f xs = Just $ case groupHead f xs of
	(Nothing, [(k, ~(Edge ks v ts))])
			-> Edge (k:ks) v ts
	(ans, xs')	-> Edge [] ans (fromDistAscListAlg xs')

groupHead :: (Eq k, TrieKey k m) => ([k] -> v -> v -> v) -> [([k], v)] -> (Maybe v, [(k, Edge k m v)])
groupHead f (([], v):xs) = case groupHead f xs of
	(v', ans)	-> (Just $ maybe v (f [] v) v', ans)
groupHead f ((k:ks, v):xs) = (Nothing, groupHead' k (Seq.singleton (ks, v)) xs) where
	groupHead' k0 xs ((k:ks, v):ys)
		| k == k0	= groupHead' k0 (xs |> (ks, v)) ys
		| otherwise	= (k0, fromJust $ edgeFromAscList (f . (k0:)) (toList xs)):groupHead' k (Seq.singleton (ks, v)) ys
	groupHead' k0 xs [] = [(k0, fromJust $ edgeFromAscList (f . (k0:)) (toList xs))]
	groupHead' _ _ _ = error "Violation of ascending invariant!"
groupHead _ [] = (Nothing, [])
 {-guardNullEdge $ Edge [] v0 $ mapMaybeAlg (\ k -> edgeFromAscList (f . (k:))) $ fromAscListAlg (const (flip (++))) ys
	where	part ([], v) (v0, ys) = (Just $ maybe v (flip (f []) v) v0, ys)
		part (k:ks, v) (v0, ys) = (v0, (k, [(ks, v)]):ys)
		(v0, ys) = foldr part (Nothing, []) xs-}

maybeSingleEdge :: TrieKey k m => [k] -> Maybe v -> MEdge k m v
maybeSingleEdge ks = fmap (\ v -> Edge ks (Just v) emptyAlg)

getSingleEdge :: (TrieKey k m) => Edge k m v -> Maybe ([k], v)
getSingleEdge (Edge ks (Just v) ts)
	| nullAlg ts	= Just (ks, v)
getSingleEdge (Edge ks Nothing ts) = do
	(x, e') <- getSingleAlg ts
	(xs, v) <- getSingleEdge e'
	return (ks ++ x:xs, v) 
getSingleEdge _ = Nothing

{-# INLINE guardNullEdge #-}
guardNullEdge :: TrieKey k m => Edge k m v -> MEdge k m v
guardNullEdge (Edge ks Nothing ts)
	| nullAlg ts	= Nothing
	| Just (x, Edge xs v ts') <- getSingleAlg ts
		= Just (Edge (ks ++ x:xs) v ts')
guardNullEdge e = Just e

alterLookupEdge :: (TrieKey k m) => (k -> k -> Bool) ->
	(Maybe v -> (a, Maybe v)) -> [k] -> Edge k m v -> (a, MEdge k m v)
alterLookupEdge (==) f ks0 e@(Edge ls0 v ts) = procEdge 0 ks0 ls0 where
	procEdge i _ _ | i `seq` False = undefined
	procEdge i (k:ks) (l:ls)
		| k == l	= procEdge (i+1) ks ls
		| otherwise	= fmap (Just . g) $ f Nothing
		where	g Nothing = e
			g (Just v') = Edge (take i ks0) Nothing $
						fromListAlg' [(k, Edge ks (Just v') emptyAlg), (l, Edge ls v ts)]
	procEdge i (k:ks) [] = proc (alterLookupAlg g k ts) where
		g Nothing = maybeSingleEdge ks <$> f Nothing
		g (Just e') = alterLookupEdge (==) f ks e'
		proc = fmap (guardNullEdge . Edge ls0 v)
	procEdge i [] (l:ls) = fmap (Just . g) $ f Nothing
		where	g Nothing = e
			g (Just v') = Edge ks0 (Just v') $ insertAlg l (Edge ls v ts) emptyAlg
	procEdge i [] [] = (ans, guardNullEdge (Edge ks0 fv ts))
		where	(ans, fv) = f v

foldWithKeyEdge :: TrieKey k m => ([k] -> v -> x -> x) -> x -> Edge k m v -> x
foldWithKeyEdge f z (Edge ks v ts) =
	foldr (f ks) (foldWithKeyAlg (\ x -> flip (foldWithKeyEdge (\ xs -> f (ks ++ x:xs)))) z ts) v

mapMaybeEdge :: (TrieKey k m) => ([k] -> v -> Maybe w) -> Edge k m v -> MEdge k m w
mapMaybeEdge f (Edge ks v ts) = guardNullEdge $
	Edge ks (join $ traverse (f ks) v) (mapMaybeAlg (\ x -> mapMaybeEdge (\ xs -> f (ks ++ x:xs))) ts)

mapEitherEdge :: TrieKey k m => ([k] -> a -> Either b c) -> Edge k m a -> (MEdge k m b, MEdge k m c)
mapEitherEdge f (Edge ks v ts) =
	(guardNullEdge $ Edge ks vL tsL, guardNullEdge $ Edge ks vR tsR) 
	where	(vL, vR) = case fmap (f ks) v of
			Nothing	-> (Nothing, Nothing)
			Just (Left v)	-> (Just v, Nothing)
			Just (Right v)	-> (Nothing, Just v)
		ts' = mapWithKeyAlg (\ x -> mapEitherEdge (\ xs -> f (ks ++ x:xs))) ts
		tsL = mapMaybeAlg (const fst) ts'
		tsR = mapMaybeAlg (const snd) ts'

mapAppEdge :: (Applicative f, TrieKey k m) => ([k] -> v -> f w) -> Edge k m v -> f (Edge k m w)
mapAppEdge f (Edge ks v ts) = liftA2 (Edge ks) (traverse (f ks) v) (mapAppAlg (\ x -> mapAppEdge (\ xs -> f (ks ++ x:xs))) ts)

unionMaybeEdge :: (Eq k, TrieKey k m) => ([k] -> v -> v -> Maybe v) -> Edge k m v -> Edge k m v -> MEdge k m v
unionMaybeEdge f (Edge ks0 vK tsK) (Edge ls0 vL tsL) = procEdge 0 ks0 ls0 where
	procEdge i _ _ | i `seq` False = undefined
	procEdge i (k:ks) (l:ls)
		| k == l	= procEdge (i+1) ks ls
		| otherwise	= Just $ Edge (take i ks0) Nothing $ fromListAlg' [(k, Edge ks vK tsK), (l, Edge ls vL tsL)]
	procEdge _ [] (l:ls) = guardNullEdge $ Edge ks0 vK $ alterAlg g l tsK
	  where	g Nothing = Just (Edge ls vL tsL)
		g (Just e') = unionMaybeEdge (\ ls' -> f (ks0 ++ l:ls')) e' (Edge ls vL tsL)
	procEdge _ (k:ks) [] = guardNullEdge $ Edge ls0 vL $ alterAlg g k tsL 
	  where	g Nothing = Just $ Edge ks vK tsK
		g (Just e') = unionMaybeEdge (\ ks' -> f (ls0 ++ k:ks')) (Edge ks vK tsK) e'
	procEdge _ [] [] = guardNullEdge $ Edge ks0 (unionMaybe (f ks0) vK vL) $
		unionMaybeAlg (\ x -> unionMaybeEdge (\ xs -> f (ks0 ++ x:xs))) tsK tsL

intersectEdge :: (Eq k, TrieKey k m) => ([k] -> a -> b -> Maybe c) -> Edge k m a -> Edge k m b -> MEdge k m c
intersectEdge f (Edge ks0 vK tsK) (Edge ls0 vL tsL) = procEdge ks0 ls0 where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
		| otherwise	= Nothing
	procEdge (k:ks) [] = do
		e' <- lookupAlg k tsL
		Edge xs vX tsX <- intersectEdge (\ ks' -> f (ls0 ++ k:ks')) (Edge ks vK tsK) e'
		return (Edge (ls0 ++ k:xs) vX tsX)
	procEdge [] (l:ls) = do
		e' <- lookupAlg l tsK
		Edge xs vX tsX <- intersectEdge (\ ls' -> f (ks0 ++ l:ls')) e' (Edge ls vL tsL)
		return (Edge (ks0 ++ l:xs) vX tsX)
	procEdge [] [] = guardNullEdge $ Edge ks0 (intersectMaybe (f ks0) vK vL) $
		intersectAlg (\ x -> intersectEdge (\ xs -> f (ks0 ++ x:xs))) tsK tsL

differenceEdge :: (Eq k, TrieKey k m) => ([k] -> v -> w -> Maybe v) -> Edge k m v -> Edge k m w -> MEdge k m v
differenceEdge f e@(Edge ks0 vK tsK) (Edge ls0 vL tsL) = procEdge ks0 ls0 where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
	procEdge (k:ks) []
		| Just e' <- lookupAlg k tsL
		= do	Edge xs vX tsX <- differenceEdge (\ ks' -> f (ls0 ++ k:ks')) (Edge ks vK tsK) e'
			return (Edge (ls0 ++ k:xs) vX tsX)
	procEdge [] (l:ls) = guardNullEdge $ Edge ks0 vK $ alterAlg g l tsK
	  where	g Nothing = Nothing
	  	g (Just e') = differenceEdge (\ ls' -> f (ks0 ++ l:ls')) e' (Edge ls vL tsL)
	procEdge [] [] = guardNullEdge $ Edge ks0 (differenceMaybe (f ks0) vK vL) $
		differenceAlg (\ x -> differenceEdge (\ xs -> f (ks0 ++ x:xs))) tsK tsL
	procEdge _ _ = Just e

getMinEdge :: TrieKey k m => Edge k m v -> (([k], v), MEdge k m v)
getMinEdge (Edge ks (Just v) ts) = ((ks, v), guardNullEdge $ Edge ks Nothing ts)
getMinEdge (Edge ks _ ts) 
	| Just ((l, e), ts') <- getMinAlg ts, ((ls, v), e') <- getMinEdge e
		= ((ks ++ l:ls, v), guardNullEdge $ Edge ks Nothing $ maybe ts' (\ e' -> snd $ updateMinAlg (\ _ _ -> (False, Just e')) ts) e')
getMinEdge _ = error "Uncompacted edge"

getMaxEdge :: TrieKey k m => Edge k m v -> (([k], v), MEdge k m v)
getMaxEdge (Edge ks v0 ts)
	| Just ((l, e), ts') <- getMaxAlg ts, ((ls, v), e') <- getMaxEdge e
		= ((ks ++ l:ls, v), guardNullEdge $ Edge ks v0 $ maybe ts' (\ e' -> snd $ updateMaxAlg (\ _ _ -> (False, Just e')) ts) e')
getMaxEdge (Edge ks (Just v) ts) = ((ks, v), guardNullEdge $ Edge ks Nothing ts)
getMaxEdge _ = error "Uncompacted edge"

updateMinEdge :: TrieKey k m => ([k] -> v -> (Bool, Maybe v)) -> Edge k m v -> (Bool, MEdge k m v)
updateMinEdge f (Edge ks (Just v) ts) = fmap (\ v' -> guardNullEdge $ Edge ks v' ts) (f ks v)
updateMinEdge f (Edge ks Nothing ts)
	= fmap (guardNullEdge . Edge ks Nothing) $ updateMinAlg (\ l -> updateMinEdge (\ ls -> f (ks ++ l:ls))) ts

updateMaxEdge :: TrieKey k m => ([k] -> v -> (Bool, Maybe v)) -> Edge k m v -> (Bool, MEdge k m v)
updateMaxEdge f (Edge ks (Just v) ts)
	| nullAlg ts	= fmap (\ v' -> guardNullEdge $ Edge ks v' ts) (f ks v)
updateMaxEdge f (Edge ks v ts) = 
	fmap (guardNullEdge . Edge ks v) $ updateMaxAlg (\ l -> updateMaxEdge (\ ls -> f (ks ++ l:ls))) ts

isSubmapEdge :: TrieKey k m => (a -> b -> Bool) -> Edge k m a -> Edge k m b -> Bool
isSubmapEdge (<=) (Edge ks vK tsK) (Edge ls vL tsL) = procEdge ks ls where
	procEdge (k:ks) (l:ls)
		| k == l	= procEdge ks ls
	procEdge (k:ks) []
		| Just e <- lookupAlg k tsL
			= isSubmapEdge (<=) (Edge ks vK tsK) e
	procEdge [] [] 
		| Nothing <- vK	= isSubmapAlg (isSubmapEdge (<=)) tsK tsL
		| Just x <- vK, Just y <- vL, x <= y
				= isSubmapAlg (isSubmapEdge (<=)) tsK tsL
	procEdge _ _ = False
validEdge :: TrieKey k m => Edge k m v -> Bool
validEdge (Edge _ Nothing m)
	| nullAlg m	= False
	| Just{} <- getSingleAlg m
			= False
validEdge (Edge _ _ m)
	= valid m && all validEdge m


splitEdge :: (Ord k, TrieKey k m) => (a -> (Maybe a, Maybe b, Maybe a)) -> [k] -> Edge k m a -> (MEdge k m a, Maybe b, MEdge k m a)
splitEdge f ks0 (Edge ls0 v ts) = procEdge ks0 ls0 where
	procEdge (k:ks) (l:ls) = case compare k l of
		LT	-> (Nothing, Nothing, Just (Edge ls0 v ts))
		EQ	-> procEdge ks ls
		GT	-> (Just (Edge ks0 v ts), Nothing, Nothing)
	procEdge (k:ks) [] = case splitLookupAlg (splitEdge f ks) k ts of
		(tsL, ans, tsR)	-> (guardNullEdge $ Edge ls0 Nothing tsL, ans, guardNullEdge $ Edge ls0 v tsR)
	procEdge [] (l:ls) = (Nothing, Nothing, Just $ Edge ls0 v ts)
	procEdge [] [] 
		| Just v <- v, (vL, ans, vR) <- f v
			= (fmap (\ v' -> Edge ls0 (Just v') emptyAlg) vL, ans, 
				guardNullEdge $ Edge ls0 vR ts)
		| otherwise = (Nothing, Nothing, Just (Edge ls0 v ts))