-- | Assignment of unique IDs to values.
-- Inspired by the 'intern' package.

{-# LANGUAGE RecordWildCards, ScopedTypeVariables, BangPatterns, MagicHash, RoleAnnotations #-}
module Data.Label(Label, unsafeMkLabel, labelNum, label, find) where

import Data.IORef
import System.IO.Unsafe
import qualified Data.Map.Strict as Map
import Data.Map.Strict(Map)
import qualified Data.DynamicArray as DynamicArray
import Data.DynamicArray(Array)
import Data.Typeable
import GHC.Exts
import GHC.Int
import Unsafe.Coerce

-- | A value of type @a@ which has been given a unique ID.
newtype Label a =
  Label {
    -- | The unique ID of a label.
    Label a -> Int32
labelNum :: Int32 }
  deriving (Label a -> Label a -> Bool
(Label a -> Label a -> Bool)
-> (Label a -> Label a -> Bool) -> Eq (Label a)
forall a. Label a -> Label a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Label a -> Label a -> Bool
$c/= :: forall a. Label a -> Label a -> Bool
== :: Label a -> Label a -> Bool
$c== :: forall a. Label a -> Label a -> Bool
Eq, Eq (Label a)
Eq (Label a)
-> (Label a -> Label a -> Ordering)
-> (Label a -> Label a -> Bool)
-> (Label a -> Label a -> Bool)
-> (Label a -> Label a -> Bool)
-> (Label a -> Label a -> Bool)
-> (Label a -> Label a -> Label a)
-> (Label a -> Label a -> Label a)
-> Ord (Label a)
Label a -> Label a -> Bool
Label a -> Label a -> Ordering
Label a -> Label a -> Label a
forall a. Eq (Label a)
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Label a -> Label a -> Bool
forall a. Label a -> Label a -> Ordering
forall a. Label a -> Label a -> Label a
min :: Label a -> Label a -> Label a
$cmin :: forall a. Label a -> Label a -> Label a
max :: Label a -> Label a -> Label a
$cmax :: forall a. Label a -> Label a -> Label a
>= :: Label a -> Label a -> Bool
$c>= :: forall a. Label a -> Label a -> Bool
> :: Label a -> Label a -> Bool
$c> :: forall a. Label a -> Label a -> Bool
<= :: Label a -> Label a -> Bool
$c<= :: forall a. Label a -> Label a -> Bool
< :: Label a -> Label a -> Bool
$c< :: forall a. Label a -> Label a -> Bool
compare :: Label a -> Label a -> Ordering
$ccompare :: forall a. Label a -> Label a -> Ordering
$cp1Ord :: forall a. Eq (Label a)
Ord, Int -> Label a -> ShowS
[Label a] -> ShowS
Label a -> String
(Int -> Label a -> ShowS)
-> (Label a -> String) -> ([Label a] -> ShowS) -> Show (Label a)
forall a. Int -> Label a -> ShowS
forall a. [Label a] -> ShowS
forall a. Label a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Label a] -> ShowS
$cshowList :: forall a. [Label a] -> ShowS
show :: Label a -> String
$cshow :: forall a. Label a -> String
showsPrec :: Int -> Label a -> ShowS
$cshowsPrec :: forall a. Int -> Label a -> ShowS
Show)

type role Label nominal

-- | Construct a @'Label' a@ from its unique ID, which must be the 'labelNum' of
-- an already existing 'Label'. Extremely unsafe!
unsafeMkLabel :: Int32 -> Label a
unsafeMkLabel :: Int32 -> Label a
unsafeMkLabel = Int32 -> Label a
forall a. Int32 -> Label a
Label

