{-# LANGUAGE MonadComprehensions, MultiParamTypeClasses #-}
module Data.GraphViz.Algorithms
       ( 
         
         CanonicaliseOptions(..)
       , defaultCanonOptions
       , dotLikeOptions
         
         
       , canonicalise
       , canonicaliseOptions
         
         
       , transitiveReduction
       , transitiveReductionOptions
       ) where
import Data.GraphViz.Attributes.Complete   (Attributes, defaultAttributeValue)
import Data.GraphViz.Attributes.Same
import Data.GraphViz.Internal.Util         (bool)
import Data.GraphViz.Types
import Data.GraphViz.Types.Canonical
import Data.GraphViz.Types.Internal.Common
import           Control.Arrow       (first, second, (***))
import           Control.Monad       (unless)
import           Control.Monad.State (State, execState, gets, modify)
import qualified Data.DList          as DList
import qualified Data.Foldable       as F
import           Data.Function       (on)
import           Data.List           (deleteBy, groupBy, partition, sortBy,
                                      (\\))
import           Data.Map            (Map)
import qualified Data.Map            as Map
import           Data.Maybe          (fromMaybe, listToMaybe, mapMaybe)
import           Data.Set            (Set)
import qualified Data.Set            as Set
data CanonicaliseOptions = COpts { 
                                   
                                   
                                   edgesInClusters :: Bool
                                   
                                   
                                 , groupAttributes :: Bool
                                 }
                         deriving (Eq, Ord, Show, Read)
defaultCanonOptions :: CanonicaliseOptions
defaultCanonOptions = COpts { edgesInClusters = True
                            , groupAttributes = True
                            }
dotLikeOptions :: CanonicaliseOptions
dotLikeOptions = COpts { edgesInClusters = True
                       , groupAttributes = False
                       }
canonicalise :: (DotRepr dg n) => dg n -> DotGraph n
canonicalise = canonicaliseOptions defaultCanonOptions
canonicaliseOptions :: (DotRepr dg n) => CanonicaliseOptions
                       -> dg n -> DotGraph n
canonicaliseOptions opts dg = cdg { strictGraph   = graphIsStrict dg
                                  , directedGraph = graphIsDirected dg
                                  }
  where
    cdg = createCanonical opts (getID dg) gas cl nl es
    (gas, cl) = graphStructureInformationClean dg
    nl = nodeInformationClean True dg
    es = edgeInformationClean True dg
type NodePath n = ([Maybe GraphID], DotNode n)
type NodePaths n = [NodePath n]
type EdgeClusters n = Map (Maybe GraphID) [DotEdge n]
type EdgeLocations n = (EdgeClusters n, [DotEdge n])
data CanonControl n = CC { cOpts    :: !CanonicaliseOptions
                         , isGraph  :: !Bool
                         , clusters :: !ClusterLookup
                         , clustEs  :: !(EdgeLocations n)
                         , topID    :: !(Maybe GraphID)
                         , topAttrs :: !Attributes
                         }
createCanonical :: (Ord n) => CanonicaliseOptions -> Maybe GraphID -> GlobalAttributes
                   -> ClusterLookup -> NodeLookup n -> [DotEdge n] -> DotGraph n
createCanonical opts gid gas cl nl es = promoteDSG $ makeGrouping cc ns
  where
    nUnlook (n,(p,as)) = (F.toList p, DotNode n as)
    
    ns = sortBy (compLists `on` fst) . map nUnlook $ Map.toList nl
    es' = if edgesInClusters opts
          then edgeClusters nl es
          else (Map.empty, es)
    cc = CC { cOpts    = opts
            , isGraph  = True
            , clusters = cl
            , clustEs  = es'
            , topID    = gid
            , topAttrs = attrs gas
            }
thisLevel :: NodePaths n -> (NodePaths n, [DotNode n])
thisLevel = second (map snd) . span (not . null . fst)
makeGrouping :: CanonControl n -> NodePaths n -> DotSubGraph n
makeGrouping cc cns = DotSG { isCluster = True
                            , subGraphID = cID
                            , subGraphStmts = stmts
                            }
  where
    cID | isGraph cc = topID cc
        | otherwise  = head . fst . head $ cns
    (nestedNs, ns) = thisLevel
                     . bool (map $ first tail) id (isGraph cc)
                     $ cns
    es = bool (fromMaybe [] . Map.lookup cID . fst) snd (isGraph cc)
         $ clustEs cc
    gas | isGraph cc = topAttrs cc
        | otherwise  = attrs . snd $ clusters cc Map.! cID
    subGs = map (makeGrouping $ cc { isGraph = False })
            . groupBy ((==) `on` (listToMaybe . fst))
            $ nestedNs
    stmts = setGlobal (cOpts cc) gas
            $ DotStmts { attrStmts = []
                       , subGraphs = subGs
                       , nodeStmts = ns
                       , edgeStmts = es
                       }
setGlobal :: CanonicaliseOptions
             -> Attributes 
             -> DotStatements n
             -> DotStatements n
setGlobal opts as stmts = stmts { attrStmts = globs'
                                , subGraphs = sgs'
                                , nodeStmts = ns'
                                , edgeStmts = es'
                                }
  where
    sgs = subGraphs stmts
    sStmts = map subGraphStmts sgs
    ns = nodeStmts stmts
    es = edgeStmts stmts
    sGlobs = map (partitionGlobal . attrStmts) sStmts
    (sgas,snas,seas) = unzip3 sGlobs
    gas' = as 
    nas' = getCommonGlobs opts nodeStmts snas sStmts $ map nodeAttributes ns
    eas' = getCommonGlobs opts edgeStmts seas sStmts $ map edgeAttributes es
    globs' = nonEmptyGAs [ GraphAttrs gas'
                         , NodeAttrs  nas'
                         , EdgeAttrs  eas'
                         ]
    ns' = map (\dn -> dn { nodeAttributes = nodeAttributes dn \\ nas' }) ns
    es' = map (\de -> de { edgeAttributes = edgeAttributes de \\ eas' }) es
    sgas' = updateGraphGlobs gas' sgas
    snas' = map (\\ nas') snas
    seas' = map (\\ eas') seas
    sGlobs' = zip3 sgas' snas' seas'
    sStmts' = zipWith (\ sSt sGl -> sSt { attrStmts = nonEmptyGAs $ unPartitionGlobal sGl })
                      sStmts
                      sGlobs'
    sgs' = zipWith (\ sg sSt -> sg { subGraphStmts = sSt }) sgs sStmts'
