{-# LANGUAGE UndecidableInstances, OverlappingInstances, ScopedTypeVariables, GADTs,
    GeneralizedNewtypeDeriving, DeriveDataTypeable #-}
module Happstack.Data.Serialize
    ( Serialize(..), Version(..), Migrate(..), Mode(..), Contained, contain, extension,
      safeGet, safePut, getSafeGet, getSafePut, serialize, deserialize, collectVersions,
      Object(objectType), mkObject, deserializeObject, parseObject,
      module Happstack.Data.Proxy
    ) where

import Control.Monad.Identity
import Data.Int()
import Data.Ratio ((%), Ratio, denominator, numerator)
import Data.Time (Day(..), DiffTime, LocalTime(..), NominalDiffTime, TimeOfDay(..)
                 ,TimeZone(..), UTCTime(..), UniversalTime(..), ZonedTime(..))
import Data.Time.Clock.TAI (AbsoluteTime, taiEpoch, addAbsoluteTime, diffAbsoluteTime)
import Foreign
import qualified Data.ByteString.Lazy.Char8 as L
import qualified Data.ByteString.Char8 as B

import Happstack.Data.Migrate
import Happstack.Data.Proxy

import Data.Typeable
import qualified Data.Map as Map
import qualified Data.IntMap as IntMap
import qualified Data.Set as Set
import qualified Data.Text as ST
import qualified Data.Text.Encoding as ST
import qualified Data.Text.Lazy as LT
import qualified Data.Text.Lazy.Encoding as LT
import Data.Binary     as B
import Data.Binary.Put as B
import Data.Binary.Get as B

--------------------------------------------------------------
-- Core types
--------------------------------------------------------------

data Contained a = Contained {unsafeUnPack :: a}

-- | Lifts the provided value into Contained
contain :: a -> Contained a
contain = Contained

data Previous a = forall b. (Serialize b, Migrate b a) => Previous (Proxy b)

mkPrevious :: forall a b. (Serialize b, Migrate b a) => Proxy b -> Previous a
mkPrevious Proxy = Previous (Proxy :: Proxy b)

-- | Creates a Mode that is a new version of the type carried by the provided proxy
-- and with the provided version number.  Note that since VersionId is an instance of
-- Num that you may use int literals when calling extension, e.g. 
-- @extension 1 (Proxy :: Proxy OldState)@
extension :: forall a b. (Serialize b, Migrate b a) => VersionId a -> Proxy b -> Mode a
extension vs prox = Versioned vs (Just (mkPrevious prox))

newtype VersionId a = VersionId {unVersion :: Int} deriving (Num,Read,Show,Eq)
instance Binary (VersionId a) where
    get = liftM VersionId get
    put = put . unVersion


data Mode a = Primitive -- ^ Data layout won't change. Used for types like Int and Char.
            | Versioned (VersionId a) (Maybe (Previous a))

-- | The Version type class is used to describe whether a type is fundamental
-- or if it is meant to extend another type.  For a user defined type that
-- does not extend any others, one can use the default instance of Version, e.g.
-- @instance Version MyType@ to define it has having a version id of 0 and previous
-- type.
class Version a where
    mode :: Mode a
    mode = Versioned 0 Nothing

class (Typeable a, Version a) => Serialize a where
    getCopy :: Contained (Get a)
    putCopy :: a -> Contained Put

--------------------------------------------------------------
-- Implementation
--------------------------------------------------------------

getSafeGet :: forall a. Serialize a => Get (Get a)
getSafeGet = case mode :: Mode a of
               Primitive -> return (unsafeUnPack getCopy)
               Versioned wantedVersion mbPrevious
                         -> do storedVersion <- get
                               return (safeGetVersioned wantedVersion mbPrevious storedVersion)

getSafePut :: forall a. Serialize a => PutM (a -> Put)
getSafePut = case mode :: Mode a of
               Primitive -> return (unsafeUnPack . putCopy)
               Versioned vs _
                         -> do B.put vs
                               return (unsafeUnPack . putCopy)


-- | Equivalent of Data.Binary.put for instances of Serialize.  
-- Takes into account versioning of types.
safePut :: forall a. Serialize a => a -> Put
safePut val = do fn <- getSafePut
                 fn val

-- | Equivalent of Data.Binary.get for instances of Serialize
-- Takes into account versioning of types.
safeGet :: forall a. Serialize a => Get a
safeGet = join getSafeGet

safeGetVersioned :: forall a b. (Serialize b) => VersionId b -> Maybe (Previous b) -> VersionId a -> B.Get b
safeGetVersioned wantedVersion mbPrevious storedVersion
    = case compareVersions storedVersion wantedVersion of
        GT -> error $ "Version tag too large: " ++ show (wantedVersion,storedVersion) ++ " (" ++ tStr ++ ")"
        EQ -> unsafeUnPack getCopy
        LT -> case mbPrevious of
                Nothing -> error $ "No previous version (" ++ tStr ++ ")"
                Just (Previous (_ :: Proxy f) :: Previous b)
                    -> case mode of
                         Primitive -> error $ "Previous version marked as a Primitive (" ++ tStr ++ ")"
                         Versioned wantedVersion' mbPrevious'
                             -> do old <- safeGetVersioned wantedVersion' mbPrevious' storedVersion :: B.Get f
                                   return $ migrate old
    where tStr = show (typeOf (error "huh?" :: b))

-- | Compares the numeric value of the versions
compareVersions :: VersionId a -> VersionId b -> Ordering
compareVersions v1 v2 = compare (unVersion v1) (unVersion v2)

-- | Pure version of 'safePut'.  Serializes to a ByteString
serialize :: Serialize a => a -> L.ByteString
serialize = runPut . safePut

-- | Pure version of 'safeGet'.  Parses a ByteString into the expected type
-- and a remainder.
deserialize :: Serialize a => L.ByteString -> (a, L.ByteString)
deserialize bs = case runGetState safeGet bs 0 of
                   (val, rest, _offset) -> (val, rest)

-- | Version lookups
collectVersions :: forall a . (Typeable a, Version a) => Proxy a -> [L.ByteString]
collectVersions prox
    = case mode :: Mode a of
        Primitive                          -> [thisType]
        Versioned _ Nothing                -> [thisType]
        Versioned _ (Just (Previous prev)) -> thisType : (collectVersions prev)
    where thisType = (L.pack . show . typeOf . unProxy) prox

--------------------------------------------------------------
-- Instances
--------------------------------------------------------------

instance Version Int where mode = Primitive
instance Serialize Int where
    getCopy = contain get; putCopy = contain . put
instance Version Integer where mode = Primitive
instance Serialize Integer where
    getCopy = contain get; putCopy = contain . put
instance Version Float where mode = Primitive
instance Serialize Float where
    getCopy = contain get; putCopy = contain . put
instance Version Double where mode = Primitive
instance Serialize Double where
    getCopy = contain get; putCopy = contain . put
instance Version L.ByteString where mode = Primitive
instance Serialize L.ByteString where
    getCopy = contain get; putCopy = contain . put
instance Version B.ByteString where mode = Primitive
instance Serialize B.ByteString where
    getCopy = contain get; putCopy = contain . put
instance Version Char where mode = Primitive
instance Serialize Char where
    getCopy = contain get; putCopy = contain . put
instance Version Word8 where mode = Primitive
instance Serialize Word8 where
    getCopy = contain get; putCopy = contain . put
instance Version Word16 where mode = Primitive
instance Serialize Word16 where
    getCopy = contain get; putCopy = contain . put
instance Version Word32 where mode = Primitive
instance Serialize Word32 where
    getCopy = contain get; putCopy = contain . put
instance Version Word64 where mode = Primitive
instance Serialize Word64 where
    getCopy = contain get; putCopy = contain . put
instance Version Ordering where mode = Primitive
instance Serialize Ordering where
    getCopy = contain get; putCopy = contain . put
instance Version Int8 where mode = Primitive
instance Serialize Int8 where
    getCopy = contain get; putCopy = contain . put
instance Version Int16 where mode = Primitive
instance Serialize Int16 where
    getCopy = contain get; putCopy = contain . put
instance Version Int32 where mode = Primitive
instance Serialize Int32 where
    getCopy = contain get; putCopy = contain . put
instance Version Int64 where mode = Primitive
instance Serialize Int64 where
    getCopy = contain get; putCopy = contain . put
{-
This instance requires base >= 4.2.0.0, which means GHC >= 6.12. We can add this back in when we drop support for GHC 6.10.

instance Version (Fixed a)
instance (Typeable a, HasResolution a) => Serialize (Fixed a) where
    getCopy = contain $ liftM fromRational safeGet ; putCopy = contain . safePut . toRational
-}
instance (Typeable a)  => Version (Ratio a)
instance (Integral a, Serialize a) => Serialize (Ratio a) where
  getCopy = contain $ liftM2 (%) safeGet safeGet
  putCopy r = contain $ safePut (numerator r) >> safePut (denominator r)
instance Version () where mode = Primitive
instance Serialize () where
    getCopy = contain get; putCopy = contain . put
instance Version Bool where mode = Primitive
instance Serialize Bool where
    getCopy = contain get; putCopy = contain . put
instance Version (Either a b) where mode = Primitive
instance (Serialize a, Serialize b) => Serialize (Either a b) where
    getCopy = contain $ do n <- get
                           if n then liftM Right safeGet
                                else liftM Left safeGet
    putCopy (Right a) = contain $ put True >> safePut a
    putCopy (Left a) = contain $ put False >> safePut a
instance Version (a,b) where mode = Primitive
instance (Serialize a, Serialize b) => Serialize (a,b) where
    getCopy = contain $ liftM2 (,) safeGet safeGet
    putCopy (a,b) = contain $ safePut a >> safePut b
instance Version (a,b,c) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c) => Serialize (a,b,c) where
    getCopy = contain $ liftM3 (,,) safeGet safeGet safeGet
    putCopy (a,b,c) = contain $ safePut a >> safePut (b,c)
instance Version (a,b,c,d) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c, Serialize d) => Serialize (a,b,c,d) where
    getCopy = contain $ liftM4 (,,,) safeGet safeGet safeGet safeGet
    putCopy (a,b,c,d) = contain $ safePut a >> safePut (b,c,d)
instance Version (a,b,c,d,e) where mode = Primitive
instance (Serialize a, Serialize b, Serialize c, Serialize d, Serialize e) => Serialize (a,b,c,d,e) where
    getCopy = contain $ liftM5 (,,,,) safeGet safeGet safeGet safeGet safeGet
    putCopy (a,b,c,d,e) = contain $ safePut a >> safePut (b,c,d,e)

instance Version (Proxy a) where mode = Primitive
instance Typeable a => Serialize (Proxy a) where
    getCopy = contain $ return Proxy
    putCopy Proxy = contain $ return ()

instance Version [a] where mode = Primitive
instance Serialize a => Serialize [a] where
    getCopy = contain $
              do n <- get
                 getSafeGet >>= replicateM n
    putCopy lst
        = contain $
          do put (length lst)
             getSafePut >>= forM_ lst

instance Version (Maybe a) where mode = Primitive
instance Serialize a => Serialize (Maybe a) where
    getCopy = contain $ do n <- get
                           if n then liftM Just safeGet
                                else return Nothing
    putCopy (Just a) = contain $ put True >> safePut a
    putCopy Nothing = contain $ put False

instance Version (Set.Set a) where mode = Primitive
instance (Serialize a, Ord a) => Serialize (Set.Set a) where
    getCopy = contain $ fmap Set.fromAscList safeGet
    putCopy = contain . safePut . Set.toList

instance Version (Map.Map a b) where mode = Primitive
instance (Serialize a,Serialize b, Ord a) => Serialize (Map.Map a b) where
    getCopy = contain $ fmap Map.fromAscList safeGet
    putCopy = contain . safePut . Map.toList

instance Version (IntMap.IntMap a) where mode = Primitive
instance (Serialize a) => Serialize (IntMap.IntMap a) where
    getCopy = contain $ fmap IntMap.fromAscList safeGet
    putCopy = contain . safePut . IntMap.toList
    
instance Version ST.Text
instance Serialize ST.Text where
    putCopy = putCopy . ST.encodeUtf8
    getCopy = contain $ fmap ST.decodeUtf8 safeGet
    
instance Version LT.Text
instance Serialize LT.Text where
    putCopy = putCopy . LT.encodeUtf8
    getCopy = contain $ fmap LT.decodeUtf8 safeGet


-- Data.Time instances
instance Version Day
instance Serialize Day where
  getCopy = contain $ liftM ModifiedJulianDay get ; putCopy = contain . put . toModifiedJulianDay

instance Version UniversalTime
instance Serialize UniversalTime where
  getCopy = contain $ liftM ModJulianDate get ; putCopy = contain . put . getModJulianDate

instance Version UTCTime
instance Serialize UTCTime where
  getCopy   = contain $ liftM2 UTCTime safeGet safeGet
  putCopy d = contain $ safePut (utctDay d) >> safePut (utctDayTime d)

instance Version TimeZone
instance Serialize TimeZone where
  getCopy   = contain $ do minutes    <- safeGet
                           summerOnly <- safeGet
                           nam        <- safeGet
                           return (TimeZone minutes summerOnly nam)
  putCopy (TimeZone minutes summerOnly nam) =
    contain $ do safePut minutes
                 safePut summerOnly
                 safePut nam

instance Version TimeOfDay
instance Serialize TimeOfDay where
  getCopy = contain $ do h <- safeGet
                         m <- safeGet
                         s <- safeGet
                         return (TimeOfDay h m (fromRational s))
  putCopy (TimeOfDay h m s) =
    contain $ do safePut h
                 safePut m
                 safePut (toRational s)

instance Version ZonedTime
instance Serialize ZonedTime where
  getCopy   = contain $ liftM2 ZonedTime safeGet safeGet
  putCopy (ZonedTime lt tz) = contain $ safePut lt >> safePut tz

instance Version LocalTime
instance Serialize LocalTime where
  getCopy   = contain $ liftM2 LocalTime safeGet safeGet
  putCopy (LocalTime day tod) = contain $ safePut day >> safePut tod

instance Version DiffTime
instance Serialize DiffTime where
  getCopy = contain $ liftM fromRational safeGet ; putCopy = contain . safePut . toRational

instance Version NominalDiffTime
instance Serialize NominalDiffTime where
  getCopy = contain $ liftM fromRational get ; putCopy = contain . put . toRational

instance Version AbsoluteTime
instance Serialize AbsoluteTime where
  getCopy = contain $ liftM toAbsoluteTime safeGet
    where
      toAbsoluteTime :: DiffTime -> AbsoluteTime
      toAbsoluteTime dt = addAbsoluteTime dt taiEpoch
  putCopy = contain . safePut . fromAbsoluteTime
    where
      fromAbsoluteTime :: AbsoluteTime -> DiffTime
      fromAbsoluteTime at = diffAbsoluteTime at taiEpoch

--------------------------------------------------------------
-- Object serialization
--------------------------------------------------------------


-- | 'deserialize' specialized to Objects 
deserializeObject :: L.ByteString -> (Object, L.ByteString)
deserializeObject = deserialize

-- | Attempts to convert an Object back into its base type.
-- If the conversion fails 'error' will be called.
parseObject :: Serialize a => Object -> a
parseObject (Object objType objData)
    = let res = runGet safeGet objData
          resType = show (typeOf res)
      in if objType /= resType
         then error $ "Failed to parse object of type '" ++ objType ++ "'. Expected type '" ++ resType ++ "'"
         else res

-- | Serializes data and stores it along with its type name in an Object
mkObject :: Serialize a => a -> Object
mkObject obj = Object { objectType = show (typeOf obj)
                      , objectData = serialize obj }

-- | Uniform container for any serialized data.  It contains a string rep of the type
-- and the actual data serialized to a byte string.
data Object = Object { objectType :: String
                     , objectData :: L.ByteString
                     }  deriving (Typeable,Show)

instance Version Object
instance Serialize Object where
    putCopy (Object objType objData) = contain $ put (objType, objData)
    getCopy = contain $
              do (objType, objData) <- get
                 return (Object objType objData)