{-# LANGUAGE PatternGuards, UnboxedTuples, TypeFamilies, PatternGuards, ViewPatterns, MagicHash #-}
{-# OPTIONS -funbox-strict-fields #-}
module Data.TrieMap.UnionMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized

import Control.Applicative
import Control.Monad

import GHC.Exts

(&) :: (TrieKey k1, TrieKey k2, Sized a) => TrieMap k1 a -> TrieMap k2 a -> TrieMap (Either k1 k2) a
m1 & m2
	| nullM m1, nullM m2	= Empty
	| otherwise		= Union (getSize# m1 +# getSize# m2) m1 m2

singletonL :: (TrieKey k1, TrieKey k2, Sized a) => k1 -> a -> TrieMap (Either k1 k2) a
singletonL k a = Union (getSize# a) (singletonM k a) emptyM

singletonR :: (TrieKey k1, TrieKey k2, Sized a) => k2 -> a -> TrieMap (Either k1 k2) a
singletonR k a = Union (getSize# a) emptyM (singletonM k a)

instance (TrieKey k1, TrieKey k2) => TrieKey (Either k1 k2) where
	data TrieMap (Either k1 k2) a = Empty | Union Int# (TrieMap k1 a) (TrieMap k2 a)
	data Hole (Either k1 k2) a = 
		LHole (Hole k1 a) (TrieMap k2 a)
		| RHole (TrieMap k1 a) (Hole k2 a)

	emptyM = Empty
	
	singletonM = either singletonL singletonR
	
	nullM Empty = True
	nullM _ = False
	
	sizeM Empty = 0#
	sizeM (Union s _ _) = s
	
	lookupM k (Union _ m1 m2) = either (`lookupM` m1) (`lookupM` m2) k
	lookupM _ _ = Nothing

	traverseWithKeyM f (Union _ m1 m2) = (&) <$> traverseWithKeyM (f . Left) m1 <*> traverseWithKeyM (f . Right) m2
	traverseWithKeyM _ _ = pure Empty

	foldrWithKeyM f (Union _ m1 m2) = foldrWithKeyM (f . Left) m1 . foldrWithKeyM (f . Right) m2
	foldrWithKeyM _ _ = id

	foldlWithKeyM f (Union _ m1 m2) = foldlWithKeyM (f . Right) m2 . foldlWithKeyM (f . Left) m1
	foldlWithKeyM _ _ = id

	mapWithKeyM f (Union _ m1 m2) = mapWithKeyM (f . Left) m1 & mapWithKeyM (f . Right) m2
	mapWithKeyM _ _ = Empty

	mapMaybeM f (Union _ m1 m2) = mapMaybeM (f . Left) m1 & mapMaybeM (f . Right) m2
	mapMaybeM _ _ = Empty

	mapEitherM f (Union _ m1 m2)
	  | (# m1L, m1R #) <- mapEitherM (f . Left) m1,
	    (# m2L, m2R #) <- mapEitherM (f . Right) m2
	    	= (# m1L & m2L, m1R & m2R #)
	mapEitherM _ _ = (# Empty, Empty #)

	unionM f (Union _ m11 m12) (Union _ m21 m22)
		= unionM (f . Left) m11 m21 & unionM (f . Right) m12 m22
	unionM _ Empty m2 = m2
	unionM _ m1 Empty = m1

	isectM _ Empty _ = Empty
	isectM _ _ Empty = Empty
	isectM f (Union _ m11 m12) (Union _ m21 m22)
		= isectM (f . Left) m11 m21 & isectM (f . Right) m12 m22

	diffM _ Empty _ = Empty
	diffM _ m1 Empty = m1
	diffM f (Union _ m11 m12) (Union _ m21 m22)
		= diffM (f . Left) m11 m21 & diffM (f . Right) m12 m22

	isSubmapM _ Empty _ = True
	isSubmapM (<=) (Union _ m11 m12) (Union _ m21 m22) = isSubmapM (<=) m11 m21 && isSubmapM (<=) m12 m22
	isSubmapM _ Union{} Empty = False

	fromListM f = onPair (&) (fromListM (f . Left)) (fromListM (f . Right)) . partEithers

	fromAscListM f = onPair (&) (fromAscListM (f . Left)) (fromAscListM (f . Right)) . partEithers

	fromDistAscListM = onPair (&) fromDistAscListM fromDistAscListM . partEithers

	singleHoleM (Left k) = LHole (singleHoleM k) emptyM
	singleHoleM (Right k) = RHole emptyM (singleHoleM k)
	
	keyM (LHole holeL _) = Left (keyM holeL)
	keyM (RHole _ holeR) = Right (keyM holeR)
	
	beforeM a (LHole holeL _) = let mL = beforeM a holeL in
		if nullM mL then Empty else Union (getSize# mL) mL emptyM
	beforeM a (RHole mL holeR) = mL & beforeM a holeR
	
	afterM a (LHole holeL mR) = afterM a holeL & mR
	afterM a (RHole _ holeR) = let mR = afterM a holeR in
		if nullM mR then Empty else Union (getSize# mR) emptyM mR
	
	searchM k Empty = (# Nothing, singleHoleM k #)
	searchM (Left k) (Union _ mL mR) = onUnboxed (`LHole` mR) (searchM k) mL
	searchM (Right k) (Union _ mL mR) = onUnboxed (RHole mL) (searchM k) mR
	
	indexM i# (Union _ mL mR)
		| i# <# sL#, (# i'#, v, holeL #) <- indexM i# mL
			= (# i'#, v, LHole holeL mR #)
		| (# i'#, v, holeR #) <- indexM (i# -# sL#) mR
			= (# i'#, v, RHole mL holeR #)
		where !sL# = getSize# mL
	indexM _ _ = (# error err, error err, error err #) where
		err = "Error: empty trie"

	extractHoleM (Union _ mL mR) = (do
		(v, holeL) <- extractHoleM mL
		return (v, LHole holeL mR)) `mplus` (do
		(v, holeR) <- extractHoleM mR
		return (v, RHole mL holeR))
	extractHoleM _ = mzero
	
	assignM v (LHole holeL mR) = assignM v holeL & mR
	assignM v (RHole mL holeR) = mL & assignM v holeR

	clearM (LHole holeL mR) = clearM holeL & mR
	clearM (RHole mL holeR) = mL & clearM holeR

onPair :: (c -> d -> e) -> (a -> c) -> (b -> d) -> (a, b) -> e
onPair f g h (a, b) = f (g a) (h b)

partEithers :: [(Either a b, x)] -> ([(a, x)], [(b, x)])
partEithers = foldr part ([], []) where
	  part (Left x, z) (xs, ys) = ((x,z):xs, ys)
	  part (Right y, z) (xs, ys) = (xs, (y, z):ys)