{-# LANGUAGE CPP                      #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE CApiFFI                  #-}
-- |An interface to Base64 codec.
module OpenSSL.EVP.Base64
    ( -- * Encoding
      encodeBase64
    , encodeBase64BS
    , encodeBase64LBS

      -- * Decoding
    , decodeBase64
    , decodeBase64BS
    , decodeBase64LBS
    )
    where
import           Control.Exception (assert)
import           Data.ByteString.Internal (createAndTrim)
import           Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import qualified Data.ByteString.Lazy.Internal as L8Internal
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as L8
import           Data.List
#if MIN_VERSION_base(4,5,0)
import           Foreign.C.Types (CChar(..), CInt(..))
#else
import           Foreign.C.Types (CChar, CInt)
#endif
import           Foreign.Ptr (Ptr, castPtr)
import           System.IO.Unsafe (unsafePerformIO)


-- On encoding, we keep fetching the next block until we get at least
-- 3 bytes. Then we apply B8.concat to the returned [ByteString] and
-- split it at the offset in multiple of 3, then prepend the remaining
-- bytes to the next block.
--
-- On decoding, we apply the same algorithm but we split the input in
-- multiple of 4.
nextBlock :: Int -> ([B8.ByteString], L8.ByteString) -> ([B8.ByteString], L8.ByteString)
nextBlock :: Int -> ([ByteString], ByteString) -> ([ByteString], ByteString)
nextBlock Int
minLen ([ByteString]
xs, ByteString
src)
    = if (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 ((ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Int
B8.length [ByteString]
xs) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
minLen then
          ([ByteString]
xs, ByteString
src)
      else
          case ByteString
src of
            ByteString
L8Internal.Empty      -> ([ByteString]
xs, ByteString
src)
            L8Internal.Chunk ByteString
y ByteString
ys -> Int -> ([ByteString], ByteString) -> ([ByteString], ByteString)
nextBlock Int
minLen ([ByteString]
xs [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [ByteString
y], ByteString
ys)


{- encode -------------------------------------------------------------------- -}

foreign import capi unsafe "openssl/evp.h EVP_EncodeBlock"
        _EncodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt


encodeBlock :: B8.ByteString -> B8.ByteString
encodeBlock :: ByteString -> ByteString
encodeBlock ByteString
inBS
    = IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
      ByteString -> (CStringLen -> IO ByteString) -> IO ByteString
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
inBS ((CStringLen -> IO ByteString) -> IO ByteString)
-> (CStringLen -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \ (Ptr CChar
inBuf, Int
inLen) ->
      Int -> (Ptr Word8 -> IO Int) -> IO ByteString
createAndTrim Int
maxOutLen ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
outBuf ->
      (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
           (Ptr CChar -> Ptr CChar -> CInt -> IO CInt
_EncodeBlock (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
outBuf) Ptr CChar
inBuf (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inLen))
    where
      maxOutLen :: Int
maxOutLen = (Int
inputLen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 -- +1: '\0'
      inputLen :: Int
inputLen  = ByteString -> Int
B8.length ByteString
inBS


-- |@'encodeBase64' str@ lazilly encodes a stream of data to
-- Base64. The string doesn't have to be finite. Note that the string
-- must not contain any letters which aren't in the range of U+0000 -
-- U+00FF.
{-# DEPRECATED encodeBase64 "Use encodeBase64BS or encodeBase64LBS instead." #-}
encodeBase64 :: String -> String
encodeBase64 :: String -> String
encodeBase64 = ByteString -> String
L8.unpack (ByteString -> String)
-> (String -> ByteString) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
encodeBase64LBS (ByteString -> ByteString)
-> (String -> ByteString) -> String -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
L8.pack

-- |@'encodeBase64BS' bs@ strictly encodes a chunk of data to Base64.
encodeBase64BS :: B8.ByteString -> B8.ByteString
encodeBase64BS :: ByteString -> ByteString
encodeBase64BS = ByteString -> ByteString
encodeBlock

-- |@'encodeBase64LBS' lbs@ lazilly encodes a stream of data to
-- Base64. The string doesn't have to be finite.
encodeBase64LBS :: L8.ByteString -> L8.ByteString
encodeBase64LBS :: ByteString -> ByteString
encodeBase64LBS ByteString
inLBS
    | ByteString -> Bool
L8.null ByteString
inLBS = ByteString
L8.empty
    | Bool
otherwise
        = let ([ByteString]
blockParts', ByteString
remain' ) = Int -> ([ByteString], ByteString) -> ([ByteString], ByteString)
nextBlock Int
3 ([], ByteString
inLBS)
              block' :: ByteString
block'                  = [ByteString] -> ByteString
B8.concat [ByteString]
blockParts'
              blockLen' :: Int
blockLen'               = ByteString -> Int
B8.length ByteString
block'
              (ByteString
block      , ByteString
leftover) = if Int
blockLen' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
3 then
                                            -- The last remnant.
                                            (ByteString
block', ByteString
B8.empty)
                                        else
                                            Int -> ByteString -> (ByteString, ByteString)
B8.splitAt (Int
blockLen' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockLen' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
3) ByteString
block'
              remain :: ByteString
remain                  = if ByteString -> Bool
B8.null ByteString
leftover then
                                            ByteString
remain'
                                        else
                                            [ByteString] -> ByteString
L8.fromChunks [ByteString
leftover] ByteString -> ByteString -> ByteString
`L8.append` ByteString
remain'
              encodedBlock :: ByteString
encodedBlock             = ByteString -> ByteString
encodeBlock ByteString
block
              encodedRemain :: ByteString
encodedRemain            = ByteString -> ByteString
encodeBase64LBS ByteString
remain
          in
            [ByteString] -> ByteString
L8.fromChunks [ByteString
encodedBlock] ByteString -> ByteString -> ByteString
`L8.append` ByteString
encodedRemain


{- decode -------------------------------------------------------------------- -}

foreign import capi unsafe "openssl/evp.h EVP_DecodeBlock"
        _DecodeBlock :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt


decodeBlock :: B8.ByteString -> B8.ByteString
decodeBlock :: ByteString -> ByteString
decodeBlock ByteString
inBS
    = Bool -> ByteString -> ByteString
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ByteString -> Int
B8.length ByteString
inBS Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
      IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
      ByteString -> (CStringLen -> IO ByteString) -> IO ByteString
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
inBS ((CStringLen -> IO ByteString) -> IO ByteString)
-> (CStringLen -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \ (Ptr CChar
inBuf, Int
inLen) ->
      Int -> (Ptr Word8 -> IO Int) -> IO ByteString
createAndTrim (ByteString -> Int
B8.length ByteString
inBS) ((Ptr Word8 -> IO Int) -> IO ByteString)
-> (Ptr Word8 -> IO Int) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \ Ptr Word8
outBuf ->
      Ptr CChar -> Ptr CChar -> CInt -> IO CInt
_DecodeBlock (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
outBuf) Ptr CChar
inBuf (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
inLen)
           IO CInt -> (CInt -> IO Int) -> IO Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ CInt
outLen -> Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
outLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
paddingLen)
    where
      paddingLen :: Int
      paddingLen :: Int
paddingLen = Char -> ByteString -> Int
B8.count Char
'=' ByteString
inBS

-- |@'decodeBase64' str@ lazilly decodes a stream of data from
-- Base64. The string doesn't have to be finite.
{-# DEPRECATED decodeBase64 "Use decodeBase64BS or decodeBase64LBS instead." #-}
decodeBase64 :: String -> String
decodeBase64 :: String -> String
decodeBase64 = ByteString -> String
L8.unpack (ByteString -> String)
-> (String -> ByteString) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
decodeBase64LBS (ByteString -> ByteString)
-> (String -> ByteString) -> String -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
L8.pack

-- |@'decodeBase64BS' bs@ strictly decodes a chunk of data from
-- Base64.
decodeBase64BS :: B8.ByteString -> B8.ByteString
decodeBase64BS :: ByteString -> ByteString
decodeBase64BS = ByteString -> ByteString
decodeBlock

-- |@'decodeBase64LBS' lbs@ lazilly decodes a stream of data from
-- Base64. The string doesn't have to be finite.
decodeBase64LBS :: L8.ByteString -> L8.ByteString
decodeBase64LBS :: ByteString -> ByteString
decodeBase64LBS ByteString
inLBS
    | ByteString -> Bool
L8.null ByteString
inLBS = ByteString
L8.empty
    | Bool
otherwise
        = let ([ByteString]
blockParts', ByteString
remain' ) = Int -> ([ByteString], ByteString) -> ([ByteString], ByteString)
nextBlock Int
4 ([], ByteString
inLBS)
              block' :: ByteString
block'                  = [ByteString] -> ByteString
B8.concat [ByteString]
blockParts'
              blockLen' :: Int
blockLen'               = ByteString -> Int
B8.length ByteString
block'
              (ByteString
block      , ByteString
leftover) = Bool -> (ByteString, ByteString) -> (ByteString, ByteString)
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
blockLen' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
4) ((ByteString, ByteString) -> (ByteString, ByteString))
-> (ByteString, ByteString) -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$
                                        Int -> ByteString -> (ByteString, ByteString)
B8.splitAt (Int
blockLen' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
blockLen' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) ByteString
block'
              remain :: ByteString
remain                  = if ByteString -> Bool
B8.null ByteString
leftover then
                                            ByteString
remain'
                                        else
                                            [ByteString] -> ByteString
L8.fromChunks [ByteString
leftover] ByteString -> ByteString -> ByteString
`L8.append` ByteString
remain'
              decodedBlock :: ByteString
decodedBlock            = ByteString -> ByteString
decodeBlock ByteString
block
              decodedRemain :: ByteString
decodedRemain           = ByteString -> ByteString
decodeBase64LBS ByteString
remain
          in
            [ByteString] -> ByteString
L8.fromChunks [ByteString
decodedBlock] ByteString -> ByteString -> ByteString
`L8.append` ByteString
decodedRemain