module Dropbox.Certificates 
  ( CertVerifierFunc
  , CertVerifier(..)
  , certVerifierInsecure
  , certVerifierFromPemFile
  , certVerifierFromRootCerts
  , certVerifierFromDbX509s
  ) where
import System.IO                       (withFile, IOMode(ReadMode))
import Data.ByteString                 (ByteString)
import qualified Data.ByteString       as B
import qualified Data.ByteString.Lazy  as LB
import qualified Data.ByteString.Char8 as B8
import           Data.PEM              (PEM(..))
import qualified Data.PEM              as PEM
import           Data.Certificate.X509 (X509)
import qualified Data.Certificate.X509 as X509
import           Data.Time.Clock       (UTCTime(utctDay), getCurrentTime)
import qualified Network.TLS           as TLS
import qualified Network.TLS.Extra     as TLSExtra
import Dropbox.Certificates.TH
dbX509s :: [X509]
dbX509s = [x509File|trustedcerts.crt|]
certVerifierFromDbX509s :: CertVerifier
certVerifierFromDbX509s = CertVerifier "compiled in Dropbox certificates" (certVerifierFromRootCerts dbX509s)
type CertVerifierFunc =
    ByteString                     
    -> [X509]                      
    -> IO TLS.TLSCertificateUsage  
data CertVerifier = CertVerifier
    { certVerifierName :: String           
    , certVerifierFunc :: CertVerifierFunc 
    }
instance Show CertVerifier where
    show (CertVerifier name _) = "CertVerifier " ++ show name
certVerifierInsecure :: CertVerifier
certVerifierInsecure = CertVerifier "insecure" (\_ _ -> return TLS.CertificateUsageAccept)
rightsOrFirstLeft :: [Either a b] -> Either a [b]
rightsOrFirstLeft = foldr f (Right [])
    where
        f (Left e) _ = Left e
        f _ (Left e) = Left e
        f (Right v) (Right vs) = Right (v:vs)
certVerifierFromPemFile :: FilePath -> IO (Either String CertVerifier)
certVerifierFromPemFile filePath = do
    raw <- withFile filePath ReadMode B.hGetContents
    case PEM.pemParseBS raw of
        Left err -> return $ Left err
        Right pems -> do
            let es = [X509.decodeCertificate (LB.fromChunks [stuff]) | PEM _ _ stuff <- pems]
            case rightsOrFirstLeft es of
                Left err -> return $ Left err
                Right x509s -> return $ Right $ CertVerifier ("PEM file: " ++ show filePath) (certVerifierFromRootCerts x509s)
certAll :: [IO TLS.TLSCertificateUsage] -> IO TLS.TLSCertificateUsage
certAll [] = return TLS.CertificateUsageAccept
certAll (head:rest) = do
    r <- head
    case r of
        TLS.CertificateUsageAccept -> certAll rest
        reject -> return $ reject
certVerifierFromRootCerts ::
    [X509]            
    -> ByteString     
    -> [X509]         
    -> IO TLS.TLSCertificateUsage
certVerifierFromRootCerts roots domain chain = do
        utcTime <- getCurrentTime
        let day = utctDay utcTime
        certAll
            [ return $ TLSExtra.certificateVerifyDomain (B8.unpack domain) chain
            , checkTrustChain day chain
            ]
    where
        checkTrustChain _ [] = return $ TLS.CertificateUsageReject $ TLS.CertificateRejectOther "empty chain"
        checkTrustChain day (head:rest) = do
            if isUnexpired day head
                then do
                    issuerMatch <- mapM (head `isIssuedBy`) roots
                    if any (== True) issuerMatch
                        then return $ TLS.CertificateUsageAccept
                        else case rest of
                            [] -> return $ TLS.CertificateUsageReject TLS.CertificateRejectUnknownCA
                            (next:_) -> do
                                nextOk <- TLSExtra.certificateVerifyAgainst head next
                                if nextOk
                                    then checkTrustChain day rest
                                    else return $ TLS.CertificateUsageReject $ TLS.CertificateRejectOther "break in verification chain"
                else return $ TLS.CertificateUsageReject $ TLS.CertificateRejectExpired
        isIssuedBy :: X509 -> X509 -> IO Bool
        isIssuedBy c issuer =
            if subjectDN issuer == issuerDN c
                then TLSExtra.certificateVerifyAgainst c issuer
                else return False
        subjectDN c = X509.certSubjectDN $ X509.x509Cert c
        issuerDN c = X509.certIssuerDN $ X509.x509Cert c
        isUnexpired day cert =
            let ((beforeDay, _, _), (afterDay, _, _)) = X509.certValidity (X509.x509Cert cert)
            in beforeDay < day && day <= afterDay