{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

-- | Rendering takes a callgraph, and produces a dot file
module Calligraphy.Phases.Render.GraphViz
  ( GraphVizConfig,
    pGraphVizConfig,
    renderGraphViz,
  )
where

import Calligraphy.Phases.Render.Common
import Calligraphy.Prelude hiding (DeclType)
import Calligraphy.Util.Printer
import Calligraphy.Util.Types
import Data.List (intercalate)
import Data.Maybe (catMaybes)
import Data.Tree (Tree)
import qualified Data.Tree as Tree
import Options.Applicative hiding (style)
import Text.Show (showListWith)

data GraphVizConfig = GraphVizConfig
  { GraphVizConfig -> Bool
showChildArrowhead :: Bool,
    GraphVizConfig -> Bool
clusterGroups :: Bool,
    GraphVizConfig -> Bool
splines :: Bool,
    GraphVizConfig -> Bool
reverseDependencyRank :: Bool
  }

pGraphVizConfig :: Parser GraphVizConfig
pGraphVizConfig :: Parser GraphVizConfig
pGraphVizConfig =
  Bool -> Bool -> Bool -> Bool -> GraphVizConfig
GraphVizConfig
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> a -> Mod FlagFields a -> Parser a
flag Bool
False Bool
True (forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"show-child-arrowhead" forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Put an arrowhead at the end of a parent-child edge")
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> a -> Mod FlagFields a -> Parser a
flag Bool
True Bool
False (forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"no-cluster-trees" forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Don't draw definition trees as a cluster.")
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> a -> Mod FlagFields a -> Parser a
flag Bool
True Bool
False (forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"no-splines" forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Render arrows as straight lines instead of splines")
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> a -> Mod FlagFields a -> Parser a
flag Bool
False Bool
True (forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"reverse-dependency-rank" forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Make dependencies have lower rank than the dependee, i.e. show dependencies above their parent.")

renderGraphViz :: GraphVizConfig -> Prints RenderGraph
renderGraphViz :: GraphVizConfig -> Prints RenderGraph
renderGraphViz GraphVizConfig {Bool
reverseDependencyRank :: Bool
splines :: Bool
clusterGroups :: Bool
showChildArrowhead :: Bool
reverseDependencyRank :: GraphVizConfig -> Bool
splines :: GraphVizConfig -> Bool
clusterGroups :: GraphVizConfig -> Bool
showChildArrowhead :: GraphVizConfig -> Bool
..} (RenderGraph Either (NonEmpty RenderModule) (NonEmpty (Tree RenderNode))
roots Set (String, String)
calls Set (String, String)
types) = do
  forall a. String -> String -> Printer a -> Printer a
brack String
"digraph calligraphy {" String
"}" forall a b. (a -> b) -> a -> b
$ do
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
splines forall a b. (a -> b) -> a -> b
$ Text -> Printer ()
textLn Text
"splines=false;"
    Text -> Printer ()
textLn Text
"node [style=filled fillcolor=\"#ffffffcf\"];"
    Text -> Printer ()
textLn Text
"graph [outputorder=edgesfirst];"
    case Either (NonEmpty RenderModule) (NonEmpty (Tree RenderNode))
roots of
      Left NonEmpty RenderModule
modules -> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Prints RenderModule
printModule NonEmpty RenderModule
modules
      Right NonEmpty (Tree RenderNode)
trees -> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Prints (Tree RenderNode)
printTree NonEmpty (Tree RenderNode)
trees
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Set (String, String)
calls forall a b. (a -> b) -> a -> b
$ \(String
caller, String
callee) ->
      if Bool
reverseDependencyRank
        then String -> String -> Attributes -> Printer ()
edge String
caller String
callee []
        else String -> String -> Attributes -> Printer ()
edge String
callee String
caller [String
"dir" String -> String -> (String, String)
.= String
"back"]
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Set (String, String)
types forall a b. (a -> b) -> a -> b
$ \(String
caller, String
callee) ->
      if Bool
reverseDependencyRank
        then String -> String -> Attributes -> Printer ()
edge String
caller String
callee [String
"style" String -> String -> (String, String)
.= String
"dotted"]
        else String -> String -> Attributes -> Printer ()
edge String
callee String
caller [String
"style" String -> String -> (String, String)
.= String
"dotted", String
"dir" String -> String -> (String, String)
.= String
"back"]
  where
    printTree :: Prints (Tree RenderNode)
    printTree :: Prints (Tree RenderNode)
printTree (Tree.Node RenderNode
nodeInfo [Tree RenderNode]
children) = forall {a}. Printer a -> Printer a
wrapCluster forall a b. (a -> b) -> a -> b
$ do
      Prints RenderNode
printNode RenderNode
nodeInfo
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Tree RenderNode]
children forall a b. (a -> b) -> a -> b
$ \child :: Tree RenderNode
child@(Tree.Node RenderNode
childInfo [Tree RenderNode]
_) -> do
        Prints (Tree RenderNode)
printTree Tree RenderNode
child
        String -> String -> Attributes -> Printer ()
edge (RenderNode -> String
nodeId RenderNode
nodeInfo) (RenderNode -> String
nodeId RenderNode
childInfo) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$
          [ forall (f :: * -> *) a. Applicative f => a -> f a
pure (String
"style" String -> String -> (String, String)
.= String
"dashed"),
            forall a. Bool -> a -> Maybe a
if' (Bool -> Bool
not Bool
showChildArrowhead) (String
"arrowhead" String -> String -> (String, String)
.= String
"none")
          ]
      where
        wrapCluster :: Printer a -> Printer a
wrapCluster Printer a
inner
          | Bool
clusterGroups Bool -> Bool -> Bool
&& Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Tree RenderNode]
children) = forall a. String -> String -> Printer a -> Printer a
brack (String
"subgraph cluster_" forall a. Semigroup a => a -> a -> a
<> RenderNode -> String
nodeId RenderNode
nodeInfo forall a. Semigroup a => a -> a -> a
<> String
" {") String
"}" forall a b. (a -> b) -> a -> b
$ do
              Text -> Printer ()
textLn Text
"style=invis;"
              Printer a
inner
          | Bool
otherwise = Printer a
inner

    printModule :: Prints RenderModule
    printModule :: Prints RenderModule
printModule (RenderModule String
lbl String
modId NonEmpty (Tree RenderNode)
trees) =
      forall a. String -> String -> Printer a -> Printer a
brack (String
"subgraph cluster_module_" forall a. Semigroup a => a -> a -> a
<> String
modId forall a. Semigroup a => a -> a -> a
<> String
" {") String
"}" forall a b. (a -> b) -> a -> b
$ do
        String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$ String
"label=" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show String
lbl forall a. Semigroup a => a -> a -> a
<> String
";"
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ NonEmpty (Tree RenderNode)
trees Prints (Tree RenderNode)
printTree

    printNode :: Prints RenderNode
    printNode :: Prints RenderNode
printNode (RenderNode String
nId DeclType
typ [String]
lbll Bool
exported) =
      String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$ String
nId forall a. Semigroup a => a -> a -> a
<> String
" " forall a. Semigroup a => a -> a -> a
<> Attributes -> String
renderAttrs Attributes
attrs
      where
        attrs :: Attributes
attrs =
          [ String
"label" String -> String -> (String, String)
.= (String
"\"" forall a. Semigroup a => a -> a -> a
<> forall a. [a] -> [[a]] -> [a]
intercalate String
"\n" [String]
lbll forall a. Semigroup a => a -> a -> a
<> String
"\""),
            String
"shape" String -> String -> (String, String)
.= DeclType -> String
nodeShape DeclType
typ,
            String
"style" String -> String -> (String, String)
.= String
nodeStyle
          ]
        nodeStyle :: String
nodeStyle =
          forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [[a]] -> [a]
intercalate String
", " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$
            [ forall a. Bool -> a -> Maybe a
if' (DeclType
typ forall a. Eq a => a -> a -> Bool
== DeclType
RecDecl) String
"rounded",
              forall a. Bool -> a -> Maybe a
if' (Bool -> Bool
not Bool
exported) String
"dashed",
              forall (f :: * -> *) a. Applicative f => a -> f a
