module Jose.Internal.Crypto
    ( hmacSign
    , hmacVerify
    , rsaSign
    , rsaVerify
    , rsaEncrypt
    , rsaDecrypt
    , ecVerify
    , encryptPayload
    , decryptPayload
    , generateCmkAndIV
    , pad
    , unpad
    )
where
import           Control.Applicative
import           Crypto.Cipher.Types (AuthTag(..))
import           Crypto.Number.Serialize (os2ip)
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.RSA as RSA
import qualified Crypto.PubKey.RSA.PKCS15 as PKCS15
import qualified Crypto.PubKey.RSA.OAEP as OAEP
import           Crypto.Random (CPRG, cprgGenerate)
import qualified Crypto.Cipher.AES as AES
import           Crypto.PubKey.HashDescr
import           Crypto.MAC.HMAC (hmac)
import           Data.Byteable (constEqBytes)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.Serialize as Serialize
import qualified Data.Text as T
import           Data.Word (Word64, Word8)
import           Jose.Jwa
import           Jose.Types (JwtError(..))
hmacSign :: JwsAlg      
         -> ByteString  
         -> ByteString  
         -> Either JwtError ByteString  
hmacSign a k m = do
    hash <- maybe (Left $ BadAlgorithm $ T.pack $ "Not an HMAC algorithm: " ++ show a) return $ lookup a hmacHashes
    return $ hmac (hashFunction hash) 64 k m
hmacVerify :: JwsAlg      
           -> ByteString  
           -> ByteString  
           -> ByteString  
           -> Bool        
hmacVerify a key msg sig = either (const False) (`constEqBytes` sig) $ hmacSign a key msg
rsaSign :: Maybe RSA.Blinder  
        -> JwsAlg             
        -> RSA.PrivateKey     
        -> ByteString         
        -> Either JwtError ByteString    
rsaSign blinder a key msg = do
    hash <- lookupRSAHash a
    either (const $ Left BadCrypto) Right $ PKCS15.sign blinder hash key msg
  where
rsaVerify :: JwsAlg        
          -> RSA.PublicKey 
          -> ByteString    
          -> ByteString    
          -> Bool          
rsaVerify a key msg sig = case lookupRSAHash a of
    Right hash -> PKCS15.verify hash key msg sig
    _          -> False
ecVerify :: JwsAlg          
         -> ECDSA.PublicKey 
         -> ByteString      
         -> ByteString      
         -> Bool            
ecVerify a key msg sig = case lookupECHash a of
    Just hash -> let (r, s) = B.splitAt (B.length sig `div` 2) sig
                 in  ECDSA.verify hash key (ECDSA.Signature (os2ip r) (os2ip s)) msg
    Nothing   -> False
hmacHashes :: [(JwsAlg, HashDescr)]
hmacHashes = [(HS256, hashDescrSHA256), (HS384, hashDescrSHA384), (HS512, hashDescrSHA512)]
lookupECHash :: JwsAlg -> Maybe HashFunction
lookupECHash alg = hashFunction <$> case alg of
    ES256 -> Just hashDescrSHA256
    ES384 -> Just hashDescrSHA384
    ES512 -> Just hashDescrSHA512
    _     -> Nothing
lookupRSAHash :: JwsAlg -> Either JwtError HashDescr
lookupRSAHash alg = case alg of
    RS256 -> Right hashDescrSHA256
    RS384 -> Right hashDescrSHA384
    RS512 -> Right hashDescrSHA512
    _     -> Left . BadAlgorithm . T.pack $ "Not an RSA algorithm: " ++ show alg
generateCmkAndIV :: CPRG g
                 => g   
                 -> Enc 
                 -> (B.ByteString, B.ByteString, g) 
