{-# LINE 1 "src/Hookup/OpenSSL.hsc" #-}
{-# Language CApiFFI #-}
{-|
Module      : Hookup.OpenSSL
Description : Hack into the internals of OpenSSL to add missing functionality
Copyright   : (c) Eric Mertens, 2016
License     : ISC
Maintainer  : emertens@gmail.com
-}






{-# LINE 17 "src/Hookup/OpenSSL.hsc" #-}

module Hookup.OpenSSL (withDefaultPassword, installVerification, getPubKeyDer) where

import           Control.Exception (bracket, bracket_)
import           Control.Monad (unless)
import           Foreign.C (CStringLen, CString(..), CSize(..), CUInt(..), CInt(..), withCStringLen, CChar(..))
import           Foreign.Ptr (FunPtr, Ptr, castPtr, nullPtr, nullFunPtr)
import           Foreign.StablePtr (StablePtr, deRefStablePtr, castPtrToStablePtr)
import           Foreign.Marshal (with)
import           OpenSSL.Session (SSLContext, SSLContext_, withContext)
import           OpenSSL.X509 (withX509Ptr, X509, X509_)
import           Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString.Unsafe as Unsafe

------------------------------------------------------------------------
-- Bindings to password callback
------------------------------------------------------------------------

foreign import ccall unsafe "hookup_new_userdata"
  hookup_new_userdata :: CString -> CInt -> IO (Ptr ())

foreign import ccall unsafe "hookup_free_userdata"
  hookup_free_userdata :: Ptr () -> IO ()

foreign import ccall "&hookup_pem_passwd_cb"
  hookup_pem_passwd_cb :: FunPtr PemPasswdCb

-- int pem_passwd_cb(char *buf, int size, int rwflag, void *userdata);
type PemPasswdCb = Ptr CChar -> CInt -> CInt -> Ptr () -> IO CInt

-- void SSL_CTX_set_default_passwd_cb(SSL_CTX *ctx, pem_password_cb *cb);
foreign import ccall unsafe "SSL_CTX_set_default_passwd_cb"
  sslCtxSetDefaultPasswdCb :: Ptr SSLContext_ -> FunPtr PemPasswdCb -> IO ()

-- void SSL_CTX_set_default_passwd_cb_userdata(SSL_CTX *ctx, void *u);
foreign import ccall unsafe "SSL_CTX_set_default_passwd_cb_userdata"
  sslCtxSetDefaultPasswdCbUserdata ::
    Ptr SSLContext_ -> Ptr a -> IO ()

withDefaultPassword :: SSLContext -> Maybe ByteString -> IO a -> IO a
withDefaultPassword :: SSLContext -> Maybe ByteString -> IO a -> IO a
withDefaultPassword SSLContext
ctx Maybe ByteString
mbBs IO a
m =
  Maybe ByteString -> (Ptr CChar -> CInt -> IO a) -> IO a
forall t a.
Num t =>
Maybe ByteString -> (Ptr CChar -> t -> IO a) -> IO a
withCPassword Maybe ByteString
mbBs ((Ptr CChar -> CInt -> IO a) -> IO a)
-> (Ptr CChar -> CInt -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr CChar
ptr CInt
len ->
  IO (Ptr ()) -> (Ptr () -> IO ()) -> (Ptr () -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Ptr CChar -> CInt -> IO (Ptr ())
hookup_new_userdata Ptr CChar
ptr CInt
len) Ptr () -> IO ()
hookup_free_userdata ((Ptr () -> IO a) -> IO a) -> (Ptr () -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ud ->
  IO () -> IO () -> IO a -> IO a
forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (FunPtr PemPasswdCb -> Ptr () -> IO ()
forall a. FunPtr PemPasswdCb -> Ptr a -> IO ()
setup FunPtr PemPasswdCb
hookup_pem_passwd_cb Ptr ()
ud) (FunPtr PemPasswdCb -> Ptr Any -> IO ()
forall a. FunPtr PemPasswdCb -> Ptr a -> IO ()
setup FunPtr PemPasswdCb
forall a. FunPtr a
nullFunPtr Ptr Any
forall a. Ptr a
nullPtr) IO a
m

  where
  withCPassword :: Maybe ByteString -> (Ptr CChar -> t -> IO a) -> IO a
withCPassword Maybe ByteString
Nothing Ptr CChar -> t -> IO a
k = Ptr CChar -> t -> IO a
k Ptr CChar
forall a. Ptr a
nullPtr (-t
1)
  withCPassword (Just ByteString
bs) Ptr CChar -> t -> IO a
k = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
Unsafe.unsafeUseAsCStringLen ByteString
bs ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) -> Ptr CChar -> t -> IO a
k Ptr CChar
ptr (Int -> t
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)

  setup :: FunPtr PemPasswdCb -> Ptr a -> IO ()
setup FunPtr PemPasswdCb
cb Ptr a
ud =
    SSLContext -> (Ptr SSLContext_ -> IO ()) -> IO ()
forall a. SSLContext -> (Ptr SSLContext_ -> IO a) -> IO a
withContext SSLContext
ctx ((Ptr SSLContext_ -> IO ()) -> IO ())
-> (Ptr SSLContext_ -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr SSLContext_
ctxPtr ->
    do Ptr SSLContext_ -> FunPtr PemPasswdCb -> IO ()
sslCtxSetDefaultPasswdCb         Ptr SSLContext_
ctxPtr FunPtr PemPasswdCb
cb
       Ptr SSLContext_ -> Ptr a -> IO ()
forall a. Ptr SSLContext_ -> Ptr a -> IO ()
sslCtxSetDefaultPasswdCbUserdata Ptr SSLContext_
ctxPtr Ptr a
ud

------------------------------------------------------------------------
-- Bindings to hostname verification interface
------------------------------------------------------------------------

data X509_VERIFY_PARAM_
data {-# CTYPE "openssl/ssl.h" "X509_PUBKEY" #-} X509_PUBKEY_
data {-# CTYPE "openssl/ssl.h" "X509" #-} X509__

-- X509_VERIFY_PARAM *SSL_CTX_get0_param(SSL_CTX *ctx);
foreign import ccall unsafe "SSL_CTX_get0_param"
  sslGet0Param ::
    Ptr SSLContext_ {- ^ ctx -} ->
    IO (Ptr X509_VERIFY_PARAM_)

-- void X509_VERIFY_PARAM_set_hostflags(X509_VERIFY_PARAM *param, unsigned int flags);
foreign import ccall unsafe "X509_VERIFY_PARAM_set_hostflags"
  x509VerifyParamSetHostflags ::
    Ptr X509_VERIFY_PARAM_ {- ^ param -} ->
    CUInt                  {- ^ flags -} ->
    IO ()

-- int X509_VERIFY_PARAM_set1_host(X509_VERIFY_PARAM *param, const char *name, size_t namelen);
foreign import ccall unsafe "X509_VERIFY_PARAM_set1_host"
  x509VerifyParamSet1Host ::
    Ptr X509_VERIFY_PARAM_ {- ^ param                -} ->
    CString                {- ^ name                 -} ->
    CSize                  {- ^ namelen              -} ->
    IO CInt                {- ^ 1 success, 0 failure -}

-- X509_PUBKEY *X509_get_X509_PUBKEY(X509 *x);
foreign import capi unsafe "openssl/x509.h X509_get_X509_PUBKEY"
  x509getX509Pubkey ::
    Ptr X509__ -> IO (Ptr X509_PUBKEY_)

-- int i2d_X509_PUBKEY(X509_PUBKEY *p, unsigned char **ppout);
foreign import ccall unsafe "i2d_X509_PUBKEY"
  i2dX509Pubkey ::
    Ptr X509_PUBKEY_ ->
    Ptr CString ->
    IO CInt

getPubKeyDer :: X509 -> IO ByteString
getPubKeyDer :: X509 -> IO ByteString
getPubKeyDer X509
x509 =
  X509 -> (Ptr X509_ -> IO ByteString) -> IO ByteString
forall a. X509 -> (Ptr X509_ -> IO a) -> IO a
withX509Ptr X509
x509 ((Ptr X509_ -> IO ByteString) -> IO ByteString)
-> (Ptr X509_ -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr X509_
x509ptr ->
  do Ptr X509_PUBKEY_
pubkey <- Ptr X509__ -> IO (Ptr X509_PUBKEY_)
x509getX509Pubkey (Ptr X509_ -> Ptr X509__
forall a b. Ptr a -> Ptr b
castPtr Ptr X509_
x509ptr)
     Int
len    <- CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr X509_PUBKEY_ -> Ptr (Ptr CChar) -> IO CInt
i2dX509Pubkey Ptr X509_PUBKEY_
pubkey Ptr (Ptr CChar)
forall a. Ptr a
nullPtr
     Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create Int
len ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
bsPtr ->
        Ptr CChar -> (Ptr (Ptr CChar) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
bsPtr) ((Ptr (Ptr CChar) -> IO ()) -> IO ())
-> (Ptr (Ptr CChar) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr CChar)
ptrPtr ->
           () () -> IO CInt -> IO ()
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Ptr X509_PUBKEY_ -> Ptr (Ptr CChar) -> IO CInt
i2dX509Pubkey Ptr X509_PUBKEY_
pubkey Ptr (Ptr CChar)
ptrPtr


-- | Add hostname checking to the certificate verification step.
-- Partial wildcards matching is disabled.
installVerification :: SSLContext -> String {- ^ hostname -} -> IO ()
installVerification :: SSLContext -> String -> IO ()
installVerification SSLContext
ctx String
host =
  SSLContext -> (Ptr SSLContext_ -> IO ()) -> IO ()
forall a. SSLContext -> (Ptr SSLContext_ -> IO a) -> IO a
withContext SSLContext
ctx     ((Ptr SSLContext_ -> IO ()) -> IO ())
-> (Ptr SSLContext_ -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr SSLContext_
ctxPtr ->
  String -> (CStringLen -> IO ()) -> IO ()
forall a. String -> (CStringLen -> IO a) -> IO a
withCStringLen String
host ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr,Int
len) ->
    do Ptr X509_VERIFY_PARAM_
param <- Ptr SSLContext_ -> IO (Ptr X509_VERIFY_PARAM_)
sslGet0Param Ptr SSLContext_
ctxPtr
       Ptr X509_VERIFY_PARAM_ -> CUInt -> IO ()
x509VerifyParamSetHostflags Ptr X509_VERIFY_PARAM_
param
         (CUInt
4)
{-# LINE 132 "src/Hookup/OpenSSL.hsc" #-}
       success <- x509VerifyParamSet1Host param ptr (fromIntegral len)
       Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (CInt
success CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
1) (String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unable to set verification host")