{-# LANGUAGE TypeFamilies, KindSignatures, FlexibleContexts, FlexibleInstances, UndecidableInstances, PatternGuards, MultiParamTypeClasses, TypeOperators #-}

module Data.TrieMap.MultiRec.UnionMap where

import Data.TrieMap.MultiRec.Class
import Data.TrieMap.MultiRec.Eq
import Data.TrieMap.Applicative
import Data.TrieMap.TrieKey

import Control.Applicative
import Control.Arrow
import Control.Monad

import Data.Maybe
import Data.Foldable
import Generics.MultiRec

import Prelude hiding (foldr)

data UnionMap (phi :: * -> *) m1 m2 (r :: * -> *) (a :: * -> *) ix = m1 r a ix :&: m2 r a ix
type instance HTrieMapT phi (f :+: g) = UnionMap phi (HTrieMapT phi f) (HTrieMapT phi g)--HTrieMap phi (f r) :*: HTrieMap phi (g r)
type instance HTrieMap phi ((f :+: g) r) = HTrieMapT phi (f :+: g) r

instance (HTrieKeyT phi f m1, HTrieKeyT phi g m2) => HTrieKeyT phi (f :+: g) (UnionMap phi m1 m2) where
	emptyT = liftM2 (:&:) emptyT emptyT
	nullT pf (m1 :&: m2) = nullT pf m1 && nullT pf m2
	sizeT s (m1 :&: m2) = sizeT s m1 + sizeT s m2
	lookupT pf k (m1 :&: m2)
		| L k <- k	= lookupT pf k m1
		| R k <- k	= lookupT pf k m2
	lookupIxT pf s k (m1 :&: m2)
		| L k <- k	= lookupIxT pf s k m1
		| R k <- k	= first (sizeT s m1 +) <$> lookupIxT pf s k m2
	assocAtT pf s i (m1 :&: m2)
		| i < s1, (i', k, a) <- assocAtT pf s i m1
				= (i', L k, a)
		| (i', k, a) <- assocAtT pf s (i - s1) m2
				= (i' + s1, R k, a)
		where	s1 = sizeT s m1
	updateAtT pf s f i (m1 :&: m2)
		| i < s1	= updateAtT pf s (\ i' -> f i' . L) i m1 :&: m2
		| otherwise	= m1 :&: updateAtT pf s (\ i' -> f (s1 + i') . R) (i - s1) m2
		where	s1 = sizeT s m1
	alterT pf s f k (m1 :&: m2)
		| L k <- k	= alterT pf s f k m1 :&: m2
		| R k <- k	= m1 :&: alterT pf s f k m2
	traverseWithKeyT pf s f (m1 :&: m2)
		= (:&:) <$> traverseWithKeyT pf s (f . L) m1 <*> traverseWithKeyT pf s (f . R) m2
	foldWithKeyT pf f (m1 :&: m2) 
		= foldWithKeyT pf (f . L) m1 . foldWithKeyT pf (f . R) m2
	foldlWithKeyT pf f (m1 :&: m2)
		= foldlWithKeyT pf (f . R) m2 . foldlWithKeyT pf (f . L) m1
	mapEitherT pf s1 s2 f (m1 :&: m2) = case (mapEitherT pf s1 s2 (f . L) m1, mapEitherT pf s1 s2 (f . R) m2) of
		((m1L, m1R), (m2L, m2R)) -> (m1L :&: m2L, m1R :&: m2R)
	splitLookupT pf s f k0 (m1 :&: m2)
		| L k <- k0, (m1L, x, m1R) <- splitLookupT pf s f k m1
			= (m1L :&: emptyT pf, x, m1R :&: m2)
		| R k <- k0, (m2L, x, m2R) <- splitLookupT pf s f k m2
			= (m1 :&: m2L, x, emptyT pf :&: m2R)
	unionT pf s f (m11 :&: m12) (m21 :&: m22)
		= unionT pf s (f . L) m11 m21 :&: unionT pf s (f . R) m12 m22
	isectT pf s f (m11 :&: m12) (m21 :&: m22)
		= isectT pf s (f . L) m11 m21 :&: isectT pf s (f . R) m12 m22
	diffT pf s f (m11 :&: m12) (m21 :&: m22)
		= diffT pf s (f . L) m11 m21 :&: diffT pf s (f . R) m12 m22
	extractMinT pf s (m1 :&: m2) = (do
		((k, v), m1') <- extractMinT pf s m1
		return ((L k, v), m1' :&: m2)) `mplus`
	  (do	((k, v), m2') <- extractMinT pf s m2
		return ((R k, v), m1 :&: m2'))
	extractMaxT pf s (m1 :&: m2) = (do
		((k, v), m1') <- extractMaxT pf s m1
		return ((L k, v), m1' :&: m2)) `mplus`
	  (do	((k, v), m2') <- extractMaxT pf s m2
		return ((R k, v), m1 :&: m2'))
	alterMinT pf s f (m1 :&: m2)
		| nullT pf m1	= m1 :&: alterMinT pf s (f . R) m2
		| otherwise	= alterMinT pf s (f . L) m1 :&: m2
	alterMaxT pf s f (m1 :&: m2)
		| nullT pf m2	= alterMaxT pf s (f . L) m1 :&: m2
		| otherwise	= m1 :&: alterMaxT pf s (f . R) m2
	isSubmapT pf (<=) (m11 :&: m12) (m21 :&: m22)
		= isSubmapT pf (<=) m11 m21 && isSubmapT pf (<=) m12 m22
	fromListT pf s f xs = case breakEither xs of
		(ys, zs) -> fromListT pf s (f . L) ys :&: fromListT pf s (f . R) zs
	fromAscListT pf s f xs = case breakEither xs of
		(ys, zs) -> fromAscListT pf s (f . L) ys :&: fromAscListT pf s (f . R) zs
	fromDistAscListT pf s xs = case breakEither xs of
		(ys, zs) -> fromDistAscListT pf s ys :&: fromDistAscListT pf s zs

breakEither :: [((f :+: g) r ix, a)] -> ([(f r ix, a)], [(g r ix, a)])
breakEither = foldr breakEither' ([], []) where
	breakEither' (L k, a) (xs, ys) = ((k, a):xs, ys)
	breakEither' (R k, a) (xs, ys) = (xs, (k, a):ys)

instance (HTrieKeyT phi f m1, m1 ~ HTrieMapT phi f, HTrieKeyT phi g m2, m2 ~ HTrieMapT phi g, 
		HTrieKey phi r (HTrieMap phi r)) => HTrieKey phi ((f :+: g) r) (UnionMap phi m1 m2 r) where
	emptyH = emptyT
	nullH = nullT
	sizeH = sizeT
	lookupH = lookupT
	lookupIxH = lookupIxT
	assocAtH = assocAtT
	updateAtH = updateAtT
	alterH = alterT
	traverseWithKeyH = traverseWithKeyT
	foldWithKeyH = foldWithKeyT
	foldlWithKeyH = foldlWithKeyT
	mapEitherH = mapEitherT
	splitLookupH = splitLookupT
	unionH = unionT
	isectH = isectT
	diffH = diffT
	alterMinH = alterMinT
	alterMaxH = alterMaxT
	extractMinH = extractMinT
	extractMaxH = extractMaxT
	isSubmapH = isSubmapT
	fromListH = fromListT
	fromAscListH = fromAscListT
	fromDistAscListH = fromDistAscListT