{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE TypeApplications #-}
module Data.ByteString.Base32.Internal.Loop
( innerLoop
, decodeLoop
) where

import Data.Bits
import Data.ByteString.Internal (ByteString(..))
import Data.ByteString.Base32.Internal.Utils
import Data.Text (Text)
import qualified Data.Text as T

import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Storable

import GHC.Exts
import GHC.Word


-- ------------------------------------------------------------------------ --
-- Encoding loops

innerLoop
    :: Addr#
    -> Ptr Word64
    -> Ptr Word8
    -> Ptr Word8
    -> (Ptr Word8 -> Ptr Word8 -> IO ByteString)
    -> IO ByteString
innerLoop !lut !dptr !sptr !end finish = go dptr sptr
  where
    lix a = w64 (aix (fromIntegral a .&. 0x1f) lut)
    {-# INLINE lix #-}

    go !dst !src
      | plusPtr src 4 >= end = finish (castPtr dst) src
      | otherwise = do
        !t <- peekWord32BE (castPtr src)
        !u <- w32 <$> peek (plusPtr src 4)

        let !a = lix (unsafeShiftR t 27)
            !b = lix (unsafeShiftR t 22)
            !c = lix (unsafeShiftR t 17)
            !d = lix (unsafeShiftR t 12)
            !e = lix (unsafeShiftR t 7)
            !f = lix (unsafeShiftR t 2)
            !g = lix ((unsafeShiftL t 3) .|. (unsafeShiftR u 5))
            !h = lix u

        let !w = a
             .|. (unsafeShiftL b 8)
             .|. (unsafeShiftL c 16)
             .|. (unsafeShiftL d 24)
             .|. (unsafeShiftL e 32)
             .|. (unsafeShiftL f 40)
             .|. (unsafeShiftL g 48)
             .|. (unsafeShiftL h 56)

        poke dst w
        go (plusPtr dst 8) (plusPtr src 5)
{-# INLINE innerLoop #-}

-- ------------------------------------------------------------------------ --
-- Decoding loops

decodeLoop
    :: Addr#
    -> ForeignPtr Word8
    -> Ptr Word8
    -> Ptr Word64
    -> Ptr Word8
    -> IO (Either Text ByteString)
decodeLoop !lut !dfp !dptr !sptr !end = go dptr sptr
  where
    lix a = w64 (aix (fromIntegral a) lut)

    err :: Ptr Word64 -> IO (Either Text ByteString)
    err p = return . Left . T.pack
      $ "invalid character at offset: "
      ++ show (p `minusPtr` sptr)

    padErr :: Ptr Word64 -> IO (Either Text ByteString)
    padErr p =  return . Left . T.pack
      $ "invalid padding at offset: "
      ++ show (p `minusPtr` sptr)

    look :: Ptr Word8 -> IO Word64
    look !p = lix . w64 <$> peek @Word8 p

    go !dst !src
      | plusPtr src 8 >= end = do

        let src' = castPtr src

        a <- look src'
        b <- look (plusPtr src' 1)
        c <- look (plusPtr src' 2)
        d <- look (plusPtr src' 3)
        e <- look (plusPtr src' 4)
        f <- look (plusPtr src' 5)
        g <- look (plusPtr src' 6)
        h <- look (plusPtr src' 7)
        finalChunk dst src a b c d e f g h

      | otherwise = do
        !t <- peekWord64BE src

        let a = lix (unsafeShiftR t 56)
            b = lix (unsafeShiftR t 48)
            c = lix (unsafeShiftR t 40)
            d = lix (unsafeShiftR t 32)
            e = lix (unsafeShiftR t 24)
            f = lix (unsafeShiftR t 16)
            g = lix (unsafeShiftR t 8)
            h = lix t

        decodeChunk dst src a b c d e f g h

    finalChunk !dst !src !a !b !c !d !e !f !g !h
      | a == 0x63 = padErr src
      | b == 0x63 = padErr (plusPtr src 1)
      | a == 0xff = err src
      | b == 0xff = err (plusPtr src 1)
      | c == 0xff = err (plusPtr src 2)
      | d == 0xff = err (plusPtr src 3)
      | e == 0xff = err (plusPtr src 4)
      | f == 0xff = err (plusPtr src 5)
      | g == 0xff = err (plusPtr src 6)
      | h == 0xff = err (plusPtr src 7)
      | otherwise = do

        let !o1 = (fromIntegral a `unsafeShiftL` 3) .|. (fromIntegral b `unsafeShiftR` 2)
            !o2 = (fromIntegral b `unsafeShiftL` 6)
              .|. (fromIntegral c `unsafeShiftL` 1)
              .|. (fromIntegral d `unsafeShiftR` 4)
            !o3 = (fromIntegral d `unsafeShiftL` 4) .|. (fromIntegral e `unsafeShiftR` 1)
            !o4 = (fromIntegral e `unsafeShiftL` 7)
              .|. (fromIntegral f `unsafeShiftL` 2)
              .|. (fromIntegral g `unsafeShiftR` 3)
            !o5 = (fromIntegral g `unsafeShiftL` 5) .|. fromIntegral h

        poke @Word8 dst o1
        poke @Word8 (plusPtr dst 1) o2

        case (c,d,e,f,g,h) of
          (0x63,0x63,0x63,0x63,0x63,0x63) ->
            return (Right (PS dfp 0 (1 + minusPtr dst dptr)))
          (0x63,_,_,_,_,_) -> padErr (plusPtr src 3)
          (_,0x63,0x63,0x63,0x63,0x63) -> padErr (plusPtr src 3)
          (_,0x63,_,_,_,_) -> padErr (plusPtr src 4)
          (_,_,0x63,0x63,0x63,0x63) -> do
            poke @Word8 (plusPtr dst 2) o3
            return (Right (PS dfp 0 (2 + minusPtr dst dptr)))
          (_,_,0x63,_,_,_) -> padErr (plusPtr src 5)
          (_,_,_,0x63,0x63,0x63) -> do
            poke @Word8 (plusPtr dst 2) o3
            poke @Word8 (plusPtr dst 3) o4
            return (Right (PS dfp 0 (3 + minusPtr dst dptr)))
          (_,_,_,0x63,_,_) -> padErr (plusPtr src 6)
          (_,_,_,_,0x63,0x63) -> padErr (plusPtr src 6)
          (_,_,_,_,0x63,_) -> padErr (plusPtr src 7)
          (_,_,_,_,_,0x63) -> do
            poke @Word8 (plusPtr dst 2) o3
            poke @Word8 (plusPtr dst 3) o4
            poke @Word8 (plusPtr dst 4) o5
            return (Right (PS dfp 0 (4 + minusPtr dst dptr)))
          (_,_,_,_,_,_) -> do
            poke @Word8 (plusPtr dst 2) o3
            poke @Word8 (plusPtr dst 3) o4
            poke @Word8 (plusPtr dst 4) o5
            return (Right (PS dfp 0 (5 + minusPtr dst dptr)))

    decodeChunk !dst !src !a !b !c !d !e !f !g !h
      | a == 0x63 = padErr src
      | b == 0x63 = padErr (plusPtr src 1)
      | c == 0x63 = padErr (plusPtr src 2)
      | d == 0x63 = padErr (plusPtr src 3)
      | e == 0x63 = padErr (plusPtr src 4)
      | f == 0x63 = padErr (plusPtr src 5)
      | g == 0x63 = padErr (plusPtr src 6)
      | h == 0x63 = padErr (plusPtr src 7)
      | a == 0xff = err src
      | b == 0xff = err (plusPtr src 1)
      | c == 0xff = err (plusPtr src 2)
      | d == 0xff = err (plusPtr src 3)
      | e == 0xff = err (plusPtr src 4)
      | f == 0xff = err (plusPtr src 5)
      | g == 0xff = err (plusPtr src 6)
      | h == 0xff = err (plusPtr src 7)
      | otherwise = do

        let !w = ((unsafeShiftL a 35)
              .|. (unsafeShiftL b 30)
              .|. (unsafeShiftL c 25)
              .|. (unsafeShiftL d 20)
              .|. (unsafeShiftL e 15)
              .|. (unsafeShiftL f 10)
              .|. (unsafeShiftL g 5)
              .|. h) :: Word64

        poke @Word32 (castPtr dst) (byteSwap32 (fromIntegral (unsafeShiftR w 8)))
        poke @Word8 (plusPtr dst 4) (fromIntegral w)
        go (plusPtr dst 5) (plusPtr src 8)