module Ivory.Opts.CFG
( callGraphDot
, SizeMap(..)
, defaultSizeMap
, hasLoop
, WithTop
, maxStack
)
where
import qualified Ivory.Language.Syntax.AST as I
import qualified Ivory.Language.Syntax.Type as I
import qualified Data.Graph.Inductive as G
import Prelude hiding (lookup)
import Data.Monoid
import System.FilePath
import Data.Maybe
import Data.List hiding (lookup)
import Control.Applicative
import qualified Data.IntMap as M
import MonadLib (StateT, get, set, Id, StateM, runM)
import MonadLib.Derive (derive_get, derive_set, Iso(..))
data WithTop a = Top | Val a
deriving (Eq, Functor)
instance Show a => Show (WithTop a) where
show Top = "Top"
show (Val a) = show a
instance Ord a => Ord (WithTop a) where
compare Top (Val _) = GT
compare (Val _) Top = LT
compare Top Top = EQ
compare (Val a) (Val b) = compare a b
instance Applicative WithTop where
pure a = Val a
Val _ <*> Top = Top
Val f <*> (Val a) = Val (f a)
_ <*> _ = Top
type Size = Integer
type CallNm = String
data StackType = TyStruct String
| TyArr StackType Size
| Ptr
| TyVoid
| TyInt IntSize
| TyWord WordSize
| TyBool
| TyChar
| TyFloat
| TyDouble
deriving (Show, Eq)
data IntSize = Int8
| Int16
| Int32
| Int64
deriving (Show,Eq)
data WordSize = Word8
| Word16
| Word32
| Word64
deriving (Show,Eq)
data (Show a, Eq a) => Block a
= Stmt a
| Branch [Block a] [Block a]
| Loop (Maybe Integer) [Block a]
deriving (Show, Eq)
type Control = Block CallNm
type Alloc = Block StackType
data ProcInfo = ProcInfo
{ procSym :: CallNm
, params :: [StackType]
, alloc :: [Alloc]
, calls :: [Control]
} deriving (Show, Eq)
data ModuleInfo = ModuleInfo
{ modName :: String
, procs :: [ProcInfo]
} deriving (Show, Eq)
toStackTyped :: I.Typed a -> StackType
toStackTyped ty = toStackType (I.tType ty)
toStackType :: I.Type -> StackType
toStackType ty =
case ty of
I.TyStruct nm -> TyStruct nm
I.TyArr i t -> TyArr (toStackType t) (fromIntegral i)
I.TyRef{} -> Ptr
I.TyConstRef{} -> Ptr
I.TyPtr{} -> Ptr
I.TyVoid -> TyVoid
I.TyInt i -> TyInt (toIntType i)
I.TyWord w -> TyWord (toWordType w)
I.TyBool -> TyBool
I.TyChar -> TyChar
I.TyFloat -> TyFloat
I.TyDouble -> TyDouble
t -> error $ "Unhandled stack type: " ++ show t
toIntType :: I.IntSize -> IntSize
toIntType i =
case i of
I.Int8 -> Int8
I.Int16 -> Int16
I.Int32 -> Int32
I.Int64 -> Int64
toWordType :: I.WordSize -> WordSize
toWordType i =
case i of
I.Word8 -> Word8
I.Word16 -> Word16
I.Word32 -> Word32
I.Word64 -> Word64
cfgProc :: I.Proc -> ProcInfo
cfgProc proc = ProcInfo
{ procSym = I.procSym proc
, params = map toStackTyped (I.procArgs proc)
, alloc = concatMap toAlloc (I.procBody proc)
, calls = concatMap toCall (I.procBody proc)
}
toAlloc :: I.Stmt -> [Alloc]
toAlloc stmt =
case stmt of
I.Assign ty _ _ -> [Stmt $ toStackType ty]
I.AllocRef ty _ _ -> [Stmt $ toStackType ty]
I.Deref ty _ _ -> [Stmt $ toStackType ty]
I.IfTE _ blk0 blk1 -> [ Branch (concatMap toAlloc blk0)
(concatMap toAlloc blk1) ]
I.Loop _ e _ blk ->
let ty = I.TyInt I.Int32 in
[Stmt (toStackType ty), Loop (getIdx e) (concatMap toAlloc blk)]
I.Forever blk ->
[Loop Nothing (concatMap toAlloc blk)]
_ -> []
toCall :: I.Stmt -> [Control]
toCall stmt =
case stmt of
I.IfTE _ blk0 blk1 -> [ Branch (concatMap toCall blk0)
(concatMap toCall blk1) ]
I.Call _ _ nm _ -> case nm of
I.NameSym sym -> [Stmt sym]
I.NameVar _ -> error $ "XXX need to implement function pointers."
I.Loop _ e _ blk -> [Loop (getIdx e) (concatMap toCall blk)]
_ -> []
getIdx :: I.Expr -> Maybe Integer
getIdx e = case e of
I.ExpLit (I.LitInteger i) -> Just i
_ -> Nothing
type CFG = G.Gr ProcInfo ()
flattenControl :: Control -> [CallNm]
flattenControl ctrl =
case ctrl of
Stmt str -> [str]
Branch ctrl0 ctrl1 ->
concatMap flattenControl ctrl0 ++ concatMap flattenControl ctrl1
Loop _ ctrl0 -> concatMap flattenControl ctrl0
cfgModule :: I.Module -> ModuleInfo
cfgModule m = ModuleInfo
{ modName = I.modName m
, procs = map cfgProc ps
}
where ps = I.public (I.modProcs m) ++ I.private (I.modProcs m)
cfg :: I.Module -> CFG
cfg m = G.insEdges (concatMap go nodes) $ G.insNodes nodes G.empty
where
nodes :: [G.LNode ProcInfo]
nodes = zip [0,1..] (procs $ cfgModule m)
go :: (Int, ProcInfo) -> [G.LEdge ()]
go (i,p) =
let outCalls = concatMap flattenControl (calls p) in
let outIdxs = catMaybes (map (lookup nodes) outCalls) in
zip3 (repeat i) outIdxs (repeat ())
lookup ls sym | [] <- ls = Nothing
| ((i,p):_) <- ls
, procSym p == sym = Just i
| (_:ls') <- ls = lookup ls' sym
| otherwise = error "Unreachable in cfg"
procSymGraph :: CFG -> G.Gr CallNm ()
procSymGraph = G.nmap procSym
hasLoop :: CFG -> Bool
hasLoop = G.hasLoop
data SizeMap = SizeMap
{ stackElemMap :: StackType -> Size
, retSize :: Size
}
defaultSizeMap :: SizeMap
defaultSizeMap = SizeMap
{ stackElemMap = const 1
, retSize = 1
}
type MaxMap = M.IntMap Size
newtype MaxState a = MaxState
{ unSt :: StateT MaxMap Id a
} deriving (Functor, Monad)
instance StateM MaxState MaxMap where
get = derive_get (Iso MaxState unSt)
set = derive_set (Iso MaxState unSt)
emptyMaxSt :: MaxMap
emptyMaxSt = M.empty
getMaxMap :: MaxState MaxMap
getMaxMap = return =<< get
maxStack :: CallNm -> CFG -> SizeMap -> WithTop Size
maxStack proc cf szmap = go proc
where
go p = fst $ runM (unSt (maxStack' cf szmap [] (findNode cf p))) emptyMaxSt
findNode :: CFG -> CallNm -> G.Node
findNode cf proc = fst $
fromMaybe (error $ "Proc " ++ proc ++ " is not in the graph!")
(find ((== proc) . procSym . snd) (G.labNodes cf))
maxStack' :: CFG -> SizeMap -> [G.Node] -> G.Node -> MaxState (WithTop Size)
maxStack' cf szmap visited curr
| curr `elem` visited = return Top
| otherwise = maxStackNode
where
maxStackNode :: MaxState (WithTop Size)
maxStackNode = do
blkMax <- goBlks cf szmap visited curr alloc' calls'
let sz = (topAllocSz + paramsSz + retSize szmap +) <$> blkMax
return sz
where
cxt = G.context cf curr
procInfo = G.lab' cxt
alloc' = alloc procInfo
calls' = calls procInfo
topAllocSz :: Size
topAllocSz = getSize szmap (getBlock alloc')
paramsSz :: Size
paramsSz = getSize szmap (params procInfo)
goBlks :: CFG -> SizeMap -> [G.Node] -> G.Node
-> [Alloc] -> [Control] -> MaxState (WithTop Size)
goBlks cf szmap visited curr acs cns =
case (acs, cns) of
([] ,[]) -> return (Val 0)
(Branch a0 a1:acs', Branch c0 c1:cns') -> do
sz0 <- goBlks' a0 c0
sz1 <- goBlks' a1 c1
sz2 <- goBlks' acs' cns'
return (liftA2 max sz0 sz1 <+> sz2)
(Loop _ a:acs', Loop _ c:cns') -> do
sz0 <- goBlks' a c
sz1 <- goBlks' acs' cns'
return (sz0 <+> sz1)
_ -> do
sz0 <- goBlk cf szmap visited curr (getBlock acs) (getBlock cns)
sz1 <- goBlks' (nxtBlock acs) (nxtBlock cns)
return (sz0 <+> sz1)
where goBlks' = goBlks cf szmap visited curr
goBlk :: CFG -> SizeMap -> [G.Node] -> G.Node -> [StackType]
-> [CallNm] -> MaxState (WithTop Size)
goBlk cf szmap visited curr acs cns = do
maxMp <- getMaxMap
let localAlloc = getSize szmap acs
let callNodes = map (findNode cf) cns
let allCalls = zip callNodes (map (flip M.lookup $ maxMp) callNodes)
newMaxs <- mapM cachedCalls allCalls
return $ Val localAlloc <+> if null newMaxs then Val 0
else maximum newMaxs
where
cachedCalls :: (G.Node, Maybe Size) -> MaxState (WithTop Size)
cachedCalls (n, msz) | Just sz <- msz = return (Val sz)
| otherwise =
maxStack' cf szmap (curr:visited) n
(<+>) :: Num a => WithTop a -> WithTop a -> WithTop a
(<+>) = liftA2 (+)
getSize :: SizeMap -> [StackType] -> Size
getSize szmap = sum . map (stackElemMap szmap)
getBlock :: (Show a, Eq a) => [Block a] -> [a]
getBlock bls | (Stmt b:bs) <- bls = b : getBlock bs
| otherwise = []
nxtBlock :: (Show a, Eq a) => [Block a] -> [Block a]
nxtBlock bls | (Stmt _:bs) <- bls = nxtBlock bs
| otherwise = bls
callGraphDot :: CallNm -> FilePath -> [I.Module] -> IO CFG
callGraphDot proc path mods =
writeFile (path </> proc `addExtension` "dot") grOut >> return graph
where
m = mconcat mods
grOut = graphviz filterG (I.modName m)
filterG :: G.Gr CallNm ()
filterG = let closure = G.reachable (findNode graph proc) graph in
G.delNodes (G.nodes graph \\ closure) (procSymGraph graph)
graph = cfg m
graphviz :: (G.Graph g, Show a, Show b)
=> g a b
-> String
-> String
graphviz g t =
let n = G.labNodes g
e = G.labEdges g
ns = concatMap sn n
es = concatMap se e
in "digraph "++t++" {\n"
++ ns
++ es
++"}"
where sn (n, a) | sa == "" = ""
| otherwise = '\t':(show n ++ sa ++ "\n")
where sa = sl a
se (n1, n2, b) = '\t':(show n1 ++ " -> " ++ show n2 ++ sl b ++ "\n")
sl :: (Show a) => a -> String
sl a = let l = sq (show a)
in if (l /= "()") then (" [label = \""++l++"\"]") else ""
sq :: String -> String
sq s@[_] = s
sq ('"':s) | last s == '"' = init s
| otherwise = s
sq ('\'':s) | last s == '\'' = init s
| otherwise = s
sq s = s