{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP                        #-}
module Raaz.Core.Encode.Base16
       ( Base16
       , fromBase16, showBase16
       ) where
import Data.Char
import Data.Bits
import Data.String
import Data.ByteString as B
import Data.ByteString.Char8 as C8
import Data.ByteString.Internal (c2w)
import Data.ByteString.Unsafe(unsafeIndex)
import Data.Monoid
import Data.Word
import Prelude
import Raaz.Core.Encode.Internal
newtype Base16 = Base16 {unBase16 :: ByteString}
#if MIN_VERSION_base(4,11,0)
                 deriving (Eq, Semigroup, Monoid)
#else
                 deriving (Eq, Monoid)
#endif
instance Encodable Base16 where
  toByteString          = hex . unBase16
  fromByteString bs
    | B.length bs `mod` 2 /= 0 = Nothing
    | validInput bs            = Just $ Base16 $ unsafeFromHex bs
    | otherwise                = Nothing
    where validInput  = C8.all isHexDigit
  unsafeFromByteString bs
    | B.length bs `mod` 2 /= 0 = error "base16 encoding is always of even size"
    | otherwise                = Base16 $ unsafeFromHex bs
instance Show Base16 where
  show = C8.unpack . toByteString
instance IsString Base16 where
  fromString = unsafeFromByteString . C8.filter (not . useless) . fromString
    where useless c = isSpace c || c == ':'
instance Format Base16 where
  encodeByteString = Base16
  {-# INLINE encodeByteString #-}
  decodeFormat     = unBase16
  {-# INLINE decodeFormat #-}
hex :: ByteString -> ByteString
hex  bs = fst $ B.unfoldrN (2 * B.length bs) gen 0
    where gen i | rm == 0   = Just (hexDigit $ top4 w, i+1)
                | otherwise = Just (hexDigit $ bot4 w, i+1)
            where (idx, rm) = quotRem i 2
                  w         = unsafeIndex bs idx
hexDigit :: Word8 -> Word8
hexDigit x | x < 10    = c2w '0' + x
           | otherwise = c2w 'a' + (x - 10)
top4 :: Word8 -> Word8; top4 x  = x `shiftR` 4
bot4 :: Word8 -> Word8; bot4 x  = x  .&. 0x0F
unsafeFromHex :: ByteString -> ByteString
unsafeFromHex bs = fst $ B.unfoldrN len gen 0
  where len   = B.length bs `quot` 2
        gen i = Just (shiftL w0 4 .|. w1, i + 1)
          where w0 = fromHexWord $ unsafeIndex bs (2 * i)
                w1 = fromHexWord $ unsafeIndex bs (2 * i + 1)
        fromHexWord x
          | c2w '0' <= x && x <= c2w '9' = x - c2w '0'
          | c2w 'a' <= x && x <= c2w 'f' = 10 + (x - c2w 'a')
          | c2w 'A' <= x && x <= c2w 'F' = 10 + (x - c2w 'A')
          | otherwise                    = error "bad base16 character"
fromBase16 :: Encodable a => String -> a
fromBase16 = unsafeFromByteString . unBase16 . fromString
showBase16 :: Encodable a => a -> String
showBase16 = show . Base16 . toByteString