-- | /Nonces/ prevent replaying of requests. This module provides a -- nonce validation function which stores previous requests while they -- are fresh. module Network.Hawk.Server.Nonce ( nonceOpts , nonceOptsReq ) where import Data.IORef import Data.Sequence (Seq, (|>)) import qualified Data.Sequence as Q import Data.HashSet (HashSet) import qualified Data.HashSet as S import Data.Time.Clock.POSIX import Data.Time.Clock (NominalDiffTime) import Data.Hashable (Hashable) import Data.Foldable (toList) import Network.Hawk.Types (Key) import Network.Hawk.Server (AuthOpts(..), AuthReqOpts(..), def, Nonce, NonceFunc) -- | Creates an 'Hawk.AuthOpts' with a nonce validation function which -- remembers previous nonces for as long as they are valid. The @skew@ -- parameter determines how long a signed request is valid for. nonceOpts :: NominalDiffTime -> IO AuthOpts nonceOpts skew = do ref <- newIORef (Q.empty, S.empty) let nf = makeNonceFunc skew ref return $ AuthOpts nf skew 0 -- | Creates an 'Hawk.AuthReqOpts' with a nonce validation function -- which remembers previous nonces for as long as they are valid. The -- @skew@ parameter determines how long a signed request is valid for. nonceOptsReq :: NominalDiffTime -> IO AuthReqOpts nonceOptsReq skew = do opts <- nonceOpts skew return $ def { saOpts = opts } instance Hashable Key -- Maintain both a queue and set. Queue provides fast expiry of stale -- nonces and the hash set provides a fast test for nonce existence. type Store = (Seq (Key, Nonce, POSIXTime), HashSet (Key, Nonce)) makeNonceFunc :: NominalDiffTime -> IORef Store -> NonceFunc makeNonceFunc skew ref = \k t n -> do now <- getPOSIXTime atomicModifyIORef' ref (update now (abs skew) k n t) update :: POSIXTime -> NominalDiffTime -> Key -> Nonce -> POSIXTime -> Store -> (Store, Bool) update now skew k n t (q, s) = ((q'', s''), fresh) where fresh = (not $ S.member (k, n) s) && t + skew >= now - skew q' | fresh = q |> (k, n, now + skew) | otherwise = q s' | fresh = S.insert (k, n) s | otherwise = s (dead, q'') = Q.breakl (\(_, _, t) -> t >= now) q' s'' = S.difference s' (S.fromList [(k, n) | (k, n, t) <- toList dead])