module OpenSSL.EVP.Base64
    ( 
      encodeBase64
    , encodeBase64BS
    , encodeBase64LBS
      
    , decodeBase64
    , decodeBase64BS
    , decodeBase64LBS
    )
    where
import           Control.Exception hiding (block)
import           Data.ByteString.Internal (createAndTrim)
import           Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import qualified Data.ByteString.Lazy.Internal as L8Internal
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as L8
import           Data.List
import           Foreign
import           Foreign.C
nextBlock :: Int -> ([B8.ByteString], L8.ByteString) -> ([B8.ByteString], L8.ByteString)
nextBlock minLen (xs, src)
    = if foldl' (+) 0 (map B8.length xs) >= minLen then
          (xs, src)
      else
          case src of
            L8Internal.Empty      -> (xs, src)
            L8Internal.Chunk y ys -> nextBlock minLen (xs ++ [y], ys)
foreign import ccall unsafe "EVP_EncodeBlock"
        _EncodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
encodeBlock :: B8.ByteString -> B8.ByteString
encodeBlock inBS
    = unsafePerformIO $
      unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
      createAndTrim maxOutLen $ \ outBuf ->
      fmap fromIntegral
           (_EncodeBlock (castPtr outBuf) inBuf (fromIntegral inLen))
    where
      maxOutLen = (inputLen `div` 3 + 1) * 4 + 1 
      inputLen  = B8.length inBS
encodeBase64 :: String -> String
encodeBase64 = L8.unpack . encodeBase64LBS . L8.pack
encodeBase64BS :: B8.ByteString -> B8.ByteString
encodeBase64BS = encodeBlock
encodeBase64LBS :: L8.ByteString -> L8.ByteString
encodeBase64LBS inLBS
    | L8.null inLBS = L8.empty
    | otherwise
        = let (blockParts', remain' ) = nextBlock 3 ([], inLBS)
              block'                  = B8.concat blockParts'
              blockLen'               = B8.length block'
              (block      , leftover) = if blockLen' < 3 then
                                            
                                            (block', B8.empty)
                                        else
                                            B8.splitAt (blockLen'  blockLen' `mod` 3) block'
              remain                  = if B8.null leftover then
                                            remain'
                                        else
					    L8.fromChunks [leftover] `L8.append` remain'
              encodedBlock             = encodeBlock block
              encodedRemain            = encodeBase64LBS remain
          in
            L8.fromChunks [encodedBlock] `L8.append` encodedRemain
foreign import ccall unsafe "EVP_DecodeBlock"
        _DecodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt
decodeBlock :: B8.ByteString -> B8.ByteString
decodeBlock inBS
    = assert (B8.length inBS `mod` 4 == 0) $
      unsafePerformIO $
      unsafeUseAsCStringLen inBS $ \ (inBuf, inLen) ->
      createAndTrim (B8.length inBS) $ \ outBuf ->
      _DecodeBlock (castPtr outBuf) inBuf (fromIntegral inLen)
           >>= \ outLen -> return (fromIntegral outLen  paddingLen)
    where
      paddingLen :: Int
      paddingLen = B8.count '=' inBS
decodeBase64 :: String -> String
decodeBase64 = L8.unpack . decodeBase64LBS . L8.pack
decodeBase64BS :: B8.ByteString -> B8.ByteString
decodeBase64BS = decodeBlock
decodeBase64LBS :: L8.ByteString -> L8.ByteString
decodeBase64LBS inLBS
    | L8.null inLBS = L8.empty
    | otherwise
        = let (blockParts', remain' ) = nextBlock 4 ([], inLBS)
              block'                  = B8.concat blockParts'
              blockLen'               = B8.length block'
              (block      , leftover) = assert (blockLen' >= 4) $
                                        B8.splitAt (blockLen'  blockLen' `mod` 4) block'
              remain                  = if B8.null leftover then
                                            remain'
                                        else
					    L8.fromChunks [leftover] `L8.append` remain'
              decodedBlock            = decodeBlock block
              decodedRemain           = decodeBase64LBS remain
          in
            L8.fromChunks [decodedBlock] `L8.append` decodedRemain