{-# LANGUAGE UndecidableInstances #-} -- required for TypeError >:(

module Binrep.Generic.Put where

import GHC.Generics
import GHC.TypeLits ( TypeError )

import Binrep.Put
import Binrep.Generic.Internal
import Util.Generic

putGeneric :: (Generic a, GPut (Rep a), Put w) => Cfg w -> a -> Builder
putGeneric :: forall a w.
(Generic a, GPut (Rep a), Put w) =>
Cfg w -> a -> Builder
putGeneric Cfg w
cfg = forall {k} (f :: k -> *) w (p :: k).
(GPut f, Put w) =>
Cfg w -> f p -> Builder
gput Cfg w
cfg forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from

class GPut f where
    gput :: Put w => Cfg w -> f p -> Builder

-- | Empty constructor.
instance GPut U1 where
    gput :: forall w (p :: k). Put w => Cfg w -> U1 p -> Builder
gput Cfg w
_ U1 p
U1 = forall a. Monoid a => a
mempty

-- | Field.
instance Put c => GPut (K1 i c) where
    gput :: forall w (p :: k). Put w => Cfg w -> K1 i c p -> Builder
gput Cfg w
_ = forall a. Put a => a -> Builder
put forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i c (p :: k). K1 i c p -> c
unK1

-- | Product type fields are consecutive.
instance (GPut l, GPut r) => GPut (l :*: r) where
    gput :: forall w (p :: k). Put w => Cfg w -> (:*:) l r p -> Builder
gput Cfg w
cfg (l p
l :*: r p
r) = forall {k} (f :: k -> *) w (p :: k).
(GPut f, Put w) =>
Cfg w -> f p -> Builder
gput Cfg w
cfg l p
l forall a. Semigroup a => a -> a -> a
<> forall {k} (f :: k -> *) w (p :: k).
(GPut f, Put w) =>
Cfg w -> f p -> Builder
gput Cfg w
cfg r p
r

-- | Constructor sums are differentiated by a prefix tag.
instance (GPutSum (l :+: r), GetConName (l :+: r)) => GPut (l :+: r) where
    gput :: forall w (p :: k). Put w => Cfg w -> (:+:) l r p -> Builder
gput = forall {k} (f :: k -> *) w (a :: k).
(GPutSum f, Put w) =>
Cfg w -> f a -> Builder
gputsum

-- | Refuse to derive instance for void datatype.
instance TypeError GErrRefuseVoid => GPut V1 where
    gput :: forall w (p :: k). Put w => Cfg w -> V1 p -> Builder
gput = forall a. HasCallStack => a
undefined

-- | Any datatype, constructor or record.
instance GPut f => GPut (M1 i d f) where
    gput :: forall w (p :: k). Put w => Cfg w -> M1 i d f p -> Builder
gput Cfg w
cfg = forall {k} (f :: k -> *) w (p :: k).
(GPut f, Put w) =>
Cfg w -> f p -> Builder
gput Cfg w
cfg forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1

--------------------------------------------------------------------------------

class GPutSum f where
    gputsum :: Put w => Cfg w -> f a -> Builder

instance (GPutSum l, GPutSum r) => GPutSum (l :+: r) where
    gputsum :: forall w (a :: k). Put w => Cfg w -> (:+:) l r a -> Builder
gputsum Cfg w
cfg = \case L1 l a
a -> forall {k} (f :: k -> *) w (a :: k).
(GPutSum f, Put w) =>
Cfg w -> f a -> Builder
gputsum Cfg w
cfg l a
a
                        R1 r a
a -> forall {k} (f :: k -> *) w (a :: k).
(GPutSum f, Put w) =>
Cfg w -> f a -> Builder
gputsum Cfg w
cfg r a
a

instance (GPut r, Constructor c) => GPutSum (C1 c r) where
    gputsum :: forall w (a :: k). Put w => Cfg w -> C1 c r a -> Builder
gputsum Cfg w
cfg C1 c r a
x = Builder
putTag forall a. Semigroup a => a -> a -> a
<> Builder
putConstructor
      where putTag :: Builder
putTag = forall a. Put a => a -> Builder
put forall a b. (a -> b) -> a -> b
$ (forall a. Cfg a -> String -> a
cSumTag Cfg w
cfg) (forall {k} (c :: k). Constructor c => String
conName' @c)
            putConstructor :: Builder
putConstructor = forall {k} (f :: k -> *) w (p :: k).
(GPut f, Put w) =>
Cfg w -> f p -> Builder
gput Cfg w
cfg forall a b. (a -> b) -> a -> b
$ forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1 C1 c r a
x

---

-- | Get the name of the constructor of a sum datatype.
class GetConName f where
    getConName :: f a -> String

instance (GetConName a, GetConName b) => GetConName (a :+: b) where
    getConName :: forall (a :: k). (:+:) a b a -> String
getConName (L1 a a
x) = forall {k} (f :: k -> *) (a :: k). GetConName f => f a -> String
getConName a a
x
    getConName (R1 b a
x) = forall {k} (f :: k -> *) (a :: k). GetConName f => f a -> String
getConName b a
x

instance Constructor c => GetConName (C1 c a) where
    getConName :: forall (a :: k). C1 c a a -> String
getConName = forall {k} (c :: k) k1 (t :: k -> (k1 -> *) -> k1 -> *)
       (f :: k1 -> *) (a :: k1).
Constructor c =>
t c f a -> String
conName