module Language.C.Monad
where
import Lens.Micro
import Lens.Micro.Mtl
import Lens.Micro.TH
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Exception
import Language.C.Quote.C
import qualified Language.C.Syntax as C
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Monoid
import Text.PrettyPrint.Mainland
import Data.Loc
import Data.List (partition,nub)
data Flags = Flags
data CEnv = CEnv
{ _flags :: Flags
, _unique :: !Integer
, _modules :: Map.Map String [C.Definition]
, _includes :: Set.Set String
, _typedefs :: [C.Definition]
, _prototypes :: [C.Definition]
, _globals :: [C.Definition]
, _aliases :: Map.Map Integer String
, _params :: [C.Param]
, _args :: [C.Exp]
, _locals :: [C.InitGroup]
, _stms :: [C.Stm]
, _finalStms :: [C.Stm]
, _usedVars :: Set.Set C.Id
, _funUsedVars :: Map.Map String (Set.Set C.Id)
}
makeLenses ''CEnv
(<<%=) :: MonadState s m =>
(forall f . Functor f => LensLike' f s a) -> (a -> a) -> m a
l <<%= f = do
s <- get
l %= f
return (s ^. l)
(<<.=) :: MonadState s m =>
(forall f . Functor f => LensLike' f s a) -> a -> m a
l <<.= f = do
s <- get
l .= f
return (s ^. l)
defaultCEnv :: Flags -> CEnv
defaultCEnv fl = CEnv
{ _flags = fl
, _unique = 0
, _modules = mempty
, _includes = mempty
, _typedefs = mempty
, _prototypes = mempty
, _globals = mempty
, _aliases = mempty
, _params = mempty
, _args = mempty
, _locals = mempty
, _stms = mempty
, _finalStms = mempty
, _usedVars = mempty
, _funUsedVars = mempty
}
type MonadC m = (Functor m, Applicative m, Monad m, MonadState CEnv m, MonadException m, MonadFix m)
newtype CGenT t a = CGenT { unCGenT :: StateT CEnv (ExceptionT t) a }
deriving (Functor, Applicative, Monad, MonadException, MonadState CEnv, MonadIO, MonadFix)
type CGen = CGenT Identity
runCGenT :: Monad m => CGenT m a -> CEnv -> m (a, CEnv)
runCGenT m s = do
Right ac <- runExceptionT (runStateT (unCGenT m) s)
return ac
runCGen :: CGen a -> CEnv -> (a, CEnv)
runCGen m = runIdentity . runCGenT m
cenvToCUnit :: CEnv -> [C.Definition]
cenvToCUnit env =
[cunit|$edecls:incs
$edecls:tds
$edecls:protos
$edecls:globs|]
where
incs = map toInclude (Set.toList (_includes env))
where
toInclude :: String -> C.Definition
toInclude inc = [cedecl|$esc:include|]
where include = "#include " ++ inc
tds = nub $ reverse $ _typedefs env
protos = nub $ reverse $ _prototypes env
globs = nub $ reverse $ _globals env
prettyCGenT :: Monad m => CGenT m a -> m [(String, Doc)]
prettyCGenT ma = do
(_,cenv) <- runCGenT ma (defaultCEnv Flags)
return $ map (("", ppr) <*>)
$ ("main", cenvToCUnit cenv) : Map.toList (_modules cenv)
prettyCGen :: CGen a -> [(String, Doc)]
prettyCGen = runIdentity . prettyCGenT
freshId :: MonadC m => m Integer
freshId = unique <<%= succ
gensym :: MonadC m => String -> m String
gensym s = do
u <- freshId
return $ s ++ show u
touchVar :: (MonadC m, ToIdent v) => v -> m ()
touchVar v = usedVars %= Set.insert (toIdent v (SrcLoc NoLoc))
setUsedVars :: MonadC m => String -> Set.Set C.Id -> m ()
setUsedVars fun uvs = funUsedVars %= Map.insert fun uvs
addInclude :: MonadC m => String -> m ()
addInclude inc = includes %= Set.insert inc
addLocalInclude :: MonadC m => String -> m ()
addLocalInclude inc = addInclude ("\"" ++ inc ++ "\"")
addSystemInclude :: MonadC m => String -> m ()
addSystemInclude inc = addInclude ("<" ++ inc ++ ">")
addTypedef :: MonadC m => C.Definition -> m ()
addTypedef def = typedefs %= (def:)
addPrototype :: MonadC m => C.Definition -> m ()
addPrototype def = prototypes %= (def:)
addGlobal :: MonadC m => C.Definition -> m ()
addGlobal def = globals %= (def:)
addGlobals :: MonadC m => [C.Definition] -> m ()
addGlobals defs = globals %= (defs++)
withAlias :: MonadC m => Integer -> String -> m a -> m a
withAlias i n act = do
oldAliases <- aliases <<%= Map.insert i n
a <- act
aliases .= oldAliases
return a
addParam :: MonadC m => C.Param -> m ()
addParam param = params %= (param:)
addParams :: MonadC m => [C.Param] -> m ()
addParams ps = params %= (reverse ps++)
addArg :: MonadC m => C.Exp -> m ()
addArg arg = args %= (arg:)
addLocal :: MonadC m => C.InitGroup -> m ()
addLocal def = do
locals %= (def:)
case def of
C.InitGroup _ _ is _ -> forM_ is $ \(C.Init id _ _ _ _ _) -> touchVar id
_ -> return ()
addLocals :: MonadC m => [C.InitGroup] -> m ()
addLocals defs = mapM_ addLocal defs
addStm :: MonadC m => C.Stm -> m ()
addStm stm = stms %= (stm:)
addStms :: MonadC m => [C.Stm] -> m ()
addStms ss = stms %= (reverse ss++)
addFinalStm :: MonadC m => C.Stm -> m ()
addFinalStm stm = finalStms %= (stm:)
inBlock :: MonadC m => m a -> m a
inBlock ma = do
(a, items) <- inNewBlock ma
addStm [cstm|{ $items:items }|]
return a
inNewBlock :: MonadC m => m a -> m (a, [C.BlockItem])
inNewBlock ma = do
oldLocals <- locals <<.= mempty
oldStms <- stms <<.= mempty
oldFinalStms <- finalStms <<.= mempty
x <- ma
ls <- reverse <$> (locals <<.= oldLocals)
ss <- reverse <$> (stms <<.= oldStms)
fss <- reverse <$> (finalStms <<.= oldFinalStms)
return (x, map C.BlockDecl ls ++
map C.BlockStm ss ++
map C.BlockStm fss
)
inNewBlock_ :: MonadC m => m a -> m [C.BlockItem]
inNewBlock_ ma = snd <$> inNewBlock ma
inNewFunction :: MonadC m => m a -> m (a,Set.Set C.Id,[C.Param],[C.BlockItem])
inNewFunction comp = do
oldParams <- params <<.= mempty
oldUsedVars <- usedVars <<.= mempty
(a,items) <- inNewBlock comp
ps <- params <<.= oldParams
uvs <- usedVars <<.= oldUsedVars
return (a, uvs, reverse ps, items)
inFunction :: MonadC m => String -> m a -> m a
inFunction = inFunctionTy [cty|void|]
inFunctionTy :: MonadC m => C.Type -> String -> m a -> m a
inFunctionTy ty fun ma = do
(a,uvs,ps,items) <- inNewFunction ma
setUsedVars fun uvs
addPrototype [cedecl| $ty:ty $id:fun($params:ps);|]
addGlobal [cedecl| $ty:ty $id:fun($params:ps){ $items:items }|]
return a
collectDefinitions :: MonadC m => m a -> m (a, [C.Definition])
collectDefinitions ma = do
oldIncludes <- includes <<.= mempty
oldTypedefs <- typedefs <<.= mempty
oldPrototypes <- prototypes <<.= mempty
oldGlobals <- globals <<.= mempty
a <- ma
s' <- get
modify $ \s -> s { _includes = oldIncludes
, _typedefs = oldTypedefs
, _prototypes = oldPrototypes
, _globals = oldGlobals
}
return (a, cenvToCUnit s')
collectArgs :: MonadC m => m [C.Exp]
collectArgs = args <<.= mempty
inModule :: MonadC m => String -> m a -> m a
inModule name prg = do
oldUnique <- unique <<.= 0
(a, defs) <- collectDefinitions prg
unique .= oldUnique
modules %= Map.insertWith (<>) name defs
return a
wrapMain :: MonadC m => m a -> m ()
wrapMain prog = do
(_,uvs,params,items) <- inNewFunction $ prog >> addStm [cstm| return 0; |]
setUsedVars "main" uvs
addGlobal [cedecl| int main($params:params){ $items:items }|]
liftSharedLocals :: MonadC m => m a -> m ()
liftSharedLocals prog = do
prog
uvs <- Set.unions . Map.elems . onlyShared . _funUsedVars <$> get
oldglobs <- _globals <$> get
let (globs, shared) = unzip $ map (extractDecls (`Set.member` uvs)) oldglobs
sharedList = Set.toList $ Set.unions shared
sharedDecls = map (\ig -> C.DecDef ig (SrcLoc NoLoc)) sharedList
void $ globals <<.= (globs ++ reverse sharedDecls)
where
onlyShared :: Map.Map String (Set.Set C.Id) -> Map.Map String (Set.Set C.Id)
onlyShared alluvs =
Map.mapWithKey funUVSIntersects alluvs
where
funUVSIntersects fun uvs =
Set.intersection uvs $ Set.unions $ Map.elems $ Map.delete fun alluvs
extractDecls :: (C.Id -> Bool)
-> C.Definition
-> (C.Definition, Set.Set C.InitGroup)
extractDecls pred (C.FuncDef (C.Func ds id decl params bis loc') loc) =
case foldr perBI ([], Set.empty) bis of
(bis', igs) -> (C.FuncDef (C.Func ds id decl params bis' loc') loc, igs)
where
perBI decl@(C.BlockDecl ig@(C.InitGroup ds attrs is loc)) (bis, igs) =
case partition (\(C.Init id _ _ _ _ _) -> pred id) is of
([], unmach) ->
(decl : bis, igs)
(match, []) ->
(bis, Set.insert ig igs)
(match, unmatch) ->
(C.BlockDecl (C.InitGroup ds attrs unmatch loc) : bis,
Set.insert (C.InitGroup ds attrs match loc) igs)
perBI bi (bis, igs) =
(bi:bis, igs)
extractDecls _ decl =
(decl, Set.empty)