{-# LANGUAGE GADTs, ScopedTypeVariables, Rank2Types #-}

module Data.Ref.Map (
    Map
  , Name
    
  , empty      -- :: Map f
  , singleton  -- :: Name a -> f a -> Map f 
  , null       -- :: Map f -> Bool
  , size       -- :: Map f -> Int
  , member     -- :: Name a -> Map f -> Bool
  , (!)        -- :: Name a -> Map f -> f a
  , lookup     -- :: Name a -> Map f -> Maybe (f a)
  , insert     -- :: Ref a -> f a -> Map f -> Map f
  , delete     -- :: Name a -> Map f -> Map f
  , adjust     -- :: (f a -> f b) -> Name a -> Map f -> Map f
  , filter     -- :: (f a -> Bool) -> Map f -> Map f
  , hmap       -- :: (f a -> h a) -> Map f -> Map h
  , union
  , difference
  , intersection
    
  , Entry(..)
  , toList     -- :: Map f -> [(Name a, f a)]
  , fromList   -- :: [(Name a, f a)] -> Map f
  ) where

import Control.Applicative ((<$>))
import Data.Ref
import Data.List (find, deleteBy)
import Data.Function (on)

import Unsafe.Coerce
import System.Mem.StableName

import Data.IntMap (IntMap)
import qualified Data.IntMap as M

import Prelude hiding (null, lookup, map, filter)

--------------------------------------------------------------------------------
-- * Reference indexed maps
--------------------------------------------------------------------------------

-- | Shorthand for stable names.
type Name = StableName

-- | For hiding types.
data HideType f where
  Hide :: f a -> HideType f

-- | A reference indexed map. Useful for associating info with a reference.
--
--   Note: this is generally unsound when `f` is a GADT!
data Map f = Map (IntMap [(HideType Name, HideType f)])

--------------------------------------------------------------------------------
-- ** Construction

-- | Construct an empty map.
empty :: Map f
empty :: Map f
empty = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map IntMap [(HideType Name, HideType f)]
forall a. IntMap a
M.empty

-- | Construct a map with a single element.
singleton :: Name a -> f a -> Map f
singleton :: Name a -> f a -> Map f
singleton Name a
n f a
v = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ Key
-> [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a. Key -> a -> IntMap a
M.singleton (Name a -> Key
forall a. StableName a -> Key
hashStableName Name a
n) [(Name a -> HideType Name
forall (f :: * -> *) a. f a -> HideType f
Hide Name a
n, f a -> HideType f
forall (f :: * -> *) a. f a -> HideType f
Hide f a
v)]

--------------------------------------------------------------------------------
-- ** Basic interface

-- | Returns 'True' if the map is empty, 'False' otherwise.
null :: Map f -> Bool
null :: Map f -> Bool
null (Map IntMap [(HideType Name, HideType f)]
m) = IntMap [(HideType Name, HideType f)] -> Bool
forall a. IntMap a -> Bool
M.null IntMap [(HideType Name, HideType f)]
m

-- | Returns the number of elements stored in this map.
size :: Map f -> Int
size :: Map f -> Key
size (Map IntMap [(HideType Name, HideType f)]
m) = IntMap [(HideType Name, HideType f)] -> Key
forall a. IntMap a -> Key
M.size IntMap [(HideType Name, HideType f)]
m

-- | Returns 'True' if the name is present in the map, 'False' otherwise.
member :: Name a -> Map f -> Bool
member :: Name a -> Map f -> Bool
member Name a
n (Map IntMap [(HideType Name, HideType f)]
m) = Key -> IntMap [(HideType Name, HideType f)] -> Bool
forall a. Key -> IntMap a -> Bool
M.member (Name a -> Key
forall a. StableName a -> Key
hashStableName Name a
n) IntMap [(HideType Name, HideType f)]
m

-- | Unsafe lookup
(!) :: Map f -> Name a -> f a
(!) Map f
m Name a
name = f a -> (f a -> f a) -> Maybe (f a) -> f a
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> f a
forall a. HasCallStack => [Char] -> a
error [Char]
"Data.Ref.Map.(!)") f a -> f a
forall a. a -> a
id (Name a -> Map f -> Maybe (f a)
forall a (f :: * -> *). Name a -> Map f -> Maybe (f a)
lookup Name a
name Map f
m)

-- | Finds the value associated with the name, or 'Nothing' if the name has no
-- value associated to it.
lookup :: Name a -> Map f -> Maybe (f a)
lookup :: Name a -> Map f -> Maybe (f a)
lookup Name a
n (Map IntMap [(HideType Name, HideType f)]
m) =  case Key
-> IntMap [(HideType Name, HideType f)]
-> Maybe [(HideType Name, HideType f)]
forall a. Key -> IntMap a -> Maybe a
M.lookup (Name a -> Key
forall a. StableName a -> Key
hashStableName Name a
n) IntMap [(HideType Name, HideType f)]
m of
  Maybe [(HideType Name, HideType f)]
Nothing -> Maybe (f a)
forall a. Maybe a
Nothing
  Just [(HideType Name, HideType f)]
xs -> case ((HideType Name, HideType f) -> Bool)
-> [(HideType Name, HideType f)]
-> Maybe (HideType Name, HideType f)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Hide Name a
x, HideType f
_) -> Name a -> Name a -> Bool
forall a b. StableName a -> StableName b -> Bool
eqStableName Name a
x Name a
n) [(HideType Name, HideType f)]
xs of
    Maybe (HideType Name, HideType f)
