module Crypto.JWT
(
JWT(..)
, claimAud
, claimExp
, claimIat
, claimIss
, claimJti
, claimNbf
, claimSub
, unregisteredClaims
, addClaim
, createJWSJWT
, validateJWSJWT
, ClaimsSet(..)
, emptyClaimsSet
, Audience(..)
, StringOrURI
, fromString
, fromURI
, getString
, getURI
, NumericDate(..)
) where
import Control.Applicative
import Control.Monad
import Data.Bifunctor
import Data.Maybe
import Control.Lens hiding ((.=))
import Data.Aeson
import qualified Data.ByteString.Lazy as BSL
import qualified Data.HashMap.Strict as M
import qualified Data.Text as T
import Data.Time
import Data.Time.Clock.POSIX
import Network.URI (parseURI)
import Crypto.JOSE
import Crypto.JOSE.Types
data StringOrURI = Arbitrary T.Text | OrURI URI deriving (Eq, Show)
fromString :: T.Text -> StringOrURI
fromString s = maybe (Arbitrary s) OrURI $ parseURI $ T.unpack s
fromURI :: URI -> StringOrURI
fromURI = OrURI
getString :: StringOrURI -> Maybe T.Text
getString (Arbitrary a) = Just a
getString (OrURI _) = Nothing
getURI :: StringOrURI -> Maybe URI
getURI (Arbitrary _) = Nothing
getURI (OrURI a) = Just a
instance FromJSON StringOrURI where
parseJSON = withText "StringOrURI" (\s ->
if T.any (== ':') s
then OrURI <$> parseJSON (String s)
else pure $ Arbitrary s)
instance ToJSON StringOrURI where
toJSON (Arbitrary s) = toJSON s
toJSON (OrURI uri) = toJSON $ show uri
newtype NumericDate = NumericDate UTCTime deriving (Eq, Show)
instance FromJSON NumericDate where
parseJSON = withScientific "NumericDate" $
pure . NumericDate . posixSecondsToUTCTime . fromRational . toRational
instance ToJSON NumericDate where
toJSON (NumericDate t)
= Number $ fromRational $ toRational $ utcTimeToPOSIXSeconds t
data Audience = General [StringOrURI] | Special StringOrURI deriving (Eq, Show)
instance FromJSON Audience where
parseJSON v = fmap General (parseJSON v) <|> fmap Special (parseJSON v)
instance ToJSON Audience where
toJSON (General auds) = toJSON auds
toJSON (Special aud) = toJSON aud
data ClaimsSet = ClaimsSet
{ _claimIss :: Maybe StringOrURI
, _claimSub :: Maybe StringOrURI
, _claimAud :: Maybe Audience
, _claimExp :: Maybe NumericDate
, _claimNbf :: Maybe NumericDate
, _claimIat :: Maybe NumericDate
, _claimJti :: Maybe T.Text
, _unregisteredClaims :: M.HashMap T.Text Value
}
deriving (Eq, Show)
makeLenses ''ClaimsSet
emptyClaimsSet :: ClaimsSet
emptyClaimsSet = ClaimsSet n n n n n n n M.empty where n = Nothing
addClaim :: T.Text -> Value -> ClaimsSet -> ClaimsSet
addClaim k v = over unregisteredClaims (M.insert k v)
filterUnregistered :: M.HashMap T.Text Value -> M.HashMap T.Text Value
filterUnregistered = M.filterWithKey (\k _ -> k `notElem` registered) where
registered = ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"]
instance FromJSON ClaimsSet where
parseJSON = withObject "JWT Claims Set" (\o -> ClaimsSet
<$> o .:? "iss"
<*> o .:? "sub"
<*> o .:? "aud"
<*> o .:? "exp"
<*> o .:? "nbf"
<*> o .:? "iat"
<*> o .:? "jti"
<*> pure (filterUnregistered o))
instance ToJSON ClaimsSet where
toJSON (ClaimsSet iss sub aud exp' nbf iat jti o) = object $ catMaybes [
fmap ("iss" .=) iss
, fmap ("sub" .=) sub
, fmap ("aud" .=) aud
, fmap ("exp" .=) exp'
, fmap ("nbf" .=) nbf
, fmap ("iat" .=) iat
, fmap ("jti" .=) jti
] ++ M.toList (filterUnregistered o)
data JWTCrypto = JWTJWS JWS deriving (Eq, Show)
instance FromCompact JWTCrypto where
fromCompact = fmap JWTJWS . fromCompact
instance ToCompact JWTCrypto where
toCompact (JWTJWS jws) = toCompact jws
data JWT = JWT
{ jwtCrypto :: JWTCrypto
, jwtClaimsSet :: ClaimsSet
} deriving (Eq, Show)
instance FromCompact JWT where
fromCompact = fromCompact >=> toJWT where
toJWT (JWTJWS jws) =
bimap CompactDecodeError (JWT (JWTJWS jws))
$ eitherDecode $ jwsPayload jws
instance ToCompact JWT where
toCompact = toCompact . jwtCrypto
validateJWSJWT
:: ValidationAlgorithms
-> ValidationPolicy
-> JWK
-> JWT
-> Bool
validateJWSJWT algs policy k (JWT (JWTJWS jws) _) = verifyJWS algs policy k jws
createJWSJWT
:: CPRG g
=> g
-> JWK
-> JWSHeader
-> ClaimsSet
-> (Either Error JWT, g)
createJWSJWT g k h c = first (fmap $ \jws -> JWT (JWTJWS jws) c) $
signJWS g (JWS payload []) h k
where
payload = Base64Octets $ BSL.toStrict $ encode c