generateCmkAndIV g e = (cmk, iv, g'')
  where
    (cmk, g') = cprgGenerate (keySize e) g
    (iv, g'') = cprgGenerate (ivSize e) g'  
keySize :: Enc -> Int
keySize A128GCM = 16
keySize A256GCM = 32
keySize A128CBC_HS256 = 32
keySize A256CBC_HS512 = 64
ivSize :: Enc -> Int
ivSize A128GCM = 12
ivSize A256GCM = 12
ivSize _       = 16
rsaEncrypt :: CPRG g
           => g                  
           -> JweAlg             
           -> RSA.PublicKey      
           -> B.ByteString       
           -> (B.ByteString, g)  
rsaEncrypt gen a pubKey content = (ct, g')
  where
    encrypt = case a of
        RSA1_5   -> PKCS15.encrypt gen
        RSA_OAEP -> OAEP.encrypt gen oaepParams
    (Right ct, g') = encrypt pubKey content
rsaDecrypt :: Maybe RSA.Blinder
           -> JweAlg                        
           -> RSA.PrivateKey                
           -> B.ByteString                  
           -> Either JwtError B.ByteString  
rsaDecrypt blinder a rsaKey jweKey = either (const $ Left BadCrypto) Right $ decrypt rsaKey jweKey
  where
    decrypt = case a of
        RSA1_5   -> PKCS15.decrypt blinder
        RSA_OAEP -> OAEP.decrypt blinder oaepParams
oaepParams :: OAEP.OAEPParams
oaepParams = OAEP.defaultOAEPParams (hashFunction hashDescrSHA1)
decryptPayload :: Enc        
               -> ByteString 
               -> ByteString 
               -> ByteString 
               -> ByteString 
               -> ByteString 
               -> Either JwtError ByteString
decryptPayload e cek iv aad sig ct = do
    (plaintext, tag) <- case e of
        A128GCM        -> decryptedGCM
        A256GCM        -> decryptedGCM
        A128CBC_HS256  -> decryptedCBC 16 hashDescrSHA256
        A256CBC_HS512  -> decryptedCBC 32 hashDescrSHA512
    if tag == AuthTag sig
      then return plaintext
      else Left BadSignature
  where
    decryptedGCM = Right $ AES.decryptGCM (AES.initAES cek) iv aad ct
    decryptedCBC l h = do
      let (macKey, encKey) = B.splitAt (B.length cek `div` 2) cek
      let al = fromIntegral (B.length aad) * 8 :: Word64
      plaintext <- unpad $ AES.decryptCBC (AES.initAES encKey) iv ct
      let mac = authTag l h macKey $ B.concat [aad, iv, ct, Serialize.encode al]
      return (plaintext, mac)
encryptPayload :: Enc                   
               -> ByteString            
               -> ByteString            
               -> ByteString            
               -> ByteString            
               -> (ByteString, AuthTag) 
encryptPayload e cek iv aad msg = case e of
    A128GCM        -> aesgcm
    A256GCM        -> aesgcm
    A128CBC_HS256  -> (aescbc, sig 16 hashDescrSHA256)
    A256CBC_HS512  -> (aescbc, sig 32 hashDescrSHA512)
  where
    aesgcm = AES.encryptGCM (AES.initAES cek) iv aad msg
    (macKey, encKey) = B.splitAt (B.length cek `div` 2) cek
    aescbc = AES.encryptCBC (AES.initAES encKey) iv (pad msg)
    al     = fromIntegral (B.length aad) * 8 :: Word64
    sig l h = authTag l h macKey $ B.concat [aad, iv, aescbc, Serialize.encode al]
authTag :: Int -> HashDescr -> ByteString -> ByteString -> AuthTag
authTag l h k m = AuthTag $ B.take l $ hmac (hashFunction h) 64 k m
unpad :: ByteString -> Either JwtError ByteString
unpad bs
    | padLen > 16 || padLen /= B.length padding = Left BadCrypto
    | B.any (/= padByte) padding = Left BadCrypto
    | otherwise = Right pt
  where
    len     = B.length bs
    padByte = B.last bs
    padLen  = fromIntegral padByte
    (pt, padding) = B.splitAt (len  padLen) bs
pad :: ByteString -> ByteString
pad bs = B.append bs padding
  where
    lastBlockSize = B.length bs `mod` 16
    padByte       = fromIntegral $ 16  lastBlockSize :: Word8
    padding       = B.replicate (fromIntegral padByte) padByte