module Data.TrieMap.Regular.UnionMap() where
import Data.TrieMap.Regular.Class
import Data.TrieMap.Regular.Base
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative
import Control.Applicative
import Control.Monad
data UnionMap m1 m2 k a = m1 k a :&: m2 k a
type instance TrieMapT (f :+: g) = UnionMap (TrieMapT f) (TrieMapT g)
type instance TrieMap ((f :+: g) r) = TrieMapT (f :+: g) r
instance (TrieKeyT f m1, TrieKeyT g m2, TrieKey k (TrieMap k)) => TrieKey ((f :+: g) k) (UnionMap m1 m2 k) where
emptyM = emptyT
nullM = nullT
lookupM = lookupT
lookupIxM = lookupIxT
assocAtM = assocAtT
alterM = alterT
alterLookupM = alterLookupT
traverseWithKeyM = traverseWithKeyT
foldWithKeyM = foldWithKeyT
foldlWithKeyM = foldlWithKeyT
mapEitherM = mapEitherT
splitLookupM = splitLookupT
unionM = unionT
isectM = isectT
diffM = diffT
extractM = extractT
isSubmapM = isSubmapT
fromListM = fromListT
fromAscListM = fromAscListT
fromDistAscListM = fromDistAscListT
instance (TrieKeyT f m1, TrieKeyT g m2) => TrieKeyT (f :+: g) (UnionMap m1 m2) where
emptyT = emptyT :&: emptyT
nullT (m1 :&: m2) = nullT m1 && nullT m2
sizeT s (m1 :&: m2) = sizeT s m1 + sizeT s m2
lookupT k (m1 :&: m2) = case k of
L k -> lookupT k m1
R k -> lookupT k m2
lookupIxT s k (m1 :&: m2) = case k of
L k | (lb, x, ub) <- onKey L (lookupIxT s k m1)
-> (lb, x, ub <|> fmap (onKeyA R . onIndexA (sizeT s m1 +)) (getMin m2))
R k | (lb, x, ub) <- onIndex (sizeT s m1 +) (onKey R (lookupIxT s k m2))
-> (fmap (onKeyA L) (getMax m1) <|> lb, x, ub)
where getMin = aboutT (return .: Asc 0)
getMax m = aboutT (\ k a -> return (Asc (sizeT s m s a) k a)) m
assocAtT s i (m1 :&: m2)
| i < s1 = onKey L (assocAtT s i m1)
| otherwise = onKey R (onIndex (s1 +) (assocAtT s (i s1) m2))
where s1 = sizeT s m1
alterT s f k (m1 :&: m2) = case k of
L k -> alterT s f k m1 :&: m2
R k -> m1 :&: alterT s f k m2
alterLookupT s f k (m1 :&: m2) = case k of
L k -> fmap (:&: m2) (alterLookupT s f k m1)
R k -> fmap (m1 :&:) (alterLookupT s f k m2)
traverseWithKeyT s f (m1 :&: m2) = (:&:) <$> traverseWithKeyT s (f . L) m1 <*> traverseWithKeyT s (f . R) m2
foldWithKeyT f (m1 :&: m2) = foldWithKeyT (f . L) m1 . foldWithKeyT (f . R) m2
foldlWithKeyT f (m1 :&: m2) = foldlWithKeyT (f . R) m2 . foldlWithKeyT (f . L) m1
mapEitherT s1 s2 f (m1 :&: m2) = case (mapEitherT s1 s2 (f . L) m1, mapEitherT s1 s2 (f . R) m2) of
((m1L, m1R), (m2L, m2R)) -> (m1L :&: m2L, m1R :&: m2R)
splitLookupT s f k (m1 :&: m2) = case k of
L k -> case splitLookupT s f k m1 of
(m1L, ans, m1R) -> (m1L :&: emptyT, ans, m1R :&: m2)
R k -> case splitLookupT s f k m2 of
(m2L, ans, m2R) -> (m1 :&: m2L, ans, emptyT :&: m2R)
unionT s f (m11 :&: m12) (m21 :&: m22) = unionT s (f . L) m11 m21 :&: unionT s (f . R) m12 m22
isectT s f (m11 :&: m12) (m21 :&: m22) = isectT s (f . L) m11 m21 :&: isectT s (f . R) m12 m22
diffT s f (m11 :&: m12) (m21 :&: m22) = diffT s (f . L) m11 m21 :&: diffT s (f . R) m12 m22
extractT s f (m1 :&: m2) = fmap (:&: m2) <$> extractT s (f . L) m1 <|>
fmap (m1 :&:) <$> extractT s (f . R) m2
isSubmapT (<=) (m11 :&: m12) (m21 :&: m22) = isSubmapT (<=) m11 m21 && isSubmapT (<=) m12 m22
fromListT s f xs = case partEithers xs of
(ys, zs) -> fromListT s (f . L) ys :&: fromListT s (f . R) zs
fromAscListT s f xs = case partEithers xs of
(ys, zs) -> fromAscListT s (f . L) ys :&: fromAscListT s (f . R) zs
fromDistAscListT s xs = case partEithers xs of
(ys, zs) -> fromDistAscListT s ys :&: fromDistAscListT s zs