{-# LANGUAGE TypeOperators, TypeFamilies, MultiParamTypeClasses, FlexibleContexts, UndecidableInstances #-}

module Data.TrieMap.Regular.UnionMap() where

import Data.TrieMap.Regular.Class
import Data.TrieMap.Regular.Base
import Data.TrieMap.TrieKey

import Control.Applicative
import Control.Arrow
import Control.Monad

import Data.Either

-- import Generics.MultiRec.Base
data UnionMap m1 m2 k (a :: * -> *) ix = m1 k a ix :&: m2 k a ix

type instance TrieMapT (f :+: g) = UnionMap (TrieMapT f) (TrieMapT g)
type instance TrieMap ((f :+: g) r) = TrieMapT (f :+: g) r

instance (TrieKeyT f m1, TrieKeyT g m2) => TrieKeyT (f :+: g) (UnionMap m1 m2) where
	emptyT = emptyT :&: emptyT
	nullT (m1 :&: m2) = nullT m1 && nullT m2
	sizeT s (m1 :&: m2) = sizeT s m1 + sizeT s m2
	lookupT k (m1 :&: m2) = case k of
		L k -> lookupT k m1
		R k -> lookupT k m2
	lookupIxT s k (m1 :&: m2) = case k of
		L k -> lookupIxT s k m1
		R k -> first (+ sizeT s m1) <$> lookupIxT s k m2
	assocAtT s i (m1 :&: m2)
		| i < s1	= case assocAtT s i m1 of
			(i', k, a) -> (i', L k, a)
		| otherwise	= case assocAtT s (i - s1) m2 of
			(i', k, a) -> (i' + s1, R k, a)
		where s1 = sizeT s m1
	updateAtT s f i (m1 :&: m2)
		| i < s1	= updateAtT s (\ i' -> f i' . L) i m1 :&: m2
		| otherwise	= m1 :&: updateAtT s (\ i' -> f (i' + s1) . R) (i - s1) m2
		where s1 = sizeT s m1
	alterT s f k (m1 :&: m2) = case k of
		L k -> alterT s f k m1 :&: m2
		R k -> m1 :&: alterT s f k m2
	traverseWithKeyT s f (m1 :&: m2) = (:&:) <$> traverseWithKeyT s (f . L) m1 <*> traverseWithKeyT s (f . R) m2
	foldWithKeyT f (m1 :&: m2) = foldWithKeyT (f . L) m1 . foldWithKeyT (f . R) m2
	foldlWithKeyT f (m1 :&: m2) = foldlWithKeyT (f . R) m2 . foldlWithKeyT (f . L) m1
	mapEitherT s1 s2 f (m1 :&: m2) = case (mapEitherT s1 s2 (f . L) m1, mapEitherT s1 s2 (f . R) m2) of
		((m1L, m1R), (m2L, m2R)) -> (m1L :&: m2L, m1R :&: m2R)
	splitLookupT s f k (m1 :&: m2) = case k of
		L k -> case splitLookupT s f k m1 of
			(m1L, ans, m1R) -> (m1L :&: emptyT, ans, m1R :&: m2)
		R k -> case splitLookupT s f k m2 of
			(m2L, ans, m2R) -> (m1 :&: m2L, ans, emptyT :&: m2R)
	unionT s f (m11 :&: m12) (m21 :&: m22) = unionT s (f . L) m11 m21 :&: unionT s (f . R) m12 m22
	isectT s f (m11 :&: m12) (m21 :&: m22) = isectT s (f . L) m11 m21 :&: isectT s (f . R) m12 m22
	diffT s f (m11 :&: m12) (m21 :&: m22) = diffT s (f . L) m11 m21 :&: diffT s (f . R) m12 m22
	extractMinT s (m1 :&: m2) = (do
		((k, a), m1') <- extractMinT s m1
		return ((L k, a), m1' :&: m2)) `mplus`
	  (do	((k, a), m2') <- extractMinT s m2
	  	return ((R k, a), m1 :&: m2'))
	extractMaxT s (m1 :&: m2) = (do
		((k, a), m1') <- extractMaxT s m1
		return ((L k, a), m1' :&: m2)) `mplus`
	  (do	((k, a), m2') <- extractMaxT s m2
	  	return ((R k, a), m1 :&: m2'))
	alterMinT s f (m1 :&: m2)
		| nullT m1	= m1 :&: alterMinT s (f . R) m2
		| otherwise	= alterMinT s (f . L) m1 :&: m2
	alterMaxT s f (m1 :&: m2)
		| nullT m2	= alterMaxT s (f . L) m1 :&: m2
		| otherwise	= m1 :&: alterMaxT s (f . R) m2
	isSubmapT (<=) (m11 :&: m12) (m21 :&: m22) = isSubmapT (<=) m11 m21 && isSubmapT (<=) m12 m22
	fromListT s f xs = case partEithers xs of
		(ys, zs) -> fromListT s (f . L) ys :&: fromListT s (f . R) zs
	fromAscListT s f xs = case partEithers xs of
		(ys, zs) -> fromAscListT s (f . L) ys :&: fromAscListT s (f . R) zs
	fromDistAscListT s xs = case partEithers xs of
		(ys, zs) -> fromDistAscListT s ys :&: fromDistAscListT s zs

partEithers :: [((f :+: g) r, a)] -> ([(f r, a)], [(g r, a)])
partEithers = foldr part ([], []) where
	part (L k, a) (xs, ys) = ((k, a):xs, ys)
	part (R k, a) (xs, ys) = (xs, (k, a):ys)

instance (TrieKeyT f m1, TrieKeyT g m2, TrieKey k (TrieMap k)) => TrieKey ((f :+: g) k) (UnionMap m1 m2 k) where
	emptyM = emptyT
	nullM = nullT
	lookupM = lookupT
	lookupIxM = lookupIxT
	assocAtM = assocAtT
	updateAtM = updateAtT
	alterM = alterT
	traverseWithKeyM = traverseWithKeyT
	foldWithKeyM = foldWithKeyT
	foldlWithKeyM = foldlWithKeyT
	mapEitherM = mapEitherT
	splitLookupM = splitLookupT
	unionM = unionT
	isectM = isectT
	diffM = diffT
	extractMinM = extractMinT
	extractMaxM = extractMaxT
	alterMinM = alterMinT
	alterMaxM = alterMaxT
	isSubmapM = isSubmapT
	fromListM = fromListT
	fromAscListM = fromAscListT
	fromDistAscListM = fromDistAscListT