{-# LINE 1 "OpenSSL/DER.hsc" #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE CApiFFI #-}
-- |Encoding and decoding of RSA keys using the ASN.1 DER format
module OpenSSL.DER
    ( toDERPub
    , fromDERPub
    , toDERPriv
    , fromDERPriv
    )
    where


{-# LINE 15 "OpenSSL/DER.hsc" #-}
import           OpenSSL.RSA                (RSA, RSAKey, RSAKeyPair, RSAPubKey,
                                             absorbRSAPtr, withRSAPtr)

import           Data.ByteString            (ByteString)
import qualified Data.ByteString            as B  (useAsCStringLen)
import qualified Data.ByteString.Internal   as BI (createAndTrim)
import           Foreign.Ptr                (Ptr, nullPtr, castPtr)
import           Foreign.C.Types            (CLong(..), CInt(..))
import           Foreign.Marshal.Alloc      (alloca)
import           Foreign.Storable           (poke)
import           GHC.Word                   (Word8)
import           System.IO.Unsafe           (unsafePerformIO)

type CDecodeFun = Ptr (Ptr RSA) -> Ptr (Ptr Word8) -> CLong -> IO (Ptr RSA)
type CEncodeFun = Ptr RSA -> Ptr (Ptr Word8) -> IO CInt

foreign import capi unsafe "HsOpenSSL.h d2i_RSAPublicKey"
  _fromDERPub :: CDecodeFun

foreign import capi unsafe "HsOpenSSL.h i2d_RSAPublicKey"
  _toDERPub :: CEncodeFun

foreign import capi unsafe "HsOpenSSL.h d2i_RSAPrivateKey"
  _fromDERPriv :: CDecodeFun

foreign import capi unsafe "HsOpenSSL.h i2d_RSAPrivateKey"
  _toDERPriv :: CEncodeFun

-- | Generate a function that decodes a key from ASN.1 DER format
makeDecodeFun :: RSAKey k => CDecodeFun -> ByteString -> Maybe k
makeDecodeFun :: forall k. RSAKey k => CDecodeFun -> ByteString -> Maybe k
makeDecodeFun CDecodeFun
fun ByteString
bs = IO (Maybe k) -> Maybe k
forall a. IO a -> a
unsafePerformIO (IO (Maybe k) -> Maybe k)
-> (((Ptr (Ptr CChar), CLong) -> IO (Maybe k)) -> IO (Maybe k))
-> ((Ptr (Ptr CChar), CLong) -> IO (Maybe k))
-> Maybe k
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Ptr (Ptr CChar), CLong) -> IO (Maybe k)) -> IO (Maybe k)
forall {b} {a}. Num b => ((Ptr (Ptr CChar), b) -> IO a) -> IO a
usingConvedBS (((Ptr (Ptr CChar), CLong) -> IO (Maybe k)) -> Maybe k)
-> ((Ptr (Ptr CChar), CLong) -> IO (Maybe k)) -> Maybe k
forall a b. (a -> b) -> a -> b
$ \(Ptr (Ptr CChar)
csPtr, CLong
ci) -> do
  -- When you pass a null pointer to this function, it will allocate the memory
  -- space required for the RSA key all by itself.  It will be freed whenever
  -- the haskell object is garbage collected, as they are stored in ForeignPtrs
  -- internally.
  Ptr RSA
rsaPtr <- CDecodeFun
fun Ptr (Ptr RSA)
forall a. Ptr a
nullPtr (Ptr (Ptr CChar) -> Ptr (Ptr Word8)
forall a b. Ptr a -> Ptr b
castPtr Ptr (Ptr CChar)
csPtr) CLong
ci
  -- CString is represented as a void* in C and the C compiler whines about
  -- a bad pointer conversion in d2i_* functions. So we declare
  -- the CDecodeFun to accept Ptr Word8 and perform the castPtr here.
  if Ptr RSA
rsaPtr Ptr RSA -> Ptr RSA -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr RSA
forall a. Ptr a
nullPtr then Maybe k -> IO (Maybe k)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe k
forall a. Maybe a
Nothing else Ptr RSA -> IO (Maybe k)
forall k. RSAKey k => Ptr RSA -> IO (Maybe k)
absorbRSAPtr Ptr RSA
rsaPtr
  where usingConvedBS :: ((Ptr (Ptr CChar), b) -> IO a) -> IO a
usingConvedBS (Ptr (Ptr CChar), b) -> IO a
io = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
B.useAsCStringLen ByteString
bs ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
cs, Int
len) ->
          (Ptr (Ptr CChar) -> IO a) -> IO a
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr (Ptr CChar) -> IO a) -> IO a)
-> (Ptr (Ptr CChar) -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr CChar)
csPtr -> Ptr (Ptr CChar) -> Ptr CChar -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (Ptr CChar)
csPtr Ptr CChar
cs IO () -> IO a -> IO a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Ptr (Ptr CChar), b) -> IO a
io (Ptr (Ptr CChar)
csPtr, Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)

-- | Generate a function that encodes a key in ASN.1 DER format
makeEncodeFun :: RSAKey k => CEncodeFun -> k -> ByteString
makeEncodeFun :: forall k. RSAKey k => CEncodeFun -> k -> ByteString
makeEncodeFun CEncodeFun
fun k
k = IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ do
  -- When you pass a null pointer to this function, it will only compute the
  -- required buffer size.  See https://www.openssl.org/docs/faq.html#PROG3
  CInt
requiredSize <- k -> (Ptr RSA -> IO CInt) -> IO CInt
forall k a. RSAKey k => k -> (Ptr RSA -> IO a) -> IO a
withRSAPtr k
k ((Ptr RSA -> IO CInt) -> IO CInt)
-> (Ptr RSA -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ CEncodeFun -> Ptr (Ptr Word8) -> Ptr RSA -> IO CInt
forall a b c. (a -> b -> c) -> b -> a -> c
flip CEncodeFun
fun Ptr (Ptr Word8)
forall a. Ptr a
nullPtr
  -- It’s too sad BI.createAndTrim is considered internal, as it does a great
  -- job here.  See https://hackage.haskell.org/package/bytestring-0.9.1.4/docs/Data-ByteString-Internal.html#v%3AcreateAndTrim
  Int -> (Ptr Word8 -> IO Int) -> IO ByteString
BI.createAndTrim (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
requiredSize) ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr ->
    (Ptr (Ptr Word8) -> IO Int) -> IO Int
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr (Ptr Word8) -> IO Int) -> IO Int)
-> (Ptr (Ptr Word8) -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr Word8)
pptr ->
      (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (IO CInt -> IO Int)
-> ((Ptr RSA -> IO CInt) -> IO CInt)
-> (Ptr RSA -> IO CInt)
-> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> (Ptr RSA -> IO CInt) -> IO CInt
forall k a. RSAKey k => k -> (Ptr RSA -> IO a) -> IO a
withRSAPtr k
k ((Ptr RSA -> IO CInt) -> IO Int) -> (Ptr RSA -> IO CInt) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr RSA
key ->
        Ptr (Ptr Word8) -> Ptr Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr (Ptr Word8)
pptr Ptr Word8
ptr IO () -> IO CInt -> IO CInt
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> CEncodeFun
fun Ptr RSA
key Ptr (Ptr Word8)
pptr

-- | Dump a public key to ASN.1 DER format
toDERPub :: RSAKey k
         => k          -- ^ You can pass either 'RSAPubKey' or 'RSAKeyPair'
                       --   because both contain the necessary information.
         -> ByteString -- ^ The public key information encoded in ASN.1 DER
toDERPub :: forall k. RSAKey k => k -> ByteString
toDERPub = CEncodeFun -> k -> ByteString
forall k. RSAKey k => CEncodeFun -> k -> ByteString
makeEncodeFun CEncodeFun
_toDERPub

-- | Parse a public key from ASN.1 DER format
fromDERPub :: ByteString -> Maybe RSAPubKey
fromDERPub :: ByteString -> Maybe RSAPubKey
fromDERPub = CDecodeFun -> ByteString -> Maybe RSAPubKey
forall k. RSAKey k => CDecodeFun -> ByteString -> Maybe k
makeDecodeFun CDecodeFun
_fromDERPub

-- | Dump a private key to ASN.1 DER format
toDERPriv :: RSAKeyPair -> ByteString
toDERPriv :: RSAKeyPair -> ByteString
toDERPriv = CEncodeFun -> RSAKeyPair -> ByteString
forall k. RSAKey k => CEncodeFun -> k -> ByteString
makeEncodeFun CEncodeFun
_toDERPriv

-- | Parse a private key from ASN.1 DER format
fromDERPriv :: RSAKey k
            => ByteString -- ^ The private key information encoded in ASN.1 DER
            -> Maybe k    -- ^ This can return either 'RSAPubKey' or
                          --   'RSAKeyPair' because there’s sufficient
                          --   information for both.
fromDERPriv :: forall k. RSAKey k => ByteString -> Maybe k
fromDERPriv = CDecodeFun -> ByteString -> Maybe k
forall k. RSAKey k => CDecodeFun -> ByteString -> Maybe k
makeDecodeFun CDecodeFun
_fromDERPriv