{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Capnp.TraversalLimit
( MonadLimit (..),
LimitT,
runLimitT,
evalLimitT,
execLimitT,
defaultLimit,
)
where
import Capnp.Bits (WordCount)
import Capnp.Errors (Error (TraversalLimitError))
import Control.Monad (when)
import Control.Monad.Catch (MonadCatch (catch), MonadThrow (throwM))
import Control.Monad.Fail (MonadFail (..))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Primitive (PrimMonad (primitive), PrimState)
import Control.Monad.RWS (RWST)
import Control.Monad.Reader (ReaderT)
import qualified Control.Monad.State.Lazy as LazyState
import Control.Monad.State.Strict
( MonadState,
StateT,
evalStateT,
execStateT,
get,
put,
runStateT,
)
import Control.Monad.Trans.Class (MonadTrans (lift))
import Control.Monad.Writer (WriterT)
import Prelude hiding (fail)
class Monad m => MonadLimit m where
invoice :: WordCount -> m ()
newtype LimitT m a = LimitT (StateT WordCount m a)
deriving (forall a b. a -> LimitT m b -> LimitT m a
forall a b. (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LimitT m b -> LimitT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
fmap :: forall a b. (a -> b) -> LimitT m a -> LimitT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
Functor, forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall a b c.
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall {m :: * -> *}. Monad m => Functor (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. LimitT m a -> LimitT m b -> LimitT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
*> :: forall a b. LimitT m a -> LimitT m b -> LimitT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
liftA2 :: forall a b c.
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
<*> :: forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
pure :: forall a. a -> LimitT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
Applicative, forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *). Monad m => Applicative (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> LimitT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
>> :: forall a b. LimitT m a -> LimitT m b -> LimitT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
>>= :: forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
Monad)
runLimitT :: MonadThrow m => WordCount -> LimitT m a -> m (a, WordCount)
runLimitT :: forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m (a, WordCount)
runLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT WordCount m a
stateT WordCount
limit
evalLimitT :: MonadThrow m => WordCount -> LimitT m a -> m a
evalLimitT :: forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT WordCount m a
stateT WordCount
limit
execLimitT :: MonadThrow m => WordCount -> LimitT m a -> m WordCount
execLimitT :: forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m WordCount
execLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT StateT WordCount m a
stateT WordCount
limit
defaultLimit :: WordCount
defaultLimit :: WordCount
defaultLimit = (WordCount
64 forall a. Num a => a -> a -> a
* WordCount
1024 forall a. Num a => a -> a -> a
* WordCount
1024) forall a. Integral a => a -> a -> a
`div` WordCount
8
instance MonadThrow m => MonadThrow (LimitT m) where
throwM :: forall e a. Exception e => e -> LimitT m a
throwM = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM
instance MonadCatch m => MonadCatch (LimitT m) where
catch :: forall e a.
Exception e =>
LimitT m a -> (e -> LimitT m a) -> LimitT m a
catch (LimitT StateT WordCount m a
m) e -> LimitT m a
f = forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch StateT WordCount m a
m forall a b. (a -> b) -> a -> b
$ \e
e ->
let LimitT StateT WordCount m a
m' = e -> LimitT m a
f e
e
in StateT WordCount m a
m'
instance MonadThrow m => MonadLimit (LimitT m) where
invoice :: WordCount -> LimitT m ()
invoice WordCount
deduct = forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT forall a b. (a -> b) -> a -> b
$ do
WordCount
limit <- forall s (m :: * -> *). MonadState s m => m s
get
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (WordCount
limit forall a. Ord a => a -> a -> Bool
< WordCount
deduct) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Error
TraversalLimitError
forall s (m :: * -> *). MonadState s m => s -> m ()
put (WordCount
limit forall a. Num a => a -> a -> a
- WordCount
deduct)
instance MonadTrans LimitT where
lift :: forall (m :: * -> *) a. Monad m => m a -> LimitT m a
lift = forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
instance MonadState s m => MonadState s (LimitT m) where
get :: LimitT m s
get = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadState s m => m s
get
put :: s -> LimitT m ()
put = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => s -> m ()
put
instance (PrimMonad m, s ~ PrimState m) => PrimMonad (LimitT m) where
type PrimState (LimitT m) = PrimState m
primitive :: forall a.
(State# (PrimState (LimitT m))
-> (# State# (PrimState (LimitT m)), a #))
-> LimitT m a
primitive = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive
instance MonadFail m => MonadFail (LimitT m) where
fail :: forall a. String -> LimitT m a
fail = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadFail m => String -> m a
fail
instance MonadIO m => MonadIO (LimitT m) where
liftIO :: forall a. IO a -> LimitT m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
instance MonadLimit m => MonadLimit (StateT s m) where
invoice :: WordCount -> StateT s m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance MonadLimit m => MonadLimit (LazyState.StateT s m) where
invoice :: WordCount -> StateT s m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance (Monoid w, MonadLimit m) => MonadLimit (WriterT w m) where
invoice :: WordCount -> WriterT w m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance (MonadLimit m) => MonadLimit (ReaderT r m) where
invoice :: WordCount -> ReaderT r m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance (Monoid w, MonadLimit m) => MonadLimit (RWST r w s m) where
invoice :: WordCount -> RWST r w s m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice