module HLearn.Models.Distributions.Visualization.Graphviz
( MultivariateLabels (..)
, MarkovNetwork (..)
)
where
import HLearn.Algebra
import HLearn.Models.Distributions.Multivariate.Interface
import HLearn.Models.Distributions.Multivariate.Internal.CatContainer
import HLearn.Models.Distributions.Multivariate.Internal.Container
import HLearn.Models.Distributions.Multivariate.Internal.TypeLens
import Data.GraphViz.Exception
import Data.GraphViz hiding (graphToDot)
import Data.GraphViz.Attributes.Complete
import Control.Arrow(second)
import GHC.TypeLits
class (Trainable datatype) => MultivariateLabels datatype where
getLabels :: datatype -> [String]
class (MultivariateLabels (Datapoint dist)) => MarkovNetwork dist where
graphL :: dist -> [String] -> [(String,[String])]
plotNetwork :: FilePath -> dist -> IO Bool
plotNetwork file dist = graphToDotPng file $ graphL dist $ getLabels (undefined :: Datapoint dist)
instance
( MultivariateLabels datapoint
) => MarkovNetwork (Multivariate datapoint '[] prob)
where
graphL _ labels = []
instance
( MultivariateLabels datapoint
, MarkovNetwork (Multivariate datapoint xs prob)
) => MarkovNetwork (Multivariate datapoint ( ('[]) ': xs) prob)
where
graphL _ labels = graphL (undefined :: Multivariate datapoint xs prob) labels
instance
( MultivariateLabels datapoint
, MarkovNetwork (Multivariate datapoint ( ys ': xs) prob)
) => MarkovNetwork (Multivariate datapoint ( (Ignore' label ': ys) ': xs) prob)
where
graphL _ labels = (graphL (undefined :: Multivariate datapoint ( ys ': xs) prob) (tail labels))
instance
( MultivariateLabels datapoint
, MarkovNetwork (Multivariate datapoint ( ys ': xs) prob)
) => MarkovNetwork (Multivariate datapoint ( (CatContainer label ': ys) ': xs) prob)
where
graphL _ labels = (head labels, tail labels)
: (graphL (undefined :: Multivariate datapoint ( ys ': xs) prob) (tail labels))
instance
( MultivariateLabels datapoint
, MarkovNetwork (Multivariate datapoint (ys ': xs) prob)
) => MarkovNetwork (Multivariate datapoint ( (Container dist label ': ys) ': xs) prob)
where
graphL _ l = (head l,[]):(graphL (undefined::Multivariate datapoint (ys ': xs) prob) (tail l))
instance
( MultivariateLabels datapoint
, SingI (Length labelL)
, MarkovNetwork (Multivariate datapoint ( ys ': xs) prob)
) => MarkovNetwork (Multivariate datapoint ( (MultiContainer dist (labelL:: [*]) ': ys) ': xs) prob)
where
graphL _ l = go (take n l) ++ (graphL (undefined :: Multivariate datapoint ( ys ': xs ) prob) $ drop n l)
where
go [] = []
go (x:xs) = (x,xs):(go xs)
n = fromIntegral $ fromSing $ (sing :: Sing (Length labelL))
graphToDot :: (Ord a) => [(a, [a])] -> DotGraph a
graphToDot = graphToDotParams vacuumParams
graphToDotParams :: (Ord a, Ord cl) => GraphvizParams a () () cl l -> [(a, [a])] -> DotGraph a
graphToDotParams params nes = graphElemsToDot params ns es
where
ns = map (second $ const ()) nes
es = concatMap mkEs nes
mkEs (f,ts) = map (\t -> (f,t,())) ts
vacuumParams :: GraphvizParams a () () () ()
vacuumParams = defaultParams { globalAttributes = gStyle }
gStyle :: [GlobalAttributes]
gStyle = [ GraphAttrs [RankDir FromLeft, FontName "courier", Layout Circo]
, NodeAttrs [textLabel "\\N", shape PlainText, fontColor Black, Shape Ellipse, style filled, fillColor AliceBlue, penWidth 2, color Navy]
, EdgeAttrs [color Black, Dir NoDir]
]
graphToDotPng :: FilePath -> [(String,[String])] -> IO Bool
graphToDotPng fpre g = handle (\(e::GraphvizException) -> return False)
$ addExtension (runGraphviz (graphToDot g)) Png fpre >> return True