Nothing          -> Maybe (f a)
forall a. Maybe a
Nothing
    Just (HideType Name
_, Hide f a
f) -> f a -> Maybe (f a)
forall a. a -> Maybe a
Just (f a -> Maybe (f a)) -> f a -> Maybe (f a)
forall a b. (a -> b) -> a -> b
$ f a -> f a
forall a b. a -> b
unsafeCoerce f a
f

-- | Associates a reference with the specified value. If the map already contains
-- a mapping for the reference, the old value is replaced.
insert :: Ref a -> f a -> Map f -> Map f
insert :: Ref a -> f a -> Map f -> Map f
insert (Ref StableName a
n a
_) f a
v (Map IntMap [(HideType Name, HideType f)]
m) = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ ([(HideType Name, HideType f)]
 -> [(HideType Name, HideType f)] -> [(HideType Name, HideType f)])
-> Key
-> [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a. (a -> a -> a) -> Key -> a -> IntMap a -> IntMap a
M.insertWith [(HideType Name, HideType f)]
-> [(HideType Name, HideType f)] -> [(HideType Name, HideType f)]
forall a. [a] -> [a] -> [a]
(++) (StableName a -> Key
forall a. StableName a -> Key
hashStableName StableName a
n) [(StableName a -> HideType Name
forall (f :: * -> *) a. f a -> HideType f
Hide StableName a
n, f a -> HideType f
forall (f :: * -> *) a. f a -> HideType f
Hide f a
v)] IntMap [(HideType Name, HideType f)]
m

-- | Removes the associated value of a reference, if any is present in the map.
delete :: forall f a. Name a -> Map f -> Map f
delete :: Name a -> Map f -> Map f
delete Name a
n map :: Map f
map@(Map IntMap [(HideType Name, HideType f)]
m) = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ ([(HideType Name, HideType f)]
 -> Maybe [(HideType Name, HideType f)])
-> Key
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a. (a -> Maybe a) -> Key -> IntMap a -> IntMap a
M.update [(HideType Name, HideType f)]
-> Maybe [(HideType Name, HideType f)]
del (Name a -> Key
forall a. StableName a -> Key
hashStableName Name a
n) IntMap [(HideType Name, HideType f)]
m
  where
    del :: [(HideType Name, HideType f)] -> Maybe [(HideType Name, HideType f)]
    del :: [(HideType Name, HideType f)]
-> Maybe [(HideType Name, HideType f)]
del [] = Maybe [(HideType Name, HideType f)]
forall a. Maybe a
Nothing
    del [(HideType Name, HideType f)]
xs = [(HideType Name, HideType f)]
-> Maybe [(HideType Name, HideType f)]
forall a. a -> Maybe a
Just ([(HideType Name, HideType f)]
 -> Maybe [(HideType Name, HideType f)])
-> [(HideType Name, HideType f)]
-> Maybe [(HideType Name, HideType f)]
forall a b. (a -> b) -> a -> b
$ ((HideType Name, HideType f)
 -> (HideType Name, HideType f) -> Bool)
-> (HideType Name, HideType f)
-> [(HideType Name, HideType f)]
-> [(HideType Name, HideType f)]
forall a. (a -> a -> Bool) -> a -> [a] -> [a]
deleteBy (HideType Name, HideType f) -> (HideType Name, HideType f) -> Bool
forall x y. (HideType Name, x) -> (HideType Name, y) -> Bool
eq (Name a -> HideType Name
forall (f :: * -> *) a. f a -> HideType f
Hide Name a
n, HideType f
forall a. HasCallStack => a
undefined) [(HideType Name, HideType f)]
xs

    eq  :: (HideType Name, x) -> (HideType Name, y) -> Bool
    eq :: (HideType Name, x) -> (HideType Name, y) -> Bool
eq  (Hide Name a
x, x
_) (Hide Name a
y, y
_) = Name a
x Name a -> Name a -> Bool
forall a b. StableName a -> StableName b -> Bool
`eqStableName` Name a
y

-- | Updates the associated value of a reference, if any is present in the map.
adjust :: forall f a b. (f a -> f b) -> Name a -> Map f -> Map f
adjust :: (f a -> f b) -> Name a -> Map f -> Map f
adjust f a -> f b
f Name a
n (Map IntMap [(HideType Name, HideType f)]
m) = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ ([(HideType Name, HideType f)] -> [(HideType Name, HideType f)])
-> Key
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a. (a -> a) -> Key -> IntMap a -> IntMap a
M.adjust (((HideType Name, HideType f) -> (HideType Name, HideType f))
-> [(HideType Name, HideType f)] -> [(HideType Name, HideType f)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (HideType Name, HideType f) -> (HideType Name, HideType f)
open) (Name a -> Key
forall a. StableName a -> Key
hashStableName Name a
n) IntMap [(HideType Name, HideType f)]
m
  where
    open :: (HideType Name, HideType f) -> (HideType Name, HideType f)
    open :: (HideType Name, HideType f) -> (HideType Name, HideType f)
open pair :: (HideType Name, HideType f)
pair@(Hide Name a
x, Hide f a
v)
      | Name a
x Name a -> Name a -> Bool
forall a b. StableName a -> StableName b -> Bool
`eqStableName` Name a
n = (Name a -> HideType Name
forall (f :: * -> *) a. f a -> HideType f
Hide Name a
x, f b -> HideType f
forall (f :: * -> *) a. f a -> HideType f
Hide (f b -> HideType f) -> f b -> HideType f
forall a b. (a -> b) -> a -> b
$ f a -> f b
f (f a -> f b) -> f a -> f b
forall a b. (a -> b) -> a -> b
$ f a -> f a
forall a b. a -> b
unsafeCoerce f a
v)
      | Bool
otherwise          = (HideType Name, HideType f)
pair

-- | Filters the map for values matching the predicate
filter :: (forall a. f a -> Bool) -> Map f -> Map f
filter :: (forall a. f a -> Bool) -> Map f -> Map f
filter forall a. f a -> Bool
f (Map IntMap [(HideType Name, HideType f)]
m) = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ ([(HideType Name, HideType f)] -> Bool)
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a. (a -> Bool) -> IntMap a -> IntMap a
M.filter ((forall a. f a -> Bool) -> [(HideType Name, HideType f)] -> Bool
forall (f :: * -> *).
(forall a. f a -> Bool) -> [(HideType Name, HideType f)] -> Bool
unwrap forall a. f a -> Bool
f) IntMap [(HideType Name, HideType f)]
m
  where
    unwrap :: (forall a. f a -> Bool) -> [(HideType Name, HideType f)] -> Bool
    unwrap :: (forall a. f a -> Bool) -> [(HideType Name, HideType f)] -> Bool
unwrap forall a. f a -> Bool
f = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool)
-> ([(HideType Name, HideType f)] -> [Bool])
-> [(HideType Name, HideType f)]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HideType Name, HideType f) -> Bool)
-> [(HideType Name, HideType f)] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(HideType Name
_, Hide f a
a) -> f a -> Bool
forall a. f a -> Bool
f f a
a)

