{-# LANGUAGE TemplateHaskell, FlexibleContexts, TypeFamilies, MultiParamTypeClasses, PatternGuards #-}

module Data.TrieMap.RadixTrie () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized
import Data.TrieMap.Applicative
import Data.TrieMap.CPair
import Data.TrieMap.Regular.Class
-- import Data.TrieMap.Regular.TH

import Control.Applicative
import Control.Arrow
import Control.Monad

import Data.Maybe
import Data.Monoid
import Data.Foldable
import Data.Traversable

import Prelude hiding (lookup, foldr, foldl)

data Edge k m a = Edge {-# UNPACK #-} !Int [k] (Maybe a) (m (Edge k m a))
type Edge' k a = Edge k (TrieMap k) a
type MEdge k m a = Maybe (Edge k m a)
type MEdge' k a = Maybe (Edge' k a)

newtype RadixTrie k a = Radix (MEdge' k a)

type instance TrieMapT [] = RadixTrie
type instance TrieMap [k] = RadixTrie k

edgeSize :: Edge k m a -> Int
edgeSize (Edge sz _ _ _) = sz

instance TrieKey k (TrieMap k) => TrieKey [k] (RadixTrie k) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	alterM = alterT
	alterLookupM = alterLookupT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractM = extractT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT

instance TrieKeyT [] RadixTrie where
	emptyT = Radix Nothing
	nullT (Radix m) = isNothing m
	sizeT _ (Radix m) = maybe 0 edgeSize m
	lookupT ks (Radix m) = m >>= lookup ks
	alterT s f ks (Radix m) = Radix (alter s f ks m)
	alterLookupT s f ks (Radix m) = Radix <$> alterLookupE s f ks m
	traverseWithKeyT s f (Radix m) = Radix <$> traverse (traverseE s f) m
	extractT s f (Radix m) = maybe empty (fmap Radix <.> extractE s f) m
	foldWithKeyT f (Radix m) z = foldr (foldE f) z m
	foldlWithKeyT f (Radix m) z = foldl (foldlE f) z m
	mapEitherT s1 s2 f (Radix m) = (Radix *** Radix) (maybe (Nothing, Nothing) (mapEitherE s1 s2 f) m)
	unionT s f (Radix m1) (Radix m2) = Radix (unionMaybe (unionE s f) m1 m2)
	isectT s f (Radix m1) (Radix m2) = Radix (isectMaybe (isectE s f) m1 m2)
	diffT s f (Radix m1) (Radix m2) = Radix (diffMaybe (diffE s f) m1 m2)
	lookupIxT s ks (Radix m) = maybe (empty, empty, empty) (lookupIxE s 0 ks) m
	isSubmapT (<=) (Radix m1) (Radix m2) = subMaybe (isSubmapE (<=)) m1 m2
	splitLookupT s f ks (Radix m) = Radix `sides` maybe (Nothing, Nothing, Nothing) (splitLookupE s f ks) m
	assocAtT s i (Radix m) = maybe (empty, empty, empty) (assocAtE s 0 i) m
  
cat :: [k] -> Edge' k a -> Edge' k a
ks `cat` Edge sz ls v ts = Edge sz (ks ++ ls) v ts

cons :: k -> Edge' k a -> Edge' k a
k `cons` Edge sz ks v ts = Edge sz (k:ks) v ts

edge :: TrieKey k (TrieMap k) => Sized a -> [k] -> Maybe a -> TrieMap k (Edge' k a) -> Edge' k a
edge s ks v ts = Edge (maybe 0 s v + sizeM edgeSize ts) ks v ts

singleMaybe :: TrieKey k (TrieMap k) => Sized a -> [k] -> Maybe a -> MEdge' k a
singleMaybe s ks v = do	v <- v
			return (edge s ks (Just v) emptyM)

compact :: TrieKey k (TrieMap k) => Edge' k a -> MEdge' k a
compact e@(Edge sz ks Nothing ts) = case assocsM ts of
	[]	-> Nothing
	[(l, e')] -> compact (ks `cat` (l `cons` e'))
	_	-> Just e
compact e = Just e

lookup :: (Eq k, TrieKey k (TrieMap k)) => [k] -> Edge' k a -> Maybe a
lookup ks (Edge _ ls v ts) = match ks ls where
	match (k:ks) (l:ls)
		| k == l = match ks ls
	match (k:ks) [] = lookupM k ts >>= lookup ks
	match [] [] = v
	match _ _ = Nothing

alter :: (TrieKey k (TrieMap k)) => Sized a -> (Maybe a -> Maybe a) -> [k] -> MEdge' k a -> MEdge' k a
alter s f ks0 Nothing = singleMaybe s ks0 (f Nothing)
alter s f ks0 (Just e@(Edge sz ls0 v ts)) = match 0 ks0 ls0 where
	match i _ _ | i `seq` False = undefined
	match i (k:ks) (l:ls) = case compare k l of
	      LT | Just v' <- f Nothing	
		      -> Just $ let sv = s v' in Edge (sv + sz) (take i ls0) Nothing (fromDistAscListM edgeSize
					[(k, Edge sv ks (Just v') emptyM), (l, Edge sz ls v ts)])
	      EQ	-> match (i+1) ks ls
	      GT | Just v' <- f Nothing
		      -> Just $ let sv = s v' in Edge (sv + sz) (take i ls0) Nothing (fromDistAscListM edgeSize
					[(l, Edge sz ls v ts), (k, Edge sv ks (Just v') emptyM)])
	      _	-> Just e
	match _ (k:ks) [] = compact $ edge s ls0 v (alterM edgeSize g k ts) where
		g = alter s f ks
	match _ [] (l:ls)
		| Just v' <- f Nothing
			= Just (Edge (s v' + sz) ks0 (Just v') (singletonM edgeSize l (Edge sz ls v ts)))
	match _ [] []
		= compact (edge s ls0 (f v) ts)
	match _ _ _ = Just e

alterLookupE :: TrieKey k (TrieMap k) => Sized a -> (Maybe a -> CPair z (Maybe a)) -> [k] -> MEdge' k a -> CPair z (MEdge' k a)
alterLookupE s f ks Nothing = singleMaybe s ks <$> f Nothing
alterLookupE s f ks0 (Just e@(Edge sz ls0 v0 ts0)) = match 0 ks0 ls0 where
      match i _ _ | i `seq` False = undefined
      match i (k:ks) (l:ls) = case compare k l of
	      LT	-> fmap (Just . maybe e (\ v' -> let sv = s v' in Edge (sz + sv) (take i ls0) Nothing $
				      fromDistAscListM edgeSize [(k, Edge sv ks (Just v') emptyM), (l, Edge sz ls v0 ts0)]))
			      (f Nothing)
	      GT	-> fmap (Just . maybe e (\ v' -> let sv = s v' in Edge (sz + sv) (take i ls0) Nothing $
				      fromDistAscListM edgeSize [(l, Edge sz ls v0 ts0), (k, Edge sv ks (Just v') emptyM)]))
			      (f Nothing)
	      EQ	-> match (i+1) ks ls
      match _ (k:ks) [] = fmap (compact . edge s ls0 v0) (alterLookupM edgeSize g k ts0) where
	      g = alterLookupE s f ks
      match _ [] (l:ls) = fmap (Just . maybe e (\ v' -> let sv = s v' in Edge (sv + sz) ks0 (Just v') (singletonM edgeSize l (Edge sz ls v0 ts0))))
			      (f Nothing)
      match _ [] [] = fmap (\ v' -> compact $ edge s ls0 v' ts0) (f v0)

traverseE :: (Applicative f, TrieKey k (TrieMap k)) => Sized b -> ([k] -> a -> f b) -> Edge' k a -> f (Edge' k b)
traverseE s f (Edge _ ks v ts)
	= edge s ks <$> traverse (f ks) v <*> traverseWithKeyM edgeSize g ts 
	where	g l = traverseE s (\ ls -> f (ks ++ l:ls))

extractE :: (Alternative f, TrieKey k (TrieMap k)) => Sized a -> ([k] -> a -> f (CPair x (Maybe a))) -> Edge' k a -> f (CPair x (MEdge' k a))
extractE s f (Edge _ ks v ts) = case v of
	Nothing	-> rest
	Just v	-> fmap (\ v' -> compact (edge s ks v' ts)) <$> f ks v <|> rest
	where	rest = fmap (compact . edge s ks v) <$> extractM edgeSize g ts
	     	g l = extractE s (\ ls -> f (ks ++ l:ls))

aboutE :: (Alternative f, TrieKey k (TrieMap k)) => ([k] -> a -> f x) -> Edge' k a -> f x
aboutE f = cpFst <.> extractE (const 0) (\ k a -> fmap (flip cP Nothing) (f k a))

foldE :: TrieKey k (TrieMap k) => ([k] -> a -> b -> b) -> Edge' k a -> b -> b
foldE f (Edge _ ks v ts) z = foldr (f ks) (foldWithKeyM g ts z) v where
	g l = foldE (\ ls -> f (ks ++ l:ls))

foldlE :: TrieKey k (TrieMap k) => ([k] -> b -> a -> b) -> b -> Edge' k a -> b 
foldlE f z (Edge _ ks v ts) = foldlWithKeyM g ts (foldl (f ks) z v) where
	g l = foldlE (\ ls -> f (ks ++ l:ls))

mapEitherE :: TrieKey k (TrieMap k) => Sized b -> Sized c -> ([k] -> a -> (Maybe b, Maybe c)) -> Edge' k a ->
	(MEdge' k b, MEdge' k c)
mapEitherE s1 s2 f (Edge _ ks v ts) = (compact *** compact) (edge s1 ks vL tsL, edge s2 ks vR tsR)
	where	(vL, vR) = maybe (Nothing, Nothing) (f ks) v
	     	(tsL, tsR) = mapEitherM edgeSize edgeSize (\ l -> mapEitherE s1 s2 (\ ls -> f (ks ++ l:ls))) ts

unionE :: TrieKey k (TrieMap k) => Sized a -> ([k] -> a -> a -> Maybe a) -> Edge' k a -> Edge' k a -> MEdge' k a
unionE s f eK@(Edge szK ks0 vK tsK) eL@(Edge szL ls0 vL tsL) = match 0 ks0 ls0 where
	match i _ _ | i `seq` False = undefined
	match i (k:ks) (l:ls) = case compare k l of
	      EQ -> match (i+1) ks ls
	      LT -> Just $ Edge (szK + szL) (take i ks0) Nothing (fromDistAscListM edgeSize 
		      [(k, Edge szK ks vK tsK), (l, Edge szL ls vL tsL)])
	      GT -> Just $ Edge (szK + szL) (take i ks0) Nothing (fromDistAscListM edgeSize 
		      [(l, Edge szL ls vL tsL), (k, Edge szK ks vK tsK)])
	match _ [] (l:ls) = compact (edge s ks0 vK (alterM edgeSize g l tsK)) where
		g (Just eK') = unionE s (\ ls' -> f (ks0 ++ l:ls')) eK' (Edge szL ls vL tsL)
		g Nothing = Just (Edge szL ls vL tsL)
	match _ (k:ks) [] = compact (edge s ls0 vL (alterM edgeSize g k tsL)) where
		g Nothing = Just (Edge szK ks vK tsK)
		g (Just eL') = unionE s (\ ks' -> f (ls0 ++ k:ks')) (Edge szK ks vK tsK) eL'
	match _ [] [] = compact (edge s ls0 (unionMaybe (f ls0) vK vL) (unionM edgeSize g tsK tsL)) where
		g x = unionE s (\ xs -> f (ls0 ++ x:xs))

isectE :: TrieKey k (TrieMap k) => Sized c -> ([k] -> a -> b -> Maybe c) -> Edge' k a -> Edge' k b -> MEdge' k c
isectE s f eK@(Edge szK ks0 vK tsK) eL@(Edge szL ls0 vL tsL) = match ks0 ls0 where
	match (k:ks) (l:ls)
		| k == l	= match ks ls
	match (k:ks) [] = do	eL' <- lookupM k tsL
			   	cat ls0 <$> cons k <$> isectE s (\ ks' -> f (ls0 ++ k:ks')) (Edge szK ks vK tsK) eL'
	match [] (l:ls) = do	eK' <- lookupM l tsK
				cat ks0 <$> cons l <$> isectE s (\ ls' -> f (ks0 ++ l:ls')) eK' (Edge szL ls vL tsL)
	match [] [] = compact (edge s ks0 (isectMaybe (f ks0) vK vL) (isectM edgeSize g tsK tsL)) where
		g x = isectE s (\ xs -> f (ks0 ++ x:xs))
	match _ _ = Nothing

diffE :: TrieKey k (TrieMap k) => Sized a -> ([k] -> a -> b -> Maybe a) -> Edge' k a -> Edge' k b -> MEdge' k a
diffE s f eK@(Edge szK ks0 vK tsK) eL@(Edge szL ls0 vL tsL) = match ks0 ls0 where
	match (k:ks) (l:ls)
		| k == l	= match ks ls
	match (k:ks) []
		| Just eL' <- lookupM k tsL
			= cat ls0 . cons k <$> diffE s (\ ks' -> f (ls0 ++ k:ks')) (Edge szK ks vK tsK) eL'
	match [] (l:ls)
		= compact (edge s ks0 vK (alterM edgeSize (>>= g) l tsK))
		where	g eK' = diffE s (\ ls' -> f (ks0 ++ l:ls')) eK' (Edge szL ls vL tsL)
	match [] [] = compact (edge s ks0 (diffMaybe (f ks0) vK vL) (diffM edgeSize g tsK tsL)) where
		g x = diffE s (\ xs -> f (ks0 ++ x:xs))
	match _ _ = Just eK

lookupIxE :: TrieKey k (TrieMap k) => Sized a -> Int -> [k] -> Edge' k a -> IndexPos [k] a
lookupIxE s i ks e@(Edge sz ls v ts) = match ks ls where
	match (k:ks) (l:ls) = case compare k l of
		LT	-> (empty, empty, aboutE (return .: Asc i) e)
		EQ	-> match ks ls
		GT	-> (aboutE (\ k a -> return (Asc (i + sz - s a) k a)) e, empty, empty)
	match (k:ks) [] = let sv = maybe 0 s v in case onIndex (i + sv +) (lookupIxM edgeSize k ts) of
		(lb, x, ub) -> let lookupX = do	Asc i' k' e' <- x
						return $ onKey (\ ks' -> ls ++ k':ks') $
							lookupIxE s i' ks e'
			in ((do v <- Last v
				return (Asc i ls v)) <|>
			    (do Asc iL kL eL <- lb
				aboutE (\ ksL vL -> return $ Asc (iL + edgeSize eL - s vL) (ls ++ kL:ksL) vL) eL) <|>
			    (do (lb', _, _) <- Last lookupX
				lb'),
			    (do (_, x', _) <- lookupX
				x'),
			    (do (_, _, ub') <- First lookupX
				ub') <|>
			    (do Asc iU kU eU <- ub
				aboutE (\ ksU -> return . Asc iU (ls ++ kU:ksU)) eU))
	match [] (l:ls) = (empty, empty, aboutE (return .: Asc i) e)
	match [] [] = (empty, Asc i ls <$> v, aboutM (\ x -> aboutE (\ xs -> return . Asc (i + maybe 0 s v) (ls ++ x:xs))) ts)

isSubmapE :: TrieKey k (TrieMap k) => LEq a b -> LEq (Edge' k a) (Edge' k b)
isSubmapE (<=) (Edge szK ks vK tsK) (Edge szL ls vL tsL) = match ks ls where
	match (k:ks) (l:ls)
		| k == l	= match ks ls
	match (k:ks) []
		| Just eL' <- lookupM k tsL
			= isSubmapE (<=) (Edge szK ks vK tsK) eL'
	match [] [] = subMaybe (<=) vK vL && isSubmapM (isSubmapE (<=)) tsK tsL
	match _ _ = False

splitLookupE :: TrieKey k (TrieMap k) => Sized a -> (a -> (Maybe a, Maybe x, Maybe a)) -> [k] -> Edge' k a ->
	(MEdge' k a, Maybe x, MEdge' k a)
splitLookupE s f ks e@(Edge _ ls v ts) = match ks ls where
	match (k:ks) (l:ls) = case compare k l of
		LT	-> (Nothing, Nothing, Just e)
		GT	-> (Just e, Nothing, Nothing)
		EQ	-> match ks ls
	match (k:ks) [] = case splitLookupM edgeSize g k ts of
		(tsL, x, tsR) -> (compact (edge s ls v tsL), x, compact (edge s ls Nothing tsR))
		where	g = splitLookupE s f ks
	match [] (l:ls) = (Nothing, Nothing, Just e)
	match [] [] = (singleMaybe s ls vL, x, compact (edge s ls vR ts))
		where	(vL, x, vR) = maybe (Nothing, Nothing, Nothing) f v

assocAtE :: TrieKey k (TrieMap k) => Sized a -> Int -> Int -> Edge' k a -> IndexPos [k] a
assocAtE _ i0 i _ | i0 `seq` i `seq` False = undefined
assocAtE s i0 i (Edge sz ks v ts) = let sv = maybe 0 s v in case assocAtM edgeSize (i - sv) ts of
	(lb, x, ub) -> let lookupX = do Asc i' l e' <- x
					return (onKey (\ ls -> ks ++ l:ls) (assocAtE s (i0 + sv + i') (i - i') e'))
		in ((do	v <- Last v
			guard (i >= sv)
			return (Asc i0 ks v)) <|>
		      (do	Asc iL lL eL <- lb
				aboutE (\ ls vL -> return (Asc (i0 + iL + sv + edgeSize eL - s vL) (ks ++ lL:ls) vL)) eL) <|>
		      (do	(lb', _, _) <- Last lookupX
				lb'),
		      (do	v <- v
				guard (i >= 0 && i < sv)
				return (Asc i0 ks v)) <|> 
		      (do	(_, x', _) <- lookupX
				x'),
		      (do	(_, _, ub') <- First lookupX
				ub') <|>
		      (do	v <- First v
				guard (i < 0)
				return (Asc i0 ks v)))