{-# LANGUAGE BangPatterns #-}
module Net.IPv4.ByteString.Char8
  ( encode
  , decode
  , builder
  , parser
  ) where

import Net.IPv4
import Control.Monad
import qualified Net.IPv4.Text as IPv4Text
import qualified Data.Attoparsec.ByteString.Char8 as AB
import qualified Data.ByteString.Builder as Builder
import Net.Internal (rightToMaybe)
import Data.Text.Encoding (encodeUtf8, decodeUtf8')
import Data.ByteString (ByteString,replicate)
import Data.Attoparsec.ByteString.Char8 (Parser)
import Data.ByteString.Builder (Builder)
import Data.ByteString.Char8 as BC8

import Data.ByteString.Internal as I
import Data.Bits
import Foreign.Ptr
import Foreign.Storable
import Data.Word
import Data.ByteString.Unsafe as BSU
import Data.Monoid

-- | This should be rewritten to not create 'Text' as an
--   intermediate step.
encode :: IPv4 -> ByteString
encode = toBSPreAllocated

toBSPreAllocated :: IPv4 -> ByteString
toBSPreAllocated (IPv4 !w) = I.unsafeCreateUptoN 15 (\ptr1 ->
  do len1 <- writeWord ptr1 w1
     let ptr2 = ptr1 `plusPtr` len1
     poke ptr2 dot
     len2 <- writeWord (ptr2 `plusPtr` 1) w2
     let ptr3 = ptr2 `plusPtr` len2 `plusPtr` 1
     poke ptr3 dot
     len3 <- writeWord (ptr3 `plusPtr` 1) w3
     let ptr4 = ptr3 `plusPtr` len3 `plusPtr` 1
     poke ptr4 dot
     len4 <- writeWord (ptr4 `plusPtr` 1) w4
     return (3 + len1 + len2 + len3 + len4))
  where w1 = fromIntegral $ shiftR w 24
        w2 = fromIntegral $ shiftR w 16
        w3 = fromIntegral $ shiftR w 8
        w4 = fromIntegral w
        dot = 46 :: Word8
        writeWord :: Ptr Word8 -> Word8 -> IO Int
        writeWord !ptr !word
          | word >= 100 = do
              let int = fromIntegral word
                  indx = int + int + int
                  get3 = fromIntegral . BSU.unsafeIndex threeDigits
              poke ptr (get3 indx)
              poke (ptr `plusPtr` 1) (get3 (indx + 1))
              poke (ptr `plusPtr` 2) (get3 (indx + 2))
              return 3
          | word >= 10 = do
              let int = fromIntegral word
                  indx = int + int
                  get2 = fromIntegral . BSU.unsafeIndex twoDigits
              poke ptr (get2 indx)
              poke (ptr `plusPtr` 1) (get2 (indx + 1))
              return 2
          | otherwise = do
              poke ptr (word + 48)
              return 1

twoDigits :: ByteString
twoDigits = BC8.pack
  "0001020304050607080910111213141516171819\
  \2021222324252627282930313233343536373839\
  \4041424344454647484950515253545556575859\
  \6061626364656667686970717273747576777879\
  \8081828384858687888990919293949596979899"

threeDigits :: ByteString
threeDigits =
  Data.ByteString.replicate 300 0 <> BC8.pack
  "100101102103104105106107108109110111112\
  \113114115116117118119120121122123124125\
  \126127128129130131132133134135136137138\
  \139140141142143144145146147148149150151\
  \152153154155156157158159160161162163164\
  \165166167168169170171172173174175176177\
  \178179180181182183184185186187188189190\
  \191192193194195196197198199200201202203\
  \204205206207208209210211212213214215216\
  \217218219220221222223224225226227228229\
  \230231232233234235236237238239240241242\
  \243244245246247248249250251252253254255"

-- | This should also be rewritten
decode :: ByteString -> Maybe IPv4
decode = IPv4Text.decode <=< rightToMaybe . decodeUtf8'

builder :: IPv4 -> Builder
builder = Builder.byteString . encode

parser :: Parser IPv4
parser = fromOctets'
  <$> (AB.decimal >>= limitSize)
  <*  AB.char '.'
  <*> (AB.decimal >>= limitSize)
  <*  AB.char '.'
  <*> (AB.decimal >>= limitSize)
  <*  AB.char '.'
  <*> (AB.decimal >>= limitSize)
  where
  limitSize i =
    if i > 255
      then fail "All octets in an ip address must be between 0 and 255"
      else return i