{-# LANGUAGE OverloadedStrings #-}
module NN.Visualize(visualize, visualizeWith, png, pdf, scaled) where

import           Data.GraphViz
import           Data.GraphViz.Attributes.Colors.Brewer
import           Data.GraphViz.Attributes.Complete
import qualified Data.Text.Lazy                         as L
import           Gen.Caffe.LayerParameter               as LP

import           Data.Graph.Inductive.Graph
import           NN.DSL

type NetVizParams = GraphvizParams Node LayerParameter () () LayerParameter

defaultNNParams :: NetVizParams
defaultNNParams =
    nonClusteredParams {
  -- Let's visualize neural networks from the bottom up
  globalAttributes = [GraphAttrs [RankDir FromBottom]],
  fmtNode = fmtLabelParameter
}


scaled :: (LayerParameter -> Double) -> NetVizParams
scaled f = defaultNNParams { fmtNode = setSize }
    where
      setSize n@(_, lp) = fmtNode defaultNNParams n ++ [Width width', Height height']
          where
            width' = 0.75 * scale
            height' = 0.5 * scale
            scale = f lp

visualizeWith :: NetVizParams -> Net -> DotGraph Node
visualizeWith = graphToDot

visualize :: Net -> DotGraph Node
visualize = visualizeWith defaultNNParams

png :: FilePath -> DotGraph Node -> IO FilePath
png path g = runGraphviz g Png path

pdf :: FilePath -> DotGraph Node -> IO FilePath
pdf path g = runGraphviz g Pdf path

fmtLabelParameter :: (Node, LayerParameter) -> [Attribute]
fmtLabelParameter (_, lp) =
    [FontName "Source Code Pro",
     textLabel label,
     style filled,
     fillColor color']
    where
      maxColors = 8
      idx = ((+1) . (`mod` maxColors) . fromEnum . layerTy) lp
      scheme = BScheme Pastel2 (fromIntegral maxColors)
      color' = BC scheme (fromIntegral idx)
      label = (L.pack . asCaffe . layerTy) lp