{-# LANGUAGE AllowAmbiguousTypes #-}

-- | Utilities for effectfully memoizing other, more effectful functions.
module Data.OpenApi.Compare.Memo
  ( MonadMemo,
    MemoState,
    runMemo,
    modifyMemoNonce,
    KnotTier (..),
    unknot,
    memoWithKnot,
    memoTaggedWithKnot,
  )
where

import Control.Monad.State
import Data.Dynamic
import qualified Data.Map as M
import Data.Tagged
import qualified Data.TypeRepMap as T
import Data.Void
import Type.Reflection

data Progress a = Finished a | Started | TyingKnot Dynamic

data MemoMap a where
  MemoMap :: !(M.Map k (Progress v)) -> MemoMap (k, v)

data MemoState s = MemoState s (T.TypeRepMap MemoMap)

-- | An effectful memoization monad.
type MonadMemo s m = MonadState (MemoState s) m

memoStateLookup ::
  forall k v s.
  (Typeable k, Typeable v, Ord k) =>
  k ->
  MemoState s ->
  Maybe (Progress v)
memoStateLookup :: k -> MemoState s -> Maybe (Progress v)
memoStateLookup k
k (MemoState s
_ TypeRepMap MemoMap
t) = case TypeRepMap MemoMap -> Maybe (MemoMap (k, v))
forall k (a :: k) (f :: k -> *).
Typeable a =>
TypeRepMap f -> Maybe (f a)
T.lookup @(k, v) TypeRepMap MemoMap
t of
  Just (MemoMap Map k (Progress v)
m) -> k -> Map k (Progress v) -> Maybe (Progress v)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
k Map k (Progress v)
Map k (Progress v)
m
  Maybe (MemoMap (k, v))
Nothing -> Maybe (Progress v)
forall a. Maybe a
Nothing

memoStateInsert ::
  forall k v s.
  (Typeable k, Typeable v, Ord k) =>
  k ->
  Progress v ->
  MemoState s ->
  MemoState s
memoStateInsert :: k -> Progress v -> MemoState s -> MemoState s
memoStateInsert k
k Progress v
x (MemoState s
s TypeRepMap MemoMap
t) = s -> TypeRepMap MemoMap -> MemoState s
forall s. s -> TypeRepMap MemoMap -> MemoState s
MemoState s
s (TypeRepMap MemoMap -> MemoState s)
-> TypeRepMap MemoMap -> MemoState s
forall a b. (a -> b) -> a -> b
$ MemoMap (k, v) -> TypeRepMap MemoMap -> TypeRepMap MemoMap
forall k (a :: k) (f :: k -> *).
Typeable a =>
f a -> TypeRepMap f -> TypeRepMap f
T.insert (Map k (Progress v) -> MemoMap (k, v)
forall k v. Map k (Progress v) -> MemoMap (k, v)
MemoMap Map k (Progress v)
m'') TypeRepMap MemoMap
t
  where
    m'' :: Map k (Progress v)
m'' = k -> Progress v -> Map k (Progress v) -> Map k (Progress v)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
k Progress v
x Map k (Progress v)
m'
    m' :: Map k (Progress v)
m' = case TypeRepMap MemoMap -> Maybe (MemoMap (k, v))
forall k (a :: k) (f :: k -> *).
Typeable a =>
TypeRepMap f -> Maybe (f a)
T.lookup @(k, v) TypeRepMap MemoMap
t of
      Just (MemoMap Map k (Progress v)
m) -> Map k (Progress v)
Map k (Progress v)
m
      Maybe (MemoMap (k, v))
Nothing -> Map k (Progress v)
forall k a. Map k a
M.empty

modifyMemoNonce :: MonadMemo s m => (s -> s) -> m s
modifyMemoNonce :: (s -> s) -> m s
modifyMemoNonce s -> s
f = do
  MemoState s
s TypeRepMap MemoMap
t <- m (MemoState s)
forall s (m :: * -> *). MonadState s m => m s
get
  MemoState s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (MemoState s -> m ()) -> MemoState s -> m ()
forall a b. (a -> b) -> a -> b
$ s -> TypeRepMap MemoMap -> MemoState s
forall s. s -> TypeRepMap MemoMap -> MemoState s
MemoState (s -> s
f s
s) TypeRepMap MemoMap
t
  pure s
s

-- | Run a memoized computation.
runMemo :: Monad m => s -> StateT (MemoState s) m a -> m a
runMemo :: s -> StateT (MemoState s) m a -> m a
runMemo s
s = (StateT (MemoState s) m a -> MemoState s -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` s -> TypeRepMap MemoMap -> MemoState s
forall s. s -> TypeRepMap MemoMap -> MemoState s
MemoState s
s TypeRepMap MemoMap
forall k (f :: k -> *). TypeRepMap f
T.empty)

-- | A description of how to effectfully tie knots in type @v@, using the @m@
-- monad, and by sharing some @d@ data among the recursive instances.
data KnotTier v d m = KnotTier
  { -- | Create some data that will be connected to this knot
    KnotTier v d m -> m d
onKnotFound :: m d
  , -- | This is what the knot will look like as a value
    -- to the inner computations
    KnotTier v d m -> d -> m v
onKnotUsed :: d -> m v
  , -- | Once we're done and we're outside, tie the
    -- knot using the datum
    KnotTier v d m -> d -> v -> m v
tieKnot :: d -> v -> m v
  }

unknot :: KnotTier v Void m
unknot :: KnotTier v Void m
unknot =
  KnotTier :: forall v d (m :: * -> *).
m d -> (d -> m v) -> (d -> v -> m v) -> KnotTier v d m
KnotTier
    { $sel:onKnotFound:KnotTier :: m Void
onKnotFound = [Char] -> m Void
forall a. HasCallStack => [Char] -> a
error [Char]
"Recursion detected"
    , $sel:onKnotUsed:KnotTier :: Void -> m v
onKnotUsed = Void -> m v
forall a. Void -> a
absurd
    , $sel:tieKnot:KnotTier :: Void -> v -> m v
tieKnot = Void -> v -> m v
forall a. Void -> a
absurd
    }

-- | Run a potentially recursive computation. The provided key will be used to
-- refer to the result of this computation. If during the computation, another
-- attempt to run the computation with the same key is made, we run a
-- tying-the-knot procedure.
--
-- If another attempt to run the computation with the same key is made
-- *after we're done*, we will return the memoized value.
memoWithKnot ::
  forall k v d m s.
  (Typeable k, Typeable v, Typeable d, Ord k, MonadMemo s m) =>
  KnotTier v d m ->
  -- | the computation to memoize
  m v ->
  -- | key for memoization
  k ->
  m v
memoWithKnot :: KnotTier v d m -> m v -> k -> m v
memoWithKnot KnotTier v d m
tier m v
f k
k =
  k -> MemoState s -> Maybe (Progress v)
forall k v s.
(Typeable k, Typeable v, Ord k) =>
k -> MemoState s -> Maybe (Progress v)
memoStateLookup @k @v k
k (MemoState s -> Maybe (Progress v))
-> m (MemoState s) -> m (Maybe (Progress v))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (MemoState s)
forall s (m :: * -> *). MonadState s m => m s
get m (Maybe (Progress v)) -> (Maybe (Progress v) -> m v) -> m v
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Just (Finished v
v) -> v -> m v
forall (f :: * -> *) a. Applicative f => a -> f a
pure v
v
    Just Progress v
Started -> do
      d
d <- KnotTier v d m -> m d
forall v d (m :: * -> *). KnotTier v d m -> m d
onKnotFound KnotTier v d m
tier
      (MemoState s -> MemoState s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((MemoState s -> MemoState s) -> m ())
-> (MemoState s -> MemoState s) -> m ()
forall a b. (a -> b) -> a -> b
$ k -> Progress v -> MemoState s -> MemoState s
forall k v s.
(Typeable k, Typeable v, Ord k) =>
k -> Progress v -> MemoState s -> MemoState s
memoStateInsert @k @v k
k (Dynamic -> Progress v
forall a. Dynamic -> Progress a
TyingKnot (Dynamic -> Progress v) -> Dynamic -> Progress v
forall a b. (a -> b) -> a -> b
$ d -> Dynamic
forall a. Typeable a => a -> Dynamic
toDyn d
d)
      KnotTier v d m -> d -> m v
forall v d (m :: * -> *). KnotTier v d m -> d -> m v
onKnotUsed KnotTier v d m
tier d
d
    Just (TyingKnot Dynamic
dyn) -> case Dynamic -> Maybe d
forall a. Typeable a => Dynamic -> Maybe a
fromDynamic Dynamic
dyn of
      Just d
d -> KnotTier v d m -> d -> m v
forall v d (m :: * -> *). KnotTier v d m -> d -> m v
onKnotUsed KnotTier v d m
tier d
d
      Maybe d
Nothing ->
        [Char] -> m v
forall a. HasCallStack => [Char] -> a
error ([Char] -> m v) -> [Char] -> m v
forall a b. (a -> b) -> a -> b
$
          [Char]
"Type mismatch when examining the knot of "
            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> TypeRep (k -> v) -> [Char]
forall a. Show a => a -> [Char]
show (Typeable (k -> v) => TypeRep (k -> v)
forall k (a :: k). Typeable a => TypeRep a
typeRep @(k -> v))
            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
": expected "
            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> TypeRep d -> [Char]
forall a. Show a => a -> [Char]
show (Typeable d => TypeRep d
forall k (a :: k). Typeable a => TypeRep a
typeRep @d)
            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
", got "
            [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> SomeTypeRep -> [Char]
forall a. Show a => a -> [Char]
show (Dynamic -> SomeTypeRep
dynTypeRep Dynamic
dyn)
    Maybe (Progress v)
Nothing -> do
      (MemoState s -> MemoState s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((MemoState s -> MemoState s) -> m ())
-> (MemoState s -> MemoState s) -> m ()
forall a b. (a -> b) -> a -> b
$ k -> Progress v -> MemoState s -> MemoState s
forall k v s.
(Typeable k, Typeable v, Ord k) =>
k -> Progress v -> MemoState s -> MemoState s
memoStateInsert @k @v k
k Progress v
forall a. Progress a
Started
      v
v <- m v
f
      v
v' <-
        k -> MemoState s -> Maybe (Progress v)
forall k v s.
(Typeable k, Typeable v, Ord k) =>
k -> MemoState s -> Maybe (Progress v)
memoStateLookup @k @v k
k (MemoState s -> Maybe (Progress v))
-> m (MemoState s) -> m (Maybe (Progress v))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (MemoState s)
forall s (m :: * -> *). MonadState s m => m s
get m (Maybe (Progress v)) -> (Maybe (Progress v) -> m v) -> m v
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Just Progress v
Started -> v -> m v
forall (f :: * -> *) a. Applicative f => a -> f a
pure v
v
          Just (TyingKnot Dynamic
dyn) -> case Dynamic -> Maybe d
forall a. Typeable a => Dynamic -> Maybe a
fromDynamic Dynamic
dyn of
            Just d
d -> KnotTier v d m -> d -> v -> m v
forall v d (m :: * -> *). KnotTier v d m -> d -> v -> m v
tieKnot KnotTier v d m
tier d
d v
v
            Maybe d
Nothing ->
              [Char] -> m v
forall a. HasCallStack => [Char] -> a
error ([Char] -> m v) -> [Char] -> m v
forall a b. (a -> b) -> a -> b
$
                [Char]
"Type mismatch when tying the knot of "
                  [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> TypeRep (k -> v) -> [Char]
forall a. Show a => a -> [Char]
show (Typeable (k -> v) => TypeRep (k -> v)
forall k (a :: k). Typeable a => TypeRep a
typeRep @(k -> v))
                  [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
": expected "
                  [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> TypeRep d -> [Char]
forall a. Show a => a -> [Char]
show (Typeable d => TypeRep d
forall k (a :: k). Typeable a => TypeRep a
typeRep @d)
                  [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
", got "
                  [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> SomeTypeRep -> [Char]
forall a. Show a => a -> [Char]
show (Dynamic -> SomeTypeRep
dynTypeRep Dynamic
dyn)
          Just (Finished v
_) ->
            [Char] -> m v
forall a. HasCallStack => [Char] -> a
error ([Char] -> m v) -> [Char] -> m v
forall a b. (a -> b) -> a -> b
$
              [Char]
"Unexpected Finished when memoizing "
                [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> TypeRep (k -> v) -> [Char]
forall a. Show a => a -> [Char]
show (Typeable (k -> v) => TypeRep (k -> v)
forall k (a :: k). Typeable a => TypeRep a
typeRep @(k -> v))
          Maybe (Progress v)
Nothing -> v -> m v
forall (f :: * -> *) a. Applicative f => a -> f a
pure v
v
      -- Normally this would be an error, but the underlying monad can refuse
      -- to remember memoization state
      (MemoState s -> MemoState s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((MemoState s -> MemoState s) -> m ())
-> (MemoState s -> MemoState s) -> m ()
forall a b. (a -> b) -> a -> b
$ k -> Progress v -> MemoState s -> MemoState s
forall k v s.
(Typeable k, Typeable v, Ord k) =>
k -> Progress v -> MemoState s -> MemoState s
memoStateInsert @k @v k
k (v -> Progress v
forall a. a -> Progress a
Finished v
v')
      pure v
v'

-- | Disambiguate memoized computations with an arbitrary tag.
memoTaggedWithKnot ::
  forall t k v d m s.
  ( Typeable t
  , Typeable k
  , Typeable v
  , Typeable d
  , Ord k
  , MonadMemo s m
  ) =>
  KnotTier v d m ->
  m v ->
  k ->
  m v
memoTaggedWithKnot :: KnotTier v d m -> m v -> k -> m v
memoTaggedWithKnot KnotTier v d m
tier m v
f k
k =
  TypeRep k -> (Typeable k => m v) -> m v
forall k (a :: k) r. TypeRep a -> (Typeable a => r) -> r
withTypeable (TypeRep t -> TypeRep k
forall k (a :: k). TypeRep a -> TypeRep k
typeRepKind (TypeRep t -> TypeRep k) -> TypeRep t -> TypeRep k
forall a b. (a -> b) -> a -> b
$ Typeable t => TypeRep t
forall k (a :: k). Typeable a => TypeRep a
typeRep @t) ((Typeable k => m v) -> m v) -> (Typeable k => m v) -> m v
forall a b. (a -> b) -> a -> b
$
    KnotTier v d m -> m v -> Tagged t k -> m v
forall k v d (m :: * -> *) s.
(Typeable k, Typeable v, Typeable d, Ord k, MonadMemo s m) =>
KnotTier v d m -> m v -> k -> m v
memoWithKnot KnotTier v d m
tier m v
f (k -> Tagged t k
forall k (s :: k) b. b -> Tagged s b
Tagged @t k
k)