{-# LANGUAGE MultiParamTypeClasses, ViewPatterns, ScopedTypeVariables #-}
module Data.Containers(
  -- * The basic data class
  DataMap(..),Indexed(..),OrderedMap(..),Container(..),
  
  lookup,resides,member,delete,touch,insert,singleton,singleton',fromAList,fromKList,(#),(#?),
  cached,

  -- * Map instances
  -- ** Sets and maps
  Set,Map,c'setOf,c'set,c'mapOf,c'map,
  
  -- ** Bimaps
  Bimap(..),toMap,keysSet,

  -- ** Relations
  Relation(..),i'Relation,i'domains,i'ranges,l'domain,l'range,link,(*>>>)
  )
  where

import Definitive.Base
import qualified Data.Set as S
import qualified Data.Map as M
import Data.Map (Map)
import Data.Set (Set)
import Control.Concurrent.MVar

class Monoid m => DataMap m k a | m -> k a where
  at :: k -> Lens' m (Maybe a)
class Indexed f i | f -> i where
  keyed :: Iso (f (i,a)) (f (i,b)) (f a) (f b) 
class Container c where weight :: c a -> Int

instance Indexed [] Int where
  keyed = iso (zip [0..]) (map snd)
instance Container [] where weight = size
instance Container Set where weight = S.size
instance Container (Map k) where weight = M.size
class OrderedMap m k a m' k' a' | m -> k a, m' -> k' a' where
  ascList :: Iso [(k,a)] [(k',a')] m m'

c'setOf :: Constraint a -> Constraint (Set a)
c'setOf _ = id
c'mapOf :: Constraint a -> Constraint b -> Constraint (Map a b)
c'mapOf _ _ = id
c'set :: Constraint (Set a)
c'set = c'setOf id
c'map :: Constraint (Map a b)
c'map = c'mapOf id id

member :: DataMap m k Void => k -> Lens' m Bool
member k = at k.from i'maybe

lookup :: DataMap m k a => k -> m -> Maybe a
lookup s m = m^.at s
resides :: DataMap m k a => k -> m -> Bool
resides = map2 nonempty lookup
delete :: DataMap m k a => k -> m -> m
delete k = at k %- Nothing
insert :: DataMap m k a => k -> a -> m -> m
insert k a = at k %- Just a
(#) :: DataMap m k a => m -> [(k,a)] -> m
m # ks = compose [insert k a | (k,a) <- ks] m
touch :: (Monoid a, DataMap m k a) => k -> m -> m
touch k = insert k zero
singleton :: DataMap m k a => k -> a -> m
singleton = map2 ($zero) insert
singleton' :: (Monoid a,DataMap m k a) => k -> m
singleton' x = touch x zero
fromAList :: DataMap m k a => [(k,a)] -> m
fromAList l = compose (uncurry insert<$>l) zero
fromKList :: (Monoid a,DataMap m k a) => [k] -> m
fromKList l = compose (touch<$>l) zero

instance Ord a => DataMap (Set a) a Void where
  at k = lens (S.member k) (flip (bool (S.insert k) (S.delete k))).i'maybe
instance Eq b => OrderedMap (Set a) a Void (Set b) b Void where
  ascList = iso S.toAscList S.fromAscList . mapping (i'_.commuted)
instance Ord k => DataMap (Map k a) k a where
  at k = lens (M.lookup k) (\m a -> M.alter (const a) k m)
instance Eq k => DataMap [(k,a)] k a where
  at k = lens (foldMap (\(k',a) -> a <$ guard (k==k'))) g
    where g l Nothing = [(k',a) | (k',a) <- l, k' /= k ]
          g l (Just a) = (k,a) : l
instance Eq k' => OrderedMap (Map k a) k a (Map k' a') k' a' where 
  ascList = iso M.toAscList M.fromAscList
  
instance Ord a => Semigroup (Set a) where (+) = S.union
instance Ord a => Monoid (Set a) where zero = S.empty
instance Ord a => Disjonctive (Set a) where (-) = S.difference
instance Ord a => Semiring (Set a) where (*) = S.intersection
instance Functor Set where map = S.mapMonotonic
instance Foldable Set where fold = S.foldr (+) zero

instance Ord k => Semigroup (Map k a) where (+) = M.union
instance Ord k => Monoid (Map k a) where zero = M.empty
instance Ord k => Disjonctive (Map k a) where (-) = M.difference
instance (Ord k,Semigroup a) => Semiring (Map k a) where (*) = M.unionWith (+)
instance Functor (Map k) where map = M.map
instance Foldable (Map k) where fold = M.foldr (+) zero
instance Eq k => Traversable (Map k) where sequence = (ascList.i'Compose) sequence
instance Indexed (Map k) k where keyed = iso (M.mapWithKey (,)) (map snd)

instance Ord k => Unit (Zip (Map k)) where
  pure = undefined
instance Ord k => Applicative (Zip (Map k)) where
  Zip fs <*> Zip xs = Zip (M.intersectionWith ($) fs xs)

-- |An invertible map
newtype Bimap a b = Bimap (Map a b,Map b a)
                  deriving (Show,Semigroup,Monoid,Disjonctive,Semiring)
instance Commutative Bimap where
  commute (Bimap (b,a)) = Bimap (a,b)

instance (Ord a,Ord b) => DataMap (Bimap a b) a b where
  at a = lens t'lookup setAt
    where t'lookup ma = toMap ma^.at a
          setAt (Bimap (ma,mb)) b' = Bimap (
            maybe id delete (b' >>= \b'' -> mb^.at b'') ma & at a %- b',
            mb & maybe id delete b >>> maybe id (flip insert a) b')
            where b = ma^.at a 
instance (Ord b,Ord a) => DataMap (Flip Bimap b a) b a where
  at k = from (commuted.i'Flip).at k
instance (Ord a,Ord b,Ord c,Ord d) => OrderedMap (Bimap a b) a b (Bimap c d) c d where
  ascList = iso (toMap >>> \m -> m^.ascList) (\l -> Bimap (l^..ascList,l^..ascList.mapping commuted))
toMap :: Bimap a b -> Map a b
toMap (Bimap (a,_)) = a

keysSet :: (Eq k,OrderedMap m k a m k a) => m -> Set k
keysSet m = map (second (const zero)) (m^.ascList)^.from ascList

--- |The Relation type
newtype Relation a b = Relation (Map a (Set b),Map b (Set a))
                     deriving (Show,Eq,Ord)
instance (Ord a,Ord b) => Semigroup (Relation a b) where
  Relation (x,x') + Relation (y,y') = Relation (M.unionWith (+) x y,M.unionWith (+) x' y')
deriving instance (Ord a,Ord b) => Monoid (Relation a b)
instance (Ord a,Ord b) => Semiring (Relation a b) where
  Relation (x,x') * Relation (y,y') = Relation (zipWith (*) x y,zipWith (*) x' y')
i'Relation :: Iso (Relation a b) (Relation c d) (Map a (Set b),Map b (Set a)) (Map c (Set d),Map d (Set c))
i'Relation = iso Relation (\(Relation r) -> r)
instance Commutative Relation where
  commute (Relation (a,b)) = Relation (b,a)

-- |Define a Relation from its ranges. O(1) <-> O(1,n*ln(n)) 
i'ranges :: (Ord c,Ord d) => Iso (Map a (Set b)) (Map c (Set d)) (Relation a b) (Relation c d)
i'ranges = iso (\(Relation (rs,_)) -> rs) fromRanges
  where fromRanges rs = Relation (rs,compose (rs^.keyed <&> \ (a,bs) -> compose $ bs <&> \b ->
                                              at b%~Just . touch a . fold) zero)
-- |Define a Relation from its domain (uses the Commutative instance)
i'domains :: (Ord c,Ord d) => Iso (Map b (Set a)) (Map d (Set c)) (Relation a b) (Relation c d)
i'domains = commuted.i'ranges

instance (Ord k,Ord a) => DataMap (Relation k a) k (Set a) where
  at a = lens (\(Relation (rs,_)) -> rs^.at a) setRan
    where setRan (Relation (rs,ds)) (fold -> ran) = Relation (
            rs & at a %- if empty ran then Nothing else Just ran,
            adjust ds)
            where oldRan = fold $ rs^.at a
                  adjust = compose ((oldRan-ran) <&> \b -> at b.traverse.member a %- False)
                           >>> compose ((ran-oldRan) <&> \b -> at b %~ Just . touch a . fold)

may :: (Monoid (f b),Foldable f) => Iso (Maybe (f a)) (Maybe (f b)) (f a) (f b)
may = iso (\f -> if empty f then Nothing else Just f) (maybe zero id)

l'domain :: (Ord a,Ord b) => a -> Lens' (Relation a b) (Set b)
l'domain a = at a.from may
l'range :: (Ord a,Ord b) => b -> Lens' (Relation a b) (Set a)
l'range a = commuted.l'domain a

link :: (Ord a,Ord b) => a -> b -> Lens' (Relation a b) Bool
link a b = l'domain a.member b

(#?) :: (Ord a,Ord b) => Relation a b -> [(a,b)] -> Relation a b
r #? ls = compose [link a b %- True | (a,b) <- ls] r

cached :: forall a b. Ord a => (a -> b) -> a -> b
cached f = \a -> g a^.thunk
  where g a = do
          m <- vm `seq` takeMVar vm
          case m^.at a of
            Just b -> putMVar vm m >> return b
            Nothing -> let b = f a in putMVar vm (insert a b m) >> return b
        vm = newMVar (zero :: Map a b)^.thunk

(*>>>) :: (Ord a,Ord b,Ord c) => Relation a b -> Relation b c -> Relation a c
r *>>> r' = r & i'ranges %~ map (foldMap (\b -> r'^.l'domain b))