module Jose.Jwt
    ( module Jose.Types
    , encode
    , decode
    , decodeClaims
    )
where
import Control.Error
import Control.Monad.State.Strict
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..))
import Crypto.Random (CPRG)
import Data.Aeson (decodeStrict')
import Data.ByteString (ByteString)
import Data.List (find)
import Data.Maybe (fromJust)
import qualified Data.ByteString.Char8 as BC
import qualified Jose.Internal.Base64 as B64
import Jose.Types
import Jose.Jwk
import Jose.Jwa
import qualified Jose.Jws as Jws
import qualified Jose.Jwe as Jwe
encode :: (CPRG g)
       => g                               
       -> Jwk                             
       -> Alg                             
       -> Maybe Enc                       
       -> ByteString                      
       -> (Either JwtError ByteString, g) 
encode rng jwk alg enc msg = flip runState rng $ runEitherT $ case alg of
    Signed a    -> do
        unless (isNothing enc) $ left (BadAlgorithm "Enc cannot be set for a JWS")
        hoistEither (validateForJws a jwk)
        hoistEither =<< state (\g -> Jws.jwkEncode g a jwk msg)
    Encrypted a -> case enc of
        Nothing   -> left (BadAlgorithm "Enc must be supplied for a JWE")
        Just e    -> hoistEither =<< state (\g -> Jwe.jwkEncode g a e jwk msg)
decode :: CPRG g
       => g                        
       -> [Jwk]                    
       -> ByteString               
       -> (Either JwtError Jwt, g) 
decode rng keySet jwt = flip runState rng $ runEitherT $ do
    let components = BC.split '.' jwt
    when (length components < 3) $ left $ BadDots 2
    hdr <- B64.decode (head components) >>= hoistEither . parseHeader
    ks  <- findKeys hdr keySet
    
    
    let decodeWith = case hdr of
                       JwsH _ -> decodeWithJws
                       _      -> decodeWithJwe
    decodings <- mapM decodeWith ks
    maybe (left $ KeyError "None of the keys was able to decode the JWT") (return . fromJust) $ find isJust decodings
  where
    decodeWithJws :: CPRG g => Jwk -> EitherT JwtError (State g) (Maybe Jwt)
    decodeWithJws k = either (const $ return Nothing) (return . Just . Jws) $ case k of
        RsaPublicJwk  kPub _ _ _ -> Jws.rsaDecode kPub jwt
        RsaPrivateJwk kPr  _ _ _ -> Jws.rsaDecode (private_pub kPr) jwt
        EcPublicJwk   kPub _ _ _ -> Jws.ecDecode kPub jwt
        EcPrivateJwk  kPr  _ _ _ -> Jws.ecDecode (ECDSA.toPublicKey kPr) jwt
        SymmetricJwk  kb   _ _ _ -> Jws.hmacDecode kb jwt
    decodeWithJwe :: CPRG g => Jwk -> EitherT JwtError (State g) (Maybe Jwt)
    decodeWithJwe k = case k of
        RsaPrivateJwk kPr _ _ _ -> do
            g <- lift get
            let (e, g') = Jwe.rsaDecode g kPr jwt
            lift $ put g'
            either (const $ return Nothing) (return . Just . Jwe) e
        _                       -> left $ KeyError "Not a JWE key (shouldn't happen)"
decodeClaims :: ByteString
             -> Either JwtError (JwtHeader, JwtClaims)
decodeClaims jwt = do
    let components = BC.split '.' jwt
    when (length components /= 3) $ Left $ BadDots 2
    hdr    <- B64.decode (head components) >>= parseHeader
    claims <- B64.decode ((head . tail) components) >>= parseClaims
    return (hdr, claims)
  where
    parseClaims bs = maybe (Left BadClaims) Right $ decodeStrict' bs
findKeys :: Monad m => JwtHeader -> [Jwk] -> EitherT JwtError m [Jwk]
findKeys hdr jwks = checkKeys $ case hdr of
    JweH h -> findMatchingJweKeys jwks h
    JwsH h -> findMatchingJwsKeys jwks h
  where
    
    checkKeys [] = left $ KeyError "No suitable key was found to decode the JWT"
    checkKeys ks = return ks