--------------------------------------------------------------------------------
-- ** Traversal

-- | Map over the container types
hmap :: forall f h a. (f a -> h a) -> Map f -> Map h
hmap :: (f a -> h a) -> Map f -> Map h
hmap f a -> h a
f (Map IntMap [(HideType Name, HideType f)]
m) = IntMap [(HideType Name, HideType h)] -> Map h
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType h)] -> Map h)
-> IntMap [(HideType Name, HideType h)] -> Map h
forall a b. (a -> b) -> a -> b
$ ([(HideType Name, HideType f)] -> [(HideType Name, HideType h)])
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType h)]
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map (((HideType Name, HideType f) -> (HideType Name, HideType h))
-> [(HideType Name, HideType f)] -> [(HideType Name, HideType h)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((HideType Name, HideType f) -> (HideType Name, HideType h))
 -> [(HideType Name, HideType f)] -> [(HideType Name, HideType h)])
-> ((HideType Name, HideType f) -> (HideType Name, HideType h))
-> [(HideType Name, HideType f)]
-> [(HideType Name, HideType h)]
forall a b. (a -> b) -> a -> b
$ (HideType f -> HideType h)
-> (HideType Name, HideType f) -> (HideType Name, HideType h)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HideType f -> HideType h
open) IntMap [(HideType Name, HideType f)]
m
  where
    open :: HideType f -> HideType h
    open :: HideType f -> HideType h
