module Data.Flags.TH (
dataBitsAsFlags,
dataBitsAsBoundedFlags,
bitmaskWrapper,
enumADT
) where
import Language.Haskell.TH
import Data.Bits (Bits(..))
import Data.Maybe (isJust)
import Data.List (find, union, intercalate)
import Foreign.Storable (Storable(..))
import Foreign.Ptr (Ptr, castPtr)
import Control.Applicative ((<$>))
import Data.Flags.Base
inst :: Name -> Name -> [Dec] -> Dec
inst className typeName = InstanceD [] (AppT (ConT className) (ConT typeName))
fun :: Name -> Exp -> Dec
fun name expr = FunD name [Clause [] (NormalB expr) []]
dataBitsAsFlags :: Name -> Q [Dec]
dataBitsAsFlags typeName = do
noFlagsE <- [| fromInteger 0 |]
andFlagsE <- [| (.|.) |]
commonFlagsE <- [| (.&.) |]
butFlagsE <- [| \x -> \y -> x .&. (complement y) |]
return [inst ''Flags typeName
[fun 'noFlags noFlagsE,
fun 'andFlags andFlagsE,
fun 'commonFlags commonFlagsE,
fun 'butFlags butFlagsE]]
dataBitsAsBoundedFlags :: Name -> Q [Dec]
dataBitsAsBoundedFlags typeName = do
allFlagsE <- [| fromInteger (1) |]
enumFlagsE <- [| \x -> map (setBit 0) $
filter (testBit x) [0 .. bitSize x 1] |]
(++ [inst ''BoundedFlags typeName
[fun 'allFlags allFlagsE,
fun 'enumFlags enumFlagsE]]) <$> dataBitsAsFlags typeName
bitmaskWrapper :: String
-> Name
-> [Name]
-> [(String, Integer)]
-> Q [Dec]
bitmaskWrapper typeNameS wrappedName derives elems = do
typeName <- return $ mkName typeNameS
showE <- [| \flags -> $(stringE $ typeNameS ++ " [") ++
(intercalate ", " $ map snd $
filter ((noFlags /=) . commonFlags flags . fst) $
$(listE $
map (\(name, _) ->
tupE [varE $ mkName name,
stringE name])
elems)) ++ "]" |]
allFlagsE <- [| foldl andFlags noFlags
$(listE $ map (varE . mkName . fst) elems) |]
enumFlagsE <- [| \flags -> filter ((noFlags /=) . commonFlags flags) $
$(listE $ map (varE . mkName . fst) elems) |]
return $ [NewtypeD [] typeName []
(NormalC typeName [(NotStrict, ConT wrappedName)])
(union [''Eq, ''Flags] derives)] ++
(concatMap (\(nameS, value) ->
let name = mkName nameS in
[SigD name (ConT typeName),
FunD name
[Clause [] (NormalB $
AppE (ConE typeName)
(LitE $ IntegerL value))
[]]]) elems) ++
[inst ''BoundedFlags typeName
[fun 'allFlags allFlagsE,
fun 'enumFlags enumFlagsE]] ++
(if (isJust $ find (''Show ==) derives)
then []
else [inst ''Show typeName [fun 'show showE]])
enumADT :: String
-> Name
-> [(String, Integer)]
-> Q [Dec]
enumADT typeNameS numName elems = do
let typeName = mkName typeNameS
wrap i = caseE (varE i) $
(map (\(name, value) ->
match (litP $ IntegerL value)
(normalB $ appE (conE 'Just)
(conE $ mkName name))
[]) elems) ++
[match wildP (normalB $ conE 'Nothing) []]
unwrap w = caseE (varE w)
(map (\(name, value) ->
match (conP (mkName name) [])
(normalB $ litE $ IntegerL value)
[]) elems) in do
alignmentE <- [| \_ -> alignment (undefined :: $(conT numName)) |]
sizeOfE <- [| \_ -> sizeOf (undefined :: $(conT numName)) |]
peekE <- [| \p -> do
i <- peek (castPtr p :: Ptr $(conT numName))
case $(wrap 'i) of
Just w -> return w
Nothing -> fail $ "Invalid value for " ++ typeNameS |]
pokeE <- [| \p -> \v -> poke (castPtr p :: Ptr $(conT numName))
$(unwrap 'v) |]
return [DataD [] typeName [] (map ((`NormalC` []) . mkName . fst) elems)
[''Eq, ''Ord, ''Show],
inst ''Storable typeName
[fun 'alignment alignmentE,
fun 'sizeOf sizeOfE,
fun 'peek peekE,
fun 'poke pokeE]]