-- | This module exports functionality for generating a call graph of
-- an Futhark program.
module Futhark.Analysis.CallGraph
  ( CallGraph,
    buildCallGraph,
    isFunInCallGraph,
    calls,
    calledByConsts,
    allCalledBy,
    numOccurences,
  )
where

import Control.Monad.Writer.Strict
import Data.List (foldl')
import Data.Map.Strict qualified as M
import Data.Maybe (isJust)
import Data.Set qualified as S
import Futhark.IR.SOACS
import Futhark.Util.Pretty

type FunctionTable = M.Map Name (FunDef SOACS)

buildFunctionTable :: Prog SOACS -> FunctionTable
buildFunctionTable :: Prog SOACS -> FunctionTable
buildFunctionTable = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {k} {rep :: k}.
Map Name (FunDef rep) -> FunDef rep -> Map Name (FunDef rep)
expand forall k a. Map k a
M.empty forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns
  where
    expand :: Map Name (FunDef rep) -> FunDef rep -> Map Name (FunDef rep)
expand Map Name (FunDef rep)
ftab FunDef rep
f = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall {k} (rep :: k). FunDef rep -> Name
funDefName FunDef rep
f) FunDef rep
f Map Name (FunDef rep)
ftab

-- | A unique (at least within a function) name identifying a function
-- call.  In practice the first element of the corresponding pattern.
type CallId = VName

data FunCalls = FunCalls
  { FunCalls -> Map CallId (Attrs, Name)
fcMap :: M.Map CallId (Attrs, Name),
    FunCalls -> Set Name
fcAllCalled :: S.Set Name
  }
  deriving (FunCalls -> FunCalls -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FunCalls -> FunCalls -> Bool
$c/= :: FunCalls -> FunCalls -> Bool
== :: FunCalls -> FunCalls -> Bool
$c== :: FunCalls -> FunCalls -> Bool
Eq, Eq FunCalls
FunCalls -> FunCalls -> Bool
FunCalls -> FunCalls -> Ordering
FunCalls -> FunCalls -> FunCalls
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FunCalls -> FunCalls -> FunCalls
$cmin :: FunCalls -> FunCalls -> FunCalls
max :: FunCalls -> FunCalls -> FunCalls
$cmax :: FunCalls -> FunCalls -> FunCalls
>= :: FunCalls -> FunCalls -> Bool
$c>= :: FunCalls -> FunCalls -> Bool
> :: FunCalls -> FunCalls -> Bool
$c> :: FunCalls -> FunCalls -> Bool
<= :: FunCalls -> FunCalls -> Bool
$c<= :: FunCalls -> FunCalls -> Bool
< :: FunCalls -> FunCalls -> Bool
$c< :: FunCalls -> FunCalls -> Bool
compare :: FunCalls -> FunCalls -> Ordering
$ccompare :: FunCalls -> FunCalls -> Ordering
Ord, Int -> FunCalls -> ShowS
[FunCalls] -> ShowS
FunCalls -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FunCalls] -> ShowS
$cshowList :: [FunCalls] -> ShowS
show :: FunCalls -> String
$cshow :: FunCalls -> String
showsPrec :: Int -> FunCalls -> ShowS
$cshowsPrec :: Int -> FunCalls -> ShowS
Show)

instance Monoid FunCalls where
  mempty :: FunCalls
mempty = Map CallId (Attrs, Name) -> Set Name -> FunCalls
FunCalls forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

instance Semigroup FunCalls where
  FunCalls Map CallId (Attrs, Name)
x1 Set Name
y1 <> :: FunCalls -> FunCalls -> FunCalls
<> FunCalls Map CallId (Attrs, Name)
x2 Set Name
y2 = Map CallId (Attrs, Name) -> Set Name -> FunCalls
FunCalls (Map CallId (Attrs, Name)
x1 forall a. Semigroup a => a -> a -> a
<> Map CallId (Attrs, Name)
x2) (Set Name
y1 forall a. Semigroup a => a -> a -> a
<> Set Name
y2)

fcCalled :: Name -> FunCalls -> Bool
fcCalled :: Name -> FunCalls -> Bool
fcCalled Name
f FunCalls
fcs = Name
f forall a. Ord a => a -> Set a -> Bool
`S.member` FunCalls -> Set Name
fcAllCalled FunCalls
fcs

type FunGraph = M.Map Name FunCalls