pure String
"filled"
            ]

nodeShape :: DeclType -> String
nodeShape :: DeclType -> String
nodeShape DeclType
DataDecl = String
"octagon"
nodeShape DeclType
ConDecl = String
"box"
nodeShape DeclType
RecDecl = String
"box"
nodeShape DeclType
ClassDecl = String
"house"
nodeShape DeclType
ValueDecl = String
"ellipse"

edge :: ID -> ID -> Attributes -> Printer ()
edge :: String -> String -> Attributes -> Printer ()
edge String
from String
to Attributes
attrs = String -> Printer ()
strLn forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show String
from forall a. Semigroup a => a -> a -> a
<> String
" -> " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show String
to forall a. Semigroup a => a -> a -> a
<> String
" " forall a. Semigroup a => a -> a -> a
<> Attributes -> String
renderAttrs Attributes
attrs

(.=) :: String -> String -> (String, String)
.= :: String -> String -> (String, String)
(.=) = (,)

renderAttrs :: Attributes -> String
renderAttrs :: Attributes -> String
renderAttrs Attributes
attrs = forall a. (a -> ShowS) -> [a] -> ShowS
showListWith (\(String
key, String
val) -> String -> ShowS
showString String
key forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ShowS
showChar Char
'=' forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
val) Attributes
attrs String
";"

type Attributes = [(String, String)]