{-# LANGUAGE CPP                      #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module OpenSSL.EVP.Base64
    ( 
      encodeBase64
    , encodeBase64BS
    , encodeBase64LBS
      
    , decodeBase64
    , decodeBase64BS
    , decodeBase64LBS
    )
    where
import           Control.Exception (assert)
import           Data.ByteString.Internal (createAndTrim)
import           Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import qualified Data.ByteString.Lazy.Internal as L8Internal
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as L8
import           Data.List
#if MIN_VERSION_base(4,5,0)
import           Foreign.C.Types (CChar(..), CInt(..))
#else
import           Foreign.C.Types (CChar, CInt)
#endif
import           Foreign.Ptr (Ptr, castPtr)
import           System.IO.Unsafe (unsafePerformIO)
nextBlock :: Int -> ([B8.ByteString], L8.ByteString) -> ([B8.ByteString], L8.ByteString)
nextBlock minLen (xs, src)
    = if foldl' (+) 0 (map B8.length xs) >= minLen then
          (xs, src)
      else
          case src of
            L8Internal.Empty      -> (xs, src)
            L8Internal.Chunk y ys -> nextBlock minLen (xs ++ [y], ys)
foreign import ccall unsafe "EVP_EncodeBlock"
        _EncodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
encodeBlock :: B8.ByteString -> B8.ByteString
encodeBlock inBS
    = unsafePerformIO $
      unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
      createAndTrim maxOutLen $ \ outBuf ->
      fmap fromIntegral
           (_EncodeBlock (castPtr outBuf) inBuf (fromIntegral inLen))
    where
      maxOutLen = (inputLen `div` 3 + 1) * 4 + 1 
      inputLen  = B8.length inBS
{-# DEPRECATED encodeBase64 "Use encodeBase64BS or encodeBase64LBS instead." #-}
encodeBase64 :: String -> String
encodeBase64 = L8.unpack . encodeBase64LBS . L8.pack
encodeBase64BS :: B8.ByteString -> B8.ByteString
encodeBase64BS = encodeBlock
encodeBase64LBS :: L8.ByteString -> L8.ByteString
encodeBase64LBS inLBS
    | L8.null inLBS = L8.empty
    | otherwise
        = let (blockParts', remain' ) = nextBlock 3 ([], inLBS)
              block'                  = B8.concat blockParts'
              blockLen'               = B8.length block'
              (block      , leftover) = if blockLen' < 3 then
                                            
                                            (block', B8.empty)
                                        else
                                            B8.splitAt (blockLen' - blockLen' `mod` 3) block'
              remain                  = if B8.null leftover then
                                            remain'
                                        else
                                            L8.fromChunks [leftover] `L8.append` remain'
              encodedBlock             = encodeBlock block
              encodedRemain            = encodeBase64LBS remain
          in
            L8.fromChunks [encodedBlock] `L8.append` encodedRemain
foreign import ccall unsafe "EVP_DecodeBlock"
        _DecodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
decodeBlock :: B8.ByteString -> B8.ByteString
decodeBlock inBS
    = assert (B8.length inBS `mod` 4 == 0) $
      unsafePerformIO $
      unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
      createAndTrim (B8.length inBS) $ \ outBuf ->
      _DecodeBlock (castPtr outBuf) inBuf (fromIntegral inLen)
           >>= \ outLen -> return (fromIntegral outLen - paddingLen)
    where
      paddingLen :: Int
      paddingLen = B8.count '=' inBS
{-# DEPRECATED decodeBase64 "Use decodeBase64BS or decodeBase64LBS instead." #-}
decodeBase64 :: String -> String
decodeBase64 = L8.unpack . decodeBase64LBS . L8.pack
decodeBase64BS :: B8.ByteString -> B8.ByteString
decodeBase64BS = decodeBlock
decodeBase64LBS :: L8.ByteString -> L8.ByteString
decodeBase64LBS inLBS
    | L8.null inLBS = L8.empty
    | otherwise
        = let (blockParts', remain' ) = nextBlock 4 ([], inLBS)
              block'                  = B8.concat blockParts'
              blockLen'               = B8.length block'
              (block      , leftover) = assert (blockLen' >= 4) $
                                        B8.splitAt (blockLen' - blockLen' `mod` 4) block'
              remain                  = if B8.null leftover then
                                            remain'
                                        else
                                            L8.fromChunks [leftover] `L8.append` remain'
              decodedBlock            = decodeBlock block
              decodedRemain           = decodeBase64LBS remain
          in
            L8.fromChunks [decodedBlock] `L8.append` decodedRemain