{-# LANGUAGE DeriveAnyClass     #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric      #-}
{-# LANGUAGE DeriveLift         #-}
{-# LANGUAGE ExplicitForAll     #-}
{-# LANGUAGE TypeFamilies       #-}
module Dhall.Map
    ( 
      Map
      
    , empty
    , singleton
    , fromList
    , fromListWithKey
    , fromMap
      
    , unorderedSingleton
    , unorderedFromList
      
    , sort
    , isSorted
      
    , insert
    , insertWith
      
    , delete
    , filter
    , restrictKeys
    , withoutKeys
    , mapMaybe
      
    , lookup
    , member
    , uncons
    , size
      
    , union
    , unionWith
    , outerJoin
    , intersection
    , intersectionWith
    , difference
      
    , mapWithKey
    , traverseWithKey
    , unorderedTraverseWithKey
    , unorderedTraverseWithKey_
    , foldMapWithKey
      
    , toList
    , toAscList
    , toMap
    , keys
    , keysSet
    , elems
    ) where
import Control.Applicative        ((<|>))
import Control.DeepSeq            (NFData)
import Data.Data                  (Data)
import GHC.Generics               (Generic)
import Instances.TH.Lift          ()
import Language.Haskell.TH.Syntax (Lift)
import Prelude                    hiding (filter, lookup)
import qualified Data.List
import qualified Data.Map
import qualified Data.Set
import qualified GHC.Exts
import qualified Prelude
data Map k v = Map (Data.Map.Map k v) (Keys k)
    deriving (Data, Generic, Lift, NFData)
data Keys a
    = Sorted
    | Original [a]
    deriving (Data, Generic, Lift, NFData)
instance (Ord k, Eq v) => Eq (Map k v) where
  m1 == m2 =
      Data.Map.size (toMap m1) == Data.Map.size (toMap m2)
      && toList m1 == toList m2
  {-# INLINABLE (==) #-}
instance (Ord k, Ord v) => Ord (Map k v) where
  compare m1 m2 = compare (toList m1) (toList m2)
  {-# INLINABLE compare #-}
instance Functor (Map k) where
  fmap f (Map m ks) = Map (fmap f m) ks
  {-# INLINABLE fmap #-}
instance Ord k => Foldable (Map k) where
  foldr f z (Map m Sorted) = foldr f z m
  foldr f z m              = foldr f z (elems m)
  {-# INLINABLE foldr #-}
  length m = size m
  {-# INLINABLE length #-}
  null (Map m _) = null m
  {-# INLINABLE null #-}
instance Ord k => Traversable (Map k) where
  traverse f m = traverseWithKey (\_ v -> f v) m
  {-# INLINABLE traverse #-}
instance Ord k => Semigroup (Map k v) where
    (<>) = union
    {-# INLINABLE (<>) #-}
instance Ord k => Monoid (Map k v) where
    mempty = Map Data.Map.empty (Original [])
    {-# INLINABLE mempty #-}
instance (Show k, Show v, Ord k) => Show (Map k v) where
    showsPrec d m =
        showParen (d > 10) (showString "fromList " . showsPrec 11 kvs)
      where
        kvs = toList m
instance Ord k => GHC.Exts.IsList (Map k v) where
    type Item (Map k v) = (k, v)
    fromList = Dhall.Map.fromList
    toList = Dhall.Map.toList
empty :: Ord k => Map k v
empty = mempty
singleton :: k -> v -> Map k v
singleton k v = Map m ks
  where
    m = Data.Map.singleton k v
    ks = Original [k]
{-# INLINABLE singleton #-}
fromList :: Ord k => [(k, v)] -> Map k v
fromList kvs = Map m ks
  where
    m = Data.Map.fromList kvs
    ks = Original (nubOrd (map fst kvs))
{-# INLINABLE fromList #-}
fromListWithKey :: Ord k => (k -> v -> v -> v) -> [(k, v)] -> Map k v
fromListWithKey f kvs = Map m ks
  where
    m = Data.Map.fromListWithKey f kvs
    ks = Original (nubOrd (map fst kvs))
{-# INLINABLE fromListWithKey #-}
fromMap :: Data.Map.Map k v -> Map k v
fromMap m = Map m Sorted
nubOrd :: Ord k => [k] -> [k]
nubOrd = go Data.Set.empty
  where
    go _      []  = []
    go set (k:ks)
        | Data.Set.member k set =     go                    set  ks
        | otherwise             = k : go (Data.Set.insert k set) ks
{-# INLINABLE nubOrd #-}
unorderedSingleton :: k -> v -> Map k v
unorderedSingleton k v = Map m Sorted
  where
    m = Data.Map.singleton k v
{-# INLINABLE unorderedSingleton #-}
unorderedFromList :: Ord k => [(k, v)] -> Map k v
unorderedFromList kvs = Map m Sorted
  where
    m = Data.Map.fromList kvs
{-# INLINABLE unorderedFromList #-}
sort :: Map k v -> Map k v
sort (Map m _) = Map m Sorted
{-# INLINABLE sort #-}
isSorted :: Eq k => Map k v -> Bool
isSorted (Map _ Sorted)        = True
isSorted (Map m (Original ks)) = Data.Map.keys m == ks 
{-# INLINABLE isSorted #-}
insert :: Ord k => k -> v -> Map k v -> Map k v
insert k v (Map m Sorted)        = Map (Data.Map.insert k v m) Sorted
insert k v (Map m (Original ks)) = Map m' (Original ks')
  where
    (mayOldV, m') = Data.Map.insertLookupWithKey (\_k new _old -> new) k v m
    ks' | Just _ <- mayOldV = ks
        | otherwise         = k : ks
{-# INLINABLE insert #-}
insertWith :: Ord k => (v -> v -> v) -> k -> v -> Map k v -> Map k v
insertWith f k v (Map m Sorted)        = Map (Data.Map.insertWith f k v m) Sorted
insertWith f k v (Map m (Original ks)) = Map m' (Original ks')
  where
    (mayOldV, m') = Data.Map.insertLookupWithKey (\_k new old -> f new old) k v m
    ks' | Just _ <- mayOldV = ks
        | otherwise         = k : ks
{-# INLINABLE insertWith #-}
delete :: Ord k => k -> Map k v -> Map k v
delete k (Map m ks) = Map m' ks'
  where
    m' = Data.Map.delete k m
    ks' = case ks of
        Sorted        -> Sorted
        Original ks'' -> Original (Data.List.delete k ks'')
{-# INLINABLE delete #-}
filter :: Ord k => (a -> Bool) -> Map k a -> Map k a
filter predicate (Map m ks) = Map m' ks'
  where
    m' = Data.Map.filter predicate m
    ks' = filterKeys (\k -> Data.Map.member k m') ks
{-# INLINABLE filter #-}
restrictKeys :: Ord k => Map k a -> Data.Set.Set k -> Map k a
restrictKeys (Map m ks) s = Map m' ks'
  where
    m' = Data.Map.restrictKeys m s
    ks' = filterKeys (\k -> Data.Set.member k s) ks
{-# INLINABLE restrictKeys #-}
withoutKeys :: Ord k => Map k a -> Data.Set.Set k -> Map k a
withoutKeys (Map m ks) s = Map m' ks'
  where
    m' = Data.Map.withoutKeys m s
    ks' = filterKeys (\k -> Data.Set.notMember k s) ks
{-# INLINABLE withoutKeys #-}
mapMaybe :: Ord k => (a -> Maybe b) -> Map k a -> Map k b
mapMaybe f (Map m ks) = Map m' ks'
  where
    m' = Data.Map.mapMaybe f m
    ks' = filterKeys (\k -> Data.Map.member k m') ks
{-# INLINABLE mapMaybe #-}
lookup :: Ord k => k -> Map k v -> Maybe v
lookup k (Map m _) = Data.Map.lookup k m
{-# INLINABLE lookup #-}
uncons :: Ord k => Map k v -> Maybe (k, v, Map k v)
uncons (Map _ (Original []))     = Nothing
uncons (Map m (Original (k:ks))) =
    Just (k, m Data.Map.! k, Map (Data.Map.delete k m) (Original ks))
uncons (Map m Sorted)
  | Just ((k, v), m') <- Data.Map.minViewWithKey m = Just (k, v, Map m' Sorted)
  | otherwise                                      = Nothing
{-# INLINABLE uncons #-}
member :: Ord k => k -> Map k v -> Bool
member k (Map m _) = Data.Map.member k m
{-# INLINABLE member #-}
size :: Map k v -> Int
size (Map m _) = Data.Map.size m
{-# INLINABLE size #-}
union :: Ord k => Map k v -> Map k v -> Map k v
union (Map mL ksL) (Map mR ksR) = Map m ks
  where
    m = Data.Map.union mL mR
    ks = case (ksL, ksR) of
        (Original l, Original r) -> Original $
            l <|> Prelude.filter (\k -> Data.Map.notMember k mL) r
        _                        -> Sorted
{-# INLINABLE union #-}
unionWith :: Ord k => (v -> v -> v) -> Map k v -> Map k v -> Map k v
unionWith combine (Map mL ksL) (Map mR ksR) = Map m ks
  where
    m = Data.Map.unionWith combine mL mR
    ks = case (ksL, ksR) of
        (Original l, Original r) -> Original $
            l <|> Prelude.filter (\k -> Data.Map.notMember k mL) r
        _                        -> Sorted
{-# INLINABLE unionWith #-}
outerJoin
    :: Ord k
    => (a -> c)
    -> (b -> c)
    -> (k -> a -> b -> c)
    -> Map k a
    -> Map k b
    -> Map k c
outerJoin fa fb fab (Map ma ksA) (Map mb ksB) = Map m ks
  where
    m = Data.Map.mergeWithKey
            (\k a b -> Just (fab k a b))
            (fmap fa)
            (fmap fb)
            ma
            mb
    ks = case (ksA, ksB) of
        (Original l, Original r) -> Original $
            l <|> Prelude.filter (\k -> Data.Map.notMember k ma) r
        _                        -> Sorted
{-# INLINABLE outerJoin #-}
intersection :: Ord k => Map k a -> Map k b -> Map k a
intersection (Map mL ksL) (Map mR _) = Map m ks
  where
    m = Data.Map.intersection mL mR
    
    ks = filterKeys (\k -> Data.Map.member k m) ksL
{-# INLINABLE intersection #-}
intersectionWith :: Ord k => (a -> b -> c) -> Map k a -> Map k b -> Map k c
intersectionWith combine (Map mL ksL) (Map mR _) = Map m ks
  where
    m = Data.Map.intersectionWith combine mL mR
    
    ks = filterKeys (\k -> Data.Map.member k m) ksL
{-# INLINABLE intersectionWith #-}
difference :: Ord k => Map k a -> Map k b -> Map k a
difference (Map mL ksL) (Map mR _) = Map m ks
  where
    m = Data.Map.difference mL mR
    ks = filterKeys (\k -> Data.Map.notMember k mR) ksL
{-# INLINABLE difference #-}
foldMapWithKey :: (Monoid m, Ord k) => (k -> a -> m) -> Map k a -> m
foldMapWithKey f (Map m Sorted) = Data.Map.foldMapWithKey f m
foldMapWithKey f m              = foldMap (uncurry f) (toList m)
{-# INLINABLE foldMapWithKey #-}
mapWithKey :: (k -> a -> b) -> Map k a -> Map k b
mapWithKey f (Map m ks) = Map m' ks
  where
    m' = Data.Map.mapWithKey f m
{-# INLINABLE mapWithKey #-}
traverseWithKey
    :: Ord k => Applicative f => (k -> a -> f b) -> Map k a -> f (Map k b)
traverseWithKey f (Map m Sorted) =
    fmap (\m' -> Map m' Sorted) (Data.Map.traverseWithKey f m)
traverseWithKey f m =
    fmap fromList (traverse f' (toList m))
  where
    f' (k, a) = fmap ((,) k) (f k a)
{-# INLINABLE traverseWithKey #-}
unorderedTraverseWithKey
    :: Ord k => Applicative f => (k -> a -> f b) -> Map k a -> f (Map k b)
unorderedTraverseWithKey f (Map m ks) =
    fmap (\m' -> Map m' ks) (Data.Map.traverseWithKey f m)
{-# INLINABLE unorderedTraverseWithKey #-}
unorderedTraverseWithKey_
    :: Ord k => Applicative f => (k -> a -> f ()) -> Map k a -> f ()
unorderedTraverseWithKey_ f (Map m _) =
    Data.Map.foldlWithKey' (\acc k v -> acc *> f k v) (pure ()) m
{-# INLINABLE unorderedTraverseWithKey_ #-}
toList :: Ord k => Map k v -> [(k, v)]
toList (Map m Sorted)        = Data.Map.toList m
toList (Map m (Original ks)) = fmap (\k -> (k, m Data.Map.! k)) ks
{-# INLINABLE toList #-}
toAscList :: Map k v -> [(k, v)]
toAscList (Map m _) = Data.Map.toAscList m
{-# INLINABLE toAscList #-}
toMap :: Map k v -> Data.Map.Map k v
toMap (Map m _) = m
{-# INLINABLE toMap #-}
keys :: Map k v -> [k]
keys (Map m Sorted)        = Data.Map.keys m
keys (Map _ (Original ks)) = ks
{-# INLINABLE keys #-}
elems :: Ord k => Map k v -> [v]
elems (Map m Sorted)        = Data.Map.elems m
elems (Map m (Original ks)) = fmap (\k -> m Data.Map.! k) ks
{-# INLINABLE elems #-}
keysSet :: Map k v -> Data.Set.Set k
keysSet (Map m _) = Data.Map.keysSet m
{-# INLINABLE keysSet #-}
filterKeys :: (a -> Bool) -> Keys a -> Keys a
filterKeys _ Sorted        = Sorted
filterKeys f (Original ks) = Original (Prelude.filter f ks)
{-# INLINABLE filterKeys #-}