{-# LANGUAGE OverloadedStrings #-} module Data.ByteString.Base32 ( encode , decode , encodeHex , decodeHex ) where import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as Internal import Data.ByteString.BaseN import Data.Word import Data.Bits (shiftL, shiftR, (.|.), (.&.)) import Foreign.Ptr (plusPtr, Ptr) import Foreign.Storable (peek, poke) import System.IO.Unsafe (unsafePerformIO) pack5 :: Int -> Word8 -> Word64 -> Word64 pack5 off word buf = buf .|. (fromIntegral word `shiftL` (off * 5)) pack8x5 :: [Word8] -> Word64 pack8x5 bs = foldr (uncurry pack8) 0 (zip [4,3..0] bs) unpack8x5 :: Word64 -> [Word8] unpack8x5 w = map (`unpack8` w) [4,3..0] unpack5 :: Int -> Word64 -> Word8 unpack5 off buf = fromIntegral $ buf `shiftR` (5 * off) .&. 0x1f pack5x8 :: [Word8] -> Word64 pack5x8 bs = foldr (uncurry pack5) 0 (zip [7,6..0] bs) unpack5x8 :: Word64 -> [Word8] unpack5x8 bs = map (`unpack5` bs) [7,6..0] padCeilN :: Int -> Int -> Int padCeilN n x | remd == 0 = x | otherwise = (x - remd) + n where mask = n - 1 remd = x .&. mask encodeHex :: B8.ByteString -> B8.ByteString encodeHex = encodeAlphabet alphabetHex decodeHex :: B8.ByteString -> Either String B8.ByteString decodeHex = decodeAlphabet alphabetHex encode :: B8.ByteString -> B8.ByteString encode = encodeAlphabet alphabet decode :: B8.ByteString -> Either String B8.ByteString decode = decodeAlphabet alphabet encodeAlphabet :: Enc -> B8.ByteString -> B8.ByteString encodeAlphabet enc src@(Internal.PS sfp soff slen) = unsafePerformIO $ byChunk 5 dlen onchunk onend src where (d, m) = (slen * 8) `divMod` 5 dlen = padCeilN 8 (d + if m == 0 then 0 else 1) onchunk sp dp = do words <- unpack5x8 . pack8x5 <$> traverse peek (map (sp `plusPtr`) [0..4]) pokeN dp 8 $ map (encodeWord enc) words return 8 onend sp dp 0 = return () onend sp dp rem = do words <- unpack5x8 . pack8x5 <$> sequence [ peek sp , if rem > 1 then peek $ sp `plusPtr` 1 else return 0 , if rem > 2 then peek $ sp `plusPtr` 2 else return 0 , if rem > 3 then peek $ sp `plusPtr` 3 else return 0 , if rem > 4 then peek $ sp `plusPtr` 4 else return 0 ] let encoded = map (encodeWord enc) words poke dp (encoded !! 0) poke (dp `plusPtr` 1) (encoded !! 1) poke (dp `plusPtr` 2) (if rem < 2 then 0x3d else encoded !! 2) poke (dp `plusPtr` 3) (if rem < 2 then 0x3d else encoded !! 3) poke (dp `plusPtr` 4) (if rem < 3 then 0x3d else encoded !! 4) poke (dp `plusPtr` 5) (if rem < 4 then 0x3d else encoded !! 5) poke (dp `plusPtr` 6) (if rem < 4 then 0x3d else encoded !! 6) poke (dp `plusPtr` 7) (0x3d :: Word8) alphabet :: Enc alphabet = mkEnc "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" alphabetHex :: Enc alphabetHex = mkEnc "0123456789ABCDEFGHIJKLMNOPQRSTUV" decodeAlphabet :: Enc -> B8.ByteString -> Either String B8.ByteString decodeAlphabet enc src@(Internal.PS sfp soff slen) | slen `mod` 8 /= 0 = Left "ByteString wrong length for valid Base32 encoding" | slen == 0 = Right BS.empty | otherwise = unsafePerformIO $ byChunkErr 8 dlen onchunk onend src where (d, m) = (slen * 5) `divMod` 8 dlen = d + if m == 0 then 0 else 1 -- Maps number of padding bytes in source buffer to number of valid bytes -- to write to destination buffer. validBytes :: Int -> Either String Int validBytes 0 = Right 5 validBytes 1 = Right 4 validBytes 3 = Right 3 validBytes 4 = Right 2 validBytes 6 = Right 1 validBytes n = Left $ "Invalid amount of padding: " ++ show n decodeBytes = mapM $ decodeWord enc onchunk :: Ptr Word8 -> Ptr Word8 -> IO (Either String Int) onchunk sp dp = do words <- traverse (peek . (sp `plusPtr`)) [0..7] case (unpack8x5 . pack5x8) <$> decodeBytes words of Left err -> return $ Left err Right decoded -> do pokeN dp 5 decoded return $ Right 5 onend :: Ptr Word8 -> Ptr Word8 -> Int -> IO (Either String Int) onend sp dp rem = do words <- traverse (peek . (sp `plusPtr`)) [0..7] let npad = length $ takeWhile (==0x3d) $ reverse words decode = (unpack8x5 . pack5x8) <$> decodeBytes (take (8-npad) words) case (,) <$> validBytes npad <*> decode of Left err -> return $ Left err Right (n, decoded) -> do pokeN dp n decoded return $ Right $ dlen - (5 - n)