module Data.Comp.Multi.Mapping
    ( Numbered (..)
    , unNumbered
    , number
    , HTraversable ()
    , Mapping (..)
    , lookupNumMap) where
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.HTraversable
import Control.Monad.State
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
data Numbered a i = Numbered Int (a i)
unNumbered :: Numbered a :-> a
unNumbered (Numbered _ x) = x
number :: HTraversable f => f a :-> f (Numbered a)
number x = evalState (hmapM run x) 0 where
  run b = do n <- get
             put (n+1)
             return $ Numbered n b
infix 1 |->
infixr 0 &
class Mapping m (k :: * -> *) | m -> k where
    
    (&) :: m v -> m v -> m v
    
    (|->) :: k i -> v -> m v
    
    empty :: m v
    
    
    prodMap :: v1 -> v2 -> m v1 -> m v2 -> m (v1, v2)
    
    
    findWithDefault :: a -> k i -> m a -> a
newtype NumMap (k :: * -> *) v = NumMap (IntMap v) deriving Functor
lookupNumMap :: a -> Int -> NumMap t a -> a
lookupNumMap d k (NumMap m) = IntMap.findWithDefault d k m
instance Mapping (NumMap k) (Numbered k) where
    NumMap m1 & NumMap m2 = NumMap (IntMap.union m1 m2)
    Numbered k _ |-> v = NumMap $ IntMap.singleton k v
    empty = NumMap IntMap.empty
    findWithDefault d (Numbered i _) m = lookupNumMap d i m
    prodMap p q (NumMap mp) (NumMap mq) = NumMap $ IntMap.mergeWithKey merge 
                                          (IntMap.map (,q)) (IntMap.map (p,)) mp mq
      where merge _ p q = Just (p,q)