module Language.HERMIT.Primitive.Common
( altFreeVarsT
, bindings
, bindingVarsT
, caseAltVarsT
, caseBinderVarT
, letVarsT
, wrongExprForm
) where
import GhcPlugins
import Control.Arrow
import Data.List
import Data.Monoid
import Language.HERMIT.Kure
import Language.HERMIT.Primitive.GHC
class BindEnv a where
bindings :: a -> [Id]
instance BindEnv CoreBind where
bindings (NonRec b _) = [b]
bindings (Rec bs) = map fst bs
instance BindEnv CoreAlt where
bindings (_,vs,_) = vs
instance BindEnv CoreExpr where
bindings (Lam b _) = [b]
bindings (Let bs _) = bindings bs
bindings (Case _ sc _ alts) = sc : (nub (concat (map bindings alts)))
bindings _ = []
instance BindEnv CoreProgram where
bindings prog = nub (concat (map bindings prog))
instance BindEnv CoreDef where
bindings (Def b _) = [b]
bindingVarsT :: TranslateH Core [Var]
bindingVarsT = translate $ \ c core -> case core of
ModGutsCore _ -> fail "Cannot get binding vars at topmost level"
ProgramCore x -> apply (promoteT ((arr bindings) :: TranslateH CoreProgram [Var])) c x
BindCore x -> apply (promoteT ((arr bindings) :: TranslateH CoreBind [Var])) c x
DefCore x -> apply (promoteT ((arr bindings) :: TranslateH CoreDef [Var])) c x
ExprCore x -> apply (promoteT ((arr bindings) :: TranslateH CoreExpr [Var])) c x
AltCore x -> apply (promoteT ((arr bindings) :: TranslateH CoreAlt [Var])) c x
letVarsT :: TranslateH CoreExpr [Var]
letVarsT = do Let bs _ <- idR
return (bindings bs)
caseAltVarsT :: TranslateH CoreExpr [[Id]]
caseAltVarsT = caseT mempty (const (extractT bindingVarsT)) $ \ () _ _ vs -> vs
caseAltVarsWithBinderT :: TranslateH CoreExpr [[Id]]
caseAltVarsWithBinderT = caseT mempty (const (extractT bindingVarsT)) $ \ () v _ vs -> map (v:) vs
caseBinderVarT :: TranslateH CoreExpr [Id]
caseBinderVarT = setFailMsg "Not a Case expression." $
do Case _ b _ _ <- idR
return [b]
altFreeVarsT :: TranslateH CoreAlt (Id -> [Var])
altFreeVarsT = altT freeVarsT $ \ _con ids frees coreBndr -> nub frees \\ nub (coreBndr : ids)
wrongExprForm :: String -> String
wrongExprForm form = "Expression does not have the form: " ++ form