{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}

module Foreign.Storable.Generic () where

import           Control.Monad
import           Data.Word
import           Foreign.Ptr
import           Foreign.Storable
import           GHC.Generics

instance Storable a => Storable (U1 a) where
  sizeOf _ = 0
  {-# INLINEABLE sizeOf #-}
  alignment = sizeOf
  peek _ = pure U1
  {-# INLINEABLE peek #-}
  poke _ _ = pure ()
  {-# INLINEABLE poke #-}

instance (Storable (f a), Storable (g a)) => Storable ((f :*: g) a) where
    sizeOf _ = sizeOf (undefined :: f a) + sizeOf (undefined :: g a)
    {-# INLINABLE sizeOf #-}
    alignment _ = gcd (sizeOf (undefined :: f a)) (sizeOf (undefined :: g a))
    peek ptr = do
        a <- peek (castPtr ptr)
        b <- peekByteOff ptr (sizeOf (undefined :: f a))
        pure $ a :*: b
    {-# INLINABLE peek #-}
    poke ptr (a :*: b) = do
        poke (castPtr ptr) a
        pokeByteOff ptr (sizeOf (undefined :: f a)) b
    {-# INLINABLE poke #-}

instance (Storable (f a), Storable (g a)) => Storable ((f :+: g) a) where
    sizeOf _ = 4 + sizeOf (undefined :: f a) `max` sizeOf (undefined :: g a)
    alignment _ = gcd (sizeOf (undefined :: f a)) (sizeOf (undefined :: g a))
    {-# INLINABLE sizeOf #-}
    peek ptr = do
        tag <- peek (castPtr ptr)
        if (tag :: Word32) == 0
            then pure L1 `ap` peekByteOff ptr 4
            else pure R1 `ap` peekByteOff ptr 4
    {-# INLINEABLE peek #-}
    poke ptr (L1 val) = poke (castPtr ptr) (0 :: Word32) >> pokeByteOff ptr 4 val
    poke ptr (R1 val) = poke (castPtr ptr) (1 :: Word32) >> pokeByteOff ptr 4 val
    {-# INLINEABLE poke #-}

instance (Storable (f a)) => Storable (M1 i c f a) where
  sizeOf _ = sizeOf (undefined :: f a)
  {-# INLINEABLE sizeOf #-}
  alignment = sizeOf
  peek ptr = pure M1 `ap` peek (castPtr ptr)
  {-# INLINEABLE peek #-}
  poke ptr (M1 val) = poke (castPtr ptr) val
  {-# INLINEABLE poke #-}

instance (Storable (f a)) => Storable (K1 i (f a) a) where
  sizeOf _ = sizeOf (undefined :: f a)
  alignment = sizeOf
  {-# INLINEABLE sizeOf #-}
  peek ptr = pure K1 `ap` peek (castPtr ptr)
  {-# INLINEABLE peek #-}
  poke ptr (K1 val) = poke (castPtr ptr) val
  {-# INLINEABLE poke #-}