-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Bayes.SVG
   ( networkToHTML, networkToSVG
   , Layout, Point, pt, pointX, pointY, Color
   ) where

import Data.Monoid ( (<>) )
import Control.Arrow
import Ideas.Text.HTML hiding (text)
import Ideas.Text.XML hiding (name, text)
import Bayes.Network
import Bayes.Probability
import Data.Maybe
import Ideas.Text.HTML.W3CSS (w3css)

type Layout = [(String, Point)]

type Color = String

type Border = (Color, Double)

data Point = Pt { pointX :: Double, pointY :: Double }

data Size = Sz { sizeW :: Double, sizeH :: Double }

instance Show Point where
   show p = unwords ["pt", show (pointX p), show (pointY p)]

instance Num Point where
   (+)    = lift2Pt (+)
   (-)    = lift2Pt (-)
   (*)    = lift2Pt (*)
   abs    = liftPt abs
   signum = liftPt signum
   fromInteger n = pt (fromInteger n) (fromInteger n)

instance Fractional Point where
   (/) = lift2Pt (/)
   fromRational r = pt (fromRational r) (fromRational r)

liftPt :: (Double -> Double) -> Point -> Point
liftPt f (Pt x y) = pt (f x) (f y)

lift2Pt :: (Double -> Double -> Double) -> Point -> Point -> Point
lift2Pt f (Pt x1 y1) (Pt x2 y2) = pt (f x1 x2) (f y1 y2)

pt :: Double -> Double -> Point
pt = Pt

sz :: Double -> Double -> Size
sz = Sz

getPoint :: Layout -> Node a -> Point
getPoint l n = fromMaybe 0 $ lookup (nodeId n) l

getCenterPoint :: Layout -> Node a -> Point
getCenterPoint l n = getPoint l n + pt (sizeW s) (sizeH s) / 2
 where
   s = nodeSize n

getPoint2 :: Layout -> Node a -> Point
getPoint2 l n = getPoint l n + pt (sizeW s) (sizeH s)
 where
   s = nodeSize n

nodeSize :: Node a -> Size
nodeSize n = sz 150 (15 * fromIntegral (size n) + 25)

normalize :: Layout -> Network a -> (Layout, Size)
normalize l nw = (map (second trans) l, sz (2*margin+x2-x1) (2*margin+y2-y1))
 where
   ps = map snd l ++ map (getPoint2 l) (nodes nw)
   (x1, x2) = minMax 0 $ map pointX ps
   (y1, y2) = minMax 0 $ map pointY ps

   trans p = p - pt x1 y1 + pt margin margin
   margin = 5

minMax :: Ord a => a -> [a] -> (a, a)
minMax a xs = if null xs then (a, a) else (minimum xs, maximum xs)

----------------------------------------------------------------------------

type SVG = XMLBuilder

networkToHTML :: (a -> Maybe Probability) -> Layout -> Network a -> HTMLPage
networkToHTML f l nw =
   w3css $ htmlPage (name nw) (networkToSVG f l nw)

networkToSVG :: (a -> Maybe Probability) -> Layout -> Network a -> SVG
networkToSVG f l0 nw = element "svg" $
   [ "width"  .=. show (sizeW s)
   , "height" .=. show (sizeH s)
   ] ++
   map (uncurry (arrowToSVG l)) (arrows nw) ++
   map (nodeToSVG f l) (nodes nw)
 where
   (l, s) = normalize l0 nw

arrows :: Network a -> [(Node a, Node a)]
arrows nw = [ (b, a) | a <- nodes nw, b <- parents nw a ]

arrowToSVG :: Layout -> Node a -> Node a -> SVG
arrowToSVG l n1 n2 = line (getCenterPoint l n1) (getCenterPoint l n2)

nodeToSVG :: (a -> Maybe Probability) -> Layout -> Node a -> SVG
nodeToSVG f l n =
   rect p (nodeSize n) "#E5F6F7" (Just ("#196498", 0.8)) (Just (label n))<>
   text (p + pt 5 14) "#034471" (nodeId n)
   <> mconcat
   [ stateToSVG p (nodeSize n) i s (f a)
   | (i, (s, a)) <- zip [0..] (states n)
   ]
 where
   p = getPoint l n

stateToSVG :: Point -> Size -> Int -> String -> Maybe Probability -> SVG
stateToSVG p s i nId Nothing =
   text (p + pt 5 (34 + fromIntegral i*15)) "#034471" nId
stateToSVG p s i nId (Just prob) =
   rect (p + pt  (5 + sizeW s / 2) (25 + fromIntegral i*15)) barSz (barColors !! i) Nothing Nothing
   <>
   text (p + pt 5 (34 + fromIntegral i*15)) "#034471" txt
 where
   txt   = nId ++ " " ++ show prob
   barSz = sz (((sizeW s / 2 - 10) * fromRational (toRational prob))) 10

barColors :: [Color]
barColors = cycle ["#0000C0", "#FF8C00", "#00C000"]

----------------------------------------------------------------------------

text :: Point -> Color -> String -> SVG
text p c s = element "text"
   [ "x" .=. show (pointX p)
   , "y" .=. show (pointY p)
   , "fill" .=. c
   , string s
   , "font-size" .=. "12"
   ]

-- filled rectangle
rect :: Point -> Size -> Color -> Maybe Border -> Maybe String -> SVG
rect p s c mb mtip = element "rect" $
   [ "x" .=. show (pointX p)
   , "y" .=. show (pointY p)
   , "width"  .=. show (sizeW s)
   , "height" .=. show (sizeH s)
   , "fill"   .=. c
   ] ++ concat
   [ ["stroke" .=. bc, "stroke-width" .=. show bw]
   | (bc, bw) <- maybeToList mb
   ] ++
   [ element "title" [string tip]
   | tip <- maybeToList mtip
   ]

line :: Point -> Point -> SVG
line p1 p2 = element "line"
   [ "x1"    .=. show (pointX p1)
   , "y1"    .=. show (pointY p1)
   , "x2"    .=. show (pointX p2)
   , "y2"    .=. show (pointY p2)
   , "style" .=. "stroke:rgb(200,200,200);stroke-width:1"
   ]