-- The global cache of labels.
{-# NOINLINE cachesRef #-}
cachesRef :: IORef Caches
cachesRef :: IORef Caches
cachesRef = IO (IORef Caches) -> IORef Caches
forall a. IO a -> a
unsafePerformIO (Caches -> IO (IORef Caches)
forall a. a -> IO (IORef a)
newIORef (Int32 -> Map TypeRep (Cache Any) -> Array Any -> Caches
Caches Int32
0 Map TypeRep (Cache Any)
forall k a. Map k a
Map.empty Array Any
forall a. Array a
DynamicArray.newArray))

data Caches =
  Caches {
    -- The next id number to assign.
    Caches -> Int32
caches_nextId :: {-# UNPACK #-} !Int32,
    -- A map from values to labels.
    Caches -> Map TypeRep (Cache Any)
caches_from   :: !(Map TypeRep (Cache Any)),
    -- The reverse map from labels to values.
    Caches -> Array Any
caches_to     :: !(Array Any) }

type Cache a = Map a Int32

atomicModifyCaches :: (Caches -> (Caches, a)) -> IO a
atomicModifyCaches :: (Caches -> (Caches, a)) -> IO a
atomicModifyCaches Caches -> (Caches, a)
f = do
  -- N.B. atomicModifyIORef' ref f evaluates f ref *after* doing the
  -- compare-and-swap. This causes bad things to happen when 'label'
  -- is used reentrantly (i.e. the Ord instance itself calls label).
  -- This function only lets the swap happen if caches_nextId didn't
  -- change (i.e., no new values were inserted).
  !Caches
caches <- IORef Caches -> IO Caches
forall a. IORef a -> IO a
readIORef IORef Caches
cachesRef
  -- First compute the update.
  let !(!Caches
caches', !a
x) = Caches -> (Caches, a)
f Caches
caches
  -- Now see if anyone else updated the cache in between
  -- (can happen if f called 'label', or in a concurrent setting).
  Bool
ok <- IORef Caches -> (Caches -> (Caches, Bool)) -> IO Bool
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Caches
cachesRef ((Caches -> (Caches, Bool)) -> IO Bool)
-> (Caches -> (Caches, Bool)) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Caches
cachesNow ->
    if Caches -> Int32
caches_nextId Caches
caches Int32 -> Int32 -> Bool
forall a. Eq a => a -> a -> Bool
== Caches -> Int32
caches_nextId Caches
cachesNow
    then (Caches
caches', Bool
True)
    else (Caches
cachesNow, Bool
False)
  if Bool
ok then a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x else (Caches -> (Caches, a)) -> IO a
forall a. (Caches -> (Caches, a)) -> IO a
atomicModifyCaches Caches -> (Caches, a)
f

-- Versions of unsafeCoerce with slightly more type checking
toAnyCache :: Cache a -> Cache Any
toAnyCache :: Cache a -> Cache Any
toAnyCache = Cache a -> Cache Any
forall a b. a -> b
unsafeCoerce

fromAnyCache :: Cache Any -> Cache a
fromAnyCache :: Cache Any -> Cache a
fromAnyCache = Cache Any -> Cache a
forall a b. a -> b
unsafeCoerce

toAny :: a -> Any
toAny :: a -> Any
toAny = a -> Any
forall a b. a -> b
unsafeCoerce

fromAny :: Any -> a
fromAny :: Any -> a
fromAny = Any -> a
forall a b. a -> b
unsafeCoerce

-- | Assign a label to a value.
{-# NOINLINE label #-}
label :: forall a. (Typeable a, Ord a) => a -> Label a
label :: a -> Label a
label a
x =
  IO (Label a) -> Label a
forall a. IO a -> a
unsafeDupablePerformIO (IO (Label a) -> Label a) -> IO (Label a) -> Label a
forall a b. (a -> b) -> a -> b
$ do
    -- Common case: label is already there.
    Caches
caches <- IORef Caches -> IO Caches
forall a. IORef a -> IO a
readIORef IORef Caches
cachesRef
    case Caches -> Maybe (Label a)
tryFind Caches
caches of
      Just Label a
l -> Label a -> IO (Label a)
forall (m :: * -> *) a. Monad m => a -> m a
return Label a
l
      Maybe (Label a)
Nothing -> do
        -- Rare case: label was not there.
        Label a
x <- (Caches -> (Caches, Label a)) -> IO (Label a)
forall a. (Caches -> (Caches, a)) -> IO a
atomicModifyCaches ((Caches -> (Caches, Label a)) -> IO (Label a))
-> (Caches -> (Caches, Label a)) -> IO (Label a)
forall a b. (a -> b) -> a -> b
$ \Caches
caches ->
          case Caches -> Maybe (Label a)
tryFind Caches
caches of
            Just Label a
l -> (Caches
caches, Label a
l)
            Maybe (Label a)
Nothing ->
              Caches -> (Caches, Label a)
insert Caches
caches
        Label a -> IO (Label a)
forall (m :: * -> *) a. Monad m => a -> m a
return Label a
x

  where
    ty :: TypeRep
ty = a -> TypeRep
forall a. Typeable a => a -> TypeRep
typeOf a
x

    tryFind :: Caches -> Maybe (Label a)
    tryFind :: Caches -> Maybe (Label a)
tryFind Caches{Int32
Map TypeRep (Cache Any)
Array Any
caches_to :: Array Any
caches_from :: Map TypeRep (Cache Any)
caches_nextId :: Int32
caches_to :: Caches -> Array Any
caches_from :: Caches -> Map TypeRep (Cache Any)
caches_nextId :: Caches -> Int32
..} =
      Int32 -> Label a
forall a. Int32 -> Label a
Label (Int32 -> Label a) -> Maybe Int32 -> Maybe (Label a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TypeRep -> Map TypeRep (Cache Any) -> Maybe (Cache Any)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TypeRep
ty Map TypeRep (Cache Any)
caches_from Maybe (Cache Any) -> (Cache Any -> Maybe Int32) -> Maybe Int32
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= a -> Map a Int32 -> Maybe Int32
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x (Map a Int32 -> Maybe Int32)
-> (Cache Any -> Map a Int32) -> Cache Any -> Maybe Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cache Any -> Map a Int32
forall a. Cache Any -> Cache a
fromAnyCache)

    insert :: Caches -> (Caches, Label a)
    insert :: Caches -> (Caches, Label a)
insert caches :: Caches
caches@Caches{Int32
Map TypeRep (Cache Any)
Array Any
caches_to :: Array Any
caches_from :: Map TypeRep (Cache Any)
caches_nextId :: Int32
caches_to :: Caches -> Array Any
caches_from :: Caches -> Map TypeRep (Cache Any)
caches_nextId :: Caches -> Int32
..} =
      if Int32
n Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
0 then String -> (Caches, Label a)
forall a. HasCallStack => String -> a
error String
"label overflow" else
      (Caches
caches {
         caches_nextId :: Int32
caches_nextId = Int32
nInt32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
+Int32
1,
         caches_from :: Map TypeRep (Cache Any)
caches_from = TypeRep
-> Cache Any -> Map TypeRep (Cache Any) -> Map TypeRep (Cache Any)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert TypeRep
ty (Map a Int32 -> Cache Any
forall a. Cache a -> Cache Any
toAnyCache (a -> Int32 -> Map a Int32 -> Map a Int32
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert a
x Int32
n Map a Int32
forall a. Cache a
cache)) Map TypeRep (Cache Any)
caches_from,
         caches_to :: Array Any
caches_to = Any -> Int -> Any -> Array Any -> Array Any
forall a. a -> Int -> a -> Array a -> Array a
DynamicArray.updateWithDefault Any
forall a. HasCallStack => a
undefined (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
n) (a -> Any
forall a. a -> Any
toAny a
x) Array Any
caches_to },
       Int32 -> Label a
forall a. Int32 -> Label a
Label Int32
n)
      where
        n :: Int32
n = Int32
caches_nextId
        cache :: Cache a
cache =
          Cache Any -> Cache a
forall a. Cache Any -> Cache a
fromAnyCache (Cache Any -> Cache a) -> Cache Any -> Cache a
forall a b. (a -> b) -> a -> b
$
          Cache Any -> TypeRep -> Map TypeRep (Cache Any) -> Cache Any
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault Cache Any
forall k a. Map k a
Map.empty TypeRep
ty Map TypeRep (Cache Any)
caches_from

-- | Recover the underlying value from a label.
find :: Label a -> a
-- N.B. must force n before calling readIORef, otherwise a call of
-- the form
--   find (label x)
-- doesn't work.
find :: Label a -> a
find (Label !(I32# Int#
n#)) = Int# -> a
forall a. Int# -> a
findWorker Int#
n#

{-# NOINLINE findWorker #-}
findWorker :: Int# -> a
findWorker :: Int# -> a
findWorker Int#
n# =
  IO a -> a
forall a. IO a -> a
unsafeDupablePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
    let n :: Int32
n = Int# -> Int32
I32# Int#
n#
    Caches{Int32
Map TypeRep (Cache Any)
Array Any
caches_to :: Array Any
caches_from :: Map TypeRep (Cache Any)
caches_nextId :: Int32
caches_to :: Caches -> Array Any
caches_from :: Caches -> Map TypeRep (Cache Any)
caches_nextId :: Caches -> Int32
..} <- IORef Caches -> IO Caches
forall a. IORef a -> IO a
readIORef IORef Caches
cachesRef
    a
x <- a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$! Any -> a
forall a. Any -> a
fromAny (Any -> Int -> Array Any -> Any
forall a. a -> Int -> Array a -> a
DynamicArray.getWithDefault Any
forall a. HasCallStack => a
undefined (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
n) Array Any
caches_to)
    a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x