{-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE ConstrainedClassMethods #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE IncoherentInstances #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -- | You'll probably want to use this module to simply derive an 'ATSStorable' instance. To do so: -- -- > {-# LANGUAGE DeriveGeneric #-} -- > {-# LANGUAGE DeriveDataTypeable #-} -- > {-# LANGUAGE DeriveAnyClass #-} -- > -- > data MyType a = MyType a -- > deriving (Generic, Data, ATSStorable) module Foreign.Storable.ATS ( ATSStorable (..) , AsCString (..) ) where import Data.Bool (bool) import Control.Composition import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BSL import Data.Data import Data.Foldable import qualified Data.Text as T import qualified Data.Text.Lazy as TL import Data.Word import Foreign.C.String import Foreign.C.Types import Foreign.Marshal.Alloc import Foreign.Ptr import qualified Foreign.Storable as C import GHC.Generics deriving instance Data CChar deriving instance Data CInt class AsCString a where toCString :: a -> IO CString instance AsCString String where toCString = newCString instance AsCString T.Text where toCString = newCString . T.unpack instance AsCString TL.Text where toCString = newCString . TL.unpack instance AsCString BS.ByteString where toCString = flip BS.useAsCString pure instance AsCString BSL.ByteString where toCString = flip BS.useAsCString pure . BSL.toStrict data ATSTypeConfig = ATSTypeConfig { _i :: Word8 -- ^ Index of the particular constructor , n :: Word8 -- ^ Number of constructors for the type , _recursive :: Bool -- ^ Whether or not the type is self-recursive , _special :: Bool -- ^ Flag to be set when the type has exactly one self-recursive type. } class Storable' f where sizeOf' :: f a -> Int alignment' :: f a -> Int peek' :: ATSTypeConfig -> Ptr (f a) -> IO (f a) poke' :: ATSTypeConfig -> Ptr (f a) -> f a -> IO () pokeByteOff' :: ATSTypeConfig -> Ptr (f a) -> Int -> f a -> IO () pokeByteOff' cfg = poke' cfg .* plusPtr peekByteOff' :: ATSTypeConfig -> Ptr (f a) -> Int -> IO (f a) peekByteOff' cfg = peek' cfg .* plusPtr instance Storable' U1 where sizeOf' = pure 0 alignment' = pure 0 poke' _ _ = pure undefined peek' _ _ = pure undefined instance Storable' V1 where peek' = undefined alignment' = undefined poke' = undefined sizeOf' = undefined instance (Storable' a, Storable' b) => Storable' (a :*: b) where sizeOf' _ = sizeOf' (undefined :: a x) + sizeOf' (undefined :: b x) alignment' _ = gcd (alignment' (undefined :: a x)) (alignment' (undefined :: b x)) peek' cfg ptr = do a <- peek' cfg (castPtr ptr) (a :*:) <$> peekByteOff' cfg (castPtr ptr) (sizeOf' a) poke' cfg ptr (a :*: b) = poke' cfg (castPtr ptr) a >> pokeByteOff' cfg (castPtr ptr) (sizeOf' a) b numConstructors :: (Data a) => a -> Int numConstructors x = subtract 1 . length $ takeWhile (/= ix) cs where ix = toConstr x cs = dataTypeConstrs (dataTypeOf x) sumHelper :: Storable' f => ATSTypeConfig -> Ptr a -- ^ Pointer we want to write our value at -> f b -- ^ Value to be written -> IO () sumHelper cfg@(ATSTypeConfig _ _ _ True) ptr val = do bytesPtr <- mallocBytes (sizeOf' val) poke' cfg bytesPtr val C.poke (castPtr ptr) bytesPtr sumHelper cfg@(ATSTypeConfig _ _ True False) ptr val = do bytesPtr <- mallocBytes (sizeOf' val) poke' cfg bytesPtr val C.pokeByteOff (castPtr ptr) 1 bytesPtr sumHelper cfg@(ATSTypeConfig _ _ False False) ptr val = pokeByteOff' cfg (castPtr ptr) 1 val ptrSize :: Int ptrSize = C.sizeOf (undefined :: (Ptr Word8)) -- The rules for storing a type in ATS are somewhat complex, so it bears writing -- them down here. -- -- 1. For a type which may be recursive (including all universally quantified -- types), the variable type must be heap-allocated. -- -- 2. In the specific case of a (possibly recursive) sum type with two -- constructors, one of which is empty, we may simply use a null pointer. -- -- 3. For other types, we simply tag the constructor number and use a boxed -- (stack-allocated) type. -- -- Product types are a good deal simpler. instance (Storable' a, Storable' b) => Storable' (a :+: b) where sizeOf' _ = 1 + ptrSize alignment' _ = 1 peek' cfg@(ATSTypeConfig _ _ True True) ptr = do i' <- C.peek (castPtr ptr) :: IO Word8 bool (R1 <$> (peek' cfg (castPtr ptr) :: IO (b x))) (L1 <$> (peek' cfg (castPtr ptr) :: IO (a x))) (i' /= 0) peek' _ _ = undefined poke' cfg@ATSTypeConfig{} ptr (L1 val) = fold [ C.poke (castPtr ptr) (n cfg) , sumHelper cfg ptr val ] poke' cfg@ATSTypeConfig{} ptr (R1 val) = fold [ C.poke (castPtr ptr) (n cfg) , sumHelper cfg ptr val ] instance (C.Storable a) => Storable' (K1 i a) where sizeOf' _ = C.sizeOf (undefined :: a) alignment' _ = C.alignment (undefined :: a) peek' _ ptr = pure K1 <*> C.peek (castPtr ptr) poke' _ ptr (K1 val) = C.poke (castPtr ptr) val instance (Storable' a) => Storable' (M1 i c a) where sizeOf' _ = sizeOf' (undefined :: a x) alignment' _ = alignment' (undefined :: a x) peek' cfg ptr = pure M1 <*> peek' cfg (castPtr ptr) poke' cfg ptr (M1 val) = poke' cfg (castPtr ptr) val index' :: Data a => a -> Word8 index' = fromIntegral . constrIndex . toConstr count' :: Data a => a -> Word8 count' = fromIntegral . numConstructors atsCfg' :: (Recurse a, Data a) => a -> ATSTypeConfig atsCfg' a = ATSTypeConfig (index' a) (count' a) (selfRecursive a) (isSpecial a) instance (Generic a, Storable' (Rep a), Data a, Recurse a) => C.Storable a where sizeOf _ = (sizeOf' . from) (undefined :: a) alignment = C.sizeOf poke ptr x = poke' (atsCfg' x) (castPtr ptr) (from x) peek = fmap to . peek' (atsCfg' (undefined :: a)) . castPtr class Recurse' f where selfRecursive' :: f a -> Bool isSpecial' :: f a -> Bool instance Recurse' V1 where selfRecursive' = undefined isSpecial' = undefined instance Recurse' U1 where selfRecursive' = pure False isSpecial' = pure False instance (Recurse' a, Recurse' b) => Recurse' (a :+: b) where selfRecursive' _ = selfRecursive' (undefined :: a x) || selfRecursive' (undefined :: b x) isSpecial' _ = selfRecursive' (undefined :: a x) /= selfRecursive' (undefined :: b x) instance (Recurse' a, Recurse' b) => Recurse' (a :*: b) where selfRecursive' _ = selfRecursive' (undefined :: a x) || selfRecursive' (undefined :: b x) isSpecial' _ = selfRecursive' (undefined :: a x) || selfRecursive' (undefined :: b x) instance Recurse' a => Recurse' (M1 i c a) where selfRecursive' (M1 val) = selfRecursive' val isSpecial' (M1 val) = isSpecial' val instance Recurse' (K1 i a) where selfRecursive' = pure True isSpecial' = pure True class Recurse a where selfRecursive :: a -> Bool isSpecial :: a -> Bool instance (Generic a, Recurse' (Rep a)) => Recurse a where selfRecursive = selfRecursive' . from isSpecial = isSpecial' . from class ATSStorable a where -- | Read a value at a pointer. readPtr :: C.Storable a => Ptr a -> IO a readPtr = C.peek -- | Write a value to a pointer. writePtr :: C.Storable a => a -> IO (Ptr a) writePtr val = do ptr <- mallocBytes (C.sizeOf val) C.poke ptr val pure ptr