updateGraphGlobs :: Attributes -> [Attributes] -> [Attributes]
updateGraphGlobs gas = map go
  where
    gasS = Set.fromList gas
    override = toSAttr $ nonSameDefaults gas
    
    
    go = Set.toList
         . (`Set.difference` gasS) 
         . unSameSet
         . (`Set.union` override) 
         . toSAttr
nonSameDefaults :: Attributes -> Attributes
nonSameDefaults = mapMaybe (\ a -> [ a' | a' <- defaultAttributeValue a, a' /= a] )
getCommonGlobs :: CanonicaliseOptions
                  -> (DotStatements n -> [a])
                  -> [Attributes] 
                  -> [DotStatements n] 
                  -> [Attributes] 
                  -> Attributes
getCommonGlobs opts f sas stmts as
  | not $ groupAttributes opts = []
  | otherwise = case sas' ++ as of
                  []  -> []
                  [_] -> []
                  as' -> Set.toList . foldr1 Set.intersection
                         $ map Set.fromList as'
  where
    sas' = keepIfAny f sas stmts
keepIfAny :: (DotStatements n -> [a]) -> [Attributes] -> [DotStatements n]
             -> [Attributes]
keepIfAny f sas = map fst . filter snd . zip sas . map (hasAny f)
hasAny      :: (DotStatements n -> [a]) -> DotStatements n -> Bool
hasAny f ds = not (null $ f ds) || any (hasAny f . subGraphStmts) (subGraphs ds)
promoteDSG     :: DotSubGraph n -> DotGraph n
promoteDSG dsg = DotGraph { strictGraph     = undefined
                          , directedGraph   = undefined
                          , graphID         = subGraphID dsg
                          , graphStatements = subGraphStmts dsg
                          }
compLists :: (Ord a) => [a] -> [a] -> Ordering
compLists []     []     = EQ
compLists []     _      = GT
compLists _      []     = LT
compLists (x:xs) (y:ys) = case compare x y of
                            EQ  -> compLists xs ys
                            oth -> oth
nonEmptyGAs :: [GlobalAttributes] -> [GlobalAttributes]
nonEmptyGAs = filter (not . null . attrs)
edgeClusters    :: (Ord n) => NodeLookup n -> [DotEdge n]
                   -> EdgeLocations n
edgeClusters nl = (toM *** map snd) . partition (not . null . fst)
                  . map inClust
  where
    nl' = Map.map (F.toList . fst) nl
    
    inClust de@(DotEdge n1 n2 _) = (flip (,) de)
                                   . map fst . takeWhile (uncurry (==))
                                   $ zip (nl' Map.! n1) (nl' Map.! n2)
    toM = Map.map DList.toList
          . Map.fromListWith (flip DList.append)
          . map (last *** DList.singleton)
transitiveReduction :: (DotRepr dg n) => dg n -> DotGraph n
transitiveReduction = transitiveReductionOptions defaultCanonOptions
transitiveReductionOptions         :: (DotRepr dg n) => CanonicaliseOptions
                                      -> dg n -> DotGraph n
transitiveReductionOptions opts dg = cdg { strictGraph = graphIsStrict dg
                                         , directedGraph = graphIsDirected dg
                                         }
  where
    cdg = createCanonical opts (getID dg) gas cl nl es'
    (gas, cl) = graphStructureInformationClean dg
    nl = nodeInformationClean True dg
    es = edgeInformationClean True dg
    es' | graphIsDirected dg = rmTransEdges es
        | otherwise          = es
rmTransEdges    :: (Ord n) => [DotEdge n] -> [DotEdge n]
rmTransEdges [] = []
rmTransEdges es = concatMap (map snd . outgoing) $ Map.elems esM
  where
    tes = tagEdges es
    esMS = do edgeGraph tes
              ns <- getsMap Map.keys
              mapM_ (traverseTag zeroTag) ns
    esM = fst $ execState esMS (Map.empty, Set.empty)
type Tag = Int
type TagSet = Set Int
type TaggedEdge n = (Tag, DotEdge n)
zeroTag :: Tag
zeroTag = 0
tagEdges :: [DotEdge n] -> [TaggedEdge n]
tagEdges = zip [(succ zeroTag)..]
data TaggedValues n = TV { marked   :: Bool
                         , incoming :: [TaggedEdge n]
                         , outgoing :: [TaggedEdge n]
                         }
                    deriving (Eq, Ord, Show, Read)
defTV :: TaggedValues n
defTV = TV False [] []
type TagMap n = Map n (TaggedValues n)
type TagState n a = State (TagMap n, TagSet) a
getMap :: TagState n (TagMap n)
getMap = gets fst
getsMap   :: (TagMap n -> a) -> TagState n a
getsMap f = gets (f . fst)
modifyMap   :: (TagMap n -> TagMap n) -> TagState n ()
modifyMap f = modify (first f)
getSet :: TagState n TagSet
getSet = gets snd
modifySet   :: (TagSet -> TagSet) -> TagState n ()
modifySet f = modify (second f)
edgeGraph :: (Ord n) => [TaggedEdge n] -> TagState n ()
edgeGraph = mapM_ addEdge . reverse
  where
    addEdge te = addVal f tvOut >> addVal t tvIn
      where
        e = snd te
        f = fromNode e
        t = toNode e
        addVal n tv = modifyMap (Map.insertWith mergeTV n tv)
        tvIn  = defTV { incoming = [te] }
        tvOut = defTV { outgoing = [te] }
        mergeTV tvNew tv  = tv { incoming = incoming tvNew ++ incoming tv
                               , outgoing = outgoing tvNew ++ outgoing tv
                               }
traverseTag     :: (Ord n) => Tag -> n -> TagState n ()
traverseTag t n = do setMark True
                     checkIncoming
                     outEs <- getsMap (maybe [] outgoing . Map.lookup n)
                     mapM_ maybeRecurse outEs
                     setMark False
  where
    setMark mrk = modifyMap (Map.adjust (\tv -> tv { marked = mrk }) n)
    isMarked m n' = maybe False marked $ n' `Map.lookup` m
    checkIncoming = do m <- gets fst
                       let es = incoming $ m Map.! n
                           (keepEs, delEs) = partition (keepEdge m) es
                       modifyMap (Map.adjust (\tv -> tv {incoming = keepEs}) n)
                       modifySet (Set.union $ Set.fromList (map fst delEs))
                       mapM_ delOtherEdge delEs
      where
        keepEdge m (t',e) = t == t' || not (isMarked m $ fromNode e)
        delOtherEdge te = modifyMap (Map.adjust delE . fromNode $ snd te)
          where
            delE tv = tv {outgoing = deleteBy ((==) `on` fst) te $ outgoing tv}
    maybeRecurse (t',e) = do m <- getMap
                             delSet <- getSet
                             let n' = toNode e
                             unless (isMarked m n' || t' `Set.member` delSet)
                               $ traverseTag t' n'