{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Data.HashCons.Internal where

import Control.Exception
import Control.Monad (when)
import Data.Hashable (Hashable, hash, hashWithSalt)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import GHC.Base (compareInt#, Int#, IO (..), anyToAddr#, addr2Int#)
import GHC.Exts (Any, Addr#, unsafeCoerce#)
import System.IO.Unsafe (unsafeDupablePerformIO)
import Text.ParserCombinators.ReadPrec (step)
import Text.Read (Read(..), lexP, parens, prec)
import Text.Read.Lex (Lexeme (Ident))

import Debug.Trace

-- | 'HashCons' with a precomputed hash and an 'IORef' to the value.
--
-- WARNING: Do not use this type to wrap types whose Eq or Ord instances
-- allow distinguishable values to compare as equal; this will result in
-- nondeterminism or even visible mutation of semantically-immutable
-- values at runtime.
data HashCons a = HashConsC
  { forall a. HashCons a -> Int
_hashCons_hash :: {-# UNPACK #-} !Int       -- ^ Precomputed hash
  , forall a. HashCons a -> IORef a
_hashCons_ref  :: {-# UNPACK #-} !(IORef a) -- ^ Reference to the value
  }

pattern HashCons :: Hashable a => () => a -> HashCons a
pattern $mHashCons :: forall {r} {a}.
Hashable a =>
HashCons a -> (a -> r) -> ((# #) -> r) -> r
$bHashCons :: forall a. Hashable a => a -> HashCons a
HashCons x <- (unHashCons -> x) where
  HashCons a
x = a -> HashCons a
forall a. Hashable a => a -> HashCons a
hashCons a
x

-- | Create a new 'HashCons'.
hashCons :: Hashable a => a -> HashCons a
hashCons :: forall a. Hashable a => a -> HashCons a
hashCons a
a = Int -> IORef a -> HashCons a
forall a. Int -> IORef a -> HashCons a
HashConsC (a -> Int
forall a. Hashable a => a -> Int
hash a
a) (IORef a -> HashCons a) -> IORef a -> HashCons a
forall a b. (a -> b) -> a -> b
$ IO (IORef a) -> IORef a
forall a. IO a -> a
unsafeDupablePerformIO (IO (IORef a) -> IORef a) -> IO (IORef a) -> IORef a
forall a b. (a -> b) -> a -> b
$ a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef a
a
{-# INLINE hashCons #-}

-- | Extract the value from a 'HashCons'.
unHashCons :: HashCons a -> a
unHashCons :: forall a. HashCons a -> a
unHashCons (HashConsC Int
_ IORef a
ref) = IO a -> a
forall a. IO a -> a
unsafeDupablePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref
{-# INLINE unHashCons #-}

-- | Show instance that displays 'HashCons' in the format "hashCons <x>"
instance Show a => Show (HashCons a) where
    showsPrec :: Int -> HashCons a -> ShowS
showsPrec Int
d HashCons a
hc = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
appPrec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
        String -> ShowS
showString String
"hashCons " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec (Int
appPrec Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (HashCons a -> a
forall a. HashCons a -> a
unHashCons HashCons a
hc)
      where
        appPrec :: Int
appPrec = Int
10

-- | Read instance that parses 'HashCons' from the format "hashCons <x>"
instance (Read a, Hashable a) => Read (HashCons a) where
  readPrec :: ReadPrec (HashCons a)
readPrec = ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a. ReadPrec a -> ReadPrec a
parens (ReadPrec (HashCons a) -> ReadPrec (HashCons a))
-> ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a b. (a -> b) -> a -> b
$ Int -> ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a. Int -> ReadPrec a -> ReadPrec a
prec Int
10 (ReadPrec (HashCons a) -> ReadPrec (HashCons a))
-> ReadPrec (HashCons a) -> ReadPrec (HashCons a)
forall a b. (a -> b) -> a -> b
$ do
    Ident String
"hashCons" <- ReadPrec Lexeme
lexP
    a
a <- ReadPrec a -> ReadPrec a
forall a. ReadPrec a -> ReadPrec a
step ReadPrec a
forall a. Read a => ReadPrec a
readPrec
    HashCons a -> ReadPrec (HashCons a)
forall a. a -> ReadPrec a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashCons a -> ReadPrec (HashCons a))
-> HashCons a -> ReadPrec (HashCons a)
forall a b. (a -> b) -> a -> b
$ a -> HashCons a
forall a. Hashable a => a -> HashCons a
hashCons a
a

instance Eq a => Eq (HashCons a) where
  HashConsC Int
h1 IORef a
ref1 == :: HashCons a -> HashCons a -> Bool
== HashConsC Int
h2 IORef a
ref2
    | IORef a
ref1 IORef a -> IORef a -> Bool
forall a. Eq a => a -> a -> Bool
== IORef a
ref2 = Bool
True
    | Int
h1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
h2 = Bool
False
    | Bool
otherwise = (a -> a -> Bool) -> Bool -> IORef a -> IORef a -> Bool
forall r a. Eq r => (a -> a -> r) -> r -> IORef a -> IORef a -> r
compareAndSubstitute (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) :: a -> a -> Bool) Bool
True IORef a
ref1 IORef a
ref2
  {- INLINE (==) #-}

-- | NOTE: This instance orders by hash first, and only secondarily by
-- the 'Ord' instance of 'a', to improve performance.
instance Ord a => Ord (HashCons a) where
  compare :: HashCons a -> HashCons a -> Ordering
compare (HashConsC Int
h1 IORef a
ref1) (HashConsC Int
h2 IORef a
ref2) = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
h1 Int
h2 of
    Ordering
EQ -> if IORef a
ref1 IORef a -> IORef a -> Bool
forall a. Eq a => a -> a -> Bool
== IORef a
ref2
      then Ordering
EQ
      else (a -> a -> Ordering) -> Ordering -> IORef a -> IORef a -> Ordering
forall r a. Eq r => (a -> a -> r) -> r -> IORef a -> IORef a -> r
compareAndSubstitute a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Ordering
EQ IORef a
ref1 IORef a
ref2
    Ordering
result -> Ordering
result
  {-# INLINE compare #-}

instance Eq a => Hashable (HashCons a) where
  hashWithSalt :: Int -> HashCons a -> Int
hashWithSalt Int
salt (HashConsC Int
h IORef a
_) = Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt Int
h
  {-# INLINE hashWithSalt #-}

-- Compare the values in the IORefs with the given comparator, and if the result
-- indicates that they are equal, replace one with the other, preferring the one
-- whose pointer is lower.  This is not expected to be totally stable, but it
-- should be *somewhat* stable, and should push us in direction of coalescing
-- more values.  Without this, if you have a, b, and c, all with equal but
-- distinct values, and compare b == a and b == c repeatedly, but never compare
-- a == c, you could end up with the value of b flapping between that of a and
-- c, costing the worst-case equality check time repeatedly, and never settling
-- on a particular representation of the value.  With this, you should settle on
-- a single value unless you get extremely unlucky with the way that addresses
-- move around.
compareAndSubstitute
  :: Eq r
  => (a -> a -> r)
  -> r
  -> IORef a
  -> IORef a
  -> r
compareAndSubstitute :: forall r a. Eq r => (a -> a -> r) -> r -> IORef a -> IORef a -> r
compareAndSubstitute a -> a -> r
cmp r
eq IORef a
ref1 IORef a
ref2  = IO r -> r
forall a. IO a -> a
unsafeDupablePerformIO (IO r -> r) -> IO r -> r
forall a b. (a -> b) -> a -> b
$ do
  a
a1 <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref1
  a
a2 <- IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
ref2
  let result :: r
result = a
a1 a -> a -> r
`cmp` a
a2
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (r
result r -> r -> Bool
forall a. Eq a => a -> a -> Bool
== r
eq) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    -- NOTE: These should already be forced by (==), but in the unlikely event
    -- that they are not (i.e. because (==) on their type unconditionally
    -- returns True), we need to ensure they are not thunks, according to the
    -- documentation of anyToAddr#
    a -> IO a
forall a. a -> IO a
evaluate a
a1
    a -> IO a
forall a. a -> IO a
evaluate a
a2
    -- NOTE: There is a race condition here: the addresses could change in
    -- between when they are read.  However, since either (or neither) swap is
    -- fine, we are OK with this only working "most" of the time (which we
    -- expect to be a very high fraction).
    Ordering
addrCmpResult <- (State# RealWorld -> (# State# RealWorld, Ordering #))
-> IO Ordering
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Ordering #))
 -> IO Ordering)
-> (State# RealWorld -> (# State# RealWorld, Ordering #))
-> IO Ordering
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
      case Any -> State# RealWorld -> (# State# RealWorld, Addr# #)
forall a. a -> State# RealWorld -> (# State# RealWorld, Addr# #)
anyToAddr# (a -> Any
forall a b. a -> b
unsafeCoerce# a
a1 :: Any) State# RealWorld
s of
        (# State# RealWorld
s', Addr#
addr1 #) -> case Any -> State# RealWorld -> (# State# RealWorld, Addr# #)
forall a. a -> State# RealWorld -> (# State# RealWorld, Addr# #)
anyToAddr# (a -> Any
forall a b. a -> b
unsafeCoerce# a
a2 :: Any) State# RealWorld
s' of
          (# State# RealWorld
s'', Addr#
addr2 #) -> (# State# RealWorld
s'', Addr# -> Int#
addr2Int# Addr#
addr1 Int# -> Int# -> Ordering
`compareInt#` Addr# -> Int#
addr2Int# Addr#
addr2 #)
    case Ordering
addrCmpResult of
      Ordering
LT -> IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
ref2 a
a1
      Ordering
GT -> IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
ref1 a
a2
      Ordering
EQ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  r -> IO r
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure r
result
{-# INLINE compareAndSubstitute #-}