{-# LANGUAGE ScopedTypeVariables, LambdaCase #-}
module Data.Serialize (
  -- * You'll need this
  module Language.Parser,
  
  -- * Serialization
  Serializable(..),Builder,bytesBuilder,chunkBuilder,serialize,serial,
  -- ** Convenience functions
  word8,Word8,Word32,Word64,Either3(..),
  ) where

import Definitive
import Language.Parser hiding (uncons)
import Data.ByteString.Lazy.Builder
import qualified Data.ByteString as BSS
import qualified Data.ByteString.Lazy as BS
import Data.ByteString.Unsafe 
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import qualified Data.Monoid as M
import System.Endian
import Data.Bits (shiftR,shiftL)
import qualified Data.ByteString.Lazy.UTF8 as UTF8 

class Serializable t where
  encode :: t -> Builder
  serializable :: Parser Bytes t 

serialize :: Serializable t => t -> Bytes
serialize = toLazyByteString . encode

serial :: (Serializable t,Serializable t') => Traversal t t' Bytes Bytes
serial = prism (serializable^.from parser & \f a -> map snd (foldr (const . Right) (Left a) (f a))) (const serialize)

bytesBuilder :: Bytes:<->:Builder
bytesBuilder = iso lazyByteString toLazyByteString
chunkBuilder :: Chunk:<->:Builder
chunkBuilder = iso byteString (by chunk.toLazyByteString)

instance Semigroup Word8 ; instance Monoid Word8
instance Semigroup Word32 ; instance Monoid Word32
instance Semigroup Word64 ; instance Monoid Word64

instance Semigroup Builder where (+) = M.mappend
instance Monoid Builder where zero = M.mempty

withChunk :: Chunk -> (Ptr b -> IO a) -> a
withChunk b f = unsafeUseAsCString b (f . castPtr)^.thunk

storable :: forall a. Storable a => Parser Bytes a
storable = p^.parser
  where p s | BSS.length ch >= sz = pure (t,res)
            | otherwise = zero
          where res = withChunk ch peek :: a
                (h,t) = BS.splitAt (fromIntegral sz) s
                ch = h^.chunk
                sz = sizeOf res
  
instance Serializable Char where
  encode = charUtf8
  serializable = gets UTF8.uncons >>= \case
    Just (c,t) -> c <$ put t
    Nothing -> zero
instance Serializable Word8 where
  encode = word8
  serializable = storable
instance Serializable Word32 where
  encode = word32BE
  serializable = fromBE32<$>storable
instance Serializable Word64 where
  encode = word64BE
  serializable = fromBE64<$>storable
instance Serializable Int where
  encode n = encode (size bytes :: Word8) + foldMap (encode . w8) bytes
    where bytes = takeWhile (>0) $ iterate (`shiftR`8) n
          w8 = fromIntegral :: Int -> Word8
  serializable = serializable >>= \n -> do
    bytes <- sequence (serializable <$ [1..n :: Word8])
    return $ sum (zipWith shiftL (map (fromIntegral :: Word8 -> Int) bytes) [0,8..])
instance Serializable Integer where
  encode n = encode s + foldMap (word8 . fromIntegral) (take s l)
    where l = iterate (`shiftR`8) (if n>=0 then n else (-n))
          s = length (takeWhile (/=0) l)
  serializable = do
    n <- serializable
    doTimes n serializable <&> sum . zipWith (\sh b -> fromIntegral (b :: Word8)`shiftL`sh) [0,8..]
instance Serializable a => Serializable (Maybe a) where
  encode (Just a) = word8 1 + encode a
  encode Nothing = word8 0
  serializable = serializable >>= \w -> case w :: Word8 of
    0 -> return Nothing
    1 -> Just<$>serializable
    _ -> error "Invalid encoding for Maybe serialized value"
instance Serializable a => Serializable [a] where
  encode l = encode (length l) + foldMap encode l
  serializable = serializable >>= \n -> doTimes n serializable
instance (Ord k,Serializable k,Serializable a) => Serializable (Map k a) where
  encode m = encode (m^.keyed & toList)
  serializable = serializable <&> fromList
instance (Ord k,Ord a,Serializable k,Serializable a) => Serializable (Bimap k a) where
  encode m = encode (toMap m^.keyed & toList)
  serializable = serializable <&> fromList
instance (Ord a,Serializable a) => Serializable (Set a) where
  encode = encode . toList
  serializable = serializable <&> fromList . map (,zero)
deriving instance Serializable a => Serializable (Range a)
instance (Serializable a,Serializable b) => Serializable (a:*:b) where
  encode (a,b) = encode a+encode b
  serializable = (,)<$>serializable<*>serializable
instance (Serializable a,Serializable b,Serializable c) => Serializable (a,b,c) where
  encode (a,b,c) = encode a+encode b+encode c
  serializable = (,,)<$>serializable<*>serializable<*>serializable
instance (Serializable a,Serializable b,Serializable c,Serializable d) => Serializable (a,b,c,d) where
  encode (a,b,c,d) = encode a+encode b+encode c+encode d
  serializable = (,,,)<$>serializable<*>serializable<*>serializable<*>serializable
instance (Serializable a,Serializable b,Serializable c,Serializable d,Serializable e) => Serializable (a,b,c,d,e) where
  encode (a,b,c,d,e) = encode a+encode b+encode c+encode d+encode e
  serializable = (,,,,)<$>serializable<*>serializable<*>serializable<*>serializable<*>serializable
instance (Serializable a,Serializable b) => Serializable (a:+:b) where
  encode (Left a) = word8 0+encode a
  encode (Right b) = word8 1+encode b
  serializable = storable >>= \x -> case x :: Word8 of
    0 -> Left<$>serializable
    1 -> Right<$>serializable
    _ -> zero

data Either3 a b c = Alt3l'1 a | Alt3l'2 b | Alt3l'3 c
instance (Serializable a,Serializable b,Serializable c) => Serializable (Either3 a b c) where
  encode (Alt3l'1 a) = word8 0+encode a
  encode (Alt3l'2 b) = word8 1+encode b
  encode (Alt3l'3 c) = word8 2+encode c
  serializable = storable >>= \x -> case x :: Word8 of
    0 -> Alt3l'1<$>serializable
    1 -> Alt3l'2<$>serializable
    2 -> Alt3l'3<$>serializable
    _ -> zero