open (Hide f a
x) = h a -> HideType h
forall (f :: * -> *) a. f a -> HideType f
Hide (h a -> HideType h) -> h a -> HideType h
forall a b. (a -> b) -> a -> b
$ f a -> h a
f (f a -> h a) -> f a -> h a
forall a b. (a -> b) -> a -> b
$ f a -> f a
forall a b. a -> b
unsafeCoerce f a
x

--------------------------------------------------------------------------------
-- ** Combine

-- | Union of two maps (left biased).
union :: Map f -> Map f -> Map f
union :: Map f -> Map f -> Map f
union (Map IntMap [(HideType Name, HideType f)]
m) (Map IntMap [(HideType Name, HideType f)]
n) = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a. IntMap a -> IntMap a -> IntMap a
M.union IntMap [(HideType Name, HideType f)]
m IntMap [(HideType Name, HideType f)]
n

-- | Difference of two maps.
difference :: Map f -> Map f -> Map f
difference :: Map f -> Map f -> Map f
difference (Map IntMap [(HideType Name, HideType f)]
m) (Map IntMap [(HideType Name, HideType f)]
n) = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a b. IntMap a -> IntMap b -> IntMap a
M.difference IntMap [(HideType Name, HideType f)]
m IntMap [(HideType Name, HideType f)]
n

-- | Intersectino of two maps.
intersection :: Map f -> Map f -> Map f
intersection :: Map f -> Map f -> Map f
intersection (Map IntMap [(HideType Name, HideType f)]
m) (Map IntMap [(HideType Name, HideType f)]
n) = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> IntMap [(HideType Name, HideType f)] -> Map f
forall a b. (a -> b) -> a -> b
$ IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
-> IntMap [(HideType Name, HideType f)]
forall a b. IntMap a -> IntMap b -> IntMap a
M.intersection IntMap [(HideType Name, HideType f)]
m IntMap [(HideType Name, HideType f)]
n

--------------------------------------------------------------------------------
-- ** Lists

-- | Entry in map.
data Entry f = forall a. Entry (Name a) (f a)

