module Data.Radius.StreamGet.Base (
  upacket, packet,

  header, attribute', vendorID, simpleVendorAttribute,

  code, bin128,

  atText, atString, atInteger, atIpV4,

  eof,
  ) where

import Control.Applicative ((<$>), pure, (<*>), (<*), (<|>), many)
import Control.Monad (guard)
import Data.ByteString (ByteString)
import Data.Word (Word8, Word32)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Data.Serialize.Get
  (Get, getWord8, getWord16be, getWord32be,
   getBytes, isEmpty, runGet)

import Data.Radius.Scalar
  (Bin128, mayBin128, AtText (..), AtString (..), AtInteger (..), AtIpV4 (..))
import Data.Radius.Packet
  (Code, Header (Header), Packet (Packet), codeFromWord)
import qualified Data.Radius.Packet as Data
import Data.Radius.Attribute (NumberAbstract (..), Attribute' (..))
import qualified Data.Radius.Attribute as Attribute


code :: Get Code
code = codeFromWord <$> getWord8

pktId :: Get Word8
pktId = getWord8

bin128 :: Get Bin128
bin128 =
  maybe
  (fail "Illegal state: Bin128")
  pure
  . mayBin128 =<< getBytes 16

header :: Get Header
header =
  Header
  <$> code
  <*> pktId
  <*> getWord16be
  <*> bin128

eof :: Get ()
eof = guard =<< isEmpty

packet :: Get a -> Get (Packet a)
packet getAttrs = do
  h   <-  header
  let alen = fromIntegral (Data.pktLength h) - 20 {- sizeof(code) + sizeof(pktId) + sizeof(pkgLength) + sizeof(authenticator) -}
  guard (alen >= 0) <|> fail ("Parse error of header: Packet: invalid length: " ++ show alen)
  bs  <-  getBytes alen
  either
    (fail . ("Parse error of attributes: Packet: " ++))
    (pure . Packet h)
    $ runGet (getAttrs <* eof) bs

radiusNumber :: Get Attribute.Number
radiusNumber = Attribute.fromWord <$> getWord8

vendorID :: Get Word32
vendorID = getWord32be

simpleVendorAttribute :: Get (Word8, ByteString)
simpleVendorAttribute = do
  n    <-  getWord8
  len  <-  getWord8
  bs   <-  getBytes $ fromIntegral len - 2 {- sizeof(number) + sizeof(attribute length) -}
  pure $ (n, bs)

-- {26, length, vendorID(45137), サブ属性番号, サブ属性 length, サブ属性値}

attribute' :: Get (Attribute' v) -> Get (Attribute' v)
attribute' va = do
  n    <- radiusNumber
  len  <- getWord8
  bs   <- getBytes $ fromIntegral len - 2 {- sizeof(number) + sizeof(attribute length) -}
  case n of
    Attribute.VendorSpecific ->
      either (fail . ("Parse error of Vendor-Specific attribute: " ++)) pure
      $ runGet va bs
    _                     ->
      pure $ Attribute' (Standard n) bs

upacket :: Get (Attribute' v) -> Get (Packet [Attribute' v])
upacket va = packet $ many $ attribute' va


atText :: Int -> Get AtText
atText len
  | 0 <= len && len <= 253  =  either
                               (fail . ("Get.atText: fail to decode UTF8: " ++) . show)
                               (pure . AtText . Text.unpack)
                               =<< Text.decodeUtf8' <$> getBytes len
  | len > 253               =  fail $ "Get.atText: Too long: " ++ show len
  | otherwise               =  fail $ "Get.atText: Positive length required: " ++ show len

atString :: Int -> Get AtString
atString len
  | 0 <= len && len <= 253  =  AtString <$> getBytes len
  | len > 253               =  fail $ "Get.atString: Too long: " ++ show len
  | otherwise               =  fail $ "Get.atString: Positive length required: " ++ show len

atInteger :: Get AtInteger
atInteger = AtInteger <$> getWord32be

atIpV4 :: Get AtIpV4
atIpV4 = AtIpV4 <$> getWord32be