{-# LANGUAGE FlexibleContexts, PatternGuards, UndecidableInstances, TypeFamilies, MultiParamTypeClasses #-}

module Data.TrieMap.UnionMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Regular.Class
-- import Data.TrieMap.Regular.TH
import Data.TrieMap.Applicative

import Control.Applicative
-- import Control.Arrow

-- import Data.Monoid

data UMap m1 k2 a = m1 a :&: TrieMap k2 a

type instance TrieMapT (Either a) = UMap (TrieMap a)
type instance TrieMap (Either a b) = UMap (TrieMap a) b

instance (TrieKey a m, TrieKey b (TrieMap b)) => TrieKey (Either a b) (UMap m b) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	alterM = alterT
	alterLookupM = alterLookupT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractM = extractT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT

instance TrieKey k1 m1 => TrieKeyT (Either k1) (UMap m1) where
	emptyT = emptyM :&: emptyM
	nullT (m1 :&: m2) = nullM m1 && nullM m2
	sizeT s (m1 :&: m2) = sizeM s m1 + sizeM s m2
	lookupT k (m1 :&: m2) = either (`lookupM` m1) (`lookupM` m2) k
	lookupIxT s k (m1 :&: m2) = case k of
		Left k	| (lb, x, ub) <- onKey Left $ lookupIxM s k m1
				-> (lb, x, ub <|> aboutM (\ k -> return . Asc (sizeM s m1) (Right k)) m2)
		Right k | (lb, x, ub) <- onKey Right $ lookupIxM s k m2
				-> (aboutM (\ k a -> return (Asc (sizeM s m1 - s a) (Left k) a)) m1 <|> lb, x, ub)
	assocAtT s i (m1 :&: m2)
		| i < s1, (lb, x, ub) <- onKey Left (assocAtM s i m1)
			= (lb, x, ub <|> aboutM (\ k -> return . Asc s1 (Right k)) m2)
		| (lb, x, ub) <- onKey Right (onIndex (s1 +) (assocAtM s (i - s1) m2))
			= (aboutM (\ k a -> return (Asc (s1 - s a) (Left k) a)) m1 <|> lb, x, ub)
		where s1 = sizeM s m1
-- 	updateAtM s r i (m1 :&: m2)
	alterT s f k (m1 :&: m2) = case k of
		Left k	-> alterM s f k m1 :&: m2
		Right k	-> m1 :&: alterM s f k m2
	alterLookupT s f k (m1 :&: m2) = case k of
		Left k	-> fmap (:&: m2) (alterLookupM s f k m1)
		Right k	-> fmap (m1 :&:) (alterLookupM s f k m2)
	traverseWithKeyT s f (m1 :&: m2) = (:&:) <$> traverseWithKeyM s (f . Left) m1 <*> traverseWithKeyM s (f . Right) m2
	foldWithKeyT f (m1 :&: m2) = foldWithKeyM (f . Left) m1 . foldWithKeyM (f . Right) m2
	foldlWithKeyT f (m1 :&: m2) = foldlWithKeyM (f . Right) m2 . foldlWithKeyM (f . Left) m1
	mapEitherT s1 s2 f (m1 :&: m2) = (m1L :&: m2L, m1R :&: m2R)
		where	(m1L, m1R) = mapEitherM s1 s2 (f . Left) m1
			(m2L, m2R) = mapEitherM s1 s2 (f . Right) m2
-- 	extractMinT s f (m1 :&: m2) = second (:&: m2) <$> extractMinM s (f . Left) m1 <|>
-- 		second (m1 :&:) <$> extractMinM s (f . Right) m2
-- 	extractMaxT s f (m1 :&: m2) = second (:&: m2) <$> extractMaxM s (f . Left) m1 <|>
-- 		second (m1 :&:) <$> extractMaxM s (f . Right) m2
	extractT s f (m1 :&: m2) = fmap (:&: m2) <$> extractM s (f . Left) m1 <|>
		fmap (m1 :&:) <$> extractM s (f . Right) m2
	splitLookupT s f k (m1 :&: m2) = case k of
		Left k | (m1L, x, m1R) <- splitLookupM s f k m1
			-> (m1L :&: emptyM, x, m1R :&: m2)
		Right k | (m2L, x, m2R) <- splitLookupM s f k m2
			-> (m1 :&: m2L, x, emptyM :&: m2R)
	unionT s f (m11 :&: m12) (m21 :&: m22)
		= unionM s (f . Left) m11 m21 :&: unionM s (f . Right) m12 m22
	isectT s f (m11 :&: m12) (m21 :&: m22)
		= isectM s (f . Left) m11 m21 :&: isectM s (f . Right) m12 m22
	diffT s f (m11 :&: m12) (m21 :&: m22)
		= diffM s (f . Left) m11 m21 :&: diffM s (f . Right) m12 m22
	isSubmapT (<=) (m11 :&: m12) (m21 :&: m22) = isSubmapM (<=) m11 m21 && isSubmapM (<=) m12 m22
	fromListT s f xs = case partEithers xs of
		(ys, zs) -> fromListM s (f . Left) ys :&: fromListM s (f . Right) zs
	fromAscListT s f xs = case partEithers xs of
		(ys, zs) -> fromAscListM s (f . Left) ys :&: fromAscListM s (f . Right) zs
	fromDistAscListT s xs = case partEithers xs of
		(ys, zs) -> fromDistAscListM s ys :&: fromDistAscListM s zs
	
partEithers :: [(Either a b, x)] -> ([(a, x)], [(b, x)])
partEithers = foldr part ([], []) where
	  part (Left x, z) (xs, ys) = ((x,z):xs, ys)
	  part (Right y, z) (xs, ys) = (xs, (y, z):ys)

--   aboutMinM :: TrieKey k (TrieMap k) => (k -> a -> x) -> TrieMap k a -> First x
--   aboutMinM f m = fst <$> extractMinM (const 0) (\ k a -> (f k a, Nothing)) m
-- 
--   aboutMaxM :: TrieKey k (TrieMap k) => (k -> a -> x) -> TrieMap k a -> Last x
--   aboutMaxM f m = fst <$> extractMaxM (const 0) (\ k a -> (f k a, Nothing)) m