-- | Fetches all the elements of a map.
toList :: Map f -> [Entry f]
toList :: Map f -> [Entry f]
toList (Map IntMap [(HideType Name, HideType f)]
m) = ((HideType Name, HideType f) -> Entry f)
-> [(HideType Name, HideType f)] -> [Entry f]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (HideType Name, HideType f) -> Entry f
forall (f :: * -> *). (HideType Name, HideType f) -> Entry f
pack ([(HideType Name, HideType f)] -> [Entry f])
-> ([[(HideType Name, HideType f)]]
    -> [(HideType Name, HideType f)])
-> [[(HideType Name, HideType f)]]
-> [Entry f]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(HideType Name, HideType f)]] -> [(HideType Name, HideType f)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(HideType Name, HideType f)]] -> [Entry f])
-> [[(HideType Name, HideType f)]] -> [Entry f]
forall a b. (a -> b) -> a -> b
$ IntMap [(HideType Name, HideType f)]
-> [[(HideType Name, HideType f)]]
forall a. IntMap a -> [a]
M.elems IntMap [(HideType Name, HideType f)]
m
  where
    pack :: (HideType Name, HideType f) -> Entry f
    pack :: (HideType Name, HideType f) -> Entry f
pack (Hide Name a
n, Hide f a
f) = Name a -> f a -> Entry f
forall (f :: * -> *) a. Name a -> f a -> Entry f
Entry Name a
n (f a -> f a
forall a b. a -> b
unsafeCoerce f a
f)

-- | Constructs a map from a list of entries.
fromList :: [Entry f] -> Map f
fromList :: [Entry f] -> Map f
fromList = IntMap [(HideType Name, HideType f)] -> Map f
forall (f :: * -> *). IntMap [(HideType Name, HideType f)] -> Map f
Map (IntMap [(HideType Name, HideType f)] -> Map f)
-> ([Entry f] -> IntMap [(HideType Name, HideType f)])
-> [Entry f]
-> Map f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Key, [(HideType Name, HideType f)])]
-> IntMap [(HideType Name, HideType f)]
forall a. [(Key, a)] -> IntMap a
M.fromList ([(Key, [(HideType Name, HideType f)])]
 -> IntMap [(HideType Name, HideType f)])
-> ([Entry f] -> [(Key, [(HideType Name, HideType f)])])
-> [Entry f]
-> IntMap [(HideType Name, HideType f)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HideType Name, HideType f)
 -> (Key, [(HideType Name, HideType f)]))
-> [(HideType Name, HideType f)]
-> [(Key, [(HideType Name, HideType f)])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (HideType Name, HideType f) -> (Key, [(HideType Name, HideType f)])
forall (f :: * -> *).
(HideType Name, HideType f) -> (Key, [(HideType Name, HideType f)])
keys ([(HideType Name, HideType f)]
 -> [(Key, [(HideType Name, HideType f)])])
-> ([Entry f] -> [(HideType Name, HideType f)])
-> [Entry f]
-> [(Key, [(HideType Name, HideType f)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Entry f -> (HideType Name, HideType f))
-> [Entry f] -> [(HideType Name, HideType f)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Entry f -> (HideType Name, HideType f)
forall (f :: * -> *). Entry f -> (HideType Name, HideType f)
unpack
  where
    unpack :: Entry f -> (HideType Name, HideType f)
    unpack :: Entry f -> (HideType Name, HideType f)
unpack (Entry Name a
n f a
f) = (Name a -> HideType Name
forall (f :: * -> *) a. f a -> HideType f
Hide Name a
n, f a -> HideType f
forall (f :: * -> *) a. f a -> HideType f
Hide f a
f)

    keys :: (HideType Name, HideType f) -> (M.Key, [(HideType Name, HideType f)])
    keys :: (HideType Name, HideType f) -> (Key, [(HideType Name, HideType f)])
keys e :: (HideType Name, HideType f)
e@(Hide Name a
n, HideType f
_) = (Name a -> Key
forall a. StableName a -> Key
hashStableName Name a
n, [(HideType Name, HideType f)
e])

--------------------------------------------------------------------------------