{-# LANGUAGE MagicHash #-}

module Data.HashCons.Internal where

import Control.Monad (when)
import Data.Hashable (Hashable, hash, hashWithSalt)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import System.IO.Unsafe (unsafeDupablePerformIO)
import Text.ParserCombinators.ReadPrec (step)
import Text.Read (Read(..), lexP, parens, prec)
import Text.Read.Lex (Lexeme (Ident))
import GHC.Exts (reallyUnsafePtrEquality#)

-- | 'HashCons' with a precomputed hash and an 'IORef' to the value.
--
-- WARNING: Do not use this type to wrap types whose Eq or Ord instances
-- allow distinguishable values to compare as equal; this will result in
-- nondeterminism or even visible mutation of semantically-immutable
-- values at runtime.
data HashCons a = HashCons
  { forall a. HashCons a -> Int
_hashCons_hash :: {-# UNPACK #-} !Int       -- ^ Precomputed hash
  , forall a. HashCons a -> IORef a
_hashCons_ref  :: {-# UNPACK #-} !(IORef a) -- ^ Reference to the value
  }

-- | Create a new 'HashCons'.
hashCons :: Hashable a => a -> HashCons a
hashCons :: forall a. Hashable a => a -> HashCons a
hashCons a
a = Int -> IORef a -> HashCons a
forall a. Int -> IORef a -> HashCons a
HashCons (a -> Int
forall a. Hashable a => a -> Int
hash a
a) (IORef a -> HashCons a) -> IORef a -> HashCons a
forall a b. (a -> b) -> a -> b
$ IO (IORef a) -> IORef a
forall a. IO a -> a
unsafeDupablePerformIO (IO (IORef a) -> IORef a) -> IO (IORef a) -> IORef a
forall a b. (a -> b) -> a -> b
$ a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef a
a

-- | Extract the value from a 'HashCons'.
unHashCons :: HashCons a -> a
unHashCons :: forall a. HashCons a -> a
unHashCons (HashCons Int
_ IORef a
ref) = IO a -> a
forall a. IO a -> a
unsafeDupablePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref

-- | Show instance that displays 'HashCons' in the format "hashCons <x>"
instance Show a => Show (HashCons a) where
    showsPrec :: Int -> HashCons a -> ShowS
showsPrec Int
d HashCons a
hc = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
appPrec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
        String -> ShowS
showString String
"hashCons " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec (Int
appPrec Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (HashCons a -> a
forall a. HashCons a -> a
unHashCons HashCons a
hc)
      where
        appPrec :: Int
appPrec = Int
10

-- | Read instance that parses 'HashCons' from the format "hashCons <x>"
instance (Read a, Hashable a) => Read (HashCons a) where
  readPrec :: ReadPrec (HashCons a)
readPrec = ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a. ReadPrec a -> ReadPrec a
parens (ReadPrec (HashCons a) -> ReadPrec (HashCons a))
-> ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a b. (a -> b) -> a -> b
$ Int -> ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a. Int -> ReadPrec a -> ReadPrec a
prec Int
10 (ReadPrec (HashCons a) -> ReadPrec (HashCons a))
-> ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a b. (a -> b) -> a -> b
$ do
    Ident String
"hashCons" <- ReadPrec Lexeme
lexP
    a
a <- ReadPrec a -> ReadPrec a
forall a. ReadPrec a -> ReadPrec a
step ReadPrec a
forall a. Read a => ReadPrec a
readPrec
    HashCons a -> ReadPrec (HashCons a)
forall a. a -> ReadPrec a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashCons a -> ReadPrec (HashCons a))
-> HashCons a -> ReadPrec (HashCons a)
forall a b. (a -> b) -> a -> b
$ a -> HashCons a
forall a. Hashable a => a -> HashCons a
hashCons a
a

instance Eq a => Eq (HashCons a) where
  HashCons Int
h1 IORef a
ref1 == :: HashCons a -> HashCons a -> Bool
== HashCons Int
h2 IORef a
ref2
    | IORef a
ref1 IORef a -> IORef a -> Bool
forall a. Eq a => a -> a -> Bool
== IORef a
ref2 = Bool
True
    | Int
h1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
h2 = Bool
False
    | Bool
otherwise = IO Bool -> Bool
forall a. IO a -> a
unsafeDupablePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
        a
a1 <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref1
        a
a2 <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref2
        let eq :: Bool
eq = a
a1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a2
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
eq (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ case a -> a -> Int#
forall a b. a -> b -> Int#
reallyUnsafePtrEquality# a
a1 a
a2 of
          Int#
0# -> IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
ref1 a
a2
          Int#
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
eq

-- | NOTE: This instance orders by hash first, and only secondarily by
-- the 'Ord' instance of 'a', to improve performance.
instance Ord a => Ord (HashCons a) where
  compare :: HashCons a -> HashCons a -> Ordering
compare (HashCons Int
h1 IORef a
ref1) (HashCons Int
h2 IORef a
ref2) = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
h1 Int
h2 of
    Ordering
EQ -> if IORef a
ref1 IORef a -> IORef a -> Bool
forall a. Eq a => a -> a -> Bool
== IORef a
ref2
      then Ordering
EQ
      else IO Ordering -> Ordering
forall a. IO a -> a
unsafeDupablePerformIO (IO Ordering -> Ordering) -> IO Ordering -> Ordering
forall a b. (a -> b) -> a -> b
$ do
        a
a1 <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref1
        a
a2 <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref2
        let result :: Ordering
result = a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
a1 a
a2
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Ordering
result Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ case a -> a -> Int#
forall a b. a -> b -> Int#
reallyUnsafePtrEquality# a
a1 a
a2 of
          Int#
0# -> IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
ref1 a
a2
          Int#
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Ordering -> IO Ordering
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ordering
result
    Ordering
result -> Ordering
result

instance Eq a => Hashable (HashCons a) where
  hashWithSalt :: Int -> HashCons a -> Int
hashWithSalt Int
salt (HashCons Int
h IORef a
_) = Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt Int
h