module Data.Map.Class where import Control.Applicative import Control.Arrow import Control.Monad import Data.Either.Both import Data.Filtrable import Data.Function (on) import Data.Functor.Classes import Data.Functor.Compose import Data.Functor.Identity import Data.Functor.Product import Data.IntMap (IntMap) import qualified Data.IntMap as Int import qualified Data.Map as M import qualified Data.Map.Merge.Lazy as M import Data.Monoid (Last (..)) import Util ((∘), (∘∘), compose2) class Traversable map => StaticMap map where type Key map adjustA :: Applicative p => (a -> p a) -> Key map -> map a -> p (map a) class (Filtrable map, StaticMap map) => Map map where alterF :: Functor f => (Maybe a -> f (Maybe a)) -> Key map -> map a -> f (map a) mergeA :: Applicative p => (Key map -> Either' a b -> p (Maybe c)) -> map a -> map b -> p (map c) mapMaybeWithKeyA :: Applicative p => (Key map -> a -> p (Maybe b)) -> map a -> p (map b) mapEitherWithKeyA :: Applicative p => (Key map -> a -> p (Either b c)) -> map a -> p (map b, map c) mapEitherWithKeyA f = liftA2 (,) <$> mapMaybeWithKeyA (fmap (Just `either` pure Nothing) ∘∘ f) <*> mapMaybeWithKeyA (fmap (pure Nothing `either` Just) ∘∘ f) defaultAdjustA :: (Map map, Applicative p) => (a -> p a) -> Key map -> map a -> p (map a) defaultAdjustA f = alterF (traverse f) instance Filtrable IntMap where mapMaybe = Int.mapMaybe instance Filtrable (M.Map key) where mapMaybe = M.mapMaybe instance StaticMap Maybe where type Key Maybe = () adjustA f () = traverse f instance StaticMap IntMap where type Key IntMap = Int adjustA = defaultAdjustA instance Ord key => StaticMap (M.Map key) where type Key (M.Map key) = key adjustA = defaultAdjustA instance (StaticMap m, StaticMap n) => StaticMap (Compose m n) where type Key (Compose m n) = (Key m, Key n) adjustA f (i, j) = fmap Compose . adjustA (adjustA f j) i . getCompose instance (StaticMap m, StaticMap n) => StaticMap (Product m n) where type Key (Product m n) = Either (Key m) (Key n) adjustA f k (Pair a b) = case k of Left i -> flip Pair b <$> adjustA f i a Right j -> id Pair a <$> adjustA f j b instance Map Maybe where alterF f () = f mergeA f = mapMaybeWithKeyA f ∘∘ fromMaybes mapMaybeWithKeyA f = fmap join . traverse (f ()) instance Map IntMap where alterF = Int.alterF mergeA f = mapMaybeWithKeyA f ∘∘ Int.mergeWithKey (pure $ Just ∘∘ Both) (fmap JustLeft) (fmap JustRight) mapMaybeWithKeyA f = fmap catMaybes . Int.traverseWithKey f mapEitherWithKeyA f = fmap partitionEithers . Int.traverseWithKey f instance Ord key => Map (M.Map key) where alterF = M.alterF mergeA f = M.mergeA (M.traverseMaybeMissing $ \ k a -> f k (JustLeft a)) (M.traverseMaybeMissing $ \ k b -> f k (JustRight b)) (M.zipWithMaybeAMatched $ \ k a b -> f k (Both a b)) mapMaybeWithKeyA f = fmap catMaybes . M.traverseWithKey f mapEitherWithKeyA f = fmap partitionEithers . M.traverseWithKey f instance (Map m, Map n) => Map (Compose m n) where alterF f (i, j) = fmap Compose . alterF (maybe (Nothing <$ f Nothing) (fmap Just . alterF f j)) i . getCompose mergeA f = fmap Compose ∘∘ compose2 (mergeA $ \ i -> fmap Just ∘ either' (mapMaybeWithKeyA $ \ j -> f (i, j) ∘ JustLeft) (mapMaybeWithKeyA $ \ j -> f (i, j) ∘ JustRight) (mergeA $ \ j -> f (i, j))) getCompose getCompose mapMaybeWithKeyA f = fmap Compose . mapMaybeWithKeyA (\ i -> fmap Just . mapMaybeWithKeyA (\ j -> f (i, j))) . getCompose instance (Map m, Map n) => Map (Product m n) where alterF f k (Pair a b) = case k of Left i -> flip Pair b <$> alterF f i a Right j -> id Pair a <$> alterF f j b mergeA f (Pair a₀ b₀) (Pair a₁ b₁) = Pair <$> mergeA (f . Left) a₀ a₁ <*> mergeA (f . Right) b₀ b₁ mapMaybeWithKeyA f (Pair a b) = Pair <$> mapMaybeWithKeyA (f . Left) a <*> mapMaybeWithKeyA (f . Right) b infix 9 !? (!?) :: StaticMap map => map a -> Key map -> Maybe a as !? k = getLast . fst $ adjustA ((,) <$> Last . Just <*> id) k as mapMaybeWithKey :: Map map => (Key map -> a -> Maybe b) -> map a -> map b mapMaybeWithKey f = runIdentity . mapMaybeWithKeyA (pure ∘∘ f) mapEitherWithKey :: Map map => (Key map -> a -> Either b c) -> map a -> (map b, map c) mapEitherWithKey f = runIdentity . mapEitherWithKeyA (pure ∘∘ f) adjustLookupA :: (StaticMap map, Applicative p) => (a -> p a) -> Key map -> map a -> p (Maybe a, map a) adjustLookupA f = sequenceA ∘∘ (getLast *** id <<< getCompose) ∘∘ adjustA (\ a -> Compose (pure a, f a)) merge :: Map map => (Key map -> Either' a b -> Maybe c) -> map a -> map b -> map c merge f = runIdentity ∘∘ mergeA (Identity ∘∘ f) newtype Union map a = Union { unUnion :: map a } deriving (Functor, Foldable, Traversable) deriving newtype (Eq, Ord, Read, Show, Eq1, Ord1, Read1, Show1) instance Filtrable map => Filtrable (Union map) where mapMaybe f = Union . mapMaybe f . unUnion instance StaticMap map => StaticMap (Union map) where type Key (Union map) = Key map adjustA f k = fmap Union . adjustA f k . unUnion instance Map map => Map (Union map) where alterF f k = fmap Union . alterF f k . unUnion mergeA f = fmap Union ∘∘ compose2 (mergeA f) unUnion unUnion mapMaybeWithKeyA f = fmap Union . mapMaybeWithKeyA f . unUnion instance (Map map, Semigroup a) => Semigroup (Union map a) where (<>) = Union ∘∘ runIdentity ∘∘ mergeA (pure $ Identity ∘ Just ∘ either' id id (<>)) `on` unUnion instance (Ord k, Semigroup a) => Monoid (Union (M.Map k) a) where mempty = Union M.empty instance Semigroup a => Monoid (Union IntMap a) where mempty = Union Int.empty newtype Intersection map a = Intersection { unIntersection :: map a } deriving (Functor, Foldable, Traversable) deriving newtype (Eq, Ord, Read, Show, Eq1, Ord1, Read1, Show1) instance Filtrable map => Filtrable (Intersection map) where mapMaybe f = Intersection . mapMaybe f . unIntersection instance StaticMap map => StaticMap (Intersection map) where type Key (Intersection map) = Key map adjustA f k = fmap Intersection . adjustA f k . unIntersection instance Map map => Map (Intersection map) where alterF f k = fmap Intersection . alterF f k . unIntersection mergeA f = fmap Intersection ∘∘ compose2 (mergeA f) unIntersection unIntersection mapMaybeWithKeyA f = fmap Intersection . mapMaybeWithKeyA f . unIntersection instance (Map map, Semigroup a) => Semigroup (Intersection map a) where (<>) = Intersection ∘∘ merge (pure $ (uncurry . liftA2) (<>) ∘ toMaybes) `on` unIntersection newtype SymmetricDifference map a = SymmetricDifference { unSymmetricDifference :: map a } deriving (Functor, Foldable, Traversable) deriving newtype (Eq, Ord, Read, Show, Eq1, Ord1, Read1, Show1) instance Filtrable map => Filtrable (SymmetricDifference map) where mapMaybe f = SymmetricDifference . mapMaybe f . unSymmetricDifference instance StaticMap map => StaticMap (SymmetricDifference map) where type Key (SymmetricDifference map) = Key map adjustA f k = fmap SymmetricDifference . adjustA f k . unSymmetricDifference instance Map map => Map (SymmetricDifference map) where alterF f k = fmap SymmetricDifference . alterF f k . unSymmetricDifference mergeA f = fmap SymmetricDifference ∘∘ compose2 (mergeA f) unSymmetricDifference unSymmetricDifference mapMaybeWithKeyA f = fmap SymmetricDifference . mapMaybeWithKeyA f . unSymmetricDifference instance Map map => Semigroup (SymmetricDifference map a) where (<>) = SymmetricDifference ∘∘ merge (pure $ either' Just Just (\ _ _ -> Nothing)) `on` unSymmetricDifference instance Ord k => Monoid (SymmetricDifference (M.Map k) a) where mempty = SymmetricDifference M.empty instance Monoid (SymmetricDifference IntMap a) where mempty = SymmetricDifference Int.empty