{-# LANGUAGE TypeOperators, MultiParamTypeClasses, FunctionalDependencies, UndecidableInstances, PatternGuards #-}

module TrieMap.TrieAlgebraic (TrieKey (..), ProdMap (..), UnionMap(..), RadixTrie(..), Edge (..), Ordered (..), unionMaybe, intersectMaybe, differenceMaybe, mapWithKeyAlg, assocsAlg, insertAlg, alterAlg, fromListAlg') where

import Data.Traversable
import Data.Foldable
import Data.Either
import Data.Sequence (Seq)
import Data.Maybe
import Data.Monoid
import Data.IntMap (IntMap)
import Data.Map (Map)
import qualified Data.Sequence as Seq
import qualified Data.IntMap as IMap
import qualified Data.Map as Map

import Control.Monad
import Control.Applicative hiding (Alternative(..))

import GHC.Exts (build)

import TrieMap.Applicative
-- import TrieMap.Algebraic (Ordered (..))
import TrieMap.MapTypes
import Prelude hiding (foldr, foldl, all, any)

newtype Ordered k = Ord {unOrd :: k} deriving (Eq, Ord)

instance Show k => Show (Ordered k) where
	show = show . unOrd
	showsPrec x = showsPrec x . unOrd

instance Functor Ordered where
	fmap f (Ord x) = Ord (f x)

-- | TrieKey defines a bijection between map types and algebraic key types.
class (Eq a, Foldable m, Traversable m) => TrieKey a m | a -> m, m -> a where
	emptyAlg :: m v
	nullAlg :: m v -> Bool
	sizeAlg :: m v -> Int
	getSingleAlg :: m v -> Maybe (a, v)
	guardNullAlg :: m v -> Maybe (m v)
	{-# SPECIALIZE alterAlg :: (Maybe v -> Id (b, Maybe v)) -> a -> m v -> Id (b, m v) #-}
	alterLookupAlg :: (Maybe v -> (b, Maybe v)) -> a -> m v -> (b, m v)
	lookupAlg :: a -> m v -> Maybe v
	foldWithKeyAlg :: (a -> v -> x -> x) -> x -> m v -> x
	mapAppAlg :: Applicative f => (a -> v -> f w) -> m v -> f (m w)
	mapMaybeAlg :: (a -> v -> Maybe w) -> m v -> m w
	mapEitherAlg :: (a -> v -> Either x y) -> m v -> (m x, m y)
	unionMaybeAlg :: (a -> v -> v -> Maybe v) -> m v -> m v -> m v
	intersectAlg :: (a -> v -> w -> Maybe x) -> m v -> m w -> m x
	differenceAlg :: (a -> v -> w -> Maybe v) -> m v -> m w -> m v
	fromDistAscListAlg :: [(a, v)] -> m v
	fromAscListAlg :: (a -> v -> v -> v) -> [(a, v)] -> m v
	fromListAlg :: (a -> v -> v -> v) -> [(a, v)] -> m v
	getMinAlg :: m v -> Maybe ((a, v), m v)
	getMaxAlg :: m v -> Maybe ((a, v), m v)
	updateMinAlg :: (a -> v -> (Bool, Maybe v)) -> m v -> (Bool, m v)
	updateMaxAlg :: (a -> v -> (Bool, Maybe v)) -> m v -> (Bool, m v)
	valid :: m v -> Bool
	isSubmapAlg :: (v -> w -> Bool) -> m v -> m w -> Bool
	splitLookupAlg :: (v -> (Maybe v, Maybe x, Maybe v)) -> a -> m v -> (m v, Maybe x, m v)

	lookupAlg k = fst . alterLookupAlg (\ v -> (v, v)) k
	guardNullAlg m
		| nullAlg m	= Nothing
		| otherwise	= Just m
	fromListAlg f = foldr (\ (k, v) -> alterAlg (Just . maybe v (f k v)) k) emptyAlg
	fromAscListAlg _ [] = emptyAlg
	fromAscListAlg f ((k, v):xs) = fromDistAscListAlg (distinct k v xs) where
		distinct k v ((k', v'):xs)
			| k == k'	= distinct k (f k v v') xs
			| otherwise	= (k, v):distinct k' v' xs
		distinct k v [] = [(k, v)]
	fromDistAscListAlg = fromListAlg'
	sizeAlg = foldl' (\ n _ -> n + 1) 0

	updateMinAlg f m = maybe (False, m) (\ ((k, v), m') -> maybe m' (\ v' -> insertAlg k v' m) <$> f k v) (getMinAlg m)
	updateMaxAlg f m = maybe (False, m) (\ ((k, v), m') -> maybe m' (\ v' -> insertAlg k v' m) <$> f k v) (getMaxAlg m)
	valid = (`seq` True)

fromListAlg' :: TrieKey k m => [(k, v)] -> m v
fromListAlg' = fromListAlg (const const)

singletonAlg :: TrieKey k m => k -> v -> m v
singletonAlg k v = insertAlg k v emptyAlg

mapWithKeyAlg :: TrieKey k m => (k -> v -> w) -> m v -> m w
mapWithKeyAlg f m = unId (mapAppAlg (\ k v -> Id (f k v)) m)

-- mapMaybeWithKeyAlg :: TrieKey k m => (k -> v -> Maybe w) -> m v -> m w
-- mapMaybeWithKeyAlg f m = unId (mapAppMaybeAlg (\ k v -> Id (f k v)) m)

insertAlg :: TrieKey k m => k -> v -> m v -> m v
insertAlg k v = alterAlg (const (Just v)) k

alterAlg :: TrieKey k m => (Maybe v -> Maybe v) -> k -> m v -> m v
alterAlg f k = snd . alterLookupAlg (\ x -> ((), f x)) k

-- alterLookupAlg :: TrieKey k m => (Maybe a -> (b, Maybe a)) -> k -> m a -> (b, m a)
-- alterLookupAlg f = unId .: alterAppAlg (Id . f)

foldrAlg :: TrieKey k m => (a -> b -> b) -> b -> m a -> b
foldrAlg = foldWithKeyAlg . const

unionMaybe :: (a -> a -> Maybe a) -> Maybe a -> Maybe a -> Maybe a
unionMaybe f (Just x) (Just y) = f x y
unionMaybe _ Nothing y = y
unionMaybe _ x Nothing = x

intersectMaybe :: (a -> b -> Maybe c) -> Maybe a -> Maybe b -> Maybe c
intersectMaybe f (Just x) (Just y) = f x y
intersectMaybe _ _ _ = Nothing

differenceMaybe :: (a -> b -> Maybe a) -> Maybe a -> Maybe b -> Maybe a
differenceMaybe _ Nothing _ = Nothing
differenceMaybe _ x Nothing = x
differenceMaybe f (Just x) (Just y) = f x y

filterLeft :: a -> Either b c -> Maybe b
filterLeft _ (Left x) = Just x
filterLeft _ _ = Nothing

filterRight :: a -> Either b c -> Maybe c
filterRight _ (Right x) = Just x
filterRight _ _ = Nothing

{-# INLINE assocsAlg #-}
assocsAlg :: TrieKey k m => m a -> [(k, a)]
assocsAlg m = build (\ c n -> foldWithKeyAlg (\ k v xs -> (k,v) `c` xs) n m)

instance (Eq a1, Eq a2, TrieKey a1 m1, TrieKey a2 m2) => TrieKey (a1, a2) (m1 `ProdMap` m2) where
	emptyAlg = PMap emptyAlg
	nullAlg (PMap m) = nullAlg m
	sizeAlg (PMap m) = foldl' (\ n m -> n + sizeAlg m) 0 m
	getSingleAlg (PMap m) = do	(k1, m') <- getSingleAlg m
					(k2, v) <- getSingleAlg m'
					return ((k1, k2), v)
	alterLookupAlg f (k1, k2) (PMap m) = PMap <$> alterLookupAlg g k1 m
		where g = fmap guardNullAlg . alterLookupAlg f k2 . fromMaybe emptyAlg
	lookupAlg (k1, k2) (PMap m) = lookupAlg k1 m >>= lookupAlg k2
	foldWithKeyAlg f z (PMap m) = foldWithKeyAlg (\ k1 -> flip (foldWithKeyAlg (\ k2 -> f (k1, k2)))) z m
	mapAppAlg f (PMap m) =
		PMap <$> mapAppAlg (\ k1 -> mapAppAlg (\ k2 -> f (k1, k2))) m
	mapMaybeAlg f (PMap m) =
		PMap $ mapMaybeAlg (\ k1 -> guardNullAlg . mapMaybeAlg (\ k2 -> f (k1, k2))) m
	mapEitherAlg f (PMap m) = (PMap (fmap fst m'), PMap (fmap snd m'))
		where	m' = mapWithKeyAlg (\ k1 -> mapEitherAlg (\ k2 -> f (k1, k2))) m
	unionMaybeAlg f (PMap m1) (PMap m2) = 
		PMap (unionMaybeAlg (\ k1 -> guardNullAlg .: unionMaybeAlg (\ k2 -> f (k1, k2))) m1 m2)
	intersectAlg f (PMap m1) (PMap m2) =
		PMap (intersectAlg (\ k1 -> guardNullAlg .: intersectAlg (\ k2 -> f (k1, k2))) m1 m2)
	differenceAlg f (PMap m1) (PMap m2) =
		PMap (differenceAlg (\ k1 -> guardNullAlg .: differenceAlg (\ k2 -> f (k1, k2))) m1 m2)
	fromListAlg f xs = PMap $ mapWithKeyAlg (\ k1 -> fromListAlg (\ k2 -> f (k1, k2))) $
		fromListAlg (const (++)) [(k1, [(k2, v)]) | ((k1, k2), v) <- xs]
	fromDistAscListAlg xs = PMap $ fromDistAscListAlg [(k1, fromDistAscListAlg ys) | (k1, ys) <- breakFst xs]		
	fromAscListAlg f xs = PMap $ fromDistAscListAlg [(k1, fromAscListAlg (\ k2 -> f (k1, k2)) ys) | (k1, ys) <- breakFst xs]
	getMinAlg (PMap m) = do
		((k1, m'), m1') <- getMinAlg m
		((k2, v), m2') <- getMinAlg m'
		return (((k1, k2), v), PMap (maybe m1' (\ m2' -> insertAlg k1 m2' m) (guardNullAlg m2')))
	getMaxAlg (PMap m) = do
		((k1, m'), m1') <- getMaxAlg m
		((k2, v), m2') <- getMaxAlg m'
		return (((k1, k2), v), PMap (maybe m1' (\ m2' -> insertAlg k1 m2' m) (guardNullAlg m2')))
	updateMinAlg f (PMap m) = 
		PMap <$> updateMinAlg (\ k1 -> guardNullAlg <.> updateMinAlg (\ k2 -> f (k1, k2))) m
	updateMaxAlg f (PMap m) =
		PMap <$> updateMaxAlg (\ k1 -> guardNullAlg <.> updateMaxAlg (\ k2 -> f (k1, k2))) m
	isSubmapAlg (<=) (PMap m1) (PMap m2) =
		isSubmapAlg (isSubmapAlg (<=)) m1 m2
		
	splitLookupAlg f (k1, k2) (PMap m) = case splitLookupAlg g k1 m of
			(mL, ans, mR)	-> (PMap mL, ans, PMap mR)
		where g m' = case splitLookupAlg f k2 m' of
			(mL, ans, mR)	-> (guardNullAlg mL, ans, guardNullAlg mR)

	valid (PMap m) = valid m && all valid m && not (any nullAlg m)

breakFst :: (Eq k1, Eq k2) => [((k1, k2), v)] -> [(k1, [(k2, v)])]
breakFst [] = []
breakFst (((k1, k2), x):xs) = breakFst' k1 (Seq.singleton (k2, x)) xs where
	breakFst' k xs (((k', k2), x):xss)
		| k == k'	= breakFst' k ((Seq.|>) xs (k2, x)) xss
		| otherwise	= (k, toList xs):breakFst' k' (Seq.singleton (k2, x)) xss
	breakFst' k xs [] = [(k, toList xs)]

instance (TrieKey a1 m1, TrieKey a2 m2) => TrieKey (Either a1 a2) (m1 `UnionMap` m2) where
	emptyAlg = emptyAlg :+: emptyAlg
	nullAlg (m1 :+: m2) = nullAlg m1 && nullAlg m2
	sizeAlg (m1 :+: m2) = sizeAlg m1 + sizeAlg m2
	getSingleAlg (m1 :+: m2) = case (getSingleAlg m1, getSingleAlg m2) of
		(Just (k, v), Nothing)	-> Just (Left k, v)
		(Nothing, Just (k, v))	-> Just (Right k, v)
		_			-> Nothing
	alterLookupAlg f (Left k) (m1 :+: m2) = 
		fmap (:+: m2) $ alterLookupAlg f k m1
	alterLookupAlg f (Right k) (m1 :+: m2) =
		fmap (m1 :+:) $ alterLookupAlg f k m2
	lookupAlg k (m1 :+: m2) = either (`lookupAlg` m1) (`lookupAlg` m2) k
	foldWithKeyAlg f z (m1 :+: m2) = foldWithKeyAlg (f . Left) (foldWithKeyAlg (f . Right) z m2) m1
	mapAppAlg f (m1 :+: m2) = 
		liftA2 (:+:) (mapAppAlg (f . Left) m1) (mapAppAlg (f . Right) m2)
	mapMaybeAlg f (m1 :+: m2) = mapMaybeAlg (f . Left) m1 :+: mapMaybeAlg  (f . Right) m2
	mapEitherAlg f (m1 :+: m2) = (m1L :+: m2L, m1R :+: m2R)
		where	(m1L, m1R) = mapEitherAlg (f . Left) m1
			(m2L, m2R) = mapEitherAlg (f . Right) m2
	unionMaybeAlg f (m11 :+: m12) (m21 :+: m22)
		= unionMaybeAlg (f . Left) m11 m21 :+: unionMaybeAlg (f . Right) m12 m22
	intersectAlg f (m11 :+: m12) (m21 :+: m22)
		= intersectAlg (f . Left) m11 m21 :+: intersectAlg (f . Right) m12 m22
	differenceAlg f (m11 :+: m12) (m21 :+: m22)
		= differenceAlg (f . Left) m11 m21 :+: differenceAlg (f . Right) m12 m22
	fromListAlg f xs = fromListAlg (f . Left) ys :+: fromListAlg (f . Right) zs
		where	(ys, zs) = partitionEithers (map pullEither xs)
	fromAscListAlg f xs = fromAscListAlg (f . Left) ys :+: fromAscListAlg (f . Right) zs
		where	(ys, zs) = partitionEithers (map pullEither xs)
	fromDistAscListAlg xs = fromDistAscListAlg ys :+: fromDistAscListAlg zs
		where	(ys, zs) = partitionEithers (map pullEither xs)
	getMinAlg (m1 :+: m2)
		| Just ((k, v), m1') <- getMinAlg m1
			= Just ((Left k, v), m1' :+: m2)
		| Just ((k, v), m2') <- getMinAlg m2
			= Just ((Right k, v), m1 :+: m2')
	getMinAlg _ = Nothing
	getMaxAlg (m1 :+: m2) = getFirst $ First
		(do	((k, v), m2') <- getMaxAlg m2
			return ((Right k, v), m1 :+: m2')) `mappend` First
		(do	((k, v), m1') <- getMaxAlg m1
			return ((Left k, v), m1' :+: m2))
	updateMinAlg f (m1 :+: m2)
		| nullAlg m1	= fmap (m1 :+:) (updateMinAlg (f . Right) m2)
		| otherwise	= fmap (:+: m2) (updateMinAlg (f . Left) m1)
	updateMaxAlg f (m1 :+: m2)
		| nullAlg m2	= fmap (:+: m2) (updateMaxAlg (f . Left) m1)
		| otherwise	= fmap (m1 :+:) (updateMaxAlg (f . Right) m2)
	isSubmapAlg (<=) (m11 :+: m12) (m21 :+: m22) =
		isSubmapAlg (<=) m11 m21 && isSubmapAlg (<=) m12 m22
	valid (m1 :+: m2) = valid m1 && valid m2
	splitLookupAlg f (Left k) (m1 :+: m2) = case splitLookupAlg f k m1 of
		(m1L, ans, m1R)	-> (m1L :+: emptyAlg, ans, m1R :+: m2)
	splitLookupAlg f (Right k) (m1 :+: m2) = case splitLookupAlg f k m2 of
		(m2L, ans, m2R)	-> (m1 :+: m2L, ans, emptyAlg :+: m2R)

pullEither :: (Either k1 k2, v) -> Either (k1, v) (k2, v)
pullEither (Left k, v) = Left (k, v)
pullEither (Right k, v) = Right (k, v)

instance TrieKey Int IntMap where
	emptyAlg = IMap.empty
	nullAlg = IMap.null
	sizeAlg = IMap.size
	getSingleAlg m
		| IMap.size m == 1, [(k, v)] <- IMap.toList m
			= Just (k, v)
	getSingleAlg _ = Nothing
	lookupAlg = IMap.lookup
	alterLookupAlg f k m = fmap (\ v' -> IMap.alter (const v') k m) (f x)
		where x = IMap.lookup k m
	foldWithKeyAlg = IMap.foldWithKey
	mapAppAlg = sequenceA .: IMap.mapWithKey
	mapMaybeAlg = IMap.mapMaybeWithKey
	mapEitherAlg = IMap.mapEitherWithKey
	unionMaybeAlg f m1 m2 = IMap.mapMaybe (either Just id) (IMap.unionWithKey g (fmap Left m1) (fmap Left m2)) where
		g k (Left v1) (Left v2) = Right (f k v1 v2)
		g k (Right v) _ = Right v
		g k _ (Right v) = Right v
	intersectAlg f m1 m2 = IMap.mapMaybe (either (const Nothing) Just) $ IMap.intersectionWithKey g (fmap Left m1) m2 where
		g k (Left x) = maybe (Left x) Right . f k x
		g _ (Right x) = const (Right x)
	differenceAlg = IMap.differenceWithKey
	fromListAlg = IMap.fromListWithKey
	fromAscListAlg = IMap.fromAscListWithKey
	fromDistAscListAlg = IMap.fromDistinctAscList
	getMinAlg = IMap.minViewWithKey
	getMaxAlg = IMap.maxViewWithKey
	updateMinAlg f m = case IMap.minViewWithKey m of
		Just ((k, v), m')	-> let (ans, v') = f k v in (ans, maybe m' (\ v' -> IMap.updateMin (const v') m) v')
		_			-> (False, m)
	updateMaxAlg f m = case IMap.maxViewWithKey m of
		Just ((k, v), m')	-> let (ans, v') = f k v in (ans, maybe m' (\ v' -> IMap.updateMax (const v') m) v')
		_			-> (False, m)
	isSubmapAlg = IMap.isSubmapOfBy
	splitLookupAlg f k m = case IMap.splitLookup k m of
		(mL, Nothing, mR)	-> (mL, Nothing, mR)
		(mL, Just v, mR) -> case f v of
			(vL, ans, vR)	-> (maybe mL (flip (IMap.insert k) mL) vL, ans, maybe mR (flip (IMap.insert k) mR) vR)

instance Ord k => TrieKey (Ordered k) (Map k) where
	emptyAlg = Map.empty
	nullAlg = Map.null
	sizeAlg = Map.size
	getSingleAlg m
		| Map.size m == 1, (k, v) <- Map.findMin m
			= Just (Ord k, v)
	lookupAlg = Map.lookup . unOrd
	alterLookupAlg f (Ord k) m = fmap (\ v -> Map.alter (const v) k m) (f x)
		where x = Map.lookup k m
	foldWithKeyAlg f = Map.foldWithKey (f . Ord)
	mapAppAlg f = sequenceA . Map.mapWithKey (f . Ord)
 	mapMaybeAlg f = Map.mapMaybeWithKey (f . Ord)
	mapEitherAlg f = Map.mapEitherWithKey (f . Ord)
	unionMaybeAlg f m1 m2 = Map.mapMaybe (either Just id) (Map.unionWithKey g (fmap Left m1) (fmap Left m2)) where
		g k (Left v1) (Left v2) = Right (f (Ord k) v1 v2)
		g k (Right v) _ = Right v
		g k _ (Right v) = Right v
	intersectAlg f = Map.mapMaybe id .: Map.intersectionWithKey (f . Ord)
	differenceAlg f = Map.differenceWithKey (f . Ord)
	fromListAlg f xs = Map.fromListWithKey (f . Ord) [(k, v) | (Ord k, v) <- xs]
	fromAscListAlg f xs = Map.fromAscListWithKey (f . Ord) [(k, v) | (Ord k, v) <- xs]
	fromDistAscListAlg xs = Map.fromDistinctAscList [(k, v) | (Ord k, v) <- xs]
	getMinAlg m = do	~(~(k, v), m') <- Map.minViewWithKey m
				return ((Ord k, v), m')
	getMaxAlg m = do	~(~(k, v), m') <- Map.maxViewWithKey m
				return ((Ord k, v), m')
	updateMinAlg f m
		| Map.null m	= (False, m)
		| otherwise	= case Map.findMin m of
			(k, v)	-> let (ans, v') = f (Ord k) v in (ans, Map.updateMin (const v') m)
	updateMaxAlg f m
		| Map.null m	= (False, m)
		| otherwise	= case Map.findMin m of
			(k, v)	-> let (ans, v') = f (Ord k) v in (ans, Map.updateMax (const v') m)
	isSubmapAlg = Map.isSubmapOfBy
	splitLookupAlg f (Ord k) m = case Map.splitLookup k m of
		(mL, Nothing, mR)	-> (mL, Nothing, mR)
		(mL, Just v, mR) -> case f v of
			(vL, ans, vR) -> (maybe mL (flip (Map.insert k) mL) vL, ans, maybe mR (flip (Map.insert k) mR) vR) 

instance TrieKey () Maybe where
	emptyAlg = Nothing
	nullAlg = isNothing
	sizeAlg = maybe 0 (const 1)
	getSingleAlg = fmap ((,) ())
	lookupAlg _ = id
	alterLookupAlg f _ = f
	foldWithKeyAlg f = foldr (f ())
	mapAppAlg f = traverse (f ())
	mapMaybeAlg f = (>>= f ())
	mapEitherAlg _ Nothing = (Nothing, Nothing)
	mapEitherAlg f (Just v) = case f () v of
		Left v	-> (Just v, Nothing)
		Right v	-> (Nothing, Just v)
	unionMaybeAlg f = unionMaybe (f ())
	intersectAlg f = intersectMaybe (f ())
	differenceAlg f = differenceMaybe (f ())
	fromListAlg _ [] = Nothing
	fromListAlg f ((_, v):xs) = Just (foldr (f () . snd) v xs)
	fromAscListAlg = fromListAlg
	getMinAlg = fmap g where
		g v = (((), v), Nothing)
	getMaxAlg = fmap g where
		g v = (((), v), Nothing)
	updateMinAlg f = maybe (False, Nothing) (f ())
	updateMaxAlg f = maybe (False, Nothing) (f ())
	isSubmapAlg _ Nothing _ = True
	isSubmapAlg _ _ Nothing = False
	isSubmapAlg (<=) (Just x) (Just y) = x <= y
	splitLookupAlg f _ (Just v) = f v
	splitLookupAlg _ _ _ = (Nothing, Nothing, Nothing)

first :: (a -> c) -> (a, b) -> (c, b)
first f (x, y) = (f x, y)