{-# LANGUAGE OverloadedStrings #-}

module Data.ByteString.Base64
  ( encode
  , decode
  
  , encodeURL
  , decodeURL
  ) where

import Data.Word
import Data.Bits (shiftL, shiftR, (.&.), (.|.))

import qualified Data.ByteString.Internal as Internal
import qualified Data.ByteString.Char8 as B8
import Data.ByteString.BaseN

import Foreign.Ptr (plusPtr, Ptr)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (peek, poke)
import System.IO.Unsafe (unsafePerformIO)

-- Pack a Word32 with a 6-bits at offset `off`. Offset is counted from the right
-- so `pack6 word 0 byte` packs the right-most 6 bits.
pack6 :: Int -> Word8 -> Word32 -> Word32
pack6 off word buf = buf .|. (fromIntegral word `shiftL` (off * 6))

-- Unpack 6 bits from a Word32 at offset `off`. Offset is counted form the right
-- so `unpack6 word 0` unpacks the 6 right-most bits into a Word8.
unpack6 :: Int -> Word32 -> Word8
unpack6 off buf = fromIntegral $ buf `shiftR` (6 * off) .&. 0x3f

pack8x3 :: [Word8] -> Word32
pack8x3 bs = foldr (uncurry pack8) 0 (zip [2,1,0] bs)

unpack8x3 :: Word32 -> [Word8]
unpack8x3 w =
  map (`unpack8` w) [2,1,0]

pack6x4 :: [Word8] -> Word32
pack6x4 bs = foldr (uncurry pack6) 0 (zip [3,2..0] bs)

unpack6x4 :: Word32 -> [Word8]
unpack6x4 bs =
  map (`unpack6` bs) [3,2..0]

encode :: B8.ByteString -> B8.ByteString
encode = encodeAlphabet alphabet

decode :: B8.ByteString -> Either String B8.ByteString
decode = decodeAlphabet alphabet

encodeURL :: B8.ByteString -> B8.ByteString
encodeURL = encodeAlphabet alphabetURL

decodeURL :: B8.ByteString -> Either String B8.ByteString
decodeURL = decodeAlphabet alphabetURL

decodeAlphabet :: Enc -> B8.ByteString -> Either String B8.ByteString
decodeAlphabet enc src@(Internal.PS sfp soff slen)
  | drem /= 0 = Left "ByteString wrong length for valid Base64 encoding padding"
  | dlen <= 0 = Right B8.empty
  | otherwise = 
    unsafePerformIO $ byChunkErr 4 dlen onchunk onend src
    where
      (di, drem) = slen `divMod` 4
      dlen = di * 3

      onchunk :: Ptr Word8 -> Ptr Word8 -> IO (Either String Int)
      onchunk sp dp = do
        words <- traverse (peek . (sp `plusPtr`)) [0..3]
        case (unpack8x3 . pack6x4) <$> traverse (decodeWord enc) words of
          Left err -> return $ Left err
          Right decoded -> do
            pokeN dp 3 decoded
            return $ Right 3

      -- Maps number of padding bytes in source buffer to number of valid bytes
      -- to write to destination buffer.
      validBytes :: Int -> Either String Int
      validBytes 0 = Right 3
      validBytes 1 = Right 2
      validBytes 2 = Right 1
      validBytes n = Left $ "Invalid amount of padding: " ++ show n

      decodeBytes = mapM $ decodeWord enc

      onend :: Ptr Word8 -> Ptr Word8 -> Int -> IO (Either String Int)
      onend sp dp rem = do
        words <- traverse (peek . (sp `plusPtr`)) [0..3]
        let 
          npad = length $ takeWhile (==0x3d) $ reverse words
          decode = (unpack8x3 . pack6x4) <$> decodeBytes (take (4-npad) words)
        case (,) <$> validBytes npad <*> decode of
          Left err -> return $ Left err
          Right (n, decoded) -> do
            pokeN dp n decoded
            return $ Right $ dlen - (3 - n)


alphabet :: Enc
alphabet = mkEnc "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"

alphabetURL :: Enc
alphabetURL = mkEnc "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"

encodeAlphabet :: Enc -> B8.ByteString -> B8.ByteString
encodeAlphabet enc src@(Internal.PS sfp soff slen) =
  unsafePerformIO $ byChunk 3 dlen onchunk onend src

  where
    dlen = ((slen + 2) `div` 3) * 4

    onchunk sp dp = do
      w <- unpack6x4 . pack8x3 <$> traverse (peek . (sp `plusPtr`)) [0..2]
      pokeN dp 4 $ map (encodeWord enc) w
      return 4

    onend sp dp 0 = return ()
    onend sp dp rem = do
      w <- (unpack6x4 . pack8x3) <$> traverse (peek . (sp `plusPtr`)) [0..rem-1]
      let encoded = map (encodeWord enc) w
      poke dp               (encoded !! 0)
      poke (dp `plusPtr` 1) (encoded !! 1)
      poke (dp `plusPtr` 2) (if rem == 1 then 0x3d else encoded !! 2)
      poke (dp `plusPtr` 3) (0x3d :: Word8)