{-# LANGUAGE TemplateHaskell, TypeFamilies, KindSignatures, FlexibleContexts, FlexibleInstances, UndecidableInstances, PatternGuards, MultiParamTypeClasses, TypeOperators #-}

module Data.TrieMap.MultiRec.UnionMap () where

import Data.TrieMap.MultiRec.Class
-- import Data.TrieMap.MultiRec.Eq
-- import Data.TrieMap.MultiRec.Base
-- import Data.TrieMap.Applicative
import Data.TrieMap.TrieKey
-- import Data.TrieMap.Rep
-- import Data.TrieMap.Rep.TH
-- import Data.TrieMap.MultiRec.TH
-- import qualified Data.TrieMap.Regular.Base as Reg

import Control.Applicative
-- import Control.Arrow
import Control.Monad

-- import Data.Maybe
-- import Data.Monoid
-- import Data.Foldable
import Generics.MultiRec

import Prelude hiding (foldr)

data UnionMap (phi :: * -> *) f g (r :: * -> *) ix a = HTrieMapT phi f r ix a :&: HTrieMapT phi g r ix a
type instance HTrieMapT phi (f :+: g) = UnionMap phi f g--HTrieMap phi (f r) :*: HTrieMap phi (g r)
-- type instance HTrieMap phi ((f :+: g) r) = HTrieMapH phi (f :+: g) r

-- type instance RepH (UnionMap phi f g r ix) = (Reg.:*:) (RepH (HTrieMapH phi f r ix)) (RepH (HTrieMapH phi g r ix))
-- type instance Rep (UnionMap phi f g r ix a) = RepH (UnionMap phi f g r ix) (Rep a)

-- -- $(genRepH [d|
--     instance (ReprH (HTrieMapH phi f r ix), ReprH (HTrieMapH phi g r ix)) => ReprH (UnionMap phi f g r ix) where
-- 	toRepH (m1 :&: m2) = (Reg.:*:) (toRepH m1) (toRepH m2)
-- 	fromRepH ((Reg.:*:) m1 m2) = fromRepH m1 :&: fromRepH m2
-- 	|])

instance (HTrieKeyT phi f (HTrieMapT phi f), HTrieKeyT phi g (HTrieMapT phi g)) => HTrieKeyT phi (f :+: g) (UnionMap phi f g) where
	emptyH = liftM2 (:&:) emptyH emptyH
	nullH pf (m1 :&: m2) = nullH pf m1 && nullH pf m2
	sizeH pf s (m1 :&: m2) = sizeH pf s m1 + sizeH pf s m2
	lookupH pf k (m1 :&: m2)
		| L k <- k	= lookupH pf k m1
		| R k <- k	= lookupH pf k m2
	lookupIxH pf s k (m1 :&: m2)
		| L k <- k	= case onKey L (lookupIxH pf s k m1) of
			(lb, x, ub) -> (lb, x, ub <|> ((onKeyA R . onIndexA (+ sizeH pf s m1)) <$> getMin pf s m2))
		| R k <- k	= case onIndex (sizeH pf s m1 +) (onKey R (lookupIxH pf s k m2)) of
			(lb, x, ub) -> ((onKeyA L <$> getMax pf s m1) <|> lb, x, ub)
			where	getMin pf s m = aboutH pf (\ k a -> return $ Asc 0 k a) m
				getMax pf s m = aboutH pf (\ k a -> return $ Asc (sizeH pf s m - s a) k a) m
	assocAtH pf s i (m1 :&: m2)
		| i < s1	= case onKey L (assocAtH pf s i m1) of
			(lb, x, ub) -> (lb, x, ub <|> ((onKeyA R . onIndexA (+ s1)) <$> getMin pf s m2))
		| otherwise	= case onKey R (onIndex (s1 +) (assocAtH pf s (i - s1) m2)) of
			(lb, x, ub) -> ((onKeyA L <$> getMax pf s m1) <|> lb, x, ub)
		where	getMin pf s m = aboutH pf (\ k a -> return $ Asc 0 k a) m
			getMax pf s m = aboutH pf (\ k a -> return $ Asc (sizeH pf s m - s a) k a) m
			s1 = sizeH pf s m1
{-	updateAtH pf s r f i (m1 :&: m2)
		| not r && i >= lastIx m1
			= m1 :&: updateAtH pf s r (\ i' -> f (i' + s1) . R) (i - s1) m2
		| i < s1
			= updateAtH pf s r (\ i' -> f i' . L) i m1 :&: m2
		| otherwise
			= m1 :&: updateAtH pf s r (\ i' -> f (i' + s1) . R) (i - s1) m2
		where	s1 = sizeH pf s m1
			lastIx m = case extractMaxH pf s (\ _ v -> (v, Just v)) m of
				Last (Just (v, _)) -> sizeH pf s m - s v
				_			-> sizeH pf s m-}
	alterH pf s f k (m1 :&: m2)
		| L k <- k	= alterH pf s f k m1 :&: m2
		| R k <- k	= m1 :&: alterH pf s f k m2
	alterLookupH pf s f k (m1 :&: m2)
		| L k <- k	= fmap (:&: m2) (alterLookupH pf s f k m1)
		| R k <- k	= fmap (m1 :&:) (alterLookupH pf s f k m2)
	traverseWithKeyH pf s f (m1 :&: m2)
		= (:&:) <$> traverseWithKeyH pf s (f . L) m1 <*> traverseWithKeyH pf s (f . R) m2
	foldWithKeyH pf f (m1 :&: m2) 
		= foldWithKeyH pf (f . L) m1 . foldWithKeyH pf (f . R) m2
	foldlWithKeyH pf f (m1 :&: m2)
		= foldlWithKeyH pf (f . R) m2 . foldlWithKeyH pf (f . L) m1
	mapEitherH pf s1 s2 f (m1 :&: m2) = case (mapEitherH pf s1 s2 (f . L) m1, mapEitherH pf s1 s2 (f . R) m2) of
		((m1L, m1R), (m2L, m2R)) -> (m1L :&: m2L, m1R :&: m2R)
	splitLookupH pf s f k0 (m1 :&: m2)
		| L k <- k0, (m1L, x, m1R) <- splitLookupH pf s f k m1
			= (m1L :&: emptyH pf, x, m1R :&: m2)
		| R k <- k0, (m2L, x, m2R) <- splitLookupH pf s f k m2
			= (m1 :&: m2L, x, emptyH pf :&: m2R)
	unionH pf s f (m11 :&: m12) (m21 :&: m22)
		= unionH pf s (f . L) m11 m21 :&: unionH pf s (f . R) m12 m22
	isectH pf s f (m11 :&: m12) (m21 :&: m22)
		= isectH pf s (f . L) m11 m21 :&: isectH pf s (f . R) m12 m22
	diffH pf s f (m11 :&: m12) (m21 :&: m22)
		= diffH pf s (f . L) m11 m21 :&: diffH pf s (f . R) m12 m22
	extractH pf s f (m1 :&: m2) = fmap (:&: m2) <$> extractH pf s (f . L) m1 <|>
		fmap (m1 :&:) <$> extractH pf s (f . R) m2
