{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module: Capnp.TraversalLimit
-- Description: Support for managing message traversal limits.
--
-- This module is used to mitigate several pitfalls with the capnproto format,
-- which could potentially lead to denial of service vulnerabilities.
--
-- In particular, while they are illegal according to the spec, it is possible to
-- encode objects which have many pointers pointing the same place, or even
-- cycles. A naive traversal therefore could involve quite a lot of computation
-- for a message that is very small on the wire.
--
-- Accordingly, most implementations of the format keep track of how many bytes
-- of a message have been accessed, and start signaling errors after a certain
-- value (the "traversal limit") has been reached. The Haskell implementation is
-- no exception; this module implements that logic. We provide a monad
-- transformer and mtl-style type class to track the limit; reading from the
-- message happens inside of this monad.
module Capnp.TraversalLimit
  ( MonadLimit (..),
    LimitT,
    runLimitT,
    evalLimitT,
    execLimitT,
    defaultLimit,
  )
where

-- Just to define 'MonadLimit' instances:

import Capnp.Bits (WordCount)
import Capnp.Errors (Error (TraversalLimitError))
import Control.Monad (when)
import Control.Monad.Catch (MonadCatch (catch), MonadThrow (throwM))
import Control.Monad.Fail (MonadFail (..))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Primitive (PrimMonad (primitive), PrimState)
import Control.Monad.RWS (RWST)
import Control.Monad.Reader (ReaderT)
import qualified Control.Monad.State.Lazy as LazyState
import Control.Monad.State.Strict
  ( MonadState,
    StateT,
    evalStateT,
    execStateT,
    get,
    put,
    runStateT,
  )
import Control.Monad.Trans.Class (MonadTrans (lift))
import Control.Monad.Writer (WriterT)
import Prelude hiding (fail)

-- | mtl-style type class to track the traversal limit. This is used
-- by other parts of the library which actually do the reading.
class Monad m => MonadLimit m where
  -- | @'invoice' n@ deducts @n@ from the traversal limit, signaling
  -- an error if the limit is exhausted.
  invoice :: WordCount -> m ()

-- | Monad transformer implementing 'MonadLimit'. The underlying monad
-- must implement 'MonadThrow'. 'invoice' calls @'throwM' 'TraversalLimitError'@
-- when the limit is exhausted.
newtype LimitT m a = LimitT (StateT WordCount m a)
  deriving (forall a b. a -> LimitT m b -> LimitT m a
forall a b. (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LimitT m b -> LimitT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
fmap :: forall a b. (a -> b) -> LimitT m a -> LimitT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
Functor, forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall a b c.
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall {m :: * -> *}. Monad m => Functor (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. LimitT m a -> LimitT m b -> LimitT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
*> :: forall a b. LimitT m a -> LimitT m b -> LimitT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
liftA2 :: forall a b c.
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
<*> :: forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
pure :: forall a. a -> LimitT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
Applicative, forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *). Monad m => Applicative (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> LimitT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
>> :: forall a b. LimitT m a -> LimitT m b -> LimitT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
>>= :: forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
Monad)

-- | Run a 'LimitT', returning the value from the computation and the remaining
-- traversal limit.
runLimitT :: MonadThrow m => WordCount -> LimitT m a -> m (a, WordCount)
runLimitT :: forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m (a, WordCount)
runLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT WordCount m a
stateT WordCount
limit

-- | Run a 'LimitT', returning the value from the computation.
evalLimitT :: MonadThrow m => WordCount -> LimitT m a -> m a
evalLimitT :: forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT WordCount m a
stateT WordCount
limit

-- | Run a 'LimitT', returning the remaining traversal limit.
execLimitT :: MonadThrow m => WordCount -> LimitT m a -> m WordCount
execLimitT :: forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m WordCount
execLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT StateT WordCount m a
stateT WordCount
limit

-- | A sensible default traversal limit. Currently 64 MiB.
defaultLimit :: WordCount
defaultLimit :: WordCount
defaultLimit = (WordCount
64 forall a. Num a => a -> a -> a
* WordCount
1024 forall a. Num a => a -> a -> a
* WordCount
1024) forall a. Integral a => a -> a -> a
`div` WordCount
8

------ Instances of mtl type classes for 'LimitT'.

instance MonadThrow m => MonadThrow (LimitT m) where
  throwM :: forall e a. Exception e => e -> LimitT m a
throwM = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance MonadCatch m => MonadCatch (LimitT m) where
  catch :: forall e a.
Exception e =>
LimitT m a -> (e -> LimitT m a) -> LimitT m a
catch (LimitT StateT WordCount m a
m) e -> LimitT m a
f = forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT forall a b. (a -> b) -> a -> b
$ do
    forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
catch StateT WordCount m a
m forall a b. (a -> b) -> a -> b
$ \e
e ->
      let LimitT StateT WordCount m a
m' = e -> LimitT m a
f e
e
       in StateT WordCount m a
m'

instance MonadThrow m => MonadLimit (LimitT m) where
  invoice :: WordCount -> LimitT m ()
invoice WordCount
deduct = forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT forall a b. (a -> b) -> a -> b
$ do
    WordCount
limit <- forall s (m :: * -> *). MonadState s m => m s
get
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (WordCount
limit forall a. Ord a => a -> a -> Bool
< WordCount
deduct) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Error
TraversalLimitError
    forall s (m :: * -> *). MonadState s m => s -> m ()
put (WordCount
limit forall a. Num a => a -> a -> a
- WordCount
deduct)

instance MonadTrans LimitT where
  lift :: forall (m :: * -> *) a. Monad m => m a -> LimitT m a
lift = forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance MonadState s m => MonadState s (LimitT m) where
  get :: LimitT m s
get = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> LimitT m ()
put = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance (PrimMonad m, s ~ PrimState m) => PrimMonad (LimitT m) where
  type PrimState (LimitT m) = PrimState m
  primitive :: forall a.
(State# (PrimState (LimitT m))
 -> (# State# (PrimState (LimitT m)), a #))
-> LimitT m a
primitive = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive

instance MonadFail m => MonadFail (LimitT m) where
  fail :: forall a. String -> LimitT m a
fail = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadFail m => String -> m a
fail

instance MonadIO m => MonadIO (LimitT m) where
  liftIO :: forall a. IO a -> LimitT m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

------ Instances of 'MonadLimit' for standard monad transformers

instance MonadLimit m => MonadLimit (StateT s m) where
  invoice :: WordCount -> StateT s m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance MonadLimit m => MonadLimit (LazyState.StateT s m) where
  invoice :: WordCount -> StateT s m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance (Monoid w, MonadLimit m) => MonadLimit (WriterT w m) where
  invoice :: WordCount -> WriterT w m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance (MonadLimit m) => MonadLimit (ReaderT r m) where
  invoice :: WordCount -> ReaderT r m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance (Monoid w, MonadLimit m) => MonadLimit (RWST r w s m) where
  invoice :: WordCount -> RWST r w s m ()
invoice = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice