module Hasql.Core.Scanner where

import Hasql.Prelude
import Hasql.Core.Model
import Scanner (Scanner)
import qualified Scanner as A
import qualified Data.ByteString as B
import qualified Data.Vector as D
import qualified Hasql.Core.MessageTypePredicates as C
import qualified Hasql.Core.NoticeFieldTypes as E


{-# INLINE word8 #-}
word8 :: Scanner Word8
word8 =
  A.anyWord8

{-# INLINE word16 #-}
word16 :: Scanner Word16
word16 =
  numOfSize 2

{-# INLINE word32 #-}
word32 :: Scanner Word32
word32 =
  numOfSize 4

{-# INLINE word64 #-}
word64 :: Scanner Word64
word64 =
  numOfSize 8

{-# INLINE numOfSize #-}
numOfSize :: (Bits a, Num a) => Int -> Scanner a
numOfSize size =
  B.foldl' (\n h -> shiftL n 8 .|. fromIntegral h) 0 <$> A.take size

{-# INLINE int32 #-}
int32 :: Scanner Int32
int32 =
  fromIntegral <$> word32

{-# INLINE messageTypeAndLength #-}
messageTypeAndLength :: (Word8 -> Word32 -> a) -> Scanner a
messageTypeAndLength cont =
  cont <$> word8 <*> payloadLength

{-# INLINE payloadLength #-}
payloadLength :: (Integral a, Bits a) => Scanner a
payloadLength =
  subtract 4 <$> numOfSize 4

{-# INLINE messageTypeAndPayload #-}
messageTypeAndPayload :: (Word8 -> ByteString -> a) -> Scanner a
messageTypeAndPayload cont =
  cont <$> word8 <*> (payloadLength >>= A.take)

-- |
-- Integral number encoded in ASCII.
{-# INLINE asciiIntegral #-}
asciiIntegral :: Integral a => Scanner a
asciiIntegral =
  B.foldl' step 0 <$> A.takeWhile byteIsDigit
  where
    byteIsDigit byte =
      byte - 48 <= 9
    step !state !byte =
      state * 10 + fromIntegral byte - 48

{-# INLINE nullTerminatedString #-}
nullTerminatedString :: Scanner ByteString
nullTerminatedString =
  A.takeWhile (/= 0) <* A.anyWord8

-- * Responses
-------------------------

{-# INLINE response #-}
response :: Scanner (Maybe Response)
response =
  do
    type_ <- word8
    bodyLength <- payloadLength
    if
      | C.dataRow type_ -> dataRowBody (Just . DataRowResponse)
      | C.commandComplete type_ -> commandCompleteBody (Just . CommandCompleteResponse)
      | C.readyForQuery type_ -> readyForQueryBody (Just . ReadyForQueryResponse)
      | C.parseComplete type_ -> pure (Just ParseCompleteResponse)
      | C.bindComplete type_ -> pure (Just BindCompleteResponse)
      | C.emptyQuery type_ -> pure (Just EmptyQueryResponse)
      | C.notification type_ -> Just <$> notificationBody NotificationResponse
      | C.error type_ -> Just <$> errorResponseBody bodyLength ErrorResponse
      | C.authentication type_ -> Just <$> authenticationBody AuthenticationResponse
      | C.parameterStatus type_ -> Just <$> parameterStatusBody ParameterStatusResponse
      | True -> A.take bodyLength $> Nothing

{-# INLINE dataRowBody #-}
dataRowBody :: (Vector (Maybe ByteString) -> result) -> Scanner result
dataRowBody result =
  do
    amountOfColumns <- word16
    bytesVector <- D.replicateM (fromIntegral amountOfColumns) sizedBytes
    return (result bytesVector)

{-# INLINE commandCompleteBody #-}
commandCompleteBody :: (Int -> result) -> Scanner result
commandCompleteBody result =
  do
    header <- A.takeWhile byteIsUpperLetter
    A.anyWord8
    count <- case header of
      "INSERT" -> A.skipWhile byteIsDigit *> A.anyWord8 *> asciiIntegral <* A.anyWord8
      _ -> asciiIntegral <* A.anyWord8
    return (result count)
  where
    byteIsUpperLetter byte =
      byte - 65 <= 25
    byteIsDigit byte =
      byte - 48 <= 9

{-# INLINE readyForQueryBody #-}
readyForQueryBody :: (TransactionStatus -> result) -> Scanner result
readyForQueryBody result =
  do
    statusByte <- A.anyWord8
    case statusByte of
      73 -> return (result IdleTransactionStatus)
      84 -> return (result ActiveTransactionStatus)
      69 -> return (result FailedTransactionStatus)
      _ -> fail (showString "Unexpected transaction status byte: " (show statusByte))

{-# INLINE notificationBody #-}
notificationBody :: (Word32 -> ByteString -> ByteString -> result) -> Scanner result
notificationBody result =
  result <$> word32 <*> nullTerminatedString <*> nullTerminatedString

{-# INLINE errorResponseBody #-}
errorResponseBody :: Int -> (ByteString -> ByteString -> result) -> Scanner result
errorResponseBody length result =
  do
    tuple <- iterate 0 Nothing Nothing
    A.anyWord8
    case tuple of
      (Just code, Just message) -> return (result code message)
      _ -> fail "Some of the required error fields are missing"
  where
    iterate !offset code message  =
      if offset < length - 1
        then join (noticeField (\type_ payload ->
          if
            | type_ == E.code -> iterate (offset + 2 + B.length payload) (Just payload) message
            | type_ == E.message -> iterate (offset + 2 + B.length payload) code (Just payload)
            | True -> iterate (offset + 2 + B.length payload) code message))
        else return (code, message)

{-# INLINE noticeField #-}
noticeField :: (Word8 -> ByteString -> result) -> Scanner result
noticeField result =
  result <$> word8 <*> nullTerminatedString

{-# INLINE authenticationBody #-}
authenticationBody :: (AuthenticationStatus -> result) -> Scanner result
authenticationBody result =
  do
    status <- word32
    case status of
      0 -> return (result OkAuthenticationStatus)
      3 -> return (result NeedClearTextPasswordAuthenticationStatus)
      5 -> do
        salt <- A.take 4
        return (result (NeedMD5PasswordAuthenticationStatus salt))
      _ -> fail ("Unsupported authentication status: " <> show status)

{-# INLINE parameterStatusBody #-}
parameterStatusBody :: (ByteString -> ByteString -> result) -> Scanner result
parameterStatusBody result =
  result <$> nullTerminatedString <*> nullTerminatedString

{-|
Int32
The length of the column value, in bytes (this count does not include itself). Can be zero. As a special case, -1 indicates a NULL column value. No value bytes follow in the NULL case.

Byten
The value of the column, in the format indicated by the associated format code. n is the above length.
-}
{-# INLINE sizedBytes #-}
sizedBytes :: Scanner (Maybe ByteString)
sizedBytes =
  do
    size <- fromIntegral <$> word32
    if size == -1
      then return Nothing
      else Just <$> A.take size