{-# LANGUAGE MultiParamTypeClasses #-}
module Data.Containers(
  DataMap(..),
  Set,Map,ascList,

  member,delete,touch,insert,singleton,fromList
  )
  where

import SimpleH
import qualified Data.Set as S
import qualified Data.Map as M
import Data.Map (Map)
import Data.Set (Set)

class Monoid m => DataMap m k a | m -> k a where
  at :: k -> Lens' m (Maybe a)
class Indexed f i | f -> i where
  withKeys :: f a -> f (i,a)
member :: DataMap m k a => k -> m -> Bool
member k = by (at k) >>> yb _maybe
delete :: DataMap m k a => k -> m -> m
delete k = at k %- Nothing
insert :: DataMap m k a => k -> a -> m -> m
insert k a = at k %- Just a
touch :: (Monoid a, DataMap m k a) => k -> m -> m
touch k = insert k zero
singleton :: DataMap m k a => k -> a -> m
singleton = map2 ($zero) insert
fromList :: DataMap m k a => [(k,a)] -> m
fromList l = compose (uncurry insert<$>l) zero

instance Ord a => DataMap (Set a) a Void where
  at k = lens (S.member k) (flip (bool (S.insert k) (S.delete k)))._maybe
instance Ord k => DataMap (Map k a) k a where
  at k = lens (M.lookup k) (\m a -> M.alter (const a) k m)
  
instance Ord a => Semigroup (Set a) where (+) = S.union
instance Ord a => Monoid (Set a) where zero = S.empty
instance (Ord a,Monoid a) => Ring (Set a) where
  one = singleton zero zero
  (*) = S.intersection
instance Functor Set where map = S.mapMonotonic
instance Foldable Set where fold = S.foldr (+) zero

instance Ord k => Semigroup (Map k a) where (+) = M.union
instance Ord k => Monoid (Map k a) where zero = M.empty
instance Functor (Map k) where map = M.map
instance Foldable (Map k) where fold = M.foldr (+) zero
instance Eq k => Traversable (Map k) where sequence = (ascList._Compose) sequence
instance Indexed (Map k) k where withKeys = M.mapWithKey (,)

ascList :: (Eq k,Eq k') => Iso [(k,a)] [(k',a')] (Map k a) (Map k' a')
ascList = iso M.toAscList M.fromAscList

newtype Bimap a b = Bimap (Map a b,Map b a)
                  deriving (Semigroup,Monoid)
_inverse :: Iso (Bimap a b) (Bimap c d) (Bimap b a) (Bimap d c)
_inverse = iso (\(Bimap (b,a)) -> Bimap (a,b)) (\(Bimap (c,d)) -> Bimap (d,c))

instance (Ord a,Ord b) => DataMap (Bimap a b) a b where
  at a = lens lookup setAt
    where lookup (Bimap (ma,_)) = ma^.at a
          setAt (Bimap (ma,mb)) b' = Bimap (
            ma',mb & maybe id delete b >>> maybe id (flip insert a) b')
            where b = ma^.at a ; ma' = ma & at a %- b'
instance (Ord b,Ord a) => DataMap (Flip Bimap b a) b a where
  at k = from (_inverse._Flip).at k