-- 	extractMinH pf s f (m1 :&: m2) = second (:&: m2) <$> extractMinH pf s (f . L) m1 <|>
-- 		second (m1 :&:) <$> extractMinH pf s (f . R) m2
-- 	extractMaxH pf s f (m1 :&: m2) = second (:&: m2) <$> extractMaxH pf s (f . L) m1 <|>
-- 		second (m1 :&:) <$> extractMaxH pf s (f . R) m2
-- 	alterMinH pf s f (m1 :&: m2)
-- 		| nullH pf m1	= m1 :&: alterMinH pf s (f . R) m2
-- 		| otherwise	= alterMinH pf s (f . L) m1 :&: m2
-- 	alterMaxH pf s f (m1 :&: m2)
-- 		| nullH pf m2	= alterMaxH pf s (f . L) m1 :&: m2
-- 		| otherwise	= m1 :&: alterMaxH pf s (f . R) m2
	isSubmapH pf (<=) (m11 :&: m12) (m21 :&: m22)
		= isSubmapH pf (<=) m11 m21 && isSubmapH pf (<=) m12 m22
	fromListH pf s f xs = case breakEither xs of
		(ys, zs) -> fromListH pf s (f . L) ys :&: fromListH pf s (f . R) zs
	fromAscListH pf s f xs = case breakEither xs of
		(ys, zs) -> fromAscListH pf s (f . L) ys :&: fromAscListH pf s (f . R) zs
	fromDistAscListH pf s xs = case breakEither xs of
		(ys, zs) -> fromDistAscListH pf s ys :&: fromDistAscListH pf s zs