{-# LANGUAGE PatternGuards, UnboxedTuples, TypeFamilies, PatternGuards, ViewPatterns #-}
{-# OPTIONS -funbox-strict-fields #-}
module Data.TrieMap.UnionMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized

import Control.Applicative

union :: (TrieKey k1, TrieKey k2) => Sized a -> TrieMap k1 a -> TrieMap k2 a -> TrieMap (Either k1 k2) a
union _ (nullM -> True) (nullM -> True)	= Empty
union s m1@(sizeM s -> s1) m2@(sizeM s -> s2) = Union (s1 + s2) m1 m2

singletonMaybe :: (TrieKey k1, TrieKey k2) => Sized a -> Either k1 k2 -> Maybe a -> TrieMap (Either k1 k2) a
singletonMaybe s k a = maybe Empty (singletonM s k) a

singletonL :: (TrieKey k1, TrieKey k2) => Sized a -> k1 -> a -> TrieMap (Either k1 k2) a
singletonL s k a = Union (s a) (singletonM s k a) emptyM

singletonR :: (TrieKey k1, TrieKey k2) => Sized a -> k2 -> a -> TrieMap (Either k1 k2) a
singletonR s k a = Union (s a) emptyM (singletonM s k a)

instance (TrieKey k1, TrieKey k2) => TrieKey (Either k1 k2) where
	data TrieMap (Either k1 k2) a = Empty | Union !Int (TrieMap k1 a) (TrieMap k2 a)

	emptyM = Empty
	
	singletonM s = either (singletonL s) (singletonR s)
	
	nullM Empty = True
	nullM _ = False
	
	sizeM _ Empty = 0
	sizeM _ (Union s _ _) = s
	
	lookupM k (Union _ m1 m2) = either (`lookupM` m1) (`lookupM` m2) k
	lookupM _ _ = Nothing
	
	alterM s f k (Union _ m1 m2) = case k of
		Left k	-> union s (alterM s f k m1) m2
		Right k	-> union s m1 (alterM s f k m2)
	alterM s f k _ = singletonMaybe s k (f Nothing)

	alterLookupM s f k Empty = onUnboxed (singletonMaybe s k) f Nothing
	alterLookupM s f (Left k) (Union _ m1 m2) = onUnboxed (flip (union s) m2) (alterLookupM s f k) m1
	alterLookupM s f (Right k) (Union _ m1 m2) = onUnboxed (union s m1) (alterLookupM s f k) m2

	traverseWithKeyM s f (Union _ m1 m2) = union s <$> traverseWithKeyM s (f . Left) m1 <*> traverseWithKeyM s (f . Right) m2
	traverseWithKeyM _ _ _ = pure Empty

	foldWithKeyM f (Union _ m1 m2) = foldWithKeyM (f . Left) m1 . foldWithKeyM (f . Right) m2
	foldWithKeyM _ _ = id

	foldlWithKeyM f (Union _ m1 m2) = foldlWithKeyM (f . Right) m2 . foldlWithKeyM (f . Left) m1
	foldlWithKeyM _ _ = id

	mapMaybeM s f (Union _ m1 m2) = union s (mapMaybeM s (f . Left) m1) (mapMaybeM s (f . Right) m2)
	mapMaybeM _ _ _ = Empty

	mapEitherM s1 s2 f (Union _ m1 m2)
	  | (# m1L, m1R #) <- mapEitherM s1 s2 (f . Left) m1,
	    (# m2L, m2R #) <- mapEitherM s1 s2 (f . Right) m2
	    	= (# union s1 m1L m2L, union s2 m1R m2R #)
	mapEitherM _ _ _ _ = (# Empty, Empty #)

	extractM s f (Union _ m1 m2) = let (&) = union s in fmap (& m2) <$> extractM s (f . Left) m1 <|>
		fmap (m1 &) <$> extractM s (f . Right) m2
	extractM _ _ _ = empty

	splitLookupM s f k (Union _ m1 m2) = let (&) = union s in case k of
		Left k | (# m1L, x, m1R #) <- splitLookupM s f k m1
			-> (# m1L & emptyM, x, m1R & m2 #)
		Right k | (# m2L, x, m2R #) <- splitLookupM s f k m2
			-> (# m1 & m2L, x, emptyM & m2R #)
	splitLookupM _ _ _ _ = (# emptyM, Nothing, emptyM #)

	unionM s f (Union _ m11 m12) (Union _ m21 m22)
		= union s (unionM s (f . Left) m11 m21) (unionM s (f . Right) m12 m22)
	unionM _ _ Empty m2 = m2
	unionM _ _ m1 Empty = m1

	isectM _ _ Empty _ = Empty
	isectM _ _ _ Empty = Empty
	isectM s f (Union _ m11 m12) (Union _ m21 m22)
		= union s (isectM s (f . Left) m11 m21) (isectM s (f . Right) m12 m22)

	diffM _ _ Empty _ = Empty
	diffM _ _ m1 Empty = m1
	diffM s f (Union _ m11 m12) (Union _ m21 m22)
		= union s (diffM s (f . Left) m11 m21) (diffM s (f . Right) m12 m22)

	isSubmapM _ Empty _ = True
	isSubmapM (<=) (Union _ m11 m12) (Union _ m21 m22) = isSubmapM (<=) m11 m21 && isSubmapM (<=) m12 m22
	isSubmapM _ Union{} Empty = False

	fromListM s f = onPair (union s) (fromListM s (f . Left)) (fromListM s (f . Right)) . partEithers

	fromAscListM s f = onPair (union s) (fromAscListM s (f . Left)) (fromAscListM s (f . Right)) . partEithers

	fromDistAscListM s = onPair (union s) (fromDistAscListM s) (fromDistAscListM s) . partEithers

onPair :: (c -> d -> e) -> (a -> c) -> (b -> d) -> (a, b) -> e
onPair f g h (a, b) = f (g a) (h b)

partEithers :: [(Either a b, x)] -> ([(a, x)], [(b, x)])
partEithers = foldr part ([], []) where
	  part (Left x, z) (xs, ys) = ((x,z):xs, ys)
	  part (Right y, z) (xs, ys) = (xs, (y, z):ys)