-- | The call graph is a mapping from a function name, i.e., the
-- caller, to a record of the names of functions called *directly* (not
-- transitively!) by the function.
--
-- We keep track separately of the functions called by constants.
data CallGraph = CallGraph
  { CallGraph -> FunGraph
cgCalledByFuns :: FunGraph,
    CallGraph -> FunCalls
cgCalledByConsts :: FunCalls
  }
  deriving (CallGraph -> CallGraph -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CallGraph -> CallGraph -> Bool
$c/= :: CallGraph -> CallGraph -> Bool
== :: CallGraph -> CallGraph -> Bool
$c== :: CallGraph -> CallGraph -> Bool
Eq, Eq CallGraph
CallGraph -> CallGraph -> Bool
CallGraph -> CallGraph -> Ordering
CallGraph -> CallGraph -> CallGraph
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: CallGraph -> CallGraph -> CallGraph
$cmin :: CallGraph -> CallGraph -> CallGraph
max :: CallGraph -> CallGraph -> CallGraph
$cmax :: CallGraph -> CallGraph -> CallGraph
>= :: CallGraph -> CallGraph -> Bool
$c>= :: CallGraph -> CallGraph -> Bool
> :: CallGraph -> CallGraph -> Bool
$c> :: CallGraph -> CallGraph -> Bool
<= :: CallGraph -> CallGraph -> Bool
$c<= :: CallGraph -> CallGraph -> Bool
< :: CallGraph -> CallGraph -> Bool
$c< :: CallGraph -> CallGraph -> Bool
compare :: CallGraph -> CallGraph -> Ordering
$ccompare :: CallGraph -> CallGraph -> Ordering
Ord, Int -> CallGraph -> ShowS
[CallGraph] -> ShowS
CallGraph -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CallGraph] -> ShowS
$cshowList :: [CallGraph] -> ShowS
show :: CallGraph -> String
$cshow :: CallGraph -> String
showsPrec :: Int -> CallGraph -> ShowS
$cshowsPrec :: Int -> CallGraph -> ShowS
Show)

-- | Is the given function known to the call graph?
isFunInCallGraph :: Name -> CallGraph -> Bool
isFunInCallGraph :: Name -> CallGraph -> Bool
isFunInCallGraph Name
f = forall k a. Ord k => k -> Map k a -> Bool
M.member Name
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. CallGraph -> FunGraph
cgCalledByFuns

-- | Does the first function call the second?
calls :: Name -> Name -> CallGraph -> Bool
calls :: Name -> Name -> CallGraph -> Bool
calls Name
caller Name
callee =
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Name -> FunCalls -> Bool
fcCalled Name
callee) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
caller forall b c a. (b -> c) -> (a -> b) -> a -> c
. CallGraph -> FunGraph
cgCalledByFuns

-- | Is the function called in any of the constants?
calledByConsts :: Name -> CallGraph -> Bool
calledByConsts :: Name -> CallGraph -> Bool
calledByConsts Name
callee = Name -> FunCalls -> Bool
fcCalled Name
callee forall b c a. (b -> c) -> (a -> b) -> a -> c
. CallGraph -> FunCalls
cgCalledByConsts

-- | All functions called by this function.
allCalledBy :: Name -> CallGraph -> S.Set Name
allCalledBy :: Name -> CallGraph -> Set Name
allCalledBy Name
f = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty FunCalls -> Set Name
fcAllCalled forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. CallGraph -> FunGraph
cgCalledByFuns

-- | @buildCallGraph prog@ build the program's call graph.
buildCallGraph :: Prog SOACS -> CallGraph
buildCallGraph :: Prog SOACS -> CallGraph
buildCallGraph Prog SOACS
prog =
  FunGraph -> FunCalls -> CallGraph
CallGraph FunGraph
fg FunCalls
cg
  where
    fg :: FunGraph
fg = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (FunctionTable -> FunGraph -> Name -> FunGraph
buildFGfun FunctionTable
ftable) forall k a. Map k a
M.empty Set Name
entry_points
    cg :: FunCalls
cg = Stms SOACS -> FunCalls
buildFGStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog SOACS
prog

    entry_points :: Set Name
entry_points =
      forall a. Ord a => [a] -> Set a
S.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). FunDef rep -> Name
funDefName (forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). FunDef rep -> Maybe EntryPoint
funDefEntryPoint) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog))
        forall a. Semigroup a => a -> a -> a
<> FunCalls -> Set Name
fcAllCalled FunCalls
cg
    ftable :: FunctionTable
ftable = Prog SOACS -> FunctionTable
buildFunctionTable Prog SOACS
prog

count :: Ord k => [k] -> M.Map k Int
count :: forall k. Ord k => [k] -> Map k Int
count [k]
ks = forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith forall a. Num a => a -> a -> a
(+) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [k]
ks forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Int
1

-- | Produce a mapping of the number of occurences in the call graph
-- of each function.  Only counts functions that are called at least
-- once.
numOccurences :: CallGraph -> M.Map Name Int
numOccurences :: CallGraph -> Map Name Int
numOccurences (CallGraph FunGraph
funs FunCalls
consts) =
  forall k. Ord k => [k] -> Map k Int
count forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems (FunCalls -> Map CallId (Attrs, Name)
fcMap FunCalls
consts forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap FunCalls -> Map CallId (Attrs, Name)
fcMap (forall k a. Map k a -> [a]
M.elems FunGraph
funs))

-- | @buildCallGraph ftable fg fname@ updates @fg@ with the
-- contributions of function @fname@.
buildFGfun :: FunctionTable -> FunGraph -> Name -> FunGraph
buildFGfun :: FunctionTable -> FunGraph -> Name -> FunGraph
buildFGfun FunctionTable
ftable FunGraph
fg Name
fname =
  -- Check if function is a non-builtin that we have not already
  -- processed.
  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
