module Data.ByteString.Base64
( encode
, decode
, encodeURL
, decodeURL
) where
import Data.Word
import Data.Bits (shiftL, shiftR, (.&.), (.|.))
import qualified Data.ByteString.Internal as Internal
import qualified Data.ByteString.Char8 as B8
import Data.ByteString.BaseN
import Foreign.Ptr (plusPtr, Ptr)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (peek, poke)
import System.IO.Unsafe (unsafePerformIO)
pack6 :: Int -> Word8 -> Word32 -> Word32
pack6 off word buf = buf .|. (fromIntegral word `shiftL` (off * 6))
unpack6 :: Int -> Word32 -> Word8
unpack6 off buf = fromIntegral $ buf `shiftR` (6 * off) .&. 0x3f
pack8x3 :: [Word8] -> Word32
pack8x3 bs = foldr (uncurry pack8) 0 (zip [2,1,0] bs)
unpack8x3 :: Word32 -> [Word8]
unpack8x3 w =
map (`unpack8` w) [2,1,0]
pack6x4 :: [Word8] -> Word32
pack6x4 bs = foldr (uncurry pack6) 0 (zip [3,2..0] bs)
unpack6x4 :: Word32 -> [Word8]
unpack6x4 bs =
map (`unpack6` bs) [3,2..0]
encode :: B8.ByteString -> B8.ByteString
encode = encodeAlphabet alphabet
decode :: B8.ByteString -> Either String B8.ByteString
decode = decodeAlphabet alphabet
encodeURL :: B8.ByteString -> B8.ByteString
encodeURL = encodeAlphabet alphabetURL
decodeURL :: B8.ByteString -> Either String B8.ByteString
decodeURL = decodeAlphabet alphabetURL
decodeAlphabet :: Enc -> B8.ByteString -> Either String B8.ByteString
decodeAlphabet enc src@(Internal.PS sfp soff slen)
| drem /= 0 = Left "ByteString wrong length for valid Base64 encoding padding"
| dlen <= 0 = Right B8.empty
| otherwise =
unsafePerformIO $ byChunkErr 4 dlen onchunk onend src
where
(di, drem) = slen `divMod` 4
dlen = di * 3
onchunk :: Ptr Word8 -> Ptr Word8 -> IO (Either String Int)
onchunk sp dp = do
words <- traverse (peek . (sp `plusPtr`)) [0..3]
case (unpack8x3 . pack6x4) <$> traverse (decodeWord enc) words of
Left err -> return $ Left err
Right decoded -> do
pokeN dp 3 decoded
return $ Right 3
validBytes :: Int -> Either String Int
validBytes 0 = Right 3
validBytes 1 = Right 2
validBytes 2 = Right 1
validBytes n = Left $ "Invalid amount of padding: " ++ show n
decodeBytes = mapM $ decodeWord enc
onend :: Ptr Word8 -> Ptr Word8 -> Int -> IO (Either String Int)
onend sp dp rem = do
words <- traverse (peek . (sp `plusPtr`)) [0..3]
let
npad = length $ takeWhile (==0x3d) $ reverse words
decode = (unpack8x3 . pack6x4) <$> decodeBytes (take (4npad) words)
case (,) <$> validBytes npad <*> decode of
Left err -> return $ Left err
Right (n, decoded) -> do
pokeN dp n decoded
return $ Right $ dlen (3 n)
alphabet :: Enc
alphabet = mkEnc "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
alphabetURL :: Enc
alphabetURL = mkEnc "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
encodeAlphabet :: Enc -> B8.ByteString -> B8.ByteString
encodeAlphabet enc src@(Internal.PS sfp soff slen) =
unsafePerformIO $ byChunk 3 dlen onchunk onend src
where
dlen = ((slen + 2) `div` 3) * 4
onchunk sp dp = do
w <- unpack6x4 . pack8x3 <$> traverse (peek . (sp `plusPtr`)) [0..2]
pokeN dp 4 $ map (encodeWord enc) w
return 4
onend sp dp 0 = return ()
onend sp dp rem = do
w <- (unpack6x4 . pack8x3) <$> traverse (peek . (sp `plusPtr`)) [0..rem1]
let encoded = map (encodeWord enc) w
poke dp (encoded !! 0)
poke (dp `plusPtr` 1) (encoded !! 1)
poke (dp `plusPtr` 2) (if rem == 1 then 0x3d else encoded !! 2)
poke (dp `plusPtr` 3) (0x3d :: Word8)