{-# LANGUAGE ScopedTypeVariables, TypeApplications #-}

module Network.AWS.CloudFront.SignedCookies.Crypto
  (
  -- * Reading the private key
    readPrivateKeyPemFile

  -- * Generating signatures
  , sign

  -- * Types
  , PrivateKey
  , ByteString

  ) where

import Network.AWS.CloudFront.SignedCookies.Crypto.Internal
import Network.AWS.CloudFront.SignedCookies.Types

-- asn1-encoding
import Data.ASN1.BinaryEncoding (DER (DER))
import Data.ASN1.Encoding (decodeASN1')
import Data.ASN1.Error (ASN1Error)

-- asn1-types
import Data.ASN1.Types (ASN1)

-- bytestring
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS

-- cryptonite
import Crypto.PubKey.RSA (PrivateKey)
import qualified Crypto.PubKey.RSA.PKCS15 as RSA
import Crypto.Hash.Algorithms (SHA1 (SHA1))

-- pem
import qualified Data.PEM as PEM

-- text
import qualified Data.Text as Text

-- | Construct the signature that will go into the
--   @CloudFront-Signature@ cookie.

sign
  :: PrivateKey     -- ^ The RSA private key that you read from the @.pem@ file
  -> ByteString     -- ^ The JSON representation of the 'Policy'
                    --   (see "Network.AWS.CloudFront.SignedCookies.Policy")
  -> IO ByteString

sign :: PrivateKey -> ByteString -> IO ByteString
sign PrivateKey
key ByteString
bs =
  Maybe SHA1
-> PrivateKey -> ByteString -> IO (Either Error ByteString)
forall hashAlg (m :: * -> *).
(HashAlgorithmASN1 hashAlg, MonadRandom m) =>
Maybe hashAlg
-> PrivateKey -> ByteString -> m (Either Error ByteString)
RSA.signSafer (SHA1 -> Maybe SHA1
forall a. a -> Maybe a
Just SHA1
SHA1) PrivateKey
key ByteString
bs IO (Either Error ByteString)
-> (Either Error ByteString -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Error -> IO ByteString)
-> (ByteString -> IO ByteString)
-> Either Error ByteString
-> IO ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> IO ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO ByteString)
-> (Error -> String) -> Error -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Error -> String
forall a. Show a => a -> String
show) ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | Read an RSA private key from a @.pem@ file you downloaded from AWS.

readPrivateKeyPemFile
  :: PemFilePath    -- ^ The filesystem path of the @.pem@ file
  -> IO PrivateKey

readPrivateKeyPemFile :: PemFilePath -> IO PrivateKey
readPrivateKeyPemFile (PemFilePath Text
path) = do

  ByteString
lbs <- String -> IO ByteString
BS.readFile (Text -> String
Text.unpack Text
path)

  [PEM]
pemSections <- (String -> IO [PEM])
-> ([PEM] -> IO [PEM]) -> Either String [PEM] -> IO [PEM]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> IO [PEM]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail [PEM] -> IO [PEM]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either String [PEM]
PEM.pemParseBS ByteString
lbs)

  ByteString
pemBs <- PEM -> ByteString
PEM.pemContent (PEM -> ByteString) -> IO PEM -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case [PEM]
pemSections of
    [PEM
x] -> PEM -> IO PEM
forall (f :: * -> *) a. Applicative f => a -> f a
pure PEM
x
    [PEM]
xs ->
      let msg :: String
msg = String
"Expected exactly 1 PEM section but found " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                Int -> String
forall a. Show a => a -> String
show @Int ([PEM] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PEM]
xs)
      in  String -> IO PEM
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
msg

  [ASN1]
asn1s :: [ASN1] <- (ASN1Error -> IO [ASN1])
-> ([ASN1] -> IO [ASN1]) -> Either ASN1Error [ASN1] -> IO [ASN1]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> IO [ASN1]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO [ASN1])
-> (ASN1Error -> String) -> ASN1Error -> IO [ASN1]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASN1Error -> String
forall a. Show a => a -> String
show) [ASN1] -> IO [ASN1]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DER -> ByteString -> Either ASN1Error [ASN1]
forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
decodeASN1' DER
DER ByteString
pemBs)

  (String -> IO PrivateKey)
-> (PrivateKey -> IO PrivateKey)
-> Either String PrivateKey
-> IO PrivateKey
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> IO PrivateKey
forall (m :: * -> *) a. MonadFail m => String -> m a
fail PrivateKey -> IO PrivateKey
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ASN1] -> Either String PrivateKey
rsaPrivateKeyFromASN1 [ASN1]
asn1s)