{-# LANGUAGE PatternGuards, FlexibleInstances, TemplateHaskell, TypeOperators, TypeFamilies, MultiParamTypeClasses, FlexibleContexts, UndecidableInstances #-}

module Data.TrieMap.Regular.UnionMap() where

import Data.TrieMap.Regular.Class
import Data.TrieMap.Regular.Base
-- import Data.TrieMap.Regular.TH
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative
-- import Data.TrieMap.Rep
-- import Data.TrieMap.Rep.TH

import Control.Applicative
-- import Control.Arrow
import Control.Monad

-- import Data.Either
-- import Data.Monoid

-- import Generics.MultiRec.Base
data UnionMap m1 m2 k a = m1 k a :&: m2 k a

type instance TrieMapT (f :+: g) = UnionMap (TrieMapT f) (TrieMapT g)
type instance TrieMap ((f :+: g) r) = TrieMapT (f :+: g) r

-- type instance RepT (UnionMap m1 m2 k) = RepT (m1 k) :*: RepT (m2 k)
-- type instance Rep (UnionMap f g k a) = RepT (UnionMap f g k) (Rep a)
-- 
-- -- $(genRepT [d|
--    instance (ReprT (m1 k), ReprT (m2 k)) => ReprT (UnionMap m1 m2 k) where
-- 	toRepT (m1 :&: m2) = toRepT m1 :*: toRepT m2
-- 	fromRepT (m1 :*: m2) = fromRepT m1 :&: fromRepT m2 |])

instance (TrieKeyT f m1, TrieKeyT g m2, TrieKey k (TrieMap k)) => TrieKey ((f :+: g) k) (UnionMap m1 m2 k) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	alterM = alterT
	alterLookupM = alterLookupT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractM = extractT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT


instance (TrieKeyT f m1, TrieKeyT g m2) => TrieKeyT (f :+: g) (UnionMap m1 m2) where
	emptyT = emptyT :&: emptyT
	nullT (m1 :&: m2) = nullT m1 && nullT m2
	sizeT s (m1 :&: m2) = sizeT s m1 + sizeT s m2
	lookupT k (m1 :&: m2) = case k of
		L k -> lookupT k m1
		R k -> lookupT k m2
	lookupIxT s k (m1 :&: m2) = case k of
		L k | (lb, x, ub) <- onKey L (lookupIxT s k m1)
			-> (lb, x, ub <|> fmap (onKeyA R . onIndexA (sizeT s m1 +)) (getMin m2))
		R k | (lb, x, ub) <- onIndex (sizeT s m1 +) (onKey R (lookupIxT s k m2))
			-> (fmap (onKeyA L) (getMax m1) <|> lb, x, ub)
		where	getMin = aboutT (return .: Asc 0)
			getMax m = aboutT (\ k a -> return (Asc (sizeT s m - s a) k a)) m
	assocAtT s i (m1 :&: m2)
		| i < s1	= onKey L (assocAtT s i m1)
		| otherwise	= onKey R (onIndex (s1 +) (assocAtT s (i - s1) m2))
		where s1 = sizeT s m1
{-	updateAtT s r f i (m1 :&: m2)
		| not r, i >= maxIx m1
				= m1 :&: updateAtT s r (\ i' -> f (i' + s1) . R) (i - s1) m2
		| i < s1	= updateAtT s r (\ i' -> f i' . L) i m1 :&: m2
		| otherwise	= m1 :&: updateAtT s r (\ i' -> f (i' + s1) . R) (i - s1) m2
		where 	s1 = sizeT s m1
			maxIx m = maybe (sizeT s m) fst $ getLast (extractMaxT s (\ _ v -> (sizeT s m - s v, Just v)) m)-}
	alterT s f k (m1 :&: m2) = case k of
		L k -> alterT s f k m1 :&: m2
		R k -> m1 :&: alterT s f k m2
	alterLookupT s f k (m1 :&: m2) = case k of
		L k -> fmap (:&: m2) (alterLookupT s f k m1)
		R k -> fmap (m1 :&:) (alterLookupT s f k m2)
	traverseWithKeyT s f (m1 :&: m2) = (:&:) <$> traverseWithKeyT s (f . L) m1 <*> traverseWithKeyT s (f . R) m2
	foldWithKeyT f (m1 :&: m2) = foldWithKeyT (f . L) m1 . foldWithKeyT (f . R) m2
	foldlWithKeyT f (m1 :&: m2) = foldlWithKeyT (f . R) m2 . foldlWithKeyT (f . L) m1
	mapEitherT s1 s2 f (m1 :&: m2) = case (mapEitherT s1 s2 (f . L) m1, mapEitherT s1 s2 (f . R) m2) of
		((m1L, m1R), (m2L, m2R)) -> (m1L :&: m2L, m1R :&: m2R)
	splitLookupT s f k (m1 :&: m2) = case k of
		L k -> case splitLookupT s f k m1 of
			(m1L, ans, m1R) -> (m1L :&: emptyT, ans, m1R :&: m2)
		R k -> case splitLookupT s f k m2 of
			(m2L, ans, m2R) -> (m1 :&: m2L, ans, emptyT :&: m2R)
	unionT s f (m11 :&: m12) (m21 :&: m22) = unionT s (f . L) m11 m21 :&: unionT s (f . R) m12 m22
	isectT s f (m11 :&: m12) (m21 :&: m22) = isectT s (f . L) m11 m21 :&: isectT s (f . R) m12 m22
	diffT s f (m11 :&: m12) (m21 :&: m22) = diffT s (f . L) m11 m21 :&: diffT s (f . R) m12 m22
	extractT s f (m1 :&: m2) = fmap (:&: m2) <$> extractT s (f . L) m1 <|>
		fmap (m1 :&:) <$> extractT s (f . R) m2
-- 	extractMinT s f (m1 :&: m2) = second (:&: m2) <$> extractMinT s (f . L) m1 <|>
-- 		second (m1 :&:) <$> extractMinT s (f . R) m2
-- 	extractMaxT s f (m1 :&: m2) = second (:&: m2) <$> extractMaxT s (f . L) m1 <|>
-- 		second (m1 :&:) <$> extractMaxT s (f . R) m2
-- 	alterMinT s f (m1 :&: m2)
-- 		| nullT m1	= m1 :&: alterMinT s (f . R) m2
-- 		| otherwise	= alterMinT s (f . L) m1 :&: m2
-- 	alterMaxT s f (m1 :&: m2)
-- 		| nullT m2	= alterMaxT s (f . L) m1 :&: m2
-- 		| otherwise	= m1 :&: alterMaxT s (f . R) m2
	isSubmapT (<=) (m11 :&: m12) (m21 :&: m22) = isSubmapT (<=) m11 m21 && isSubmapT (<=) m12 m22
	fromListT s f xs = case partEithers xs of
		(ys, zs) -> fromListT s (f . L) ys :&: fromListT s (f . R) zs
	fromAscListT s f xs = case partEithers xs of
		(ys, zs) -> fromAscListT s (f . L) ys :&: fromAscListT s (f . R) zs
	fromDistAscListT s xs = case partEithers xs of
		(ys, zs) -> fromDistAscListT s ys :&: fromDistAscListT s zs