{-|
Module      : Z.Data.Vector.Base64
Description : Base64 codec for bytes.
Copyright   : (c) Dong Han, 2017-2018
License     : BSD
Maintainer  : winterland1989@gmail.com
Stability   : experimental
Portability : non-portable

This module provides base64 encoding & decoding tools, as well as 'Base64Bytes' newtype with base64 textual instances.

-}

module Z.Data.Vector.Base64
  (-- * Encoding & Decoding functions
    base64Encode
  , base64EncodeLength
  , base64EncodeText
  , base64EncodeBuilder
  , base64Decode
  , base64Decode'
  , base64DecodeLength
  , Base64DecodeException(..)
  -- * Internal C FFIs
  ,  hs_base64_encode, hs_base64_decode
  ) where

import           Control.Exception
import           Data.Word
import           Data.Bits                      (unsafeShiftL, unsafeShiftR, (.&.))
import           GHC.Stack
import           System.IO.Unsafe
import qualified Z.Data.Vector.Base         as V
import qualified Z.Data.Builder.Base        as B
import qualified Z.Data.Text.Base           as T
import           Z.Foreign

-- | Encode 'V.Bytes' using base64 encoding.
base64Encode :: V.Bytes -> V.Bytes
{-# INLINABLE base64Encode #-}
base64Encode :: Bytes -> Bytes
base64Encode (V.PrimVector PrimArray Word8
arr Int
s Int
l) = forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ do
    forall a b.
Prim a =>
Int -> (MBA# a -> IO b) -> IO (PrimVector a, b)
allocPrimVectorUnsafe (Int -> Int
base64EncodeLength Int
l) forall a b. (a -> b) -> a -> b
$ \ MBA# a
buf# ->
        forall a b. Prim a => PrimArray a -> (BA# a -> Int -> IO b) -> IO b
withPrimArrayUnsafe PrimArray Word8
arr forall a b. (a -> b) -> a -> b
$ \ BA# a
parr Int
_ ->
            MBA# a -> Int -> BA# a -> Int -> Int -> IO ()
hs_base64_encode MBA# a
buf# Int
0 BA# a
parr Int
s Int
l

-- | Return the encoded length of a given input length, always a multipler of 4.
base64EncodeLength :: Int -> Int
{-# INLINE base64EncodeLength #-}
base64EncodeLength :: Int -> Int
base64EncodeLength Int
n = ((Int
nforall a. Num a => a -> a -> a
+Int
2) forall a. Integral a => a -> a -> a
`quot` Int
3) forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2

-- | 'B.Builder' version of 'base64Encode'.
base64EncodeBuilder :: V.Bytes -> B.Builder ()
{-# INLINE base64EncodeBuilder #-}
base64EncodeBuilder :: Bytes -> Builder ()
base64EncodeBuilder (V.PrimVector PrimArray Word8
arr Int
s Int
l) =
    Int
-> (MutablePrimArray RealWorld Word8 -> Int -> IO ()) -> Builder ()
B.writeN (Int -> Int
base64EncodeLength Int
l) (\ (MutablePrimArray MBA# a
mba#) Int
i -> do
        forall a b. Prim a => PrimArray a -> (BA# a -> Int -> IO b) -> IO b
withPrimArrayUnsafe PrimArray Word8
arr forall a b. (a -> b) -> a -> b
$ \ BA# a
parr Int
_ ->
            MBA# a -> Int -> BA# a -> Int -> Int -> IO ()
hs_base64_encode MBA# a
mba# Int
i BA# a
parr Int
s Int
l)

-- | Text version of 'base64Encode'.
base64EncodeText :: V.Bytes -> T.Text
{-# INLINABLE base64EncodeText #-}
base64EncodeText :: Bytes -> Text
base64EncodeText = Bytes -> Text
T.Text forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Bytes
base64Encode

-- | Decode a base64 encoding string, return Nothing on illegal bytes or incomplete input.
base64Decode :: V.Bytes -> Maybe V.Bytes
{-# INLINABLE base64Decode #-}
base64Decode :: Bytes -> Maybe Bytes
base64Decode Bytes
ba
    | Int
inputLen forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. a -> Maybe a
Just forall (v :: * -> *) a. Vec v a => v a
V.empty
    | Int
decodeLen forall a. Eq a => a -> a -> Bool
== -Int
1 = forall a. Maybe a
Nothing
    | Bool
otherwise = forall a. IO a -> a
unsafeDupablePerformIO forall a b. (a -> b) -> a -> b
$ do
        (PrimArray Word8
arr, Int
r) <- forall a b.
Prim a =>
PrimVector a -> (BA# a -> Int -> Int -> IO b) -> IO b
withPrimVectorUnsafe Bytes
ba forall a b. (a -> b) -> a -> b
$ \ BA# a
ba# Int
s Int
l ->
            forall a b.
Prim a =>
Int -> (MBA# a -> IO b) -> IO (PrimArray a, b)
allocPrimArrayUnsafe Int
decodeLen forall a b. (a -> b) -> a -> b
$ \ MBA# a
buf# ->
                MBA# a -> BA# a -> Int -> Int -> IO Int
hs_base64_decode MBA# a
buf# BA# a
ba# Int
s Int
l
        if Int
r forall a. Eq a => a -> a -> Bool
== Int
0
        then forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        else forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (forall a. PrimArray a -> Int -> Int -> PrimVector a
V.PrimVector PrimArray Word8
arr Int
0 Int
r))
  where
    inputLen :: Int
inputLen = forall (v :: * -> *) a. Vec v a => v a -> Int
V.length Bytes
ba
    decodeLen :: Int
decodeLen = Int -> Int
base64DecodeLength Int
inputLen

-- | Exception during base64 decoding.
data Base64DecodeException = IllegalBase64Bytes V.Bytes CallStack
                           | IncompleteBase64Bytes V.Bytes CallStack
                        deriving Int -> Base64DecodeException -> ShowS
[Base64DecodeException] -> ShowS
Base64DecodeException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Base64DecodeException] -> ShowS
$cshowList :: [Base64DecodeException] -> ShowS
show :: Base64DecodeException -> String
$cshow :: Base64DecodeException -> String
showsPrec :: Int -> Base64DecodeException -> ShowS
$cshowsPrec :: Int -> Base64DecodeException -> ShowS
Show
instance Exception Base64DecodeException

-- | Decode a base64 encoding string, throw 'Base64DecodeException' on error.
base64Decode' :: HasCallStack => V.Bytes -> V.Bytes
{-# INLINABLE base64Decode' #-}
base64Decode' :: HasCallStack => Bytes -> Bytes
base64Decode' Bytes
ba = case Bytes -> Maybe Bytes
base64Decode Bytes
ba of
    Just Bytes
r -> Bytes
r
    Maybe Bytes
_ -> forall a e. Exception e => e -> a
throw (Bytes -> CallStack -> Base64DecodeException
IllegalBase64Bytes Bytes
ba HasCallStack => CallStack
callStack)

-- | Return the upper bound of decoded length of a given input length
-- , return -1 if illegal(not a multipler of 4).
base64DecodeLength :: Int -> Int
{-# INLINE base64DecodeLength #-}
base64DecodeLength :: Int -> Int
base64DecodeLength Int
n | Int
n forall a. Bits a => a -> a -> a
.&. Int
3 forall a. Eq a => a -> a -> Bool
== Int
1 = -Int
1
                     | Bool
otherwise = (Int
n forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2) forall a. Num a => a -> a -> a
* Int
3 forall a. Num a => a -> a -> a
+ Int
2

--------------------------------------------------------------------------------

foreign import ccall unsafe hs_base64_encode :: MBA# Word8 -> Int -> BA# Word8 -> Int -> Int -> IO ()
foreign import ccall unsafe hs_base64_decode :: MBA# Word8 -> BA# Word8 -> Int -> Int -> IO Int