{-# LANGUAGE BangPatterns, UnboxedTuples, TupleSections, TypeFamilies, PatternGuards, MagicHash #-}

module Data.TrieMap.RadixTrie () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized
-- import Data.TrieMap.Applicative

import Control.Applicative
import Control.Monad

import Data.Maybe
import Data.Foldable (foldr, foldl)
import Data.Traversable

import GHC.Exts

import Prelude hiding (lookup, foldr, foldl)

data Assoc k a = Empty | Assoc [k] a
data Edge k a = Edge Int# [k] (Assoc k a) (TrieMap k (Edge k a))
type MEdge k a = Maybe (Edge k a)

instance Sized (Edge k a) where
	getSize# (Edge sz _ _ _) = sz

instance Sized a => Sized (Assoc k a) where
	getSize# (Assoc _ a) = getSize# a
	getSize# _ = 0#

data Path k a = Root
	| Deep (Path k a) [k] (Assoc k a) (Hole k (Edge k a))

instance TrieKey k =>  TrieKey [k] where
	newtype TrieMap [k] a = Radix (MEdge k a)
	data Hole [k] a = Hole [k] [k] (TrieMap k (Edge k a)) (Path k a)

	emptyM = Radix Nothing
	singletonM ks a = Radix (Just (Edge (getSize# a) ks (Assoc ks a) emptyM))
	nullM (Radix m) = isNothing m
	sizeM (Radix (Just e)) = getSize# e
	sizeM _ = 0#
	lookupM ks (Radix m) = m >>= lookup ks
	traverseWithKeyM f (Radix m) = Radix <$> traverse (traverseE f) m
	foldrWithKeyM f (Radix m) z = foldr (foldrE f) z m
	foldlWithKeyM f (Radix m) z = foldl (foldlE f) z m
	mapWithKeyM f (Radix m) = Radix (mapWithKeyE f <$> m)
	mapMaybeM f (Radix m) = Radix (m >>= mapMaybeE f)
	mapEitherM _ (Radix Nothing) = (# emptyM, emptyM #)
	mapEitherM f (Radix (Just m)) = both Radix Radix (mapEitherE f) m
	unionM f (Radix m1) (Radix m2) = Radix (unionMaybe (unionE f) m1 m2)
	isectM f (Radix m1) (Radix m2) = Radix (isectMaybe (isectE f) m1 m2)
	diffM f (Radix m1) (Radix m2) = Radix (diffMaybe (diffE f) m1 m2)
	isSubmapM (<=) (Radix m1) (Radix m2) = subMaybe (isSubmapE (<=)) m1 m2

	singleHoleM ks = Hole ks ks emptyM Root
	keyM (Hole ks _ _ _) = ks
	beforeM a (Hole ks0 ks ts path) = before (compact (edge ks v ts)) path where
		v = case a of
			Nothing	-> Empty
			Just a	-> Assoc ks0 a
		before t Root = Radix t
		before e (Deep path ks v tHole) =
			before (compact $ edge ks v $ beforeM e tHole) path
	afterM a (Hole ks0 ks ts path) = after (compact (edge ks v ts)) path where
		v = case a of
			Nothing	-> Empty
			Just a	-> Assoc ks0 a
		after t Root = Radix t
		after e (Deep path ks v tHole) =
			after (compact $ edge ks v $ afterM e tHole) path

	searchM ks (Radix Nothing) = (# Nothing, singleHoleM ks #)
	searchM ks (Radix (Just e)) = case searchE ks e Root of
		(# v, holer #) -> (# v, holer ks #)

	indexM _ (Radix Nothing) = (# error err, error err, error err #)
		where err = "Error: trie map is empty"
	indexM i# (Radix (Just e)) = indexE i# e Root
	
	extractHoleM (Radix Nothing) = mzero
	extractHoleM (Radix (Just e)) = extractHoleE Root e
	
	assignM a (Hole ks0 ks ts path) = Radix $ rebuild (compact (edge ks (Assoc ks0 a) ts)) path
	
	clearM (Hole _ ks ts path) = Radix $ rebuild (compact (edge ks Empty ts)) path

rebuild :: (TrieKey k, Sized a) => MEdge k a -> Path k a -> MEdge k a
rebuild e (Deep path ks v tHole) =
	rebuild (compact (edge ks v (fillHoleM e tHole))) path
rebuild e _ = e

cat :: [k] -> Edge k a -> Edge k a
ks `cat` Edge sz ls v ts = Edge sz (ks ++ ls) v ts

cons :: k -> Edge k a -> Edge k a
k `cons` Edge sz ks v ts = Edge sz (k:ks) v ts

edge :: (TrieKey k, Sized a) =>  [k] -> Assoc k a -> TrieMap k (Edge k a) -> Edge k a
edge ks v ts = Edge (getSize# v +# getSize# ts) ks v ts

compact :: TrieKey k => Edge k a -> MEdge k a
compact e@(Edge _ ks Empty ts) = case assocsM ts of
	[]	-> Nothing
	[(l, e')] -> compact (ks `cat` (l `cons` e'))
	_	-> Just e
compact e = Just e

lookup :: (Eq k, TrieKey k) => [k] -> Edge k a -> Maybe a
lookup ks (Edge _ ls v ts) = match ks ls where
	match (k:ks) (l:ls)
		| k == l = match ks ls
	match (k:ks) [] = lookupM k ts >>= lookup ks
	match [] [] = case v of
		Assoc _ a	-> Just a
		_		-> Nothing
	match _ _ = Nothing

traverseA :: Applicative f => ([k] -> a -> f b) -> Assoc k a -> f (Assoc k b)
traverseA f (Assoc ks a) = Assoc ks <$> f ks a
traverseA _ _ = pure Empty

traverseE :: (Applicative f, TrieKey k, Sized b) => ([k] -> a -> f b) -> Edge k a -> f (Edge k b)
traverseE f (Edge _ ks v ts)
	= edge ks <$> traverseA f v <*> traverseM (traverseE f) ts

foldrA :: ([k] -> a -> b -> b) -> Assoc k a -> b -> b
foldrA f (Assoc ks a) = f ks a
foldrA _ _ = id

foldlA :: ([k] -> b -> a -> b) -> b -> Assoc k a -> b
foldlA f z (Assoc ks a) = f ks z a
foldlA _ z _ = z

foldrE :: TrieKey k => ([k] -> a -> b -> b) -> Edge k a -> b -> b
foldrE f (Edge _ _ v ts) z = foldrA f v (foldr (foldrE f) z ts)

foldlE :: TrieKey k => ([k] -> b -> a -> b) -> b -> Edge k a -> b 
foldlE f z (Edge _ _ v ts) = foldl (foldlE f) (foldlA f z v) ts

mapWithKeyA :: ([k] -> a -> b) -> Assoc k a -> Assoc k b
mapWithKeyA f (Assoc ks a)	= Assoc ks (f ks a)
mapWithKeyA _ _			= Empty

mapWithKeyE :: (TrieKey k, Sized b) => ([k] -> a -> b) -> Edge k a -> Edge k b
mapWithKeyE f (Edge _ ks v ts) = edge ks (mapWithKeyA f v) (fmapM (mapWithKeyE f) ts)

mapMaybeA :: ([k] -> a -> Maybe b) -> Assoc k a -> Assoc k b
mapMaybeA f (Assoc ks a) = maybe Empty (Assoc ks) (f ks a)
mapMaybeA _ _ = Empty

mapMaybeE :: (TrieKey k, Sized b) => ([k] -> a -> Maybe b) -> Edge k a -> MEdge k b
mapMaybeE f (Edge _ ks v ts) = compact (edge ks (mapMaybeA f v)
	(mapMaybeM (const $ mapMaybeE f) ts))

mapEitherA :: ([k] -> a -> (# Maybe b, Maybe c #)) -> Assoc k a -> (# Assoc k b, Assoc k c #)
mapEitherA f (Assoc ks a) = case f ks a of
	(# vL, vR #)	-> (# maybe Empty (Assoc ks) vL, maybe Empty (Assoc ks) vR #)
mapEitherA _ _ = (# Empty, Empty #)

mapEitherE :: (TrieKey k, Sized b, Sized c) => ([k] -> a -> (# Maybe b, Maybe c #)) -> Edge k a ->
	(# MEdge k b, MEdge k c #)
mapEitherE f (Edge _ ks v ts) = case mapEitherA f v of
	(# vL, vR #) -> case mapEitherM (\ _ -> mapEitherE f) ts of
		(# tsL, tsR #) -> (# compact (edge ks vL tsL), compact (edge ks vR tsR) #)

unionE :: (TrieKey k, Sized a) =>  ([k] -> a -> a -> Maybe a) -> Edge k a -> Edge k a -> MEdge k a
unionE f (Edge szK# ks0 vK tsK) (Edge szL# ls0 vL tsL) = match 0 ks0 ls0 where
	match !i (k:ks) (l:ls) = case compare k l of
	      EQ -> match (i+1) ks ls
	      LT -> Just $ Edge (szK# +# szL#) (take i ks0) Empty (fromDistAscListM 
		      [(k, Edge szK# ks vK tsK), (l, Edge szL# ls vL tsL)])
	      GT -> Just $ Edge (szK# +# szL#) (take i ks0) Empty (fromDistAscListM
		      [(l, Edge szL# ls vL tsL), (k, Edge szK# ks vK tsK)])
	match _ [] (l:ls) = compact (edge ks0 vK (alterM g l tsK)) where
		g (Just eK') = unionE f eK' (Edge szL# ls vL tsL)
		g Nothing = Just (Edge szL# ls vL tsL)
	match _ (k:ks) [] = compact (edge ls0 vL (alterM g k tsL)) where
		g Nothing = Just (Edge szK# ks vK tsK)
		g (Just eL') = unionE f (Edge szK# ks vK tsK) eL'
	match _ [] [] = compact (edge ls0 (unionA f vK vL) (unionM (const $ unionE f) tsK tsL))

unionA :: ([k] -> a -> a -> Maybe a) -> Assoc k a -> Assoc k a -> Assoc k a
unionA f (Assoc ks v1) (Assoc _ v2) = maybe Empty (Assoc ks) (f ks v1 v2)
unionA _ Empty v = v
unionA _ v Empty = v

isectE :: (TrieKey k, Sized c) => ([k] -> a -> b -> Maybe c) -> Edge k a -> Edge k b -> MEdge k c
isectE f (Edge szK ks0 vK tsK) (Edge szL ls0 vL tsL) = match ks0 ls0 where
	match (k:ks) (l:ls)
		| k == l	= match ks ls
	match (k:ks) [] = do	eL' <- lookupM k tsL
			   	cat ls0 <$> cons k <$> isectE f (Edge szK ks vK tsK) eL'
	match [] (l:ls) = do	eK' <- lookupM l tsK
				cat ks0 <$> cons l <$> isectE f eK' (Edge szL ls vL tsL)
	match [] [] = compact (edge ks0 (isectA f vK vL) (isectM (const $ isectE f) tsK tsL))
	match _ _ = Nothing

isectA :: ([k] -> a -> b -> Maybe c) -> Assoc k a -> Assoc k b -> Assoc k c
isectA f (Assoc ks a) (Assoc _ b) = maybe Empty (Assoc ks) (f ks a b)
isectA _ _ _ = Empty

diffE :: (TrieKey k, Sized a) =>  ([k] -> a -> b -> Maybe a) -> Edge k a -> Edge k b -> MEdge k a
diffE f eK@(Edge szK ks0 vK tsK) (Edge szL ls0 vL tsL) = match ks0 ls0 where
	match (k:ks) (l:ls)
		| k == l	= match ks ls
	match (k:ks) []
		| Just eL' <- lookupM k tsL
			= cat ls0 . cons k <$> diffE f (Edge szK ks vK tsK) eL'
	match [] (l:ls)
		= compact (edge ks0 vK (alterM (>>= g) l tsK))
		where	g eK' = diffE f eK' (Edge szL ls vL tsL)
	match [] [] = compact (edge ks0 (diffA f vK vL) (diffM (const $ diffE f) tsK tsL))
	match _ _ = Just eK

diffA :: ([k] -> a -> b -> Maybe a) -> Assoc k a -> Assoc k b -> Assoc k a
diffA f (Assoc ks a) (Assoc _ b)	= maybe Empty (Assoc ks) (f ks a b)
diffA _ a@Assoc{} Empty			= a
diffA _ Empty _				= Empty

isSubmapE :: TrieKey k => LEq a b -> LEq (Edge k a) (Edge k b)
isSubmapE (<=) (Edge szK ks vK tsK) (Edge _ ls vL tsL) = match ks ls where
	match (k:ks) (l:ls)
		| k == l	= match ks ls
	match (k:ks) []
		| Just eL' <- lookupM k tsL
			= isSubmapE (<=) (Edge szK ks vK tsK) eL'
	match [] [] = subA (<=) vK vL && isSubmapM (isSubmapE (<=)) tsK tsL
	match _ _ = False

subA :: LEq a b -> LEq (Assoc k a) (Assoc k b)
subA (<=) (Assoc _ a) (Assoc _ b) = a <= b
subA _ Assoc{} Empty = False
subA _ Empty _ = True

searchE :: TrieKey k => [k] -> Edge k a -> Path k a -> (# Maybe a, [k] -> Hole [k] a #)
searchE ks0 (Edge sz ls0 v ts) path = match 0 ks0 ls0 where
	match !_ [] [] = (# assocToMaybe v, \ k0 -> Hole k0 ls0 ts path #)
	match _ (k:ks) [] = case searchM k ts of
		(# Just e', tHole #) -> searchE ks e' (Deep path ls0 v tHole)
		(# Nothing, tHole #) -> (# Nothing, \ k0 -> Hole k0 ks emptyM (Deep path ls0 v tHole) #)
	match i [] (l:ls) = (# Nothing, \ k0 -> Hole k0 (take i ls0) (singletonM l (Edge sz ls v ts)) path #)
	match i (k:ks) (l:ls)
		| k == l	= match (i+1) ks ls
		| (# _, kHole #) <- searchM k (singletonM l (Edge sz ls v ts))
				= (# Nothing, \ k0 -> Hole k0 ks emptyM (Deep path (take i ls0) Empty kHole) #)

assocToMaybe :: Assoc k a -> Maybe a
assocToMaybe (Assoc _ a) = Just a
assocToMaybe _ = Nothing

indexE :: (TrieKey k, Sized a) => Int# -> Edge k a -> Path k a -> (# Int#, a, Hole [k] a #)
indexE i# (Edge _ ks Empty ts) path
	| (# i'#, e, tHole #) <- indexM i# ts
	  	= indexE i'# e (Deep path ks Empty tHole)
indexE i# (Edge _ ks v@(Assoc ks0 a) ts) path
	| i# <# sa#	= (# i#, a, Hole ks0 ks ts path #)
	| (# i'#, e, tHole #) <- indexM (i# -# sa#) ts
			= indexE i'# e (Deep path ks v tHole)
	where !sa# = getSize# a

extractHoleE :: (TrieKey k, MonadPlus m) => Path k a -> Edge k a -> m (a, Hole [k] a)
extractHoleE path (Edge _ ks v ts) = case v of
	Empty	-> tsHoles
	Assoc ks0 a -> return (a, Hole ks0 ks ts path) `mplus` tsHoles
	where	tsHoles = do	(e, tHole) <- extractHoleM ts
				extractHoleE (Deep path ks v tHole) e