--   This Source Code Form is subject to the terms of the Mozilla Public
--   License, v. 2.0. If a copy of the MPL was not distributed with this
--   file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# OPTIONS_HADDOCK show-extensions #-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | JWT decoding definition
--   
--   __This module can be considered internal to the library__
--   Users should never need to implement the `Decode` typeclass or use any of the exported functions or types directly.
--   You'll only need to know of `Decode` typeclass if you want to write a function polymorphic in the type of payloads. 
--
--   If you want to extend the types supported by the library, see "Libjwt.Classes"
module Libjwt.Decoding
  ( DecodeResult(..)
  , hoistResult
  , ClaimDecoder(..)
  , Decode(..)
  , getOrEmpty
  , decodeClaimOrThrow
  , decodeClaimProxied
  , Decodable
  , JwtIO
  )
where

import           Libjwt.Classes
import           Libjwt.Exceptions              ( MissingClaim(..) )
import           Libjwt.FFI.Jwt
import           Libjwt.JsonByteString
import           Libjwt.NumericDate

import           Control.Applicative            ( Alternative )
import           Control.Monad                  ( (<=<) )

import           Control.Monad.Catch            ( throwM )

import           Control.Monad.Trans.Maybe

import           Data.ByteString                ( ByteString )

import           Data.Coerce
import           Data.Kind                      ( Constraint )
import           Data.Maybe                     ( fromMaybe )
import           Data.Proxy

newtype DecodeResult t = Result { DecodeResult t -> JwtIO (Maybe t)
getOptional :: JwtIO (Maybe t) }
  deriving (a -> DecodeResult b -> DecodeResult a
(a -> b) -> DecodeResult a -> DecodeResult b
(forall a b. (a -> b) -> DecodeResult a -> DecodeResult b)
-> (forall a b. a -> DecodeResult b -> DecodeResult a)
-> Functor DecodeResult
forall a b. a -> DecodeResult b -> DecodeResult a
forall a b. (a -> b) -> DecodeResult a -> DecodeResult b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DecodeResult b -> DecodeResult a
$c<$ :: forall a b. a -> DecodeResult b -> DecodeResult a
fmap :: (a -> b) -> DecodeResult a -> DecodeResult b
$cfmap :: forall a b. (a -> b) -> DecodeResult a -> DecodeResult b
Functor, Functor DecodeResult
a -> DecodeResult a
Functor DecodeResult
-> (forall a. a -> DecodeResult a)
-> (forall a b.
    DecodeResult (a -> b) -> DecodeResult a -> DecodeResult b)
-> (forall a b c.
    (a -> b -> c)
    -> DecodeResult a -> DecodeResult b -> DecodeResult c)
-> (forall a b. DecodeResult a -> DecodeResult b -> DecodeResult b)
-> (forall a b. DecodeResult a -> DecodeResult b -> DecodeResult a)
-> Applicative DecodeResult
DecodeResult a -> DecodeResult b -> DecodeResult b
DecodeResult a -> DecodeResult b -> DecodeResult a
DecodeResult (a -> b) -> DecodeResult a -> DecodeResult b
(a -> b -> c) -> DecodeResult a -> DecodeResult b -> DecodeResult c
forall a. a -> DecodeResult a
forall a b. DecodeResult a -> DecodeResult b -> DecodeResult a
forall a b. DecodeResult a -> DecodeResult b -> DecodeResult b
forall a b.
DecodeResult (a -> b) -> DecodeResult a -> DecodeResult b
forall a b c.
(a -> b -> c) -> DecodeResult a -> DecodeResult b -> DecodeResult c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DecodeResult a -> DecodeResult b -> DecodeResult a
$c<* :: forall a b. DecodeResult a -> DecodeResult b -> DecodeResult a
*> :: DecodeResult a -> DecodeResult b -> DecodeResult b
$c*> :: forall a b. DecodeResult a -> DecodeResult b -> DecodeResult b
liftA2 :: (a -> b -> c) -> DecodeResult a -> DecodeResult b -> DecodeResult c
$cliftA2 :: forall a b c.
(a -> b -> c) -> DecodeResult a -> DecodeResult b -> DecodeResult c
<*> :: DecodeResult (a -> b) -> DecodeResult a -> DecodeResult b
$c<*> :: forall a b.
DecodeResult (a -> b) -> DecodeResult a -> DecodeResult b
pure :: a -> DecodeResult a
$cpure :: forall a. a -> DecodeResult a
$cp1Applicative :: Functor DecodeResult
Applicative, Applicative DecodeResult
a -> DecodeResult a
Applicative DecodeResult
-> (forall a b.
    DecodeResult a -> (a -> DecodeResult b) -> DecodeResult b)
-> (forall a b. DecodeResult a -> DecodeResult b -> DecodeResult b)
-> (forall a. a -> DecodeResult a)
-> Monad DecodeResult
DecodeResult a -> (a -> DecodeResult b) -> DecodeResult b
DecodeResult a -> DecodeResult b -> DecodeResult b
forall a. a -> DecodeResult a
forall a b. DecodeResult a -> DecodeResult b -> DecodeResult b
forall a b.
DecodeResult a -> (a -> DecodeResult b) -> DecodeResult b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DecodeResult a
$creturn :: forall a. a -> DecodeResult a
>> :: DecodeResult a -> DecodeResult b -> DecodeResult b
$c>> :: forall a b. DecodeResult a -> DecodeResult b -> DecodeResult b
>>= :: DecodeResult a -> (a -> DecodeResult b) -> DecodeResult b
$c>>= :: forall a b.
DecodeResult a -> (a -> DecodeResult b) -> DecodeResult b
$cp1Monad :: Applicative DecodeResult
Monad, Applicative DecodeResult
DecodeResult a
Applicative DecodeResult
-> (forall a. DecodeResult a)
-> (forall a. DecodeResult a -> DecodeResult a -> DecodeResult a)
-> (forall a. DecodeResult a -> DecodeResult [a])
-> (forall a. DecodeResult a -> DecodeResult [a])
-> Alternative DecodeResult
DecodeResult a -> DecodeResult a -> DecodeResult a
DecodeResult a -> DecodeResult [a]
DecodeResult a -> DecodeResult [a]
forall a. DecodeResult a
forall a. DecodeResult a -> DecodeResult [a]
forall a. DecodeResult a -> DecodeResult a -> DecodeResult a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: DecodeResult a -> DecodeResult [a]
$cmany :: forall a. DecodeResult a -> DecodeResult [a]
some :: DecodeResult a -> DecodeResult [a]
$csome :: forall a. DecodeResult a -> DecodeResult [a]
<|> :: DecodeResult a -> DecodeResult a -> DecodeResult a
$c<|> :: forall a. DecodeResult a -> DecodeResult a -> DecodeResult a
empty :: DecodeResult a
$cempty :: forall a. DecodeResult a
$cp1Alternative :: Applicative DecodeResult
Alternative) via (MaybeT JwtIO)

-- | Lift pure value
hoistResult :: Maybe a -> DecodeResult a
hoistResult :: Maybe a -> DecodeResult a
hoistResult = JwtIO (Maybe a) -> DecodeResult a
forall t. JwtIO (Maybe t) -> DecodeResult t
Result (JwtIO (Maybe a) -> DecodeResult a)
-> (Maybe a -> JwtIO (Maybe a)) -> Maybe a -> DecodeResult a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> JwtIO (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | Action that returns 'mempty' if decoding has failed
getOrEmpty :: (Monoid a) => DecodeResult a -> JwtIO a
getOrEmpty :: DecodeResult a -> JwtIO a
getOrEmpty (Result JwtIO (Maybe a)
x) = a -> Maybe a -> a
forall a. a -> Maybe a -> a
fromMaybe a
forall a. Monoid a => a
mempty (Maybe a -> a) -> JwtIO (Maybe a) -> JwtIO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JwtIO (Maybe a)
x

-- | 'decodeClaim' through proxy
decodeClaimProxied
  :: (ClaimDecoder t) => String -> proxy t -> JwtT -> DecodeResult t
decodeClaimProxied :: String -> proxy t -> JwtT -> DecodeResult t
decodeClaimProxied String
name proxy t
_ = String -> JwtT -> DecodeResult t
forall t. ClaimDecoder t => String -> JwtT -> DecodeResult t
decodeClaim String
name

-- | Action that throws 'MissingClaim' if decoding has failed
decodeClaimOrThrow :: (ClaimDecoder t) => String -> proxy t -> JwtT -> JwtIO t
decodeClaimOrThrow :: String -> proxy t -> JwtT -> JwtIO t
decodeClaimOrThrow String
name proxy t
p =
  JwtIO t -> (t -> JwtIO t) -> Maybe t -> JwtIO t
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (MissingClaim -> JwtIO t
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (MissingClaim -> JwtIO t) -> MissingClaim -> JwtIO t
forall a b. (a -> b) -> a -> b
$ String -> MissingClaim
Missing String
name) t -> JwtIO t
forall (m :: * -> *) a. Monad m => a -> m a
return
    (Maybe t -> JwtIO t)
-> (JwtT -> JwtIO (Maybe t)) -> JwtT -> JwtIO t
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< DecodeResult t -> JwtIO (Maybe t)
forall t. DecodeResult t -> JwtIO (Maybe t)
getOptional
    (DecodeResult t -> JwtIO (Maybe t))
-> (JwtT -> DecodeResult t) -> JwtT -> JwtIO (Maybe t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.   String -> proxy t -> JwtT -> DecodeResult t
forall t (proxy :: * -> *).
ClaimDecoder t =>
String -> proxy t -> JwtT -> DecodeResult t
decodeClaimProxied String
name proxy t
p

data DecoderType = Native | Derived

type family DecoderDef a :: DecoderType where
  DecoderDef ByteString     = 'Native
  DecoderDef Bool           = 'Native
  DecoderDef Int            = 'Native
  DecoderDef NumericDate    = 'Native
  DecoderDef JsonByteString = 'Native
  DecoderDef String         = 'Derived
  DecoderDef [a]            = 'Native
  DecoderDef _              = 'Derived

-- | Low-level definition of claims decoding.
class ClaimDecoder t where
  -- | Given a pointer to /jwt_t/, try to decode the value of type @t@
  decodeClaim :: String -> JwtT -> DecodeResult t

instance (DecoderDef a ~ ty, ClaimDecoder' ty a) => ClaimDecoder a where
  decodeClaim :: String -> JwtT -> DecodeResult a
decodeClaim = Proxy ty -> String -> JwtT -> DecodeResult a
forall (ty :: DecoderType) t (proxy :: DecoderType -> *).
ClaimDecoder' ty t =>
proxy ty -> String -> JwtT -> DecodeResult t
decodeClaim' (Proxy ty
forall k (t :: k). Proxy t
Proxy :: Proxy ty)

class ClaimDecoder' (ty :: DecoderType) t where
  decodeClaim' :: proxy ty -> String -> JwtT -> DecodeResult t

instance ClaimDecoder' 'Native ByteString where
  decodeClaim' :: proxy 'Native -> String -> JwtT -> DecodeResult ByteString
decodeClaim' proxy 'Native
_ String
name = JwtIO (Maybe ByteString) -> DecodeResult ByteString
forall t. JwtIO (Maybe t) -> DecodeResult t
Result (JwtIO (Maybe ByteString) -> DecodeResult ByteString)
-> (JwtT -> JwtIO (Maybe ByteString))
-> JwtT
-> DecodeResult ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> JwtT -> JwtIO (Maybe ByteString)
getGrant String
name
  {-# INLINE decodeClaim' #-}

instance ClaimDecoder' 'Native Bool where
  decodeClaim' :: proxy 'Native -> String -> JwtT -> DecodeResult Bool
decodeClaim' proxy 'Native
_ String
name = JwtIO (Maybe Bool) -> DecodeResult Bool
forall t. JwtIO (Maybe t) -> DecodeResult t
Result (JwtIO (Maybe Bool) -> DecodeResult Bool)
-> (JwtT -> JwtIO (Maybe Bool)) -> JwtT -> DecodeResult Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> JwtT -> JwtIO (Maybe Bool)
getGrantBool String
name
  {-# INLINE decodeClaim' #-}

instance ClaimDecoder' 'Native Int where
  decodeClaim' :: proxy 'Native -> String -> JwtT -> DecodeResult Int
decodeClaim' proxy 'Native
_ String
name = JwtIO (Maybe Int) -> DecodeResult Int
forall t. JwtIO (Maybe t) -> DecodeResult t
Result (JwtIO (Maybe Int) -> DecodeResult Int)
-> (JwtT -> JwtIO (Maybe Int)) -> JwtT -> DecodeResult Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> JwtT -> JwtIO (Maybe Int)
getGrantInt String
name
  {-# INLINE decodeClaim' #-}

instance ClaimDecoder' 'Native NumericDate where
  decodeClaim' :: proxy 'Native -> String -> JwtT -> DecodeResult NumericDate
decodeClaim' proxy 'Native
_ String
name = JwtIO (Maybe Int64) -> DecodeResult NumericDate
coerce (JwtIO (Maybe Int64) -> DecodeResult NumericDate)
-> (JwtT -> JwtIO (Maybe Int64))
-> JwtT
-> DecodeResult NumericDate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> JwtT -> JwtIO (Maybe Int64)
getGrantInt64 String
name
  {-# INLINE decodeClaim' #-}

instance ClaimDecoder' 'Native JsonByteString where
  decodeClaim' :: proxy 'Native -> String -> JwtT -> DecodeResult JsonByteString
decodeClaim' proxy 'Native
_ String
name = (ByteString -> JsonByteString)
-> DecodeResult ByteString -> DecodeResult JsonByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> JsonByteString
jsonFromStrict (DecodeResult ByteString -> DecodeResult JsonByteString)
-> (JwtT -> DecodeResult ByteString)
-> JwtT
-> DecodeResult JsonByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JwtIO (Maybe ByteString) -> DecodeResult ByteString
forall t. JwtIO (Maybe t) -> DecodeResult t
Result (JwtIO (Maybe ByteString) -> DecodeResult ByteString)
-> (JwtT -> JwtIO (Maybe ByteString))
-> JwtT
-> DecodeResult ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> JwtT -> JwtIO (Maybe ByteString)
getGrantAsJson String
name
  {-# INLINE decodeClaim' #-}

fromJsonNative
  :: (JsonByteString -> JwtIO (Maybe a)) -> String -> JwtT -> DecodeResult a
fromJsonNative :: (JsonByteString -> JwtIO (Maybe a))
-> String -> JwtT -> DecodeResult a
fromJsonNative JsonByteString -> JwtIO (Maybe a)
f String
name =
  (JwtIO (Maybe a) -> DecodeResult a
forall t. JwtIO (Maybe t) -> DecodeResult t
Result (JwtIO (Maybe a) -> DecodeResult a)
-> (JsonByteString -> JwtIO (Maybe a))
-> JsonByteString
-> DecodeResult a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JsonByteString -> JwtIO (Maybe a)
f) (JsonByteString -> DecodeResult a)
-> (JwtT -> DecodeResult JsonByteString) -> JwtT -> DecodeResult a
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Proxy 'Native -> String -> JwtT -> DecodeResult JsonByteString
forall (ty :: DecoderType) t (proxy :: DecoderType -> *).
ClaimDecoder' ty t =>
proxy ty -> String -> JwtT -> DecodeResult t
decodeClaim' (Proxy 'Native
forall k (t :: k). Proxy t
Proxy :: Proxy 'Native) String
name
{-# INLINE fromJsonNative #-}

instance JsonParser a => ClaimDecoder' 'Native [a] where
  decodeClaim' :: proxy 'Native -> String -> JwtT -> DecodeResult [a]
decodeClaim' proxy 'Native
_ =
    (JsonByteString -> JwtIO (Maybe [a]))
-> String -> JwtT -> DecodeResult [a]
forall a.
(JsonByteString -> JwtIO (Maybe a))
-> String -> JwtT -> DecodeResult a
fromJsonNative
      ((JsonByteString -> JwtIO (Maybe [a]))
 -> String -> JwtT -> DecodeResult [a])
-> (JsonByteString -> JwtIO (Maybe [a]))
-> String
-> JwtT
-> DecodeResult [a]
forall a b. (a -> b) -> a -> b
$ (Maybe [Maybe a] -> Maybe [a])
-> JwtIO (Maybe [Maybe a]) -> JwtIO (Maybe [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Maybe a] -> Maybe [a]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([Maybe a] -> Maybe [a]) -> Maybe [Maybe a] -> Maybe [a]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<)
      (JwtIO (Maybe [Maybe a]) -> JwtIO (Maybe [a]))
-> (JsonByteString -> JwtIO (Maybe [Maybe a]))
-> JsonByteString
-> JwtIO (Maybe [a])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (JsonToken -> Maybe a) -> ByteString -> JwtIO (Maybe [Maybe a])
forall b. (JsonToken -> b) -> ByteString -> JwtIO (Maybe [b])
unsafeMapTokenizedJsonArray JsonToken -> Maybe a
forall a. JsonParser a => JsonToken -> Maybe a
jsonParser
      (ByteString -> JwtIO (Maybe [Maybe a]))
-> (JsonByteString -> ByteString)
-> JsonByteString
-> JwtIO (Maybe [Maybe a])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JsonByteString -> ByteString
toJsonStrict
  {-# INLINE decodeClaim' #-}

instance (JwtRep b a, DecoderDef b ~ ty, ClaimDecoder' ty b) => ClaimDecoder' 'Derived a where
  decodeClaim' :: proxy 'Derived -> String -> JwtT -> DecodeResult a
decodeClaim' proxy 'Derived
_ String
name =
    (Maybe a -> DecodeResult a
forall a. Maybe a -> DecodeResult a
hoistResult (Maybe a -> DecodeResult a)
-> (b -> Maybe a) -> b -> DecodeResult a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Maybe a
forall a b. JwtRep a b => a -> Maybe b
unRep) (b -> DecodeResult a)
-> (JwtT -> DecodeResult b) -> JwtT -> DecodeResult a
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Proxy ty -> String -> JwtT -> DecodeResult b
forall (ty :: DecoderType) t (proxy :: DecoderType -> *).
ClaimDecoder' ty t =>
proxy ty -> String -> JwtT -> DecodeResult t
decodeClaim' (Proxy ty
forall k (t :: k). Proxy t
Proxy :: Proxy ty) String
name

type family Decodable t :: Constraint where
  Decodable t = ClaimDecoder' (DecoderDef t) t

-- | Definition of claims decoding.
--   
--   The only use for the user is probably to write a function that is polymorphic in the payload type
class Decode c where
  -- | Construct an action that decodes the value of type @c@, given a pointer to /jwt_t/. The action may fail.
  decode :: JwtT -> JwtIO c