module TH.Derive.Storable
( makeStorableInst
) where
import Control.Monad
import Data.Data (Data, gmapT)
import Data.Generics.Aliases (extT)
import Data.List (find)
import qualified Data.Map as M
import Data.Maybe (fromMaybe)
import Data.Primitive.Types (Prim(..))
import Data.Typeable
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import TH.ReifyDataType
import TH.Utilities
import TH.Derive.Internal
instance Deriver (Storable a) where
runDeriver _ = makeStorableInst
makeStorableInst :: Cxt -> Type -> Q [Dec]
makeStorableInst cxt ty = do
argTy <- expectTyCon1 ''Storable ty
dt <- reifyDataTypeSubstituted argTy
makeStorableImpl cxt ty (dtCons dt)
makeStorableImpl :: Cxt -> Type -> [DataCon] -> Q [Dec]
makeStorableImpl cxt headTy cons = do
alignmentMethod <- [| 1 |]
sizeOfMethod <- sizeExpr
peekMethod <- peekExpr
pokeMethod <- pokeExpr
let methods =
[ FunD (mkName "alignment") [Clause [WildP] (NormalB alignmentMethod) []]
, FunD (mkName "sizeOf") [Clause [WildP] (NormalB sizeOfMethod) []]
, FunD (mkName "peek") [Clause [VarP ptrName] (NormalB peekMethod) []]
, FunD (mkName "poke") [Clause [VarP ptrName, VarP valName] (NormalB pokeMethod) []]
]
return [InstanceD cxt headTy methods]
where
(tagType, _, tagSize) =
fromMaybe (error "Too many constructors") $
find (\(_, maxN, _) -> maxN >= length cons) tagTypes
tagTypes :: [(Name, Int, Int)]
tagTypes =
[ ('(), 1, 0)
, (''Word8, fromIntegral (maxBound :: Word8), 1)
, (''Word16, fromIntegral (maxBound :: Word16), 2)
, (''Word32, fromIntegral (maxBound :: Word32), 4)
, (''Word64, fromIntegral (maxBound :: Word64), 8)
]
valName = mkName "val"
tagName = mkName "tag"
ptrName = mkName "ptr"
fName ix = mkName ("f" ++ show ix)
ptrExpr = varE ptrName
sizeExpr = appE (varE 'maximum) $
listE [ appE (varE 'sum) (listE [sizeOfExpr ty | (_, ty) <- fields])
| (DataCon cname _ _ fields) <- cons
]
peekExpr = case cons of
[] -> [| error ("Attempting to peek type with no constructors (" ++ $(lift (pprint headTy)) ++ ")") |]
[con] -> peekCon con
_ -> doE
[ bindS (varP tagName) [| peek (castPtr $(ptrExpr)) |]
, noBindS (caseE (sigE (varE tagName) (conT tagType))
(map peekMatch (zip [0..] cons) ++ [peekErr]))
]
peekMatch (ix, con) = match (litP (IntegerL ix)) (normalB (peekCon con)) []
peekErr = match wildP (normalB [| error ("Found invalid tag while peeking (" ++ $(lift (pprint headTy)) ++ ")") |]) []
peekCon (DataCon cname _ _ fields) =
letE (offsetDecls fields) $
case fields of
[] -> [| pure $(conE cname) |]
(_:fields') ->
foldl (\acc (ix, _) -> [| $(acc) <*> $(peekOffset ix) |] )
[| $(conE cname) <$> $(peekOffset 0) |]
(zip [1..] fields')
peekOffset ix = [| peek (castPtr (plusPtr $(ptrExpr) $(varE (offset ix)))) |]
pokeExpr = caseE (varE valName) (map pokeMatch (zip [0..] cons))
pokeMatch :: (Int, DataCon) -> MatchQ
pokeMatch (ixcon, DataCon cname _ _ fields) =
match (conP cname (map varP (map fName ixs)))
(normalB (doE (tagPokes ++ offsetLet ++ fieldPokes)))
[]
where
tagPokes = case cons of
(_:_:_) -> [noBindS [| poke (castPtr $(ptrExpr)) (ixcon :: $(conT tagType)) |]]
_ -> []
offsetLet
| null ixs = []
| otherwise = [letS (offsetDecls fields)]
fieldPokes = map (noBindS . pokeField) ixs
ixs = map fst (zip [0..] fields)
pokeField ix = [| poke (castPtr (plusPtr $(ptrExpr)
$(varE (offset ix))))
$(varE (fName ix)) |]
offsetDecls fields =
init $
map (\(ix, expr) -> valD (varP (offset ix)) (normalB expr) []) $
((0, [| tagSize |]) :) $
map (\(ix, (_, ty)) -> (ix, offsetExpr ix ty)) $
zip [1..] fields
where
offsetExpr ix ty = [| $(sizeOfExpr ty) + $(varE (offset (ix 1))) |]
sizeOfExpr ty = [| $(varE 'sizeOf) (error "sizeOf evaluated its argument" :: $(return ty)) |]
offset ix = mkName ("offset" ++ show ix)