module Ivory.Language.Syntax.Concrete.QQ.BitDataQQ (fromBitData) where
import Control.Monad (MonadPlus, join,
msum, mzero, unless,
when)
import Data.Foldable (find, foldl')
import Data.List (sort)
import Data.Maybe (catMaybes, isJust,
mapMaybe)
import Data.Traversable (mapAccumL)
import MonadLib (ChoiceT, findOne,
lift)
import Language.Haskell.TH hiding (Exp, Type)
import qualified Language.Haskell.TH as TH
import qualified Ivory.Language.Bits as I
import qualified Ivory.Language.Cast as I
import qualified Ivory.Language.IBool as I
import qualified Ivory.Language.Init as I
import qualified Ivory.Language.Ref as I
import Ivory.Language.Syntax.Concrete.ParseAST hiding (tyDef)
import qualified Ivory.Language.Type as I
import qualified Ivory.Language.BitData.Array as B
import qualified Ivory.Language.BitData.BitData as B
import qualified Ivory.Language.BitData.Bits as B
#if __GLASGOW_HASKELL__ >= 709
import Ivory.Language.Syntax.Concrete.QQ.Common
#endif
import Ivory.Language.Syntax.Concrete.Location
import Ivory.Language.Syntax.Concrete.QQ.TypeQQ
anyOf :: MonadPlus m => [a] -> m a
anyOf = msum . map return
convertType :: BitTy -> Q TH.Type
convertType Bit = getType "Bit"
convertType (Bits n) =
let b = getType "Bits" in
appT b (return $ szTy n)
convertType (BitArray n t) = do
appT (appT (getType "BitArray") (return $ szTy n)) (convertType t)
convertType (BitTySynonym s) = getType s
convertType (LocBitTy b) = convertType (unLoc b)
getType :: String -> Q TH.Type
getType s = do
m <- lookupTypeName s
case m of
Just ty -> return $ ConT ty
Nothing -> fail $ "undefined type: " ++ s
getTyBits :: TH.Type -> ChoiceT Q Integer
getTyBits ty =
case ty of
ConT name
| name == ''B.Bit -> return 1
| otherwise -> tyInsts ''B.BitType ty >>= decBits
AppT (AppT (ConT name) (LitT (NumTyLit n))) ty2
| name == ''B.BitArray -> do
m <- lift $ tyBits ty2
return (fromIntegral n * m)
AppT (ConT name) (LitT (NumTyLit n))
| name == ''B.Bits -> return $ fromIntegral n
| otherwise -> mzero
_ -> mzero
where
tyInsts name t = lift (reifyInstances name [t]) >>= anyOf
#if __GLASGOW_HASKELL__ >= 708
decBits (TySynInstD _ (TySynEqn _ t)) = getTyBits t
#else
decBits (TySynInstD _ _ t) = getTyBits t
#endif
decBits _ = mzero
tyBits :: TH.Type -> Q Integer
tyBits ty = do
r <- findOne (getTyBits ty)
case r of
Just x -> return x
Nothing -> fail $ "invalid bit value base type: " ++ show ty
data THField = THField
{ thFieldName :: Maybe Name
, thFieldType :: TH.Type
, thFieldLen :: Integer
} deriving Show
annotateField :: BitField -> Q THField
annotateField (BitField mn t _) = do
ty <- convertType t
len <- tyBits ty
return $ THField (fmap mkName mn) ty len
data THLayoutItem =
THLayoutConst BitLiteral Integer
| THLayoutField THField Integer
deriving Show
type THLayout = [THLayoutItem]
defaultLayout :: Integer -> [THField] -> THLayout
defaultLayout len [] = [THLayoutConst (BitLitKnown len 0) 0]
defaultLayout _ fs = snd $ mapAccumL go 0 fs
where
go pos f =
let len = thFieldLen f in
case thFieldName f of
Nothing -> (pos + len, THLayoutConst (BitLitKnown len 0) pos)
Just _ -> (pos + thFieldLen f, THLayoutField f pos)
annotateLayout :: Integer -> [LayoutItem] -> [THField] -> THLayout
annotateLayout len [] fs = defaultLayout len (reverse fs)
annotateLayout _ ls fs = snd $ mapAccumL go 0 (reverse ls)
where
go pos l =
case l of
LayoutConst lit@(BitLitKnown len _)
-> (pos + len, THLayoutConst lit pos)
LayoutField name
| Just f <- lookupField name fs
-> (pos + thFieldLen f, THLayoutField f pos)
_ -> error "invalid bitdata layout"
layoutItemSize :: [THField] -> LayoutItem -> Integer
layoutItemSize _ (LayoutConst (BitLitKnown len _)) = len
layoutItemSize _ (LayoutConst (BitLitUnknown _)) = 0
layoutItemSize fs (LayoutField name) =
case lookupField name fs of
Just field -> fromIntegral (thFieldLen field)
Nothing -> error "undefined field"
lookupField :: String -> [THField] -> Maybe THField
lookupField name fs = find getNm fs
where
getNm th = case thFieldName th of
Just n -> mkName name == n
Nothing -> False
layoutSize :: [THField] -> [LayoutItem] -> Integer
layoutSize fs ls = sum (map (layoutItemSize fs) ls)
hasUnknownSize :: LayoutItem -> Bool
hasUnknownSize (LayoutConst (BitLitUnknown _)) = True
hasUnknownSize _ = False
updateSizeL :: Integer -> LayoutItem -> LayoutItem
updateSizeL size l =
case l of
LayoutConst (BitLitUnknown x) -> LayoutConst (BitLitKnown size x)
_ -> l
updateFirstL :: Integer -> [LayoutItem] -> [LayoutItem]
updateFirstL _ [] = []
updateFirstL size (l:ls)
| hasUnknownSize l = updateSizeL size l : ls
| otherwise = l : updateFirstL size ls
updateLiterals :: Integer -> [THField] -> [LayoutItem] -> Q [LayoutItem]
updateLiterals defLen fs ls = do
let slop = defLen layoutSize fs ls
ls' = updateFirstL slop ls
when (any hasUnknownSize ls') $
fail "multiple unknown size bit fields"
return ls'
foldLayout :: (b -> THLayoutItem -> b) -> b -> THConstr -> b
foldLayout f z c = foldl' f z (thConstrLayout c)
mapLayout :: (THLayoutItem -> a) -> THConstr -> [a]
mapLayout f c = map f (thConstrLayout c)
thLayoutItemSize :: THLayoutItem -> Integer
thLayoutItemSize (THLayoutConst (BitLitKnown len _) _) = len
thLayoutItemSize (THLayoutConst _ _) = error "invalid layout item"
thLayoutItemSize (THLayoutField f _) = thFieldLen f
thLayoutSize :: THLayout -> Integer
thLayoutSize l = sum $ map thLayoutItemSize l
data THConstr = THConstr
{ thConstrName :: Name
, thConstrFields :: [THField]
, thConstrLayout :: THLayout
} deriving Show
constrFieldNames :: THConstr -> [Name]
constrFieldNames c = catMaybes $ map thFieldName (thConstrFields c)
annotateConstr :: Integer -> Constr -> Q THConstr
annotateConstr len (Constr n fs ls _) = do
fs' <- mapM annotateField fs
ls' <- updateLiterals len fs' ls
return $ THConstr (mkName n) fs' (annotateLayout len ls' fs')
data THDef = THDef
{ thDefName :: Name
, thDefType :: TH.Type
, thDefConstrs :: [THConstr]
, thDefLen :: Integer
} deriving Show
annotateDef :: BitDataDef -> Q THDef
annotateDef (BitDataDef n t cs _) = do
ty <- convertType t
len <- tyBits ty
cs' <- mapM (annotateConstr len) cs
return (THDef (mkName n) ty cs' len)
checkDef :: THDef -> Q ()
checkDef def = do
mapM_ (checkConstr def) (thDefConstrs def)
checkConstr :: THDef -> THConstr -> Q ()
checkConstr def constr = do
checkLayout def constr (thConstrLayout constr)
layoutFieldNames :: THLayout -> [Name]
layoutFieldNames = mapMaybe (join . go)
where go (THLayoutField f _) = Just $ thFieldName f
go _ = Nothing
checkLayout :: THDef -> THConstr -> THLayout -> Q ()
checkLayout def c l = do
let cnames = filter (/= (mkName "_")) (constrFieldNames c)
lnames = filter (/= (mkName "_")) (layoutFieldNames l)
unless (sort cnames == sort lnames) $
fail "layout does not mention each field exactly once"
when (thLayoutSize l > thDefLen def) $
fail "constructor layout is too large"
fromBitData :: BitDataDef -> Q [Dec]
fromBitData d = do
def <- annotateDef d
checkDef def
defs <- sequence $ concat
[ mkDefNewtype def
, mkDefInstance def
, concatMap (mkConstr def) (thDefConstrs def)
, mkArraySizeTypeInsts def
]
#if __GLASGOW_HASKELL__ >= 709
ln <- lnPragma (bdLoc d)
return (ln ++ defs)
#else
return defs
#endif
mkDefNewtype :: THDef -> [DecQ]
mkDefNewtype def =
#if __GLASGOW_HASKELL__ >= 800
[newtypeD (cxt []) name []
Nothing
(normalC name
[bangType (bang noSourceUnpackedness noSourceStrictness) (return ty)])
(mapM conT
[ ''I.IvoryType, ''I.IvoryVar, ''I.IvoryExpr , ''I.IvoryEq
, ''I.IvoryInit, ''I.IvoryStore, ''I.IvoryZeroVal ])]
#else
[newtypeD (cxt []) name []
(normalC name [strictType notStrict (return ty)])
[ ''I.IvoryType, ''I.IvoryVar, ''I.IvoryExpr , ''I.IvoryEq
, ''I.IvoryInit, ''I.IvoryStore, ''I.IvoryZeroVal ]]
#endif
where
name = thDefName def
ty = thDefType def
mkDefInstance :: THDef -> [DecQ]
mkDefInstance def = [instanceD (cxt []) instTy body]
where
name = thDefName def
baseTy = thDefType def
instTy = [t| B.BitData $(conT (thDefName def)) |]
body = [tyDef, toFun, fromFun]
#if __GLASGOW_HASKELL__ >= 708
tyDef = return (TySynInstD ''B.BitType (TySynEqn [ConT name] baseTy))
#else
tyDef = return (TySynInstD ''B.BitType [ConT name] baseTy)
#endif
x = mkName "x"
toFun = funD 'B.toBits [clause [conP name [varP x]]
(normalB (varE x)) []]
fromFun = valD (varP 'B.fromBits) (normalB (conE name)) []
mkArraySizeTypeInsts :: THDef -> [DecQ]
mkArraySizeTypeInsts def =
concatMap (uncurry mkArraySizeTypeInst)
(mapMaybe getArrayType
(concatMap constrFieldTypes (thDefConstrs def)))
where
getArrayType :: TH.Type -> Maybe (Integer, TH.Type)
getArrayType (AppT (AppT (ConT name) (LitT (NumTyLit n))) ty)
| name == ''B.BitArray = Just (fromIntegral n, ty)
getArrayType _ = Nothing
mkArraySizeTypeInst :: Integer -> TH.Type -> [DecQ]
mkArraySizeTypeInst n ty =
#if __GLASGOW_HASKELL__ >= 708
[tySynInstD ''B.ArraySize (tySynEqn args size)]
#else
[tySynInstD ''B.ArraySize args size]
#endif
where
size = tyBits ty >>= litT . numTyLit . fromIntegral . (* n)
args = [litT (numTyLit (fromIntegral n)), return ty]
constrFieldTypes :: THConstr -> [TH.Type]
constrFieldTypes c = map thFieldType fields
where fields = filter (isJust . thFieldName) (thConstrFields c)
mkConstrType :: THDef -> THConstr -> TH.Type
mkConstrType d c = foldl (flip (AppT . AppT ArrowT)) (ConT (thDefName d)) fields
where fields = constrFieldTypes c
argName :: Integer -> Name
argName n = mkName ("arg" ++ show n)
mkConstrArgs :: THConstr -> [PatQ]
mkConstrArgs c = zipWith f [0..] names
where names = filter (/= (mkName "_")) (constrFieldNames c)
f x _ = varP (argName x)
constrBodyExpr :: (Integer, ExpQ) -> THLayoutItem -> (Integer, ExpQ)
constrBodyExpr (n, expr) l =
case l of
THLayoutField _ pos ->
(n + 1, infixApp expr (varE '(I..|))
(infixApp (appE (varE 'I.safeCast) (appE (varE 'B.toRep)
(varE (argName n))))
(varE 'I.iShiftL) (litE (integerL
(fromIntegral pos)))))
THLayoutConst val pos
| bitLitVal val /= 0 ->
(n, infixApp expr (varE '(I..|))
(infixApp (litE (integerL (fromIntegral (bitLitVal val))))
(varE 'I.iShiftL) (litE (integerL (fromIntegral pos)))))
| otherwise -> (n, expr)
mkConstr :: THDef -> THConstr -> [DecQ]
mkConstr def constr = [sig, fun] ++ mkConstrFields def constr
where
cname = thConstrName constr
sig = sigD cname (return (mkConstrType def constr))
args = mkConstrArgs constr
zexpr = litE (integerL 0)
expr = snd (foldLayout constrBodyExpr (0, zexpr) constr)
body = normalB (appE (varE 'B.unsafeFromRep) expr)
fun = funD cname [clause args body []]
mkConstrFields :: THDef -> THConstr -> [DecQ]
mkConstrFields def c = concat $ mapLayout (mkField def) c
mkField :: THDef -> THLayoutItem -> [DecQ]
mkField def (THLayoutField f pos) =
case thFieldName f of
Nothing -> []
Just name ->
[ sigD name ty
, valD (varP name) (normalB [| B.BitDataField $posE $lenE $nameE |]) []]
where
nameE = litE (stringL (nameBase name))
lenE = litE (integerL (fromIntegral (thFieldLen f)))
posE = litE (integerL (fromIntegral pos))
fty = return (thFieldType f)
ty = [t| B.BitDataField $(conT (thDefName def)) $fty |]
mkField _ _ = []