{-|
Module      : Braiddiagrams
Description : Draws braids, their closures, and Khovanov generators.
Copyright   : Adam Saltz
License     : BSD3
Maintainer  : saltz.adam@gmail.com
Stability   : experimental

Longer description to come.
-}
{-# LANGUAGE NoMonomorphismRestriction, FlexibleInstances #-}

module Braiddiagrams
where
import Prelude hiding (exp)
import Diagrams.Prelude
import Diagrams.Backend.SVG.CmdLine
import Diagrams.TwoD.Path.Metafont
import Diagrams.Direction
import qualified Data.Map as M
import qualified Data.Set as S
import Data.List (group, sort, maximumBy, find)
import qualified Diagrams.TwoD.Size as Size 
import Control.Arrow ((&&&), second)
import Data.Tuple (swap)
import Control.Monad (replicateM)
import Data.Maybe (fromMaybe)

import Braids -- (Braid, braidWord, braidWidth, Node (Join), Resolution)
import Complex --(Generator, components, signs, resolution)
import Util

type ArcLabel = Int
type Height = Int
type BraidIndex = Int
type ArtinGen = Int

-- | A @newtype@ wrapper for 'Node' to create an `IsName` instance. 
newtype NameNode = NameNode (Node ArcLabel) deriving (Ord, Eq, Show)
instance IsName NameNode


-- | The identity braid of index @index@.
identity :: BraidIndex -> Diagram B
identity index = hsep 1 (fmap vrule (replicate index 1))

-- | The identity braid of index @index@ with arc names starting at @lastlabel + 1@.
identityAt :: BraidIndex -> ArcLabel -> Diagram B -- lastLabel is the last label of the diagram above
identityAt index lastLabel = hsep 1 (fmap (\x -> named (name x) $ vrule 1) [1 .. index])
                             where
                               --name :: Int -> Node Int
                               name x = NameNode (Join (lastLabel + x) (lastLabel + x + index))

-- | A negative crossing.
negativeCrossing :: Diagram B
negativeCrossing =    metafont (p2 (-0.5,0.5)  .- leaving unit_Y <> arriving unit_Y -. endpt (p2 (0.5,-0.5))) 
                   <> metafont (p2 (0.5,0.5)   .- leaving unit_Y <> arriving (unit_Y + unit_X) -. endpt (p2 (0.1,0.1)))
                   <> metafont (p2 (-0.5,-0.5) .- leaving unitY <> arriving (unitY + unitX)  -. endpt (p2 (-0.1,-0.1)))

-- | Draws a positive crossing.
positiveCrossing :: Diagram B
positiveCrossing =    metafont (p2 (0.5,0.5)  .- leaving unit_Y <> arriving unit_Y -. endpt (p2 (-0.5,-0.5))) 
                   <> metafont (p2 (-0.5,0.5) .- leaving unit_Y <> arriving (unit_Y + unitX) -. endpt (p2 (-0.1,0.1)))
                   <> metafont (p2 (0.5,-0.5) .- leaving unitY <> arriving (unitY + unit_X)  -. endpt (p2 (0.1,-0.1)))

-- | A cup-cap combo.
cupCap :: Diagram B
cupCap =    metafont (p2 (0.5,0.5) .- leaving unit_Y <> arriving unit_X
                      -. p2 (0,0.3) .- leaving unit_X <> arriving unitY 
                      -. endpt (p2 (-0.5,0.5)))
            <> metafont (p2 (-0.5,-0.5) .- leaving unitY <> arriving unitX
                      -. p2 (0, -0.3) .- arriving unit_Y <> leaving unitX
                      -. endpt (p2 (0.5,-0.5)))

-- | A cup-cap combo with names starting at @lastLabel + 1@.
cupCapLevelAt :: BraidIndex -> ArcLabel -> ArtinGen -> Diagram B
cupCapLevelAt index lastLabel gen = hsep 1  
                       [ hsep 1 (fmap (\x -> named (name x) $ vrule 1) [1 .. spot - 1])
                       {-, vsep 0.6 [metafont (p2 (0.5,0.5) .- leaving unit_Y <> arriving unit_X
                              -. p2 (0,0.3) .- leaving unit_X <> arriving unitY 
                              -. endpt (p2 (-0.5,0.5))) # named (Join (lastLabel + spot) (lastLabel + spot + 1))
                                  ,metafont (p2 (-0.5,-0.5) .- leaving unitY <> arriving unitX
                              -. p2 (0, -0.3) .- arriving unit_Y <> leaving unitX
                              -. endpt (p2 (0.5,-0.5))) # named (Join (lastLabel + spot + index) (lastLabel + spot + index + 1))]-}
                         , vsep 0.6 [metafont (p2 (0.5,0.1) .- leaving unit_Y <> arriving unit_X
                              -. p2 (0,-0.1) .- leaving unit_X <> arriving unitY 
                              -. endpt (p2 (-0.5,0.1))) # named (NameNode (Join (lastLabel + spot) (lastLabel + spot + 1))) # translateY 0.4
                                , metafont (p2 (-0.5,-0.1) .- leaving unitY <> arriving unitX
                              -. p2 (0, 0.1) .- arriving unit_Y <> leaving unitX
                              -. endpt (p2 (0.5,-0.1))) # named (NameNode (Join (lastLabel + spot + index) (lastLabel + spot + index + 1))) # translateY (-0.4)
                                ]

                       , hsep 1 (fmap (\x -> named (name x) $ vrule 1) [spot + 2 .. index])
                       ]
                    where
                          name x = NameNode (Join (lastLabel + x) (lastLabel + x + index))
                          spot = abs gen

-- | The @k@th Artin generator of the braid group on @n@ strands.
-- To draw the inverse of the @k@th generator, use @-k@.
artin :: BraidIndex -> ArtinGen -> Diagram B
artin n k = case compare k 0  of 
                GT   ->   hsep 1 [identity (k-1) , positiveCrossing, identity (n-k-1)]
                LT   ->   hsep 1 [identity (-k-1), negativeCrossing, identity (n-(-k)-1)]
                EQ   ->   identity n

-- | Draws a braid.
drawBraid :: Braid -> Diagram B
drawBraid b = vcat $ alignL <$> fmap (artin index) word where
                  index = braidWidth b
                  word  = braidWord b

-- | Draws a braid closure.
drawBraidClosure :: Braid -> Diagram B
drawBraidClosure b =  alignL (mconcat [arc' r (dir unitX) (pi @@ rad) | r <- [1.0..fromIntegral index]])
                                             ===
                          alignL (drawBraid b ||| strutX 2 ||| vcat (replicate (length word) (identity index)))
                                             ===
                           alignL (mconcat [arc' r (dir unit_X) (pi @@ rad) | r <- [1.0..fromIntegral index]])
                      where
                        index = braidWidth b
                        word  = braidWord b

-- | Draws the @r@ resolution of the @k@th Artin generator in the @n@ strand braid group.
-- To draw the inverse of the @k@th generator, use @-k@.
resolutionD :: Int -> BraidIndex -> ArtinGen -> Diagram B
resolutionD r n k = if r == 0
                    then case compare k 0  of 
                         GT   ->   identity n
                         LT   ->   hsep 1 [identity (-k-1), cupCap, identity (n-(-k)-1)]
                         EQ   ->   identity n
                    else case compare k 0  of 
                         GT   ->   hsep 1 [identity (k-1), cupCap, identity (n-(k)-1)]
                         LT   ->   identity n
                         EQ   ->   identity n

-- | Draws the @r@ resolution of the @k@th Artin generator in the @n@ strand braid group with names starting at @lastLabel + 1@.
-- To draw the inverse of the @k@th generator, use @-k@.
resolutionAt :: Int -> BraidIndex -> ArtinGen -> Height -> Diagram B
resolutionAt r index gen height = if r == 0 
                                   then case compare gen 0 of
                                        GT   -> identityAt index (height * index)
                                        LT   -> cupCapLevelAt index (height * index) gen
                                        EQ   -> identityAt index (height * index)
                                   else case compare gen 0 of 
                                        GT   -> cupCapLevelAt index (height * index) gen
                                        LT   -> identityAt index (height * index)
                                        EQ   -> identityAt index (height * index)

-- | Draws the braid @b@ resolved according to @rs@.
resolveD :: Resolution -> Braid -> Diagram B
resolveD rs b = vcat $ alignL 
                    <$> zipWith ($) (map uncurry (fmap (`resolutionAt` index) rs))  (zip word [0..])
               where
                  index = braidWidth b
                  word  = braidWord b

-- | Draws the closure of the braid @b@ resolved according to @rs@.
resolveClosureD :: Resolution -> Braid -> Diagram B
resolveClosureD rs b = alignL (mconcat [arc' r (dir unitX) (pi @@ rad) | r <- [1.0..fromIntegral index]])
                                             ===
                          alignL (resolveD rs b ||| strutX 2 ||| vcat (replicate (length word) (identity index)))
                                             ===
                           alignL (mconcat [arc' r (dir unit_X) (pi @@ rad) | r <- [1.0..fromIntegral index]])
                      where
                        index = braidWidth b
                        word  = braidWord b

-- | A `Map` from `Resolution`s to a diagram of the corresponding resolution of the braid @b@.
cubeOfResolutionsD :: Braid -> M.Map Resolution (Diagram B)
cubeOfResolutionsD b = M.fromList $ fmap (\rs -> (rs, resolveD rs b)) ress
                              where
                                ress = replicateM (length word) [0,1]
                                word  = braidWord b

-- | A `Map` from `Resolution`s to a diagram of the corresponding resolution of the closure of the braid @b@.
cubeOfResolutionsClosureD :: Braid -> M.Map Resolution (Diagram B)
cubeOfResolutionsClosureD b = M.fromList $ fmap (\rs -> (rs, resolveClosureD rs b)) ress
                              where
                                ress = replicateM (length word) [0,1]
                                word  = braidWord b

-- | Prints `Map`s like the output of `cubeOfResolutionsD` and `cubeOfResolutionsClosureD.`
printCube :: M.Map [Int] (Diagram B) -> Diagram B
printCube cube =    lw veryThin 
                  . hcat' (with & sep .~ maxWidth) -- . fmap (translateY) 
                  . fmap center
                  . M.elems 
                  . fmap (vcat' (with & sep .~ maxHeight / 3))  -- Map Int Diagram B
                  . M.mapKeysWith (++) sum  -- Map Weight [Diagrams]
                  . fmap (:[]) -- Map Int [Diagram B] -- Map Key [Diagram]
                  $ cubeWithText where
                    maxHeight = maximum . fmap height . M.elems $ cube :: Double
                    cubeWithText = M.mapWithKey (\k a -> vcat' (with & sep .~ 1) [a -- Map Key Diagram
                                                   , (text . show $ k) -- # fontSize (local 1) -- put resolution below each diagram
                                                                       # translateX (Size.width a / 2)]) cube
                    maxWidth = maximum . fmap Size.width . M.elems $ cubeWithText :: Double

-- | Diagrams for an `AlgGen` indexed by their resolutions.
bigGeneratorD :: Braid -> AlgGen -> M.Map [Int] (Diagram B)
bigGeneratorD braid gens = M.fromList $ fmap (second (generatorD braid) . swap . graph resolution) (S.toList . toSet $ gens)

-- | Diagram for a single `Generator` -- mostly exists to be called by `bigGeneratorD`.
generatorD :: Braid -> Generator -> Diagram B
generatorD b gen = markComponents (signs gen) flatDiagram
                   where 
                      flatDiagram = resolveClosureD (resolution gen) b :: Diagram B
                      markComponents :: M.Map Component Sign -> Diagram B -> Diagram B
                      markComponents theSigns diag = compose (M.elems (M.mapWithKey (markIn diag) theSigns)) diag

-- | Marks a component of diagram.
markIn :: Diagram B -> Component -> Sign -> (Diagram B -> Diagram B)
markIn diag comp s = withName myNameIsMyName $ atop . place (circle 0.1 # fc purple) . location
                     where
                        possibleJoins = toName . NameNode . uncurry Join <$> cartesian (S.toList comp) (S.toList comp)
                        myNameIsMyName = fromMaybe (toName (NameNode (Join 1 1 :: Node Int)))
                                          (find (`elem` possibleJoins) (fmap fst (names diag)))