{-# LANGUAGE UndecidableInstances, OverlappingInstances, ScopedTypeVariables, GADTs, PatternSignatures,
    GeneralizedNewtypeDeriving, DeriveDataTypeable #-}
module Happstack.Data.Serialize
    ( Serialize(..), Version(..), Migrate(..), Mode(..), Contained, contain, extension,
      safeGet, safePut, serialize, deserialize, collectVersions,
      Object(objectType), mkObject, deserializeObject, parseObject,
      module Happstack.Data.Proxy
    ) where

import Control.Monad.Identity
import Data.Int()
import Foreign
import qualified Data.ByteString.Lazy.Char8 as L
import qualified Data.ByteString.Char8 as B

import Happstack.Data.Migrate
import Happstack.Data.Proxy

import Data.Typeable
import qualified Data.Map as M
import qualified Data.Map as Map
import qualified Data.IntMap as IntMap
import qualified Data.Set as Set

import Data.Binary     as B
import Data.Binary.Put as B
import Data.Binary.Get as B

--------------------------------------------------------------
-- Core types
--------------------------------------------------------------

data Contained a = Contained {unsafeUnPack :: a}

-- | Lifts the provided value into Contained
contain :: a -> Contained a
contain = Contained

data Previous a = forall b. (Serialize b, Migrate b a) => Previous (Proxy b)

mkPrevious :: forall a b. (Serialize b, Migrate b a) => Proxy b -> Previous a
mkPrevious Proxy = Previous (Proxy :: Proxy b)

-- | Creates a Mode that is a new version of the type carried by the provided proxy
-- and with the provided version number.  Note that since VersionId is an instance of
-- Num that you may use int literals when calling extension, e.g. 
-- @extension 1 (Proxy :: Proxy OldState)@
extension :: forall a b. (Serialize b, Migrate b a) => VersionId a -> Proxy b -> Mode a
extension vs prox = Versioned vs (Just (mkPrevious prox))

newtype VersionId a = VersionId {unVersion :: Int} deriving (Num,Read,Show,Eq)
instance Binary (VersionId a) where
    get = liftM VersionId get
    put = put . unVersion


data Mode a = Primitive -- ^ Data layout won't change. Used for types like Int and Char.
            | Versioned (VersionId a) (Maybe (Previous a))

-- | The Version type class is used to describe whether a type is fundamental
-- or if it is meant to extend another type.  For a user defined type that
-- does not extend any others, one can use the default instance of Version, e.g.
-- @instance Version MyType@ to define it has having a version id of 0 and previous
-- type.
class Version a where
    mode :: Mode a
    mode = Versioned 0 Nothing

class (Typeable a, Version a) => Serialize a where
    getCopy :: Contained (Get a)
    putCopy :: a -> Contained Put

--------------------------------------------------------------
-- Implementation
--------------------------------------------------------------

getSafeGet :: forall a. Serialize a => Get (Get a)
getSafeGet = case mode :: Mode a of
               Primitive -> return (unsafeUnPack getCopy)
               Versioned wantedVersion mbPrevious
                         -> do storedVersion <- get
                               return (safeGetVersioned wantedVersion mbPrevious storedVersion)

getSafePut :: forall a. Serialize a => PutM (a -> Put)
getSafePut = case mode :: Mode a of
               Primitive -> return (unsafeUnPack . putCopy)
               Versioned vs _
                         -> do B.put vs
                               return (unsafeUnPack . putCopy)


-- | Equivalent of Data.Binary.put for instances of Serialize.  
-- Takes into account versioning of types.
safePut :: forall a. Serialize a => a -> Put
safePut val = do fn <- getSafePut
                 fn val

-- | Equivalent of Data.Binary.get for instances of Serialize
-- Takes into account versioning of types.
safeGet :: forall a. Serialize a => Get a
safeGet = join getSafeGet

safeGetVersioned :: forall a b. (Serialize b) => VersionId b -> Maybe (Previous b) -> VersionId a -> B.Get b
safeGetVersioned wantedVersion mbPrevious storedVersion
    = case compareVersions storedVersion wantedVersion of
        GT -> error $ "Version tag too large: " ++ show (wantedVersion,storedVersion) ++ " (" ++ tStr ++ ")"
        EQ -> unsafeUnPack getCopy
        LT -> case mbPrevious of
                Nothing -> error $ "No previous version (" ++ tStr ++ ")"
                Just (Previous (_ :: Proxy f) :: Previous b)
                    -> case mode of
                         Primitive -> error $ "Previous version marked as a Primitive (" ++ tStr ++ ")"
                         Versioned wantedVersion' mbPrevious'
                             -> do old <- safeGetVersioned wantedVersion' mbPrevious' storedVersion :: B.Get f
                                   return $ migrate old
    where tStr = show (typeOf (error "huh?" :: b))

-- | Compares the numeric value of the versions
compareVersions :: VersionId a -> VersionId b -> Ordering
compareVersions v1 v2 = compare (unVersion v1) (unVersion v2)

-- | Pure version of 'safePut'.  Serializes to a ByteString
serialize :: Serialize a => a -> L.ByteString
serialize = runPut . safePut

-- | Pure version of 'safeGet'.  Parses a ByteString into the expected type
-- and a remainder.
deserialize :: Serialize a => L.ByteString -> (a, L.ByteString)
deserialize bs = case runGetState safeGet bs 0 of
                   (val, rest, _offset) -> (val, rest)

-- | Version lookups
collectVersions :: forall a . (Typeable a, Version a) => Proxy a -> [L.ByteString]
collectVersions prox
    = case mode :: Mode a of
        Primitive                          -> [thisType]
        Versioned _ Nothing                -> [thisType]
        Versioned _ (Just (Previous prev)) -> thisType : (collectVersions prev)
    where thisType = (L.pack . show . typeOf . unProxy) prox

--------------------------------------------------------------
-- Instances
--------------------------------------------------------------

instance Version Int where mode = Primitive
instance Serialize Int where
    getCopy = contain get; putCopy = contain . put
instance Version Integer where mode = Primitive
instance Serialize Integer where
    getCopy = contain get; putCopy = contain . put
instance Version Float where mode = Primitive
instance Serialize Float where
    getCopy = contain get; putCopy = contain . put
instance Version Double where mode = Primitive
instance Serialize Double where
    getCopy = contain get; putCopy = contain . put
instance Version L.ByteString where mode = Primitive
instance Serialize L.ByteString where
    getCopy = contain get; putCopy = contain . put
instance Version B.ByteString where mode = Primitive
instance Serialize B.ByteString where
    getCopy = contain get; putCopy = contain . put
instance Version Char where mode = Primitive
instance Serialize Char where
    getCopy = contain get; putCopy = contain . put
instance Version Word8 where mode = Primitive
instance Serialize Word8 where
    getCopy = contain get; putCopy = contain . put
instance Version Word16 where mode = Primitive
instance Serialize Word16 where
    getCopy = contain get; putCopy = contain . put
instance Version Word32 where mode = Primitive
instance Serialize Word32 where
    getCopy = contain get; putCopy = contain . put
instance Version Word64 where mode = Primitive
instance Serialize Word64 where
    getCopy = contain get; putCopy = contain . put
instance Version Ordering where mode = Primitive
instance Serialize Ordering where
    getCopy = contain get; putCopy = contain . put
instance Version Int8 where mode = Primitive
instance Serialize Int8 where
    getCopy = contain get; putCopy = contain . put
instance Version Int16 where mode = Primitive
instance Serialize Int16 where
    getCopy = contain get; putCopy = contain . put
instance Version Int32 where mode = Primitive
instance Serialize Int32 where
    getCopy = contain get; putCopy = contain . put
instance Version Int64 where mode = Primitive
instance Serialize Int64 where
    getCopy = contain get; putCopy = contain . put
instance Version () where mode = Primitive
instance Serialize () where
    getCopy = contain get; putCopy = contain . put
instance Version Bool where mode = Primitive
instance Serialize Bool where
    getCopy = contain get; putCopy = contain . put
instance Version (Either a b) where mode = Primitive
instance (Serialize a, Serialize b) => Serialize (Either a b) where
    getCopy = contain $ do n <- get
                           if n then liftM Right safeGet
                                else liftM Left safeGet
    putCopy (Right a) = contain $ put True >> safePut a
    putCopy (Left a) = contain $ put False >> safePut a
instance Version (a,b) where mode = Primitive
instance (Serialize a, Serialize b) => Serialize (a,b) where
    getCopy = contain $ liftM2 (,) safeGet safeGet
    putCopy (a,b) = contain $ safePut a >> safePut b
instance Version (a,b,c) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c) => Serialize (a,b,c) where
    getCopy = contain $ liftM3 (,,) safeGet safeGet safeGet
    putCopy (a,b,c) = contain $ safePut a >> safePut (b,c)
instance Version (a,b,c,d) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c, Serialize d) => Serialize (a,b,c,d) where
    getCopy = contain $ liftM4 (,,,) safeGet safeGet safeGet safeGet
    putCopy (a,b,c,d) = contain $ safePut a >> safePut (b,c,d)
instance Version (a,b,c,d,e) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c, Serialize d, Serialize e) => Serialize (a,b,c,d,e) where
    getCopy = contain $ liftM5 (,,,,) safeGet safeGet safeGet safeGet safeGet
    putCopy (a,b,c,d,e) = contain $ safePut a >> safePut (b,c,d,e)

instance Version (Proxy a) where mode = Primitive
instance Typeable a => Serialize (Proxy a) where
    getCopy = contain $ return Proxy
    putCopy Proxy = contain $ return ()

instance Version [a] where mode = Primitive
instance Serialize a => Serialize [a] where
    getCopy = contain $
              do n <- get
                 getSafeGet >>= replicateM n
    putCopy lst
        = contain $
          do put (length lst)
             getSafePut >>= forM_ lst

instance Version (Maybe a) where mode = Primitive
instance Serialize a => Serialize (Maybe a) where
    getCopy = contain $ do n <- get
                           if n then liftM Just safeGet
                                else return Nothing
    putCopy (Just a) = contain $ put True >> safePut a
    putCopy Nothing = contain $ put False

instance Version (Set.Set a) where mode = Primitive
instance (Serialize a, Ord a) => Serialize (Set.Set a) where
    getCopy = contain $ fmap Set.fromAscList safeGet
    putCopy = contain . safePut . Set.toList

instance Version (Map.Map a b) where mode = Primitive
instance (Serialize a,Serialize b, Ord a) => Serialize (Map.Map a b) where
    getCopy = contain $ fmap Map.fromAscList safeGet
    putCopy = contain . safePut . Map.toList

instance Version (IntMap.IntMap a) where mode = Primitive
instance (Serialize a) => Serialize (IntMap.IntMap a) where
    getCopy = contain $ fmap IntMap.fromAscList safeGet
    putCopy = contain . safePut . IntMap.toList


--------------------------------------------------------------
-- Object serialization
--------------------------------------------------------------


-- | 'deserialize' specialized to Objects 
deserializeObject :: L.ByteString -> (Object, L.ByteString)
deserializeObject = deserialize

-- | Attempts to convert an Object back into its base type.
-- If the conversion fails 'error' will be called.
parseObject :: Serialize a => Object -> a
parseObject (Object objType objData)
    = let res = runGet safeGet objData
          resType = show (typeOf res)
      in if objType /= resType
         then error $ "Failed to parse object of type '" ++ objType ++ "'. Expected type '" ++ resType ++ "'"
         else res

-- | Serializes data and stores it along with its type name in an Object
mkObject :: Serialize a => a -> Object
mkObject obj = Object { objectType = show (typeOf obj)
                      , objectData = serialize obj }

-- | Uniform container for any serialized data.  It contains a string rep of the type
-- and the actual data serialized to a byte string.
data Object = Object { objectType :: String
                     , objectData :: L.ByteString
                     }  deriving (Typeable,Show)

instance Version Object
instance Serialize Object where
    putCopy (Object objType objData) = contain $ put (objType, objData)
    getCopy = contain $
              do (objType, objData) <- get
                 return (Object objType objData)