{-# OPTIONS -fglasgow-exts -fallow-undecidable-instances -fallow-overlapping-instances #-}
module HAppS.Data.Serialize
    ( Serialize(..), Version(..), Migrate(..), Mode(..), Contained, contain, extension,
      safeGet, safePut, serialize, deserialize, collectVersions,
      Object(objectType), mkObject, deserializeObject, parseObject,
      module HAppS.Data.Proxy
    ) where

import Control.Monad.Identity
import Data.Int()
import Foreign
import qualified Data.ByteString.Lazy.Char8 as L
import qualified Data.ByteString.Char8 as B

import HAppS.Data.Migrate
import HAppS.Data.Proxy

import Data.Typeable
import qualified Data.Map as M
import qualified Data.Map as Map
import qualified Data.IntMap as IntMap
import qualified Data.Set as Set

import Data.Binary     as B
import Data.Binary.Put as B
import Data.Binary.Get as B

--------------------------------------------------------------
-- Core types
--------------------------------------------------------------

data Contained a = Contained {unsafeUnPack :: a}

contain :: a -> Contained a
contain = Contained

data Previous a = forall b. (Serialize b, Migrate b a) => Previous (Proxy b)

mkPrevious :: forall a b. (Serialize b, Migrate b a) => Proxy b -> Previous a
mkPrevious Proxy = Previous (Proxy :: Proxy b)

extension :: forall a b. (Serialize b, Migrate b a) => VersionId a -> Proxy b -> Mode a
extension vs proxy = Versioned vs (Just (mkPrevious proxy))

newtype VersionId a = VersionId {unVersion :: Int} deriving (Num,Read,Show,Eq)
instance Binary (VersionId a) where
    get = liftM VersionId get
    put = put . unVersion


data Mode a = Primitive -- ^ Data layout won't change. Used for types like Int and Char.
            | Versioned (VersionId a) (Maybe (Previous a))

class Version a where
    mode :: Mode a
    mode = Versioned 0 Nothing

class (Typeable a, Version a) => Serialize a where
    getCopy :: Contained (Get a)
    putCopy :: a -> Contained Put

--------------------------------------------------------------
-- Implementation
--------------------------------------------------------------

getSafeGet :: forall a. Serialize a => Get (Get a)
getSafeGet = case mode :: Mode a of
               Primitive -> return (unsafeUnPack getCopy)
               Versioned wantedVersion mbPrevious
                         -> do storedVersion <- get
                               return (safeGetVersioned wantedVersion mbPrevious storedVersion)

getSafePut :: forall a. Serialize a => PutM (a -> Put)
getSafePut = case mode :: Mode a of
               Primitive -> return (unsafeUnPack . putCopy)
               Versioned vs _
                         -> do B.put vs
                               return (unsafeUnPack . putCopy)


safePut :: forall a. Serialize a => a -> Put
safePut val = do fn <- getSafePut
                 fn val

safeGet :: forall a. Serialize a => Get a
safeGet = join getSafeGet

safeGetVersioned :: forall a b. (Serialize b) => VersionId b -> Maybe (Previous b) -> VersionId a -> B.Get b
safeGetVersioned wantedVersion mbPrevious storedVersion
    = case compareVersions storedVersion wantedVersion of
        GT -> error $ "Version tag too large: " ++ show (wantedVersion,storedVersion) ++ " (" ++ tStr ++ ")"
        EQ -> unsafeUnPack getCopy
        LT -> case mbPrevious of
                Nothing -> error $ "No previous version (" ++ tStr ++ ")"
                Just (Previous (_ :: Proxy f) :: Previous b)
                    -> case mode of
                         Primitive -> error $ "Previous version marked as a Primitive (" ++ tStr ++ ")"
                         Versioned wantedVersion mbPrevious
                             -> do old <- safeGetVersioned wantedVersion mbPrevious storedVersion :: B.Get f
                                   return $ migrate old
    where tStr = show (typeOf (error "huh?" :: b))

compareVersions :: VersionId a -> VersionId b -> Ordering
compareVersions v1 v2 = compare (unVersion v1) (unVersion v2)


serialize :: Serialize a => a -> L.ByteString
serialize a = runPut (safePut a)

deserialize :: Serialize a => L.ByteString -> (a, L.ByteString)
deserialize bs = case runGetState safeGet bs 0 of
                   (val, rest, _offset) -> (val, rest)

-- Version lookups
collectVersions :: forall a . (Typeable a, Version a) => Proxy a -> [L.ByteString]
collectVersions proxy
    = case mode :: Mode a of
        Primitive                          -> [thisType]
        Versioned _ Nothing                -> [thisType]
        Versioned _ (Just (Previous prev)) -> thisType : (collectVersions prev)
    where thisType = (L.pack . show . typeOf . unProxy) proxy

--------------------------------------------------------------
-- Instances
--------------------------------------------------------------
{-
instance (ToString a,FromString a, Typeable a) => Version a
instance (ToString a,FromString a, Typeable a,Version a) => Serialize a where
    getCopy = contain (defaultDecodeXmlStringM =<< get)
    putCopy = contain . put . defaultEncodeXmlString
-}
instance Version Int where mode = Primitive
instance Serialize Int where
    getCopy = contain $ get; putCopy = contain . put
instance Version Integer where mode = Primitive
instance Serialize Integer where
    getCopy = contain $ get; putCopy = contain . put
instance Version Float where mode = Primitive
instance Serialize Float where
    getCopy = contain $ get; putCopy = contain . put
instance Version Double where mode = Primitive
instance Serialize Double where
    getCopy = contain $ get; putCopy = contain . put
instance Version L.ByteString where mode = Primitive
instance Serialize L.ByteString where
    getCopy = contain $ get; putCopy = contain . put
instance Version B.ByteString where mode = Primitive
instance Serialize B.ByteString where
    getCopy = contain $ get; putCopy = contain . put
instance Version Char where mode = Primitive
instance Serialize Char where
    getCopy = contain $ get; putCopy = contain . put
instance Version Word8 where mode = Primitive
instance Serialize Word8 where
    getCopy = contain $ get; putCopy = contain . put
instance Version Word16 where mode = Primitive
instance Serialize Word16 where
    getCopy = contain $ get; putCopy = contain . put
instance Version Word32 where mode = Primitive
instance Serialize Word32 where
    getCopy = contain $ get; putCopy = contain . put
instance Version Word64 where mode = Primitive
instance Serialize Word64 where
    getCopy = contain $ get; putCopy = contain . put
instance Version Ordering where mode = Primitive
instance Serialize Ordering where
    getCopy = contain $ get; putCopy = contain . put
instance Version Int8 where mode = Primitive
instance Serialize Int8 where
    getCopy = contain $ get; putCopy = contain . put
instance Version Int16 where mode = Primitive
instance Serialize Int16 where
    getCopy = contain $ get; putCopy = contain . put
instance Version Int32 where mode = Primitive
instance Serialize Int32 where
    getCopy = contain $ get; putCopy = contain . put
instance Version Int64 where mode = Primitive
instance Serialize Int64 where
    getCopy = contain $ get; putCopy = contain . put
instance Version () where mode = Primitive
instance Serialize () where
    getCopy = contain $ get; putCopy = contain . put
instance Version Bool where mode = Primitive
instance Serialize Bool where
    getCopy = contain $ get; putCopy = contain . put
instance Version (Either a b) where mode = Primitive
instance (Serialize a, Serialize b) => Serialize (Either a b) where
    getCopy = contain $ do n <- get
                           if n then liftM Right safeGet
                                else liftM Left safeGet
    putCopy (Right a) = contain $ put True >> safePut a
    putCopy (Left a) = contain $ put False >> safePut a
instance Version (a,b) where mode = Primitive
instance (Serialize a, Serialize b) => Serialize (a,b) where
    getCopy = contain $ liftM2 (,) safeGet safeGet
    putCopy (a,b) = contain $ safePut a >> safePut b
instance Version (a,b,c) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c) => Serialize (a,b,c) where
    getCopy = contain $ liftM3 (,,) safeGet safeGet safeGet
    putCopy (a,b,c) = contain $ safePut a >> safePut (b,c)
instance Version (a,b,c,d) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c, Serialize d) => Serialize (a,b,c,d) where
    getCopy = contain $ liftM4 (,,,) safeGet safeGet safeGet safeGet
    putCopy (a,b,c,d) = contain $ safePut a >> safePut (b,c,d)
instance Version (a,b,c,d,e) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c, Serialize d, Serialize e) => Serialize (a,b,c,d,e) where
    getCopy = contain $ liftM5 (,,,,) safeGet safeGet safeGet safeGet safeGet
    putCopy (a,b,c,d,e) = contain $ safePut a >> safePut (b,c,d,e)

instance Version (Proxy a) where mode = Primitive
instance Typeable a => Serialize (Proxy a) where
    getCopy = contain $ return Proxy
    putCopy Proxy = contain $ return ()

instance Version [a] where mode = Primitive
instance Serialize a => Serialize [a] where
    getCopy = contain $
              do n <- get
                 getSafeGet >>= replicateM n
    putCopy lst
        = contain $
          do put (length lst)
             getSafePut >>= forM_ lst

instance Version (Maybe a) where mode = Primitive
instance Serialize a => Serialize (Maybe a) where
    getCopy = contain $ do n <- get
                           if n then liftM Just safeGet
                                else return Nothing
    putCopy (Just a) = contain $ put True >> safePut a
    putCopy Nothing = contain $ put False

instance Version (Set.Set a) where mode = Primitive
instance (Serialize a, Ord a) => Serialize (Set.Set a) where
    getCopy = contain $ fmap Set.fromAscList safeGet
    putCopy = contain . safePut . Set.toList

instance Version (Map.Map a b) where mode = Primitive
instance (Serialize a,Serialize b, Ord a) => Serialize (Map.Map a b) where
    getCopy = contain $ fmap Map.fromAscList safeGet
    putCopy = contain . safePut . Map.toList

instance Version (IntMap.IntMap a) where mode = Primitive
instance (Serialize a) => Serialize (IntMap.IntMap a) where
    getCopy = contain $ fmap IntMap.fromAscList safeGet
    putCopy = contain . safePut . IntMap.toList


--------------------------------------------------------------
-- Object serialization
--------------------------------------------------------------



deserializeObject :: L.ByteString -> (Object, L.ByteString)
deserializeObject = deserialize


parseObject :: Serialize a => Object -> a
parseObject (Object objType objData)
    = let res = runGet safeGet objData
          resType = show (typeOf res)
      in if objType /= resType
         then error $ "Failed to parse object of type '" ++ objType ++ "'. Expected type '" ++ resType ++ "'"
         else res

mkObject :: Serialize a => a -> Object
mkObject obj = Object { objectType = show (typeOf obj)
                      , objectData = serialize obj }

data Object = Object { objectType :: String
                     , objectData :: L.ByteString
                     }  deriving (Typeable,Show)

instance Version Object
instance Serialize Object where
    putCopy (Object objType objData) = contain $ put (objType, objData)
    getCopy = contain $
              do (objType, objData) <- get
                 return (Object objType objData)


--------------------------------------------------------------
-- Xml serialization
--------------------------------------------------------------

{-
defaultEncodeXmlString :: ToString a => a -> String
defaultEncodeXmlString x = toString x

defaultDecodeXmlStringM :: forall a. (FromString a,Typeable a) => String -> Get a
defaultDecodeXmlStringM str
    = case fromString Flexible str of
        Identity v -> return v
{-
        Nothing -> fail ("defaultDecodeXmlStringM: parsing failed "++show (typeOf (undefined::a))++" @ "++
                         show (take 256 str))
        Just v  -> return v
-}
-}