{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}
{- |
Module: Capnp.TraversalLimit
Description: Support for managing message traversal limits.

This module is used to mitigate several pitfalls with the capnproto format,
which could potentially lead to denial of service vulnerabilities.

In particular, while they are illegal according to the spec, it is possible to
encode objects which have many pointers pointing the same place, or even
cycles. A naive traversal therefore could involve quite a lot of computation
for a message that is very small on the wire.

Accordingly, most implementations of the format keep track of how many bytes
of a message have been accessed, and start signaling errors after a certain
value (the "traversal limit") has been reached. The Haskell implementation is
no exception; this module implements that logic. We provide a monad
transformer and mtl-style type class to track the limit; reading from the
message happens inside of this monad.

-}
module Capnp.TraversalLimit
    ( MonadLimit(..)
    , LimitT
    , runLimitT
    , evalLimitT
    , execLimitT
    , defaultLimit
    ) where

import Prelude hiding (fail)

import Control.Monad              (when)
import Control.Monad.Catch        (MonadThrow(throwM))
import Control.Monad.Fail         (MonadFail (..))
import Control.Monad.IO.Class     (MonadIO (..))
import Control.Monad.Primitive    (PrimMonad(primitive), PrimState)
import Control.Monad.State.Strict
    (MonadState, StateT, evalStateT, execStateT, get, put, runStateT)
import Control.Monad.Trans.Class (MonadTrans(lift))

-- Just to define 'MonadLimit' instances:
import Control.Monad.RWS    (RWST)
import Control.Monad.Reader (ReaderT)
import Control.Monad.Writer (WriterT)

import qualified Control.Monad.State.Lazy as LazyState

import Capnp.Bits   (WordCount)
import Capnp.Errors (Error(TraversalLimitError))

-- | mtl-style type class to track the traversal limit. This is used
-- by other parts of the library which actually do the reading.
class Monad m => MonadLimit m where
    -- | @'invoice' n@ deducts @n@ from the traversal limit, signaling
    -- an error if the limit is exhausted.
    invoice :: WordCount -> m ()

-- | Monad transformer implementing 'MonadLimit'. The underlying monad
-- must implement 'MonadThrow'. 'invoice' calls @'throwM' 'TraversalLimitError'@
-- when the limit is exhausted.
newtype LimitT m a = LimitT (StateT WordCount m a)
    deriving(a -> LimitT m b -> LimitT m a
(a -> b) -> LimitT m a -> LimitT m b
(forall a b. (a -> b) -> LimitT m a -> LimitT m b)
-> (forall a b. a -> LimitT m b -> LimitT m a)
-> Functor (LimitT m)
forall a b. a -> LimitT m b -> LimitT m a
forall a b. (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> LimitT m b -> LimitT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
fmap :: (a -> b) -> LimitT m a -> LimitT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
Functor, Functor (LimitT m)
a -> LimitT m a
Functor (LimitT m)
-> (forall a. a -> LimitT m a)
-> (forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b)
-> (forall a b c.
    (a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c)
-> (forall a b. LimitT m a -> LimitT m b -> LimitT m b)
-> (forall a b. LimitT m a -> LimitT m b -> LimitT m a)
-> Applicative (LimitT m)
LimitT m a -> LimitT m b -> LimitT m b
LimitT m a -> LimitT m b -> LimitT m a
LimitT m (a -> b) -> LimitT m a -> LimitT m b
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall a b c.
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall (m :: * -> *). Monad m => Functor (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: LimitT m a -> LimitT m b -> LimitT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
*> :: LimitT m a -> LimitT m b -> LimitT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
liftA2 :: (a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
<*> :: LimitT m (a -> b) -> LimitT m a -> LimitT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
pure :: a -> LimitT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (LimitT m)
Applicative, Applicative (LimitT m)
a -> LimitT m a
Applicative (LimitT m)
-> (forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b)
-> (forall a b. LimitT m a -> LimitT m b -> LimitT m b)
-> (forall a. a -> LimitT m a)
-> Monad (LimitT m)
LimitT m a -> (a -> LimitT m b) -> LimitT m b
LimitT m a -> LimitT m b -> LimitT m b
forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *). Monad m => Applicative (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> LimitT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
>> :: LimitT m a -> LimitT m b -> LimitT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
>>= :: LimitT m a -> (a -> LimitT m b) -> LimitT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (LimitT m)
Monad)

-- | Run a 'LimitT', returning the value from the computation and the remaining
-- traversal limit.
runLimitT :: MonadThrow m => WordCount -> LimitT m a -> m (a, WordCount)
runLimitT :: WordCount -> LimitT m a -> m (a, WordCount)
runLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = StateT WordCount m a -> WordCount -> m (a, WordCount)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT WordCount m a
stateT WordCount
limit

-- | Run a 'LimitT', returning the value from the computation.
evalLimitT :: MonadThrow m => WordCount -> LimitT m a -> m a
evalLimitT :: WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = StateT WordCount m a -> WordCount -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT WordCount m a
stateT WordCount
limit

-- | Run a 'LimitT', returning the remaining traversal limit.
execLimitT :: MonadThrow m => WordCount -> LimitT m a -> m WordCount
execLimitT :: WordCount -> LimitT m a -> m WordCount
execLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = StateT WordCount m a -> WordCount -> m WordCount
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT StateT WordCount m a
stateT WordCount
limit

-- | A sensible default traversal limit. Currently 64 MiB.
defaultLimit :: WordCount
defaultLimit :: WordCount
defaultLimit = (WordCount
64 WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
* WordCount
1024 WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
* WordCount
1024) WordCount -> WordCount -> WordCount
forall a. Integral a => a -> a -> a
`div` WordCount
8

------ Instances of mtl type classes for 'LimitT'.

instance MonadThrow m => MonadThrow (LimitT m) where
    throwM :: e -> LimitT m a
throwM = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a) -> (e -> m a) -> e -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM

instance MonadThrow m => MonadLimit (LimitT m) where
    invoice :: WordCount -> LimitT m ()
invoice WordCount
deduct = StateT WordCount m () -> LimitT m ()
forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT (StateT WordCount m () -> LimitT m ())
-> StateT WordCount m () -> LimitT m ()
forall a b. (a -> b) -> a -> b
$ do
        WordCount
limit <- StateT WordCount m WordCount
forall s (m :: * -> *). MonadState s m => m s
get
        Bool -> StateT WordCount m () -> StateT WordCount m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (WordCount
limit WordCount -> WordCount -> Bool
forall a. Ord a => a -> a -> Bool
< WordCount
deduct) (StateT WordCount m () -> StateT WordCount m ())
-> StateT WordCount m () -> StateT WordCount m ()
forall a b. (a -> b) -> a -> b
$ Error -> StateT WordCount m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Error
TraversalLimitError
        WordCount -> StateT WordCount m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (WordCount
limit WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
- WordCount
deduct)

instance MonadTrans LimitT where
    lift :: m a -> LimitT m a
lift = StateT WordCount m a -> LimitT m a
forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT (StateT WordCount m a -> LimitT m a)
-> (m a -> StateT WordCount m a) -> m a -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT WordCount m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

instance MonadState s m => MonadState s (LimitT m) where
    get :: LimitT m s
get = m s -> LimitT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
    put :: s -> LimitT m ()
put = m () -> LimitT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> LimitT m ()) -> (s -> m ()) -> s -> LimitT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance (PrimMonad m, s ~ PrimState m) => PrimMonad (LimitT m) where
    type PrimState (LimitT m) = PrimState m
    primitive :: (State# (PrimState (LimitT m))
 -> (# State# (PrimState (LimitT m)), a #))
-> LimitT m a
primitive = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a)
-> ((State# s -> (# State# s, a #)) -> m a)
-> (State# s -> (# State# s, a #))
-> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State# s -> (# State# s, a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive

instance MonadFail m => MonadFail (LimitT m) where
    fail :: String -> LimitT m a
fail = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a) -> (String -> m a) -> String -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail

instance MonadIO m => MonadIO (LimitT m) where
    liftIO :: IO a -> LimitT m a
liftIO = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a) -> (IO a -> m a) -> IO a -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

------ Instances of 'MonadLimit' for standard monad transformers

instance MonadLimit m => MonadLimit (StateT s m) where
    invoice :: WordCount -> StateT s m ()
invoice = m () -> StateT s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT s m ())
-> (WordCount -> m ()) -> WordCount -> StateT s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance MonadLimit m => MonadLimit (LazyState.StateT s m) where
    invoice :: WordCount -> StateT s m ()
invoice = m () -> StateT s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT s m ())
-> (WordCount -> m ()) -> WordCount -> StateT s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance (Monoid w, MonadLimit m) => MonadLimit (WriterT w m) where
    invoice :: WordCount -> WriterT w m ()
invoice = m () -> WriterT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WriterT w m ())
-> (WordCount -> m ()) -> WordCount -> WriterT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance (MonadLimit m) => MonadLimit (ReaderT r m) where
    invoice :: WordCount -> ReaderT r m ()
invoice = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (WordCount -> m ()) -> WordCount -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice

instance (Monoid w, MonadLimit m) => MonadLimit (RWST r w s m) where
    invoice :: WordCount -> RWST r w s m ()
invoice = m () -> RWST r w s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> RWST r w s m ())
-> (WordCount -> m ()) -> WordCount -> RWST r w s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice