{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE ConstrainedClassMethods #-}
{-# LANGUAGE DefaultSignatures       #-}
{-# LANGUAGE DeriveDataTypeable      #-}
{-# LANGUAGE FlexibleContexts        #-}
{-# LANGUAGE FlexibleInstances       #-}
{-# LANGUAGE IncoherentInstances     #-}
{-# LANGUAGE InstanceSigs            #-}
{-# LANGUAGE MonoLocalBinds          #-}
{-# LANGUAGE RankNTypes              #-}
{-# LANGUAGE ScopedTypeVariables     #-}
{-# LANGUAGE StandaloneDeriving      #-}
{-# LANGUAGE TypeOperators           #-}
{-# LANGUAGE UndecidableInstances    #-}

module Foreign.Storable.ATS
    ( ATSStorable (..)
    , AsCString (..)
    , Indexed (..)
    ) where

import           Control.Composition
import qualified Data.ByteString       as BS
import qualified Data.ByteString.Lazy  as BSL
import           Data.Data
import qualified Data.Text             as T
import qualified Data.Text.Lazy        as TL
import           Data.Word
import           Foreign.C.String
import           Foreign.C.Types
import           Foreign.Marshal.Alloc
import           Foreign.Ptr
import qualified Foreign.Storable      as C
import           GHC.Generics

deriving instance Data CChar
deriving instance Data CInt

class AsCString a where
    toCString :: a -> IO CString

instance AsCString String where
    toCString = newCString

instance AsCString T.Text where
    toCString = newCString . T.unpack

instance AsCString TL.Text where
    toCString = newCString . TL.unpack

instance AsCString BS.ByteString where
    toCString = flip BS.useAsCString pure

instance AsCString BSL.ByteString where
    toCString = flip BS.useAsCString pure . BSL.toStrict

class Storable' f where

    sizeOf' :: f a -> Int

    alignment' :: f a -> Int

    peek' :: Word8 -> Ptr (f a) -> IO (f a)

    poke' :: Word8 -> Ptr (f a) -> f a -> IO ()

    pokeByteOff' :: Ptr (f a) -> Int -> f a -> IO ()
    pokeByteOff' = poke' 0 .* plusPtr

    peekByteOff' :: Ptr (f a) -> Int -> IO (f a)
    peekByteOff' = peek' 0 .* plusPtr

instance Storable' U1 where
    sizeOf' = pure 0
    alignment' = pure 0
    poke' _ _ = pure undefined
    peek' _ _ = pure undefined

instance Storable' V1 where
    peek' = undefined
    alignment' = undefined
    poke' = undefined
    sizeOf' = undefined

instance (Storable' a, Storable' b) => Storable' (a :*: b) where
    sizeOf' _ = sizeOf' (undefined :: a x) + sizeOf' (undefined :: b x)
    alignment' _ = gcd (alignment' (undefined :: a x)) (alignment' (undefined :: b x))
    peek' n ptr = do
        a <- peek' n (castPtr ptr)
        (a :*:) <$> peekByteOff' (castPtr ptr) (sizeOf' a)
    poke' n ptr (a :*: b) = mconcat
        [ poke' n (castPtr ptr) a
        , pokeByteOff' (castPtr ptr) (sizeOf' a) b ]

asIndex :: (Data a) => a -> Int
asIndex x = subtract 1 . length $ takeWhile (/= ix) cs
    where ix = toConstr x
          cs = dataTypeConstrs (dataTypeOf x)

sumHelper :: Storable' f => Word8 -> Ptr a -> f b -> IO ()
sumHelper n ptr val = do
    bptr <- mallocBytes (sizeOf' val)
    poke' n bptr val
    C.poke (castPtr ptr) bptr

ptrSize :: Int
ptrSize = C.sizeOf (undefined :: (Ptr Word8))

instance (Storable' a, Storable' b) => Storable' (a :+: b) where
    sizeOf' _ = 1 + ptrSize
    alignment' _ = 1
    peek' _ _ = undefined
    poke' n ptr (L1 val) = mconcat
        [ C.poke (castPtr ptr) n
        , sumHelper n ptr val ]
    poke' n ptr (R1 val) = mconcat
        [ C.poke (castPtr ptr) n
        , sumHelper n ptr val ]

instance (C.Storable a) => Storable' (K1 i a) where
    sizeOf' _ = C.sizeOf (undefined :: a)
    alignment' _ = C.alignment (undefined :: a)
    peek' _ ptr = pure K1 <*> C.peek (castPtr ptr)
    poke' _ ptr (K1 val) = C.poke (castPtr ptr) val

instance (Storable' a) => Storable' (M1 i c a) where
    sizeOf' _ = sizeOf' (undefined :: a x)
    alignment' _ = alignment' (undefined :: a x)
    peek' n ptr = pure M1 <*> peek' n (castPtr ptr)
    poke' n ptr (M1 val) = poke' n (castPtr ptr) val

instance C.Storable a => Indexed a where
    index :: a -> Word8
    index = pure 1

class Indexed a where
    index :: a -> Word8
    default index :: Data a => a -> Word8
    index = fromIntegral . asIndex

instance (Generic a, Storable' (Rep a), Data a) => C.Storable a where
    sizeOf _ = (sizeOf' . from) (undefined :: a)
    alignment = C.sizeOf
    poke ptr x = poke' (index x) (castPtr ptr) (from x)
    peek = fmap to . peek' undefined . castPtr

class ATSStorable a where

    readPtr :: C.Storable a => Ptr a -> IO a
    readPtr = C.peek

    writePtr :: C.Storable a => a -> IO (Ptr a)
    writePtr val = do
        ptr <- mallocBytes (C.sizeOf val)
        C.poke ptr val
        pure ptr