module Database.PostgreSQL.Typed.Inet where
import Control.Monad (void, guard, liftM2)
import qualified Data.ByteString.Char8 as BSC
import Data.Bits (shiftL, (.|.))
import Data.Maybe (fromJust)
import Data.Word (Word8, Word16, Word32)
import Foreign.Marshal.Array (withArray)
import Foreign.Ptr (castPtr)
import Foreign.Storable (peek)
import qualified Network.Socket as Net
import Numeric (readDec, readHex)
import System.IO.Unsafe (unsafeDupablePerformIO)
import qualified Text.ParserCombinators.ReadP as RP
import qualified Text.ParserCombinators.ReadPrec as RP (lift)
import Text.Read (Read(readPrec))
import Database.PostgreSQL.Typed.Types
data PGInet 
  = PGInet
    { pgInetAddr :: !Net.HostAddress
    , pgInetMask :: !Word8
    }
  | PGInet6
    { pgInetAddr6 :: !Net.HostAddress6
    , pgInetMask :: !Word8
    }
  deriving (Eq)
sockAddrPGInet :: Net.SockAddr -> Maybe PGInet
sockAddrPGInet (Net.SockAddrInet _ a) = Just $ PGInet a 32
sockAddrPGInet (Net.SockAddrInet6 _ _ a _) = Just $ PGInet6 a 128
sockAddrPGInet _ = Nothing
bton32 :: (Word8, Word8, Word8, Word8) -> Word32
bton32 (b1, b2, b3, b4) = unsafeDupablePerformIO $
  withArray [b1, b2, b3, b4] (peek . castPtr)
instance Show PGInet where
  
  show (PGInet a 32) = unsafeDupablePerformIO $ Net.inet_ntoa a
  show (PGInet a m) = show (PGInet a 32) ++ '/' : show m
  show (PGInet6 a 128) = fromJust $ fst $ unsafeDupablePerformIO $
    Net.getNameInfo [Net.NI_NUMERICHOST] True False (Net.SockAddrInet6 0 0 a 0)
  show (PGInet6 a m) = show (PGInet6 a 128) ++ '/' : show m
instance Read PGInet where
  
  
  readPrec = RP.lift $ r4 RP.+++ r6 where
    r4i = do
      o1 <- rdec
      _ <- RP.char '.'
      o2 <- rdec
      _ <- RP.char '.'
      o3 <- rdec
      _ <- RP.char '.'
      o4 <- rdec
      return (o1, o2, o3, o4)
    
    r4 = do
      q <- r4i
      m <- mask 32
      return $ PGInet (bton32 q) m
    
    r64 = do
      (b1, b2, b3, b4) <- r4i
      return [jb b1 b2, jb b3 b4]
    
    r6l 0 = return []
    r6l 2 = colon >> r6lc 2 RP.+++ r64
    r6l n = colon >> r6lc n
    r6lc n = r6lp n RP.+++ r6b n
    r6lp n = r6w (r6l (pred n))
    
    r6b n = do
      colon
      r <- r6rp (pred n) RP.<++ return []
      let l = length r
      return $ replicate (n  l) 0 ++ r
    
    r6r 0 = return []
    r6r n = (colon >> r6rp n) RP.<++ return []
    r6rp n
      | n >= 2 = r6rc n RP.+++ r64
      | otherwise = r6rc n
    r6rc n = r6w (r6r (pred n))
    r6w = liftM2 (:) rhex
    
    r6 = do
      [w1, w2, w3, w4, w5, w6, w7, w8] <- r6lp 8 RP.<++ (colon >> r6b 8)
      m <- mask 128
      return $ PGInet6 (jw w1 w2, jw w3 w4, jw w5 w6, jw w7 w8) m
    colon = void $ RP.char ':'
    mask m = RP.option m $ do
      _ <- RP.char '/'
      n <- rdec
      guard (n <= m)
      return n
    rdec :: RP.ReadP Word8
    rdec = RP.readS_to_P readDec
    rhex :: RP.ReadP Word16
    rhex = RP.readS_to_P readHex
    jw :: Word16 -> Word16 -> Word32
    jw x y = fromIntegral x `shiftL` 16 .|. fromIntegral y
    jb :: Word8 -> Word8 -> Word16
    jb x y = fromIntegral x `shiftL` 8 .|. fromIntegral y
instance PGType "inet" where
  type PGVal "inet" = PGInet
instance PGType "cidr" where
  type PGVal "cidr" = PGInet
instance PGParameter "inet" PGInet where
  pgEncode _ = BSC.pack . show
instance PGParameter "cidr" PGInet where
  pgEncode _ = BSC.pack . show
instance PGColumn "inet" PGInet where
  pgDecode _ = read . BSC.unpack
instance PGColumn "cidr" PGInet where
  pgDecode _ = read . BSC.unpack