fname FunctionTable
ftable of
    Just FunDef SOACS
f | Maybe FunCalls
Nothing <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
fname FunGraph
fg -> do
      let callees :: FunCalls
callees = Body SOACS -> FunCalls
buildFGBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef SOACS
f
          fg' :: FunGraph
fg' = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
fname FunCalls
callees FunGraph
fg
      -- recursively build the callees
      forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (FunctionTable -> FunGraph -> Name -> FunGraph
buildFGfun FunctionTable
ftable) FunGraph
fg' forall a b. (a -> b) -> a -> b
$ FunCalls -> Set Name
fcAllCalled FunCalls
callees
    Maybe (FunDef SOACS)
_ -> FunGraph
fg

buildFGStms :: Stms SOACS -> FunCalls
buildFGStms :: Stms SOACS -> FunCalls
buildFGStms = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Stm SOACS -> FunCalls
buildFGstm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList

buildFGBody :: Body SOACS -> FunCalls
buildFGBody :: Body SOACS -> FunCalls
buildFGBody = Stms SOACS -> FunCalls
buildFGStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms

buildFGstm :: Stm SOACS -> FunCalls
buildFGstm :: Stm SOACS -> FunCalls
buildFGstm (Let (Pat (PatElem (LetDec SOACS)
p : [PatElem (LetDec SOACS)]
_)) StmAux (ExpDec SOACS)
aux (Apply Name
fname [(SubExp, Diet)]
_ [RetType SOACS]
_ (Safety, SrcLoc, [SrcLoc])
_)) =
  Map CallId (Attrs, Name) -> Set Name -> FunCalls
FunCalls (forall k a. k -> a -> Map k a
M.singleton (forall dec. PatElem dec -> CallId
patElemName PatElem (LetDec SOACS)
p) (forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux, Name
fname)) (forall a. a -> Set a
S.singleton Name
fname)
buildFGstm (Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
_ (Op Op SOACS
op)) = forall w a. Writer w a -> w
execWriter forall a b. (a -> b) -> a -> b
$ forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS SOACS (WriterT FunCalls Identity)
folder Op SOACS
op
  where
    folder :: SOACMapper SOACS SOACS (WriterT FunCalls Identity)
folder =
      forall {k} (m :: * -> *) (rep :: k).
Monad m =>
SOACMapper rep rep m
identitySOACMapper
        { mapOnSOACLambda :: Lambda SOACS -> WriterT FunCalls Identity (Lambda SOACS)
mapOnSOACLambda = \Lambda SOACS
lam -> do
            forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ Body SOACS -> FunCalls
buildFGBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
            forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
        }
buildFGstm (Let Pat (LetDec SOACS)
_ StmAux (ExpDec SOACS)
_ Exp SOACS
e) = forall w a. Writer w a -> w
execWriter forall a b. (a -> b) -> a -> b
$ forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS (WriterT FunCalls Identity)
folder Exp SOACS
e
  where
    folder :: Mapper SOACS SOACS (WriterT FunCalls Identity)
folder =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope SOACS -> Body SOACS -> WriterT FunCalls Identity (Body SOACS)
mapOnBody = \Scope SOACS
_ Body SOACS
body -> do
            forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ Body SOACS -> FunCalls
buildFGBody Body SOACS
body
            forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
body
        }

instance Pretty FunCalls where
  pretty :: forall ann. FunCalls -> Doc ann
pretty = forall a. [Doc a] -> Doc a
stack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a} {a} {ann}.
(Pretty a, Pretty a, Pretty a) =>
(a, (a, a)) -> Doc ann
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunCalls -> Map CallId (Attrs, Name)
fcMap
    where
      f :: (a, (a, a)) -> Doc ann
f (a
x, (a
attrs, a
y)) = Doc ann
"=>" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
y forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
parens (Doc ann
"at" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
x forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty a
attrs)

instance Pretty CallGraph where
  pretty :: forall ann. CallGraph -> Doc ann
pretty (CallGraph FunGraph
fg FunCalls
cg) =
    forall a. [Doc a] -> Doc a
stack forall a b. (a -> b) -> a -> b
$
      forall ann. Doc ann -> [Doc ann] -> [Doc ann]
punctuate forall ann. Doc ann
line forall a b. (a -> b) -> a -> b
$
        forall {a} {a}. Pretty a => (Name, a) -> Doc a
ppFunCalls (Name
"called at top level", FunCalls
cg) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a}. Pretty a => (Name, a) -> Doc a
ppFunCalls (forall k a. Map k a -> [(k, a)]
M.toList FunGraph
fg)
    where
      ppFunCalls :: (Name, a) -> Doc a
ppFunCalls (Name
f, a
fcalls) =
        forall a ann. Pretty a => a -> Doc ann
pretty Name
f
          forall ann. Doc ann -> Doc ann -> Doc ann
</> forall a ann. Pretty a => a -> Doc ann
pretty (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Char
'=') (Name -> String
nameToString Name
f))
          forall ann. Doc ann -> Doc ann -> Doc ann
</> forall ann. Int -> Doc ann -> Doc ann
indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty a
fcalls)