module Network.Secure.Identity
    , readPeerIdentity
    , writePeerIdentity
    , LocalIdentity
    , readLocalIdentity
    , writeLocalIdentity
    , toPeerIdentity
    , newLocalIdentity
    , piX509
    , liX509
    , liKey
    , fromX509
    ) where

import Control.Applicative ((<$>))
import Control.Exception (bracket)
import Control.Monad (when)
import Data.ByteString (ByteString, append, hPut)
import Data.ByteString.Char8 (pack, unpack)
import Data.Maybe (fromJust, isNothing)
import HSH
import OpenSSL.EVP.PKey (toKeyPair)
import OpenSSL.PEM (PemPasswordSupply(PwNone), readPrivateKey,
                    writePKCS8PrivateKey, readX509, writeX509)
import OpenSSL.RSA (RSAKeyPair)
import OpenSSL.Session (context, contextSetPrivateKey,
                        contextSetCertificate, contextCheckPrivateKey)
import OpenSSL.X509 (X509, compareX509)
import System.Directory (getTemporaryDirectory, removeFile)
import System.IO (openBinaryTempFile, hFlush)
import System.IO.Unsafe (unsafePerformIO)

-- |The public identity of a peer. This kind of identity can be used
-- to authenticate the remote ends of connections.
newtype PeerIdentity = PI {
      piX509 :: X509

-- |Serialize a 'PeerIdentity' to a 'ByteString' for storage or
-- transmission.
writePeerIdentity :: PeerIdentity -> IO ByteString
writePeerIdentity (PI cert) = pack <$> writeX509 cert

-- |Read back a 'PeerIdentity' previously serialized with
-- 'writePeerIdentity'.
readPeerIdentity :: ByteString -> IO PeerIdentity
readPeerIdentity = fmap PI . readX509 . unpack

instance Eq PeerIdentity where
    a == b = compare a b == EQ

instance Ord PeerIdentity where
    compare (PI a) (PI b) = unsafePerformIO $ compareX509 a b

fromX509 :: X509 -> PeerIdentity
fromX509 = PI

-- |A local identity. This kind of identity can be used to
-- authenticate /to/ remote ends of connections.
data LocalIdentity = LI
      liX509 :: X509 
    , liKey  :: RSAKeyPair

instance Eq LocalIdentity where
    a == b = compare a b == EQ

instance Ord LocalIdentity where
    compare (LI c1 k1) (LI c2 k2) =
        case compare (PI c1) (PI c2) of
            EQ -> compare k1 k2
            GT -> GT
            LT -> LT

-- |Serialize a 'LocalIdentity' to a 'ByteString' for storage.
writeLocalIdentity :: LocalIdentity -> IO ByteString
writeLocalIdentity (LI cert key) = do
    c <- writeX509 cert
    k <- writePKCS8PrivateKey key Nothing
    return $ pack (c ++ k)

-- |Read back a 'LocalIdentity' previously serialized with
-- 'writeLocalIdentity'.
readLocalIdentity :: ByteString -> IO LocalIdentity
readLocalIdentity b = do
    let s = unpack b
    cert <- readX509 s
    key <- toKeyPair <$> readPrivateKey s PwNone
    when (isNothing key) $ fail "Bad private key"
    certMatchesKey cert (fromJust key) >>= \r ->
        if r
        then return $ LI cert (fromJust key)
        else fail "Cert and key don't match"

-- |Extract the public parts of a 'LocalIdentity' into a
-- 'PeerIdentity' suitable for sharing with peers. The resulting
-- 'PeerIdentity' will allow them to verify your identity when you
-- authenticate using the corresponding 'LocalIdentity'.
toPeerIdentity :: LocalIdentity -> PeerIdentity
toPeerIdentity (LI cert _) = PI cert

-- |Generate a new 'LocalIdentity', giving it an identifying name and
-- a validity period in days.
-- Note that this function may take quite a while to execute, as it is
-- generating key material for the identity.
newLocalIdentity :: String -> Int -> IO LocalIdentity
newLocalIdentity commonName days = bracket mkKeyFile rmKeyFile $ \(p,h) -> do
    key <- run genKey
    hPut h key >> hFlush h
    cert <- run $ genCert p
    readLocalIdentity $ append key cert
    mkKeyFile = getTemporaryDirectory >>= flip openBinaryTempFile "key.pem"
    rmKeyFile = removeFile . fst
    genKey = "openssl genrsa 4096 2>/dev/null"
    genCert p = ("openssl", ["req", "-batch", "-new", "-x509",
                             "-key", p, "-nodes",
                             "-subj", "/CN=" ++ commonName,
                             "-days", show days])

certMatchesKey :: X509 -> RSAKeyPair -> IO Bool
certMatchesKey cert key = do
    ctx <- context
    contextSetPrivateKey ctx key
    contextSetCertificate ctx cert
    contextCheckPrivateKey ctx