module OpenSSL.EVP.Cipher
    ( Cipher
    , EVP_CIPHER 
    , withCipherPtr 
    , getCipherByName
    , getCipherNames
    , cipherIvLength 
    , CipherCtx 
    , EVP_CIPHER_CTX 
    , newCtx 
    , withCipherCtxPtr 
    , CryptoMode(..)
    , cipherStrictly 
    , cipherLazily 
    , cipher
    , cipherBS
    , cipherLBS
    , cipherStrictLBS
    )
    where
import           Control.Exception(bracket_)
import           Control.Monad
import           Data.ByteString.Internal (createAndTrim)
import           Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as L8
import qualified Data.ByteString.Lazy.Internal as L8Internal
import           Foreign
import           Foreign.C
import           OpenSSL.Objects
import           OpenSSL.Utils
import           System.IO.Unsafe
newtype Cipher     = Cipher (Ptr EVP_CIPHER)
data    EVP_CIPHER
foreign import ccall unsafe "EVP_get_cipherbyname"
        _get_cipherbyname :: CString -> IO (Ptr EVP_CIPHER)
foreign import ccall unsafe "HsOpenSSL_EVP_CIPHER_iv_length"
        _iv_length :: Ptr EVP_CIPHER -> CInt
withCipherPtr :: Cipher -> (Ptr EVP_CIPHER -> IO a) -> IO a
withCipherPtr (Cipher cipherPtr) f = f cipherPtr
getCipherByName :: String -> IO (Maybe Cipher)
getCipherByName name
    = withCString name $ \ namePtr ->
      do ptr <- _get_cipherbyname namePtr
         if ptr == nullPtr then
             return Nothing
           else
             return $ Just $ Cipher ptr
getCipherNames :: IO [String]
getCipherNames = getObjNames CipherMethodType True
cipherIvLength :: Cipher -> Int
cipherIvLength (Cipher cipherPtr) = fromIntegral $ _iv_length cipherPtr
newtype CipherCtx      = CipherCtx (ForeignPtr EVP_CIPHER_CTX)
data    EVP_CIPHER_CTX
foreign import ccall unsafe "EVP_CIPHER_CTX_init"
        _ctx_init :: Ptr EVP_CIPHER_CTX -> IO ()
foreign import ccall unsafe "&EVP_CIPHER_CTX_cleanup"
        _ctx_cleanup :: FunPtr (Ptr EVP_CIPHER_CTX -> IO ())
foreign import ccall unsafe "EVP_CIPHER_CTX_cleanup"
        _ctx_cleanup' :: Ptr EVP_CIPHER_CTX -> IO ()
foreign import ccall unsafe "HsOpenSSL_EVP_CIPHER_CTX_block_size"
        _ctx_block_size :: Ptr EVP_CIPHER_CTX -> CInt
newCtx :: IO CipherCtx
newCtx = do ctx <- mallocForeignPtrBytes ((140))
            withForeignPtr ctx _ctx_init
            addForeignPtrFinalizer _ctx_cleanup ctx
            return $ CipherCtx ctx
withCipherCtxPtr :: CipherCtx -> (Ptr EVP_CIPHER_CTX -> IO a) -> IO a
withCipherCtxPtr (CipherCtx ctx) = withForeignPtr ctx
data CryptoMode = Encrypt | Decrypt
foreign import ccall unsafe "EVP_CipherInit"
        _CipherInit :: Ptr EVP_CIPHER_CTX -> Ptr EVP_CIPHER -> CString -> CString -> CInt -> IO CInt
foreign import ccall unsafe "EVP_CipherUpdate"
        _CipherUpdate :: Ptr EVP_CIPHER_CTX -> Ptr CChar -> Ptr CInt -> Ptr CChar -> CInt -> IO CInt
foreign import ccall unsafe "EVP_CipherFinal"
        _CipherFinal :: Ptr EVP_CIPHER_CTX -> Ptr CChar -> Ptr CInt -> IO CInt
cryptoModeToInt :: CryptoMode -> CInt
cryptoModeToInt Encrypt = 1
cryptoModeToInt Decrypt = 0
cipherInit :: Cipher -> String -> String -> CryptoMode -> IO CipherCtx
cipherInit (Cipher c) key iv mode
    = do ctx <- newCtx
         withCipherCtxPtr ctx $ \ ctxPtr ->
             withCString key $ \ keyPtr ->
                 withCString iv $ \ ivPtr ->
                     _CipherInit ctxPtr c keyPtr ivPtr (cryptoModeToInt mode)
                          >>= failIf_ (/= 1)
         return ctx
cipherStrictLBS :: Cipher         
                -> B8.ByteString  
                -> B8.ByteString  
                -> CryptoMode     
                -> L8.ByteString  
                -> IO L8.ByteString
cipherStrictLBS (Cipher c) key iv mode input =
  allocaBytes ((140)) $ \cptr ->
  bracket_ (_ctx_init cptr) (_ctx_cleanup' cptr) $
  unsafeUseAsCStringLen key $ \(keyp,_) ->
  unsafeUseAsCStringLen iv  $ \(ivp, _) -> do
  failIf_ (/= 1) =<< _CipherInit cptr c keyp ivp (cryptoModeToInt mode)
  cc <- fmap CipherCtx (newForeignPtr_ cptr)
  rr <- cipherUpdateBS cc `mapM` L8.toChunks input
  rf <- cipherFinalBS cc
  return $ L8.fromChunks (rr++[rf])
cipherUpdateBS :: CipherCtx -> B8.ByteString -> IO B8.ByteString
cipherUpdateBS ctx inBS
    = withCipherCtxPtr ctx $ \ ctxPtr ->
      unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
      createAndTrim (inLen + fromIntegral (_ctx_block_size ctxPtr)  1) $ \ outBuf ->
      alloca $ \ outLenPtr ->
      _CipherUpdate ctxPtr (castPtr outBuf) outLenPtr inBuf (fromIntegral inLen)
           >>= failIf (/= 1)
           >>  liftM fromIntegral (peek outLenPtr)
cipherFinalBS :: CipherCtx -> IO B8.ByteString
cipherFinalBS ctx
    = withCipherCtxPtr ctx $ \ ctxPtr ->
      createAndTrim (fromIntegral $ _ctx_block_size ctxPtr) $ \ outBuf ->
      alloca $ \ outLenPtr ->
      _CipherFinal ctxPtr (castPtr outBuf) outLenPtr
           >>= failIf (/= 1)
           >>  liftM fromIntegral (peek outLenPtr)
cipher :: Cipher     
       -> String     
       -> String     
       -> CryptoMode 
       -> String     
                     
                     
                     
       -> IO String  
cipher c key iv mode input
    = liftM L8.unpack $ cipherLBS c key iv mode $ L8.pack input
cipherBS :: Cipher        
         -> String        
         -> String        
         -> CryptoMode    
         -> B8.ByteString    
         -> IO B8.ByteString 
cipherBS c key iv mode input
    = do ctx <- cipherInit c key iv mode
         cipherStrictly ctx input
cipherLBS :: Cipher            
          -> String            
          -> String            
          -> CryptoMode        
          -> L8.ByteString    
          -> IO L8.ByteString 
cipherLBS c key iv mode input
    = do ctx <- cipherInit c key iv mode
         cipherLazily ctx input
cipherStrictly :: CipherCtx -> B8.ByteString -> IO B8.ByteString
cipherStrictly ctx input
    = do output'  <- cipherUpdateBS ctx input
         output'' <- cipherFinalBS ctx
         return $ B8.append output' output''
cipherLazily :: CipherCtx -> L8.ByteString -> IO L8.ByteString
cipherLazily ctx (L8Internal.Empty) =
  cipherFinalBS ctx >>= \ bs -> (return . L8.fromChunks) [bs]
cipherLazily ctx (L8Internal.Chunk x xs) = do
  y  <- cipherUpdateBS ctx x
  ys <- unsafeInterleaveIO $
        cipherLazily ctx xs
  return $ L8Internal.Chunk y ys