{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} -- | Implementation of a 'Storable' deriver for data types. This works for -- any non-recursive datatype which has 'Storable' fields. -- -- Most users won't need to import this module directly. Instead, use -- 'derive' / 'Deriving' to create 'Storable' instances. module TH.Derive.Storable ( makeStorableInst ) where import Control.Monad import Data.Data (Data, gmapT) import Data.Generics.Aliases (extT) import Data.List (find) import qualified Data.Map as M import Data.Maybe (fromMaybe) import Data.Primitive.Types (Prim(..)) import Data.Typeable import Data.Word import Foreign.Ptr import Foreign.Storable import Language.Haskell.TH import Language.Haskell.TH.Syntax import TH.ReifyDataType import TH.Utilities import TH.Derive.Internal instance Deriver (Storable a) where runDeriver _ = makeStorableInst -- | Implementation used for 'runDeriver'. makeStorableInst :: Cxt -> Type -> Q [Dec] makeStorableInst cxt ty = do argTy <- expectTyCon1 ''Storable ty dt <- reifyDataTypeSubstituted argTy makeStorableImpl cxt ty (dtCons dt) -- TODO: recursion check? At least document that this could in some -- cases work, but produce a bogus instance. makeStorableImpl :: Cxt -> Type -> [DataCon] -> Q [Dec] makeStorableImpl cxt headTy cons = do -- Since this instance doesn't pay attention to alignment, we -- just say alignment doesn't matter. alignmentMethod <- [| 1 |] sizeOfMethod <- sizeExpr peekMethod <- peekExpr pokeMethod <- pokeExpr let methods = [ FunD (mkName "alignment") [Clause [WildP] (NormalB alignmentMethod) []] , FunD (mkName "sizeOf") [Clause [WildP] (NormalB sizeOfMethod) []] , FunD (mkName "peek") [Clause [VarP ptrName] (NormalB peekMethod) []] , FunD (mkName "poke") [Clause [VarP ptrName, VarP valName] (NormalB pokeMethod) []] ] return [InstanceD cxt headTy methods] where -- NOTE: Much of the code here resembles code in store for deriving -- Store instances. Changes here may be relevant there as well. (tagType, _, tagSize) = fromMaybe (error "Too many constructors") $ find (\(_, maxN, _) -> maxN >= length cons) tagTypes tagTypes :: [(Name, Int, Int)] tagTypes = [ ('(), 1, 0) , (''Word8, fromIntegral (maxBound :: Word8), 1) , (''Word16, fromIntegral (maxBound :: Word16), 2) , (''Word32, fromIntegral (maxBound :: Word32), 4) , (''Word64, fromIntegral (maxBound :: Word64), 8) ] valName = mkName "val" tagName = mkName "tag" ptrName = mkName "ptr" fName ix = mkName ("f" ++ show ix) ptrExpr = varE ptrName -- [[Int]] expression, where the inner lists are the sizes of the -- fields. Each member of the outer list corresponds to a different -- constructor. sizeExpr = appE (varE 'maximum) $ listE [ appE (varE 'sum) (listE [sizeOfExpr ty | (_, ty) <- fields]) | (DataCon cname _ _ fields) <- cons ] -- Choose a tag size large enough for this constructor count. -- Expression used for the definition of peek. peekExpr = case cons of [] -> [| error ("Attempting to peek type with no constructors (" ++ $(lift (pprint headTy)) ++ ")") |] [con] -> peekCon con _ -> doE [ bindS (varP tagName) [| peek (castPtr $(ptrExpr)) |] , noBindS (caseE (sigE (varE tagName) (conT tagType)) (map peekMatch (zip [0..] cons) ++ [peekErr])) ] peekMatch (ix, con) = match (litP (IntegerL ix)) (normalB (peekCon con)) [] peekErr = match wildP (normalB [| error ("Found invalid tag while peeking (" ++ $(lift (pprint headTy)) ++ ")") |]) [] peekCon (DataCon cname _ _ fields) = letE (offsetDecls fields) $ case fields of [] -> [| pure $(conE cname) |] (_:fields') -> foldl (\acc (ix, _) -> [| $(acc) <*> $(peekOffset ix) |] ) [| $(conE cname) <$> $(peekOffset 0) |] (zip [1..] fields') peekOffset ix = [| peek (castPtr (plusPtr $(ptrExpr) $(varE (offset ix)))) |] -- Expression used for the definition of poke. pokeExpr = caseE (varE valName) (map pokeMatch (zip [0..] cons)) pokeMatch :: (Int, DataCon) -> MatchQ pokeMatch (ixcon, DataCon cname _ _ fields) = match (conP cname (map varP (map fName ixs))) (normalB (doE (tagPokes ++ offsetLet ++ fieldPokes))) [] where tagPokes = case cons of (_:_:_) -> [noBindS [| poke (castPtr $(ptrExpr)) (ixcon :: $(conT tagType)) |]] _ -> [] offsetLet | null ixs = [] | otherwise = [letS (offsetDecls fields)] fieldPokes = map (noBindS . pokeField) ixs ixs = map fst (zip [0..] fields) pokeField ix = [| poke (castPtr (plusPtr $(ptrExpr) $(varE (offset ix)))) $(varE (fName ix)) |] -- Generate declarations which compute the field offsets. offsetDecls fields = -- Skip the last one, to avoid unused variable warnings. init $ map (\(ix, expr) -> valD (varP (offset ix)) (normalB expr) []) $ -- Initial offset is the tag size. ((0, [| tagSize |]) :) $ map (\(ix, (_, ty)) -> (ix, offsetExpr ix ty)) $ zip [1..] fields where offsetExpr ix ty = [| $(sizeOfExpr ty) + $(varE (offset (ix - 1))) |] sizeOfExpr ty = [| $(varE 'sizeOf) (error "sizeOf evaluated its argument" :: $(return ty)) |] offset ix = mkName ("offset" ++ show ix)