{-# LANGUAGE ExistentialQuantification, ScopedTypeVariables,
             GeneralizedNewtypeDeriving, DeriveDataTypeable
  #-}

module Network.DBus.Value (
  Endianness(..),

  DValue(..), DBasicTypedValue(..),

  ObjectPath, mkObjectPath, getPath, -- constructor not exported

  DString, mkDString, mkDString0, getString,

  Variant(..), fromVariant,

  Bytes, Serializer, runSerializer, advanceBy, padTo,
  Deserializer, runDeserializer, skipTo, deserializeAs
) where

import Control.Monad (when, forM_, liftM2, liftM3, liftM4, liftM5)
import Data.Word (Word8, Word16, Word32, Word64)
import Data.Int (Int16, Int32, Int64)
import Data.Char (chr, ord)
import Foreign (Ptr, alloca, castPtr, peek, poke)
import System.IO.Unsafe (unsafePerformIO)

import qualified Control.Monad.State as S
import qualified Control.Monad.Reader as R
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.UTF8 as BS8
import qualified Data.Map as M
import qualified Data.Typeable as T

import Data.Binary.Get
import Data.Binary.Put

import Network.DBus.Type (DBasicType(..), DType(..), Signature(..))

data Endianness = LittleEndian
                | BigEndian
  deriving Show

type Bytes = Int

bigLittle :: a -> a -> Endianness -> a
bigLittle b l e = case e of BigEndian -> b
                            LittleEndian -> l

padding :: Int -> Int -> Int
padding position alignment = let offset = position `mod` alignment
                             in if offset > 0 then alignment - offset else 0

--
-- Serialization fu
--

type Serializer = R.ReaderT Endianness (S.StateT Bytes PutM) ()

runSerializer :: Endianness -> Serializer -> LBS.ByteString
runSerializer e s = runPut (S.runStateT (R.runReaderT s e) 0 >> return ())

advanceBy :: Bytes -> Serializer
advanceBy n = R.lift $ S.modify (+n)

padTo :: Bytes -> Serializer
padTo alignment = do
  pos <- R.lift S.get
  let bytes = padding pos alignment
  when (bytes > 0) $ R.lift $ S.lift (sequence_ . replicate bytes $ putWord8 0)
  advanceBy bytes

basicSerializer :: DValue a => (a -> Put) -> (a -> Put) -> a -> Serializer
basicSerializer b l thing = do
  padTo $ alignment thing
  e <- R.ask
  R.lift $ S.lift $ bigLittle b l e $ thing
  advanceBy $ alignment thing

--
-- Deserialization fu
--

type Deserializer a = R.ReaderT Endianness Get a

runDeserializer :: Endianness -> Deserializer a -> LBS.ByteString -> a
runDeserializer e d = runGet $ R.runReaderT d e

basicDeserializer :: forall a. DValue a => Get a -> Get a -> Deserializer a
basicDeserializer b l = do
  skipTo $ alignment (undefined :: a)
  R.ask >>= R.lift . bigLittle b l

skipTo :: Bytes -> Deserializer ()
skipTo alignment = do
  pos <- R.lift $ fromIntegral `fmap` bytesRead
  let jump = padding pos alignment
  when (jump > 0) $ do
      bytes <- R.lift (getByteString jump)
      when (not . BS.all (== 0) $ bytes) $ error "bad padding"

--
-- Value classes
--

class (Eq a, Show a, T.Typeable a) => DValue a where
    dtype :: a -> DType
    alignment :: a -> Bytes
    serializer :: a -> Serializer
    deserializer :: Deserializer a

class (DValue a, Ord a) => DBasicTypedValue a where
    dbasictype :: a -> DBasicType
    dbasictype x = let DBasicType t = dtype x in t

--
-- ObjectPath type wrapping Strings with mkObjectPath ensuring that the various
-- rules about object paths are respected.
--

newtype ObjectPath = MkObjectPath String
  deriving (Eq, Ord, T.Typeable)
instance Show ObjectPath where
  show (MkObjectPath path) = show path

getPath :: ObjectPath -> String
getPath (MkObjectPath p) = p

splitOn :: Eq a => a -> [a] -> [[a]]
splitOn c xs =
    let (foo, bar) = break (== c) xs
    in case bar of
        (_:bs) -> foo:splitOn c bs
        []     -> [foo]

mkObjectPath :: Monad m => String -> m ObjectPath
mkObjectPath [] = fail "object paths must be non-empty"
mkObjectPath str@(x:xs) = do
    when (x /= '/') $ fail "object paths must begin with /"
    when (any (`notElem` validChars) str) $
        fail ("object paths may only contain chars from '" ++ validChars ++ "'")
    when (not (null xs) && any null (splitOn '/' xs)) $
        fail "no element of an object path may be the empty string"
    return $ MkObjectPath str
  where validChars = ['a'..'z'] ++ ['A'..'Z'] ++ ['0'..'9'] ++ "_/"

--
-- DString type wrapping Strings with mkDString ensuring there are no null
-- bytes.
--

newtype DString = MkDString { getString :: String }
  deriving (Eq, Ord, T.Typeable)

instance Show DString where
  show = show . getString

mkDString :: Monad m => String -> m DString
mkDString str = do
  when ('\NUL' `elem` str) $ fail "null bytes not allowed in D-Bus strings"
  return $ MkDString str

mkDString0 :: String -> DString
mkDString0 = MkDString . (filter (/= '\NUL'))

--
-- DBasicTypedValue instances
--

instance DBasicTypedValue ObjectPath where
instance DValue ObjectPath where
  dtype _ = DBasicType DTypeObjectPath
  alignment _ = 4
  serializer (MkObjectPath s) = serializer (MkDString s)
  deserializer = MkObjectPath `fmap` getString `fmap` deserializer

instance DBasicTypedValue Bool where
instance DValue Bool where
  dtype _ = DBasicType DTypeBoolean
  alignment _ = 4
  serializer b = serializer ((if b then 1 else 0) :: Word32)
  deserializer = do (b :: Word32) <- deserializer
                    case b of
                        1 -> return True
                        0 -> return False
                        _ -> fail "bad boolean value"

-- Useful, but you can't do [Char] without getting DTypeString
-- It's also a bit of a lie because characters are truncated FIXME:UTF-8
instance DBasicTypedValue Char where
instance DValue Char where
  dtype _ = DBasicType DTypeByte
  alignment _ = 1
  serializer c = let word = fromIntegral $ ord c :: Word8
                 in serializer word
  deserializer = chr `fmap` fromIntegral
                     `fmap` (deserializer :: Deserializer Word8)

instance DBasicTypedValue Word8 where
instance DValue Word8 where
  dtype _ = DBasicType DTypeByte
  alignment _ = 1
  serializer x = (R.lift $ S.lift $ putWord8 x) >> advanceBy 1 >> return ()
  deserializer = S.lift getWord8

instance DBasicTypedValue Word16 where
instance DValue Word16 where
  dtype _ = DBasicType DTypeUInt16
  alignment _ = 2
  serializer = basicSerializer putWord16be putWord16le
  deserializer = basicDeserializer getWord16be getWord16le

instance DBasicTypedValue Word32 where
instance DValue Word32 where
  dtype _ = DBasicType DTypeUInt32
  alignment _ = 4
  serializer = basicSerializer putWord32be putWord32le
  deserializer = basicDeserializer getWord32be getWord32le

instance DBasicTypedValue Word64 where
instance DValue Word64 where
  dtype _ = DBasicType DTypeUInt64
  alignment _ = 8
  serializer = basicSerializer putWord64be putWord64le
  deserializer = basicDeserializer getWord64be getWord64le

instance DBasicTypedValue Int16 where
instance DValue Int16 where
  dtype _ = DBasicType DTypeInt16
  alignment _ = 2
  serializer = basicSerializer (putWord16be . fromIntegral)
                               (putWord16le . fromIntegral)
  deserializer = basicDeserializer (fromIntegral `fmap` getWord16be)
                                   (fromIntegral `fmap` getWord16le)

instance DBasicTypedValue Int32 where
instance DValue Int32 where
  dtype _ = DBasicType DTypeInt32
  alignment _ = 4
  serializer = basicSerializer (putWord32be . fromIntegral)
                               (putWord32le . fromIntegral)
  deserializer = basicDeserializer (fromIntegral `fmap` getWord32be)
                                   (fromIntegral `fmap` getWord32le)

instance DBasicTypedValue Int64 where
instance DValue Int64 where
  dtype _ = DBasicType DTypeInt64
  alignment _ = 8
  serializer = basicSerializer (putWord64be . fromIntegral)
                               (putWord64le . fromIntegral)
  deserializer = basicDeserializer (fromIntegral `fmap` getWord64be)
                                   (fromIntegral `fmap` getWord64le)

-- Nasty.

doubleToWord64 :: Double -> Word64
doubleToWord64 d = unsafePerformIO $ alloca $
  \(p :: Ptr Double) -> poke p d >> peek (castPtr p :: Ptr Word64)

word64ToDouble :: Word64 -> Double
word64ToDouble w = unsafePerformIO $ alloca $
  \(p :: Ptr Word64) -> poke p w >> peek (castPtr p :: Ptr Double)

instance DBasicTypedValue Double where
instance DValue Double where
  dtype _ = DBasicType DTypeDouble
  alignment _ = 8
  serializer = basicSerializer (putWord64be . doubleToWord64)
                               (putWord64le . doubleToWord64)
  deserializer = basicDeserializer (word64ToDouble `fmap` getWord64be)
                                   (word64ToDouble `fmap` getWord64le)

getNull :: R.ReaderT Endianness Get ()
getNull = do
   c :: Word8 <- deserializer
   when (c /= 0) $ fail $ "expecting null byte, got " ++ show c

instance DBasicTypedValue DString where
instance DValue DString where
  dtype _ = DBasicType DTypeString
  alignment _ = 4
  serializer (MkDString s) = do
    let bytes = BS8.fromString s
    serializer (fromIntegral (BS.length bytes) :: Word32)
    mapM_ serializer . BS.unpack $ bytes
    serializer (0 :: Word8)
  deserializer = do
    l <- fromIntegral `fmap` (deserializer :: Deserializer Word32)
    s <- S.lift $ getByteString l
    getNull
    return . MkDString . BS8.toString $ s

instance DBasicTypedValue Signature where
instance DValue Signature where
  dtype _ = DBasicType DTypeSignature
  alignment _ = 1
  serializer s = do
    -- XXX: duplication of String serializer
    let s' = show s
    let l = length s'
    when (l > 255) $ fail "signatures must be no more than 255 bytes long."
    serializer (fromIntegral l :: Word8)
    let bytes = map (fromIntegral . ord) s' :: [Word8]
    mapM_ serializer bytes
    serializer (0 :: Word8)
  deserializer = do
    -- XXX: duplication of String deserializer
    l <- fromIntegral `fmap` S.lift getWord8
    s <- read `fmap` map (chr . fromIntegral)
              `fmap` BS.unpack
              `fmap` S.lift (getByteString l)
    getNull
    return s

data Variant = forall v. DValue v => Variant { unVariant :: v }
    deriving T.Typeable

fromVariant :: T.Typeable a => Variant -> Maybe a
fromVariant (Variant v) = T.cast v

instance Show Variant where
    showsPrec n (Variant v) =
        showParen (n > 0) ( showString "Variant "
                          . showParen True (showsPrec (n+1) v)
                          . showString " {- "
                          . showsPrec 0 (dtype v)
                          . showString " -}"
                          )

instance Eq Variant where
    (Variant (x :: a)) == (Variant (y :: b)) = Just x == T.cast y

instance DValue Variant where
  dtype _ = DTypeVariant
  alignment _ = 1
  serializer (Variant v) = do
    serializer (Signature [dtype v])
    serializer v
  deserializer = do
    Signature [t] <- deserializer
    case mkDummy t of
      Dummy (_ :: a) -> Variant `fmap` (deserializer :: Deserializer a)

data Dummy = forall a. DValue a => Dummy a
data BasicDummy = forall a. DBasicTypedValue a => BasicDummy a

mkBasicDummy :: DBasicType -> BasicDummy
mkBasicDummy bt = case bt of
    DTypeByte -> BasicDummy (undefined :: Word8)
    DTypeBoolean -> BasicDummy (undefined :: Bool)
    DTypeInt16 -> BasicDummy (undefined :: Int16)
    DTypeInt32 -> BasicDummy (undefined :: Int32)
    DTypeInt64 -> BasicDummy (undefined :: Int64)
    DTypeUInt16 -> BasicDummy (undefined :: Word16)
    DTypeUInt32 -> BasicDummy (undefined :: Word32)
    DTypeUInt64 -> BasicDummy (undefined :: Word64)
    DTypeDouble -> BasicDummy (undefined :: Double)
    DTypeString -> BasicDummy (undefined :: DString)
    DTypeObjectPath -> BasicDummy (undefined :: ObjectPath)
    DTypeSignature -> BasicDummy (undefined :: Signature)

mkDummy :: DType -> Dummy
mkDummy t =
    case t of
      DBasicType bt -> case mkBasicDummy bt of BasicDummy x -> Dummy x
      DTypeVariant -> Dummy (Variant "ouroburos")
      -- heh. if you write this as let Dummy subdummy = ..., GHC says:
      --    My brain just exploded.
      DTypeArray u -> case u of
        -- using the types of Dummy values here requires PatternSignatures
        DTypeDictEntry keybt valt ->
          case mkBasicDummy keybt of
            BasicDummy (_keydummy :: k) -> case mkDummy valt of
              Dummy (_valdummy :: v) -> Dummy (M.empty :: M.Map k v)
        _ -> case mkDummy u of Dummy subdummy -> Dummy [subdummy]
      DTypeStruct u us -> let dummies = map mkDummy (u:us)
                          in case dummies of
                              [] -> error "world ended, this isn't possible."
                              [Dummy _d1] -> error "Shit shit we don't support 1-element structs"
                              [Dummy d1, Dummy d2] -> Dummy (d1, d2)
                              [Dummy d1, Dummy d2, Dummy d3] -> Dummy (d1, d2, d3)
                              [Dummy d1, Dummy d2, Dummy d3, Dummy d4] -> Dummy (d1, d2, d3, d4)
                              [Dummy d1, Dummy d2, Dummy d3, Dummy d4, Dummy d5] -> Dummy (d1, d2, d3, d4, d5)
                              _ -> error "FIXME: add more cases to mkDummy DTypeStruct when more tuple instances are added"
      DTypeDictEntry _ _ -> error "DTypeDictEntry is only valid as the element type of DTypeArray"

deserializeAs :: Signature -> Deserializer [Variant]
deserializeAs (Signature ts) = mapM magic ts
    where magic t = case mkDummy t of
                        Dummy (_ :: b) ->
                            Variant `fmap` (deserializer :: Deserializer b)

-- Used rather than "undefined" to help track down bugs wherein it gets
-- evaluated.
undefinedAt :: String -> a
undefinedAt = error . ("DValue " ++)

-- Instances for containers

instance (DBasicTypedValue k, DValue v) => DValue (M.Map k v) where
    dtype _ = let kType = dbasictype (undefinedAt "Map key" :: k)
                  vType = dtype      (undefinedAt "Map value" :: v)
              in DTypeArray (DTypeDictEntry kType vType)
    alignment _ = 4
    serializer = serializer . M.toList
    deserializer = M.fromList `fmap` deserializer

instance DValue a => DValue [a] where
    dtype _ = let t = dtype (undefinedAt "[]" :: a)
              in DTypeArray t
    alignment _ = 4
    serializer vs = do
      padTo 4
      pos <- R.lift S.get
      e <- R.ask
      let paddingLength = LBS.length $ runSerializer e $ do
              advanceBy (pos + 4)
              padTo $ alignment (undefined :: a)
          bytes = runSerializer e $ do
              advanceBy (pos + 4)
              padTo $ alignment (undefined :: a)
              mapM_ serializer vs
          l = LBS.length bytes - paddingLength
      when (l > (2^(26::Int))) $ fail "arrays cannot be longer than 2^26 :-/"
      serializer (fromIntegral l :: Word32)
      R.lift $ S.lift $ putLazyByteString bytes
      advanceBy . fromIntegral . LBS.length $ bytes
    deserializer = do
      l <- fromIntegral `fmap` (deserializer :: Deserializer Word32)
      skipTo $ alignment (undefined :: a)
      start <- fromIntegral `fmap` S.lift bytesRead
      go l start
      where go :: Int64 -> Int64 -> Deserializer [a]
            go n pos | n < 0  = fail "too few bytes in array"
                     | n == 0 = return []
                     | otherwise = do
                           x <- deserializer
                           pos' <- fromIntegral `fmap` S.lift bytesRead
                           xs <- go (n + pos - pos') pos'
                           return (x : xs)

instance (DValue a, DValue b) => DValue (a, b) where
    dtype _ = DTypeStruct (dtype (undefinedAt "(a,)" :: a))
                          [dtype (undefinedAt "(,b)" :: b)]
    alignment _ = 8
    serializer (x, y) = padTo 8 >> serializer x >> serializer y
    deserializer = skipTo 8 >> liftM2 (,) deserializer deserializer

instance (DValue a, DValue b, DValue c) => DValue (a, b, c) where
    dtype _ = DTypeStruct (dtype (undefinedAt "(a,,)" :: a))
                          [dtype (undefinedAt "(,b,)" :: b)
                          ,dtype (undefinedAt "(,,c)" :: c)
                          ]
    alignment _ = 8
    serializer (x,y,z) = padTo 8 >> serializer x >> serializer y
                      >> serializer z
    deserializer = skipTo 8 >> liftM3 (,,)
                        deserializer deserializer deserializer

instance (DValue a, DValue b, DValue c, DValue d) =>
         DValue (a, b, c, d) where
    dtype _ = DTypeStruct (dtype (undefinedAt "(a,,,)" :: a))
                          [dtype (undefinedAt "(,b,,)" :: b)
                          ,dtype (undefinedAt "(,,c,)" :: c)
                          ,dtype (undefinedAt "(,,,d)" :: d)
                          ]
    alignment _ = 8
    serializer (x,y,z,p) = padTo 8 >> serializer x >> serializer y
                        >> serializer z >> serializer p
    deserializer = skipTo 8 >> liftM4 (,,,)
                       deserializer deserializer deserializer deserializer

instance (DValue a, DValue b, DValue c, DValue d, DValue e) =>
         DValue (a, b, c, d, e) where
    dtype _ = DTypeStruct (dtype (undefinedAt "(a,,,,)" :: a))
                          [dtype (undefinedAt "(,b,,,)" :: b)
                          ,dtype (undefinedAt "(,,c,,)" :: c)
                          ,dtype (undefinedAt "(,,,d,)" :: d)
                          ,dtype (undefinedAt "(,,,,e)" :: e)
                          ]
    alignment _ = 8
    serializer (x,y,z,p,q) = padTo 8 >> serializer x >> serializer y
                          >> serializer z >> serializer p >> serializer q
    deserializer = skipTo 8 >> liftM5 (,,,,)
                        deserializer deserializer deserializer deserializer
                        deserializer

-- vim: sts=2 sw=2 et