{-# LANGUAGE UndecidableInstances #-}

module Dahdit.Generic
  ( ViaGeneric (..)
  , ViaStaticGeneric (..)
  )
where

import Control.Applicative (liftA2)
import Dahdit.Binary (Binary (..))
import Dahdit.Free (Get, Put)
import Dahdit.Funs (putStaticHint)
import Dahdit.Nums (Word16LE, Word32LE)
import Dahdit.Sizes (ByteCount, ByteSized (..), StaticByteSized (..))
import Data.Bits (Bits (..))
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Word (Word8)
import GHC.Generics (C1, Generic (..), K1 (..), M1 (..), U1 (..), (:*:) (..), (:+:) (..))

-- | Use: deriving (ByteSized, Binary) via (ViaGeneric Foo)
newtype ViaGeneric a = ViaGeneric {forall a. ViaGeneric a -> a
unViaGeneric :: a}

-- | Use: deriving (ByteSized, StaticByteSized, Binary) via (ViaStaticGeneric Foo)
newtype ViaStaticGeneric a = ViaStaticGeneric {forall a. ViaStaticGeneric a -> a
unViaStaticGeneric :: a}

-- ByteSized:

class GByteSized f where
  gbyteSize :: f a -> ByteCount

-- Unit
instance GByteSized U1 where
  gbyteSize :: forall a. U1 a -> ByteCount
gbyteSize U1 a
_ = ByteCount
0

-- Product
instance (GByteSized a, GByteSized b) => GByteSized (a :*: b) where
  gbyteSize :: forall a. (:*:) a b a -> ByteCount
gbyteSize (a a
x :*: b a
y) = forall (f :: * -> *) a. GByteSized f => f a -> ByteCount
gbyteSize a a
x forall a. Num a => a -> a -> a
+ forall (f :: * -> *) a. GByteSized f => f a -> ByteCount
gbyteSize b a
y

-- Metadata
instance GByteSized a => GByteSized (M1 i c a) where
  gbyteSize :: forall a. M1 i c a a -> ByteCount
gbyteSize = forall (f :: * -> *) a. GByteSized f => f a -> ByteCount
gbyteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1

-- Sum
instance (GByteSized a, GByteSized b, SumSize a, SumSize b) => GByteSized (a :+: b) where
  gbyteSize :: forall a. (:+:) a b a -> ByteCount
gbyteSize (:+:) a b a
s =
    forall (f :: * -> *) a. SumSize f => f a -> ByteCount
sumSizeBytes (:+:) a b a
s forall a. Num a => a -> a -> a
+ case (:+:) a b a
s of
      L1 a a
a -> forall (f :: * -> *) a. GByteSized f => f a -> ByteCount
gbyteSize a a
a
      R1 b a
b -> forall (f :: * -> *) a. GByteSized f => f a -> ByteCount
gbyteSize b a
b

-- Field
instance ByteSized a => GByteSized (K1 i a) where
  gbyteSize :: forall a. K1 i a a -> ByteCount
gbyteSize = forall a. ByteSized a => a -> ByteCount
byteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1

instance (Generic t, GByteSized (Rep t)) => ByteSized (ViaGeneric t) where
  byteSize :: ViaGeneric t -> ByteCount
byteSize = forall (f :: * -> *) a. GByteSized f => f a -> ByteCount
gbyteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ViaGeneric a -> a
unViaGeneric

instance (Generic t, GByteSized (Rep t)) => ByteSized (ViaStaticGeneric t) where
  byteSize :: ViaStaticGeneric t -> ByteCount
byteSize = forall (f :: * -> *) a. GByteSized f => f a -> ByteCount
gbyteSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ViaStaticGeneric a -> a
unViaStaticGeneric

-- StaticByteSized:

class GByteSized f => GStaticByteSized (f :: Type -> Type) where
  gstaticByteSize :: Proxy f -> ByteCount

instance GStaticByteSized U1 where
  gstaticByteSize :: Proxy U1 -> ByteCount
gstaticByteSize Proxy U1
_ = ByteCount
0

instance (GStaticByteSized a, GStaticByteSized b) => GStaticByteSized (a :*: b) where
  gstaticByteSize :: Proxy (a :*: b) -> ByteCount
gstaticByteSize Proxy (a :*: b)
_ = forall (f :: * -> *). GStaticByteSized f => Proxy f -> ByteCount
gstaticByteSize (forall {k} (t :: k). Proxy t
Proxy :: Proxy a) forall a. Num a => a -> a -> a
+ forall (f :: * -> *). GStaticByteSized f => Proxy f -> ByteCount
gstaticByteSize (forall {k} (t :: k). Proxy t
Proxy :: Proxy b)

instance GStaticByteSized a => GStaticByteSized (M1 i c a) where
  gstaticByteSize :: Proxy (M1 i c a) -> ByteCount
gstaticByteSize Proxy (M1 i c a)
_ = forall (f :: * -> *). GStaticByteSized f => Proxy f -> ByteCount
gstaticByteSize (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)

instance StaticByteSized a => GStaticByteSized (K1 i a) where
  gstaticByteSize :: Proxy (K1 i a) -> ByteCount
gstaticByteSize Proxy (K1 i a)
_ = forall a. StaticByteSized a => Proxy a -> ByteCount
staticByteSize (forall {k} (t :: k). Proxy t
Proxy :: Proxy a)

instance (Generic t, GStaticByteSized (Rep t)) => StaticByteSized (ViaStaticGeneric t) where
  staticByteSize :: Proxy (ViaStaticGeneric t) -> ByteCount
staticByteSize Proxy (ViaStaticGeneric t)
_ = forall (f :: * -> *). GStaticByteSized f => Proxy f -> ByteCount
gstaticByteSize (forall {k} (t :: k). Proxy t
Proxy :: Proxy (Rep t))

-- Binary:

class GBinary (f :: Type -> Type) where
  gget :: Get (f a)
  gput :: f a -> Put

instance GBinary U1 where
  gget :: forall a. Get (U1 a)
gget = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall k (p :: k). U1 p
U1
  gput :: forall a. U1 a -> Put
gput U1 a
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

instance (GBinary a, GBinary b) => GBinary (a :*: b) where
  gget :: forall a. Get ((:*:) a b a)
gget = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) forall (f :: * -> *) a. GBinary f => Get (f a)
gget forall (f :: * -> *) a. GBinary f => Get (f a)
gget
  gput :: forall a. (:*:) a b a -> Put
gput (a a
x :*: b a
y) = forall (f :: * -> *) a. GBinary f => f a -> Put
gput a a
x forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. GBinary f => f a -> Put
gput b a
y

instance GBinary a => GBinary (M1 i c a) where
  gget :: forall a. Get (M1 i c a a)
gget = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 forall (f :: * -> *) a. GBinary f => Get (f a)
gget
  gput :: forall a. M1 i c a a -> Put
gput = forall (f :: * -> *) a. GBinary f => f a -> Put
gput forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1

instance Binary a => GBinary (K1 i a) where
  gget :: forall a. Get (K1 i a a)
gget = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall k i c (p :: k). c -> K1 i c p
K1 forall a. Binary a => Get a
get
  gput :: forall a. K1 i a a -> Put
gput = forall a. Binary a => a -> Put
put forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1

instance (Generic t, GBinary (Rep t)) => Binary (ViaGeneric t) where
  get :: Get (ViaGeneric t)
get = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. a -> ViaGeneric a
ViaGeneric forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => Rep a x -> a
to) forall (f :: * -> *) a. GBinary f => Get (f a)
gget
  put :: ViaGeneric t -> Put
put = forall (f :: * -> *) a. GBinary f => f a -> Put
gput forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ViaGeneric a -> a
unViaGeneric

instance (Generic t, GStaticByteSized (Rep t), GBinary (Rep t)) => Binary (ViaStaticGeneric t) where
  get :: Get (ViaStaticGeneric t)
get = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. a -> ViaStaticGeneric a
ViaStaticGeneric forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => Rep a x -> a
to) forall (f :: * -> *) a. GBinary f => Get (f a)
gget
  put :: ViaStaticGeneric t -> Put
put = forall a. StaticByteSized a => (a -> Put) -> a -> Put
putStaticHint (forall (f :: * -> *) a. GBinary f => f a -> Put
gput forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ViaStaticGeneric a -> a
unViaStaticGeneric)

-- Everything that follows is borrowed from the binary package, which
-- borrows from the cereal package!

-- The following GBinary instance for sums has support for serializing
-- types with up to 2^64-1 constructors. It will use the minimal
-- number of bytes needed to encode the constructor. For example when
-- a type has 2^8 constructors or less it will use a single byte to
-- encode the constructor. If it has 2^16 constructors or less it will
-- use two bytes, and so on till 2^64-1.

instance
  ( GSumBinary a
  , GSumBinary b
  , SumSize a
  , SumSize b
  )
  => GBinary (a :+: b)
  where
  gget :: forall a. Get ((:+:) a b a)
gget
    | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word8) = (forall a. Binary a => Get a
get :: Get Word8) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall word (f :: * -> *) a.
(Ord word, Num word, Bits word, GSumBinary f) =>
word -> word -> Get (f a)
checkGetSum (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32LE
size)
    | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word16LE) = (forall a. Binary a => Get a
get :: Get Word16LE) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall word (f :: * -> *) a.
(Ord word, Num word, Bits word, GSumBinary f) =>
word -> word -> Get (f a)
checkGetSum (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32LE
size)
    | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= (forall a. Bounded a => a
maxBound :: Word32LE) = (forall a. Binary a => Get a
get :: Get Word32LE) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall word (f :: * -> *) a.
(Ord word, Num word, Bits word, GSumBinary f) =>
word -> word -> Get (f a)
checkGetSum Word32LE
size
    | Bool
otherwise = forall size error. Show size => String -> size -> error
sizeError String
"decode" Word32LE
size
   where
    size :: Word32LE
size = forall (s :: * -> *). Tagged s -> Word32LE
unTagged (forall (f :: * -> *). SumSize f => Tagged f
sumSize :: Tagged (a :+: b))
  gput :: forall a. (:+:) a b a -> Put
gput
    | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word8) = forall (f :: * -> *) w a.
(GSumBinary f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum (Word8
0 :: Word8) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32LE
size)
    | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word16LE) = forall (f :: * -> *) w a.
(GSumBinary f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum (Word16LE
0 :: Word16LE) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32LE
size)
    | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= (forall a. Bounded a => a
maxBound :: Word32LE) = forall (f :: * -> *) w a.
(GSumBinary f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum (Word32LE
0 :: Word32LE) Word32LE
size
    | Bool
otherwise = forall size error. Show size => String -> size -> error
sizeError String
"encode" Word32LE
size
   where
    size :: Word32LE
size = forall (s :: * -> *). Tagged s -> Word32LE
unTagged (forall (f :: * -> *). SumSize f => Tagged f
sumSize :: Tagged (a :+: b))

sizeError :: Show size => String -> size -> error
sizeError :: forall size error. Show size => String -> size -> error
sizeError String
s size
size = forall a. HasCallStack => String -> a
error (String
"Can't " forall a. [a] -> [a] -> [a]
++ String
s forall a. [a] -> [a] -> [a]
++ String
" a type with " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show size
size forall a. [a] -> [a] -> [a]
++ String
" constructors")

checkGetSum
  :: (Ord word, Num word, Bits word, GSumBinary f)
  => word
  -> word
  -> Get (f a)
checkGetSum :: forall word (f :: * -> *) a.
(Ord word, Num word, Bits word, GSumBinary f) =>
word -> word -> Get (f a)
checkGetSum word
size word
code
  | word
code forall a. Ord a => a -> a -> Bool
< word
size = forall (f :: * -> *) word a.
(GSumBinary f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum word
code word
size
  | Bool
otherwise = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unknown encoding for constructor"
{-# INLINE checkGetSum #-}

class GSumBinary f where
  getSum :: (Ord word, Num word, Bits word) => word -> word -> Get (f a)
  putSum :: (Num w, Bits w, Binary w) => w -> w -> f a -> Put

instance (GSumBinary a, GSumBinary b) => GSumBinary (a :+: b) where
  getSum :: forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get ((:+:) a b a)
getSum !word
code !word
size
    | word
code forall a. Ord a => a -> a -> Bool
< word
sizeL = forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) word a.
(GSumBinary f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum word
code word
sizeL
    | Bool
otherwise = forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) word a.
(GSumBinary f, Ord word, Num word, Bits word) =>
word -> word -> Get (f a)
getSum (word
code forall a. Num a => a -> a -> a
- word
sizeL) word
sizeR
   where
    sizeL :: word
sizeL = word
size forall a. Bits a => a -> Int -> a
`shiftR` Int
1
    sizeR :: word
sizeR = word
size forall a. Num a => a -> a -> a
- word
sizeL
  putSum :: forall w a.
(Num w, Bits w, Binary w) =>
w -> w -> (:+:) a b a -> Put
putSum !w
code !w
size (:+:) a b a
s = case (:+:) a b a
s of
    L1 a a
x -> forall (f :: * -> *) w a.
(GSumBinary f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum w
code w
sizeL a a
x
    R1 b a
x -> forall (f :: * -> *) w a.
(GSumBinary f, Num w, Bits w, Binary w) =>
w -> w -> f a -> Put
putSum (w
code forall a. Num a => a -> a -> a
+ w
sizeL) w
sizeR b a
x
   where
    sizeL :: w
sizeL = w
size forall a. Bits a => a -> Int -> a
`shiftR` Int
1
    sizeR :: w
sizeR = w
size forall a. Num a => a -> a -> a
- w
sizeL

instance GBinary a => GSumBinary (C1 c a) where
  getSum :: forall word a.
(Ord word, Num word, Bits word) =>
word -> word -> Get (C1 c a a)
getSum word
_ word
_ = forall (f :: * -> *) a. GBinary f => Get (f a)
gget
  putSum :: forall w a. (Num w, Bits w, Binary w) => w -> w -> C1 c a a -> Put
putSum !w
code w
_ C1 c a a
x = forall a. Binary a => a -> Put
put w
code forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. GBinary f => f a -> Put
gput C1 c a a
x

class SumSize (f :: Type -> Type) where
  sumSize :: Tagged f

newtype Tagged (s :: Type -> Type) = Tagged {forall (s :: * -> *). Tagged s -> Word32LE
unTagged :: Word32LE}

instance (SumSize a, SumSize b) => SumSize (a :+: b) where
  sumSize :: Tagged (a :+: b)
sumSize = forall (s :: * -> *). Word32LE -> Tagged s
Tagged (forall (s :: * -> *). Tagged s -> Word32LE
unTagged (forall (f :: * -> *). SumSize f => Tagged f
sumSize :: Tagged a) forall a. Num a => a -> a -> a
+ forall (s :: * -> *). Tagged s -> Word32LE
unTagged (forall (f :: * -> *). SumSize f => Tagged f
sumSize :: Tagged b))

instance SumSize (C1 c a) where
  sumSize :: Tagged (C1 c a)
sumSize = forall (s :: * -> *). Word32LE -> Tagged s
Tagged Word32LE
1

sumSizeFor :: SumSize f => f a -> Tagged f
sumSizeFor :: forall (f :: * -> *) a. SumSize f => f a -> Tagged f
sumSizeFor = forall a b. a -> b -> a
const forall (f :: * -> *). SumSize f => Tagged f
sumSize

taggedBytes :: Tagged f -> ByteCount
taggedBytes :: forall (f :: * -> *). Tagged f -> ByteCount
taggedBytes (Tagged Word32LE
size)
  | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word8) = ByteCount
1
  | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bounded a => a
maxBound :: Word16LE) = ByteCount
2
  | Word32LE
size forall a. Num a => a -> a -> a
- Word32LE
1 forall a. Ord a => a -> a -> Bool
<= (forall a. Bounded a => a
maxBound :: Word32LE) = ByteCount
4
  | Bool
otherwise = forall size error. Show size => String -> size -> error
sizeError String
"size" Word32LE
size

sumSizeBytes :: SumSize f => f a -> ByteCount
sumSizeBytes :: forall (f :: * -> *) a. SumSize f => f a -> ByteCount
sumSizeBytes = forall (f :: * -> *). Tagged f -> ByteCount
taggedBytes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. SumSize f => f a -> Tagged f
sumSizeFor