module Data.Map.Class where

import Control.Applicative
import Control.Arrow
import Control.Monad
import Data.Either.Both
import Data.Filtrable
import Data.Function (on)
import Data.Functor.Classes
import Data.Functor.Compose
import Data.Functor.Identity
import Data.Functor.Product
import Data.IntMap (IntMap)
import qualified Data.IntMap as Int
import qualified Data.Map as M
import qualified Data.Map.Merge.Lazy as M
import Data.Monoid (Last (..))
import Util ((), (∘∘), compose2)

class Traversable map => StaticMap map where
    type Key map
    adjustA :: Applicative p => (a -> p a) -> Key map -> map a -> p (map a)

class (Filtrable map, StaticMap map) => Map map where
    alterF :: Functor f => (Maybe a -> f (Maybe a)) -> Key map -> map a -> f (map a)
    mergeA :: Applicative p => (Key map -> Either' a b -> p (Maybe c)) -> map a -> map b -> p (map c)
    mapMaybeWithKeyA :: Applicative p => (Key map -> a -> p (Maybe b)) -> map a -> p (map b)
    mapEitherWithKeyA :: Applicative p => (Key map -> a -> p (Either b c)) -> map a -> p (map b, map c)
    mapEitherWithKeyA f = liftA2 (,) <$> mapMaybeWithKeyA (fmap (Just `either` pure Nothing) ∘∘ f)
                                     <*> mapMaybeWithKeyA (fmap (pure Nothing `either` Just) ∘∘ f)

defaultAdjustA :: (Map map, Applicative p) => (a -> p a) -> Key map -> map a -> p (map a)
defaultAdjustA f = alterF (traverse f)

instance Filtrable IntMap where
    mapMaybe = Int.mapMaybe

instance Filtrable (M.Map key) where
    mapMaybe = M.mapMaybe

instance StaticMap Maybe where
    type Key Maybe = ()
    adjustA f () = traverse f

instance StaticMap IntMap where
    type Key IntMap = Int
    adjustA = defaultAdjustA

instance Ord key => StaticMap (M.Map key) where
    type Key (M.Map key) = key
    adjustA = defaultAdjustA

instance (StaticMap m, StaticMap n) => StaticMap (Compose m n) where
    type Key (Compose m n) = (Key m, Key n)
    adjustA f (i, j) = fmap Compose . adjustA (adjustA f j) i . getCompose

instance (StaticMap m, StaticMap n) => StaticMap (Product m n) where
    type Key (Product m n) = Either (Key m) (Key n)
    adjustA f k (Pair a b) = case k of
        Left  i -> flip Pair b <$> adjustA f i a
        Right j -> id   Pair a <$> adjustA f j b

instance Map Maybe where
    alterF f () = f
    mergeA f = mapMaybeWithKeyA f ∘∘ fromMaybes
    mapMaybeWithKeyA f = fmap join . traverse (f ())

instance Map IntMap where
    alterF = Int.alterF
    mergeA f = mapMaybeWithKeyA f ∘∘
               Int.mergeWithKey (pure $ Just ∘∘ Both) (fmap JustLeft) (fmap JustRight)
    mapMaybeWithKeyA f = fmap catMaybes . Int.traverseWithKey f
    mapEitherWithKeyA f = fmap partitionEithers . Int.traverseWithKey f

instance Ord key => Map (M.Map key) where
    alterF = M.alterF
    mergeA f = M.mergeA (M.traverseMaybeMissing $ \ k a -> f k (JustLeft a))
                        (M.traverseMaybeMissing $ \ k b -> f k (JustRight b))
                        (M.zipWithMaybeAMatched $ \ k a b -> f k (Both a b))
    mapMaybeWithKeyA f = fmap catMaybes . M.traverseWithKey f
    mapEitherWithKeyA f = fmap partitionEithers . M.traverseWithKey f

instance (Map m, Map n) => Map (Compose m n) where
    alterF f (i, j) = fmap Compose . alterF (maybe (Nothing <$ f Nothing) (fmap Just . alterF f j)) i . getCompose
    mergeA f =
        fmap Compose ∘∘
        compose2 (mergeA $ \ i ->
                  fmap Just  either' (mapMaybeWithKeyA $ \ j -> f (i, j)  JustLeft)
                                      (mapMaybeWithKeyA $ \ j -> f (i, j)  JustRight)
                                      (mergeA $ \ j -> f (i, j))) getCompose getCompose
    mapMaybeWithKeyA f = fmap Compose . mapMaybeWithKeyA (\ i -> fmap Just . mapMaybeWithKeyA (\ j -> f (i, j))) . getCompose

instance (Map m, Map n) => Map (Product m n) where
    alterF f k (Pair a b) = case k of
        Left  i -> flip Pair b <$> alterF f i a
        Right j -> id   Pair a <$> alterF f j b
    mergeA f (Pair a₀ b₀) (Pair a₁ b₁) = Pair <$> mergeA (f . Left) a₀ a₁ <*> mergeA (f . Right) b₀ b₁
    mapMaybeWithKeyA f (Pair a b) = Pair <$> mapMaybeWithKeyA (f . Left) a <*> mapMaybeWithKeyA (f . Right) b

infix 9 !?
(!?) :: StaticMap map => map a -> Key map -> Maybe a
as !? k = getLast . fst $ adjustA ((,) <$> Last . Just <*> id) k as

mapMaybeWithKey :: Map map => (Key map -> a -> Maybe b) -> map a -> map b
mapMaybeWithKey f = runIdentity . mapMaybeWithKeyA (pure ∘∘ f)

mapEitherWithKey :: Map map => (Key map -> a -> Either b c) -> map a -> (map b, map c)
mapEitherWithKey f = runIdentity . mapEitherWithKeyA (pure ∘∘ f)

adjustLookupA :: (StaticMap map, Applicative p) => (a -> p a) -> Key map -> map a -> p (Maybe a, map a)
adjustLookupA f = sequenceA ∘∘ (getLast *** id <<< getCompose) ∘∘ adjustA (\ a -> Compose (pure a, f a))

merge :: Map map => (Key map -> Either' a b -> Maybe c) -> map a -> map b -> map c
merge f = runIdentity ∘∘ mergeA (Identity ∘∘ f)

newtype Union map a = Union { unUnion :: map a }
  deriving (Functor, Foldable, Traversable)
  deriving newtype (Eq, Ord, Read, Show, Eq1, Ord1, Read1, Show1)

instance Filtrable map => Filtrable (Union map) where
    mapMaybe f = Union . mapMaybe f . unUnion

instance StaticMap map => StaticMap (Union map) where
    type Key (Union map) = Key map
    adjustA f k = fmap Union . adjustA f k . unUnion

instance Map map => Map (Union map) where
    alterF f k = fmap Union . alterF f k . unUnion
    mergeA f = fmap Union ∘∘ compose2 (mergeA f) unUnion unUnion
    mapMaybeWithKeyA f = fmap Union . mapMaybeWithKeyA f . unUnion

instance (Map map, Semigroup a) => Semigroup (Union map a) where
    (<>) = Union ∘∘ runIdentity ∘∘ mergeA (pure $ Identity  Just  either' id id (<>)) `on` unUnion

instance (Ord k, Semigroup a) => Monoid (Union (M.Map k) a) where
    mempty = Union M.empty

instance Semigroup a => Monoid (Union IntMap a) where
    mempty = Union Int.empty

newtype Intersection map a = Intersection { unIntersection :: map a }
  deriving (Functor, Foldable, Traversable)
  deriving newtype (Eq, Ord, Read, Show, Eq1, Ord1, Read1, Show1)

instance Filtrable map => Filtrable (Intersection map) where
    mapMaybe f = Intersection . mapMaybe f . unIntersection

instance StaticMap map => StaticMap (Intersection map) where
    type Key (Intersection map) = Key map
    adjustA f k = fmap Intersection . adjustA f k . unIntersection

instance Map map => Map (Intersection map) where
    alterF f k = fmap Intersection . alterF f k . unIntersection
    mergeA f = fmap Intersection ∘∘ compose2 (mergeA f) unIntersection unIntersection
    mapMaybeWithKeyA f = fmap Intersection . mapMaybeWithKeyA f . unIntersection

instance (Map map, Semigroup a) => Semigroup (Intersection map a) where
    (<>) = Intersection ∘∘ merge (pure $ (uncurry . liftA2) (<>)  toMaybes) `on` unIntersection

newtype SymmetricDifference map a = SymmetricDifference { unSymmetricDifference :: map a }
  deriving (Functor, Foldable, Traversable)
  deriving newtype (Eq, Ord, Read, Show, Eq1, Ord1, Read1, Show1)

instance Filtrable map => Filtrable (SymmetricDifference map) where
    mapMaybe f = SymmetricDifference . mapMaybe f . unSymmetricDifference

instance StaticMap map => StaticMap (SymmetricDifference map) where
    type Key (SymmetricDifference map) = Key map
    adjustA f k = fmap SymmetricDifference . adjustA f k . unSymmetricDifference

instance Map map => Map (SymmetricDifference map) where
    alterF f k = fmap SymmetricDifference . alterF f k . unSymmetricDifference
    mergeA f = fmap SymmetricDifference ∘∘ compose2 (mergeA f) unSymmetricDifference unSymmetricDifference
    mapMaybeWithKeyA f = fmap SymmetricDifference . mapMaybeWithKeyA f . unSymmetricDifference

instance Map map => Semigroup (SymmetricDifference map a) where
    (<>) = SymmetricDifference ∘∘ merge (pure $ either' Just Just (\ _ _ -> Nothing)) `on` unSymmetricDifference

instance Ord k => Monoid (SymmetricDifference (M.Map k) a) where
    mempty = SymmetricDifference M.empty

instance Monoid (SymmetricDifference IntMap a) where
    mempty = SymmetricDifference Int.empty