{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
module Data.StableMemo.Internal (Ref (..), Strong (..), memo) where

import Data.HashTable.IO (BasicHashTable)
import qualified Data.HashTable.IO as HashTable
import Data.Proxy
import System.IO.Unsafe (unsafePerformIO)
import System.Mem.StableName
import System.Mem.Weak (Weak)
import qualified System.Mem.Weak as Weak
import Unsafe.Coerce (unsafeCoerce)

#if MIN_VERSION_base(4,10,0)
import GHC.Types (Any)
#else
import GHC.Prim (Any)
#endif

newtype (f <<< g) a = O { forall {k} {k} (f :: k -> *) (g :: k -> k) (a :: k).
(<<<) f g a -> f (g a)
unO :: f (g a) }

-- Invariant: The type parameters for a key and its corresponding
-- value are the same.
type SNMap f g = BasicHashTable (StableName (f Any)) (g Any)

type MemoTable ref f g = SNMap f (ref <<< g)

class Ref ref where
  mkRef    :: a -> b -> IO () -> IO (ref b)
  deRef    :: ref a -> IO (Maybe a)
  finalize :: ref a -> IO ()

instance Ref Weak where
  mkRef :: forall a b. a -> b -> IO () -> IO (Weak b)
mkRef a
x b
y = forall k v. k -> v -> Maybe (IO ()) -> IO (Weak v)
Weak.mkWeak a
x b
y forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just
  deRef :: forall a. Weak a -> IO (Maybe a)
deRef = forall a. Weak a -> IO (Maybe a)
Weak.deRefWeak
  finalize :: forall a. Weak a -> IO ()
finalize = forall a. Weak a -> IO ()
Weak.finalize

data Strong a = Strong a !(Weak a)

instance Ref Strong where
  mkRef :: forall a b. a -> b -> IO () -> IO (Strong b)
mkRef a
_ b
y IO ()
final = do
    Weak b
weak <- forall k. k -> Maybe (IO ()) -> IO (Weak k)
Weak.mkWeakPtr b
y forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just IO ()
final
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Weak a -> Strong a
Strong b
y Weak b
weak
  deRef :: forall a. Strong a -> IO (Maybe a)
deRef (Strong a
x Weak a
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just a
x
  finalize :: forall a. Strong a -> IO ()
finalize (Strong a
_ Weak a
weak) = forall a. Weak a -> IO ()
Weak.finalize Weak a
weak

finalizer :: StableName (f Any) -> Weak (MemoTable ref f g) -> IO ()
finalizer :: forall {k} {k} {k} (f :: k -> *) (ref :: k -> *) (g :: k -> k).
StableName (f Any) -> Weak (MemoTable ref f g) -> IO ()
finalizer StableName (f Any)
sn Weak (MemoTable ref f g)
weakTbl = do
  Maybe (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any))
r <- forall a. Weak a -> IO (Maybe a)
Weak.deRefWeak Weak (MemoTable ref f g)
weakTbl
  case Maybe (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any))
r of
    Maybe (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any))
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Just HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl -> forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO ()
HashTable.delete HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl StableName (f Any)
sn

unsafeToAny :: f a -> f Any
unsafeToAny :: forall {k} (f :: k -> *) (a :: k). f a -> f Any
unsafeToAny = forall a b. a -> b
unsafeCoerce

unsafeFromAny :: f Any -> f a
unsafeFromAny :: forall {k} (f :: k -> *) (a :: k). f Any -> f a
unsafeFromAny = forall a b. a -> b
unsafeCoerce

memo' :: Ref ref =>
         Proxy ref -> (forall a. f a -> g a) -> MemoTable ref f g ->
         Weak (MemoTable ref f g) -> f b -> g b
memo' :: forall {k} (ref :: * -> *) (f :: k -> *) (g :: k -> *) (b :: k).
Ref ref =>
Proxy ref
-> (forall (a :: k). f a -> g a)
-> MemoTable ref f g
-> Weak (MemoTable ref f g)
-> f b
-> g b
memo' Proxy ref
_ forall (a :: k). f a -> g a
f MemoTable ref f g
tbl Weak (MemoTable ref f g)
weakTbl !f b
x = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  StableName (f Any)
sn <- forall a. a -> IO (StableName a)
makeStableName forall a b. (a -> b) -> a -> b
$ forall {k} (f :: k -> *) (a :: k). f a -> f Any
unsafeToAny f b
x
  Maybe ((<<<) ref g Any)
lkp <- forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> IO (Maybe v)
HashTable.lookup MemoTable ref f g
tbl StableName (f Any)
sn
  case Maybe ((<<<) ref g Any)
lkp of
    Maybe ((<<<) ref g Any)
Nothing -> StableName (f Any) -> IO (g b)
notFound StableName (f Any)
sn
    Just (O ref (g Any)
w) -> do
      Maybe (g Any)
maybeVal <- forall (ref :: * -> *) a. Ref ref => ref a -> IO (Maybe a)
deRef ref (g Any)
w
      case Maybe (g Any)
maybeVal of
        Maybe (g Any)
Nothing -> StableName (f Any) -> IO (g b)
notFound StableName (f Any)
sn
        Just g Any
val -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall {k} (f :: k -> *) (a :: k). f Any -> f a
unsafeFromAny g Any
val
  where notFound :: StableName (f Any) -> IO (g b)
notFound StableName (f Any)
sn = do
          let y :: g b
y = forall (a :: k). f a -> g a
f f b
x
          ref (g Any)
weak <- forall (ref :: * -> *) a b.
Ref ref =>
a -> b -> IO () -> IO (ref b)
mkRef f b
x (forall {k} (f :: k -> *) (a :: k). f a -> f Any
unsafeToAny g b
y) forall a b. (a -> b) -> a -> b
$ forall {k} {k} {k} (f :: k -> *) (ref :: k -> *) (g :: k -> k).
StableName (f Any) -> Weak (MemoTable ref f g) -> IO ()
finalizer StableName (f Any)
sn Weak (MemoTable ref f g)
weakTbl
          forall (h :: * -> * -> * -> *) k v.
(HashTable h, Eq k, Hashable k) =>
IOHashTable h k v -> k -> v -> IO ()
HashTable.insert MemoTable ref f g
tbl StableName (f Any)
sn forall a b. (a -> b) -> a -> b
$ forall {k} {k} (f :: k -> *) (g :: k -> k) (a :: k).
f (g a) -> (<<<) f g a
O ref (g Any)
weak
          forall (m :: * -> *) a. Monad m => a -> m a
return g b
y

tableFinalizer :: Ref ref => MemoTable ref f g -> IO ()
tableFinalizer :: forall {k} {k} (ref :: * -> *) (f :: k -> *) (g :: k -> *).
Ref ref =>
MemoTable ref f g -> IO ()
tableFinalizer = forall (h :: * -> * -> * -> *) k v a.
HashTable h =>
((k, v) -> IO a) -> IOHashTable h k v -> IO ()
HashTable.mapM_ forall a b. (a -> b) -> a -> b
$ forall (ref :: * -> *) a. Ref ref => ref a -> IO ()
finalize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {k} (f :: k -> *) (g :: k -> k) (a :: k).
(<<<) f g a -> f (g a)
unO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd

memo :: Ref ref => Proxy (ref :: * -> *) -> (forall a. f a -> g a) -> f b -> g b
memo :: forall {k} (ref :: * -> *) (f :: k -> *) (g :: k -> *) (b :: k).
Ref ref =>
Proxy ref -> (forall (a :: k). f a -> g a) -> f b -> g b
memo Proxy ref
p forall (a :: k). f a -> g a
f =
  let (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl, Weak (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any))
weak) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
        HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl' <- forall (h :: * -> * -> * -> *) k v.
HashTable h =>
IO (IOHashTable h k v)
HashTable.new
        Weak (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any))
weak' <- forall k. k -> Maybe (IO ()) -> IO (Weak k)
Weak.mkWeakPtr HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {k} {k} (ref :: * -> *) (f :: k -> *) (g :: k -> *).
Ref ref =>
MemoTable ref f g -> IO ()
tableFinalizer HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl
        forall (m :: * -> *) a. Monad m => a -> m a
return (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl', Weak (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any))
weak')
  in forall {k} (ref :: * -> *) (f :: k -> *) (g :: k -> *) (b :: k).
Ref ref =>
Proxy ref
-> (forall (a :: k). f a -> g a)
-> MemoTable ref f g
-> Weak (MemoTable ref f g)
-> f b
-> g b
memo' Proxy ref
p forall (a :: k). f a -> g a
f forall {k} {k} {f :: k -> *} {g :: k -> *}.
HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any)
tbl forall {k} {k} {f :: k -> *} {g :: k -> *}.
Weak (HashTable RealWorld (StableName (f Any)) ((<<<) ref g Any))
weak