{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DeriveDataTypeable         #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TemplateHaskell            #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}
module Diagrams.TwoD.Arrow
       ( 
         
         
         
         arrowV
       , arrowV'
       , arrowAt
       , arrowAt'
       , arrowBetween
       , arrowBetween'
       , connect
       , connect'
       , connectPerim
       , connectPerim'
       , connectOutside
       , connectOutside'
       , arrow
       , arrow'
       , arrowFromLocatedTrail
       , arrowFromLocatedTrail'
         
       , ArrowOpts(..)
       , arrowHead
       , arrowTail
       , arrowShaft
       , headGap
       , tailGap
       , gaps, gap
       , headTexture
       , headStyle
       , headLength
       , tailTexture
       , tailStyle
       , tailLength
       , lengths
       , shaftTexture
       , shaftStyle
       , straightShaft
         
         
       , module Diagrams.TwoD.Arrowheads
       ) where
import           Control.Lens              (Lens', Traversal',
                                            generateSignatures, lensRules,
                                            makeLensesWith, view, (%~), (&),
                                            (.~), (^.))
import           Data.Default.Class
import           Data.Maybe                (fromMaybe)
import           Data.Monoid.Coproduct     (untangle)
import           Data.Semigroup
import           Data.Typeable
import           Data.Colour               hiding (atop)
import           Diagrams.Core
import           Diagrams.Core.Style       (unmeasureAttrs)
import           Diagrams.Core.Types       (QDiaLeaf (..), mkQD')
import           Diagrams.Angle
import           Diagrams.Attributes
import           Diagrams.Direction        hiding (dir)
import           Diagrams.Located          (Located (..), unLoc)
import           Diagrams.Parametric
import           Diagrams.Path
import           Diagrams.Solve.Polynomial (quadForm)
import           Diagrams.Tangent          (tangentAtEnd, tangentAtStart)
import           Diagrams.Trail
import           Diagrams.TwoD.Arrowheads
import           Diagrams.TwoD.Attributes
import           Diagrams.TwoD.Path        (stroke, strokeT)
import           Diagrams.TwoD.Transform   (reflectY, translateX)
import           Diagrams.TwoD.Types
import           Diagrams.TwoD.Vector      (unitX, unit_X)
import           Diagrams.Util             (( # ))
import           Linear.Affine
import           Linear.Metric
import           Linear.Vector
data ArrowOpts n
  = ArrowOpts
    { _arrowHead  :: ArrowHT n
    , _arrowTail  :: ArrowHT n
    , _arrowShaft :: Trail V2 n
    , _headGap    :: Measure n
    , _tailGap    :: Measure n
    , _headStyle  :: Style V2 n
    , _headLength :: Measure n
    , _tailStyle  :: Style V2 n
    , _tailLength :: Measure n
    , _shaftStyle :: Style V2 n
    }
straightShaft :: OrderedField n => Trail V2 n
straightShaft = trailFromOffsets [unitX]
instance TypeableFloat n => Default (ArrowOpts n) where
  def = ArrowOpts
        { _arrowHead    = dart
        , _arrowTail    = noTail
        , _arrowShaft   = straightShaft
        , _headGap      = none
        , _tailGap      = none
        
        , _headStyle    = mempty
        , _headLength   = normal
        , _tailStyle    = mempty
        , _tailLength   = normal
        , _shaftStyle   = mempty
        }
makeLensesWith (lensRules & generateSignatures .~ False) ''ArrowOpts
arrowHead :: Lens' (ArrowOpts n) (ArrowHT n)
arrowTail :: Lens' (ArrowOpts n) (ArrowHT n)
arrowShaft :: Lens' (ArrowOpts n) (Trail V2 n)
headGap :: Lens' (ArrowOpts n) (Measure n)
tailGap :: Lens' (ArrowOpts n) (Measure n)
gaps :: Traversal' (ArrowOpts n) (Measure n)
gaps f opts = (\h t -> opts & headGap .~ h & tailGap .~ t)
        <$> f (opts ^. headGap)
        <*> f (opts ^. tailGap)
gap :: Traversal' (ArrowOpts n) (Measure n)
gap = gaps
headStyle :: Lens' (ArrowOpts n) (Style V2 n)
tailStyle :: Lens' (ArrowOpts n) (Style V2 n)
shaftStyle :: Lens' (ArrowOpts n) (Style V2 n)
headLength :: Lens' (ArrowOpts n) (Measure n)
tailLength :: Lens' (ArrowOpts n) (Measure n)
lengths :: Traversal' (ArrowOpts n) (Measure n)
lengths f opts =
  (\h t -> opts & headLength .~ h & tailLength .~ t)
    <$> f (opts ^. headLength)
    <*> f (opts ^. tailLength)
headTexture :: TypeableFloat n => Lens' (ArrowOpts n) (Texture n)
headTexture = headStyle . _fillTexture
tailTexture :: TypeableFloat n => Lens' (ArrowOpts n) (Texture n)
tailTexture = tailStyle . _fillTexture
shaftTexture :: TypeableFloat n => Lens' (ArrowOpts n) (Texture n)
shaftTexture = shaftStyle . _lineTexture
shaftSty :: ArrowOpts n -> Style V2 n
shaftSty opts = opts^.shaftStyle
headSty :: TypeableFloat n => ArrowOpts n -> Style V2 n
headSty opts = fc black (opts^.headStyle)
tailSty :: TypeableFloat n => ArrowOpts n -> Style V2 n
tailSty opts = fc black (opts^.tailStyle)
xWidth :: Floating n => (Traced t, V t ~ V2, N t ~ n) => t -> n
xWidth p = a + b
  where
    a = fromMaybe 0 (norm <$> traceV origin unitX p)
    b = fromMaybe 0 (norm <$> traceV origin unit_X p)
colorJoint :: TypeableFloat n => Style V2 n -> Style V2 n
colorJoint sStyle =
  let c = fmap getLineTexture . getAttr $ sStyle
      o = fmap getOpacity . getAttr $ sStyle
  in
  case (c, o) of
      (Nothing, Nothing) -> fillColor black mempty
      (Just t, Nothing)  -> fillTexture t mempty
      (Nothing, Just o') -> opacity o' . fillColor black $ mempty
      (Just t, Just o')  -> opacity o' . fillTexture t $ mempty
widthOfJoint :: forall n. TypeableFloat n => Style V2 n -> n -> n -> n
widthOfJoint sStyle gToO nToO =
  fromMaybe
    (fromMeasured gToO nToO medium) 
    (fmap getLineWidth . getAttr . unmeasureAttrs gToO nToO $ sStyle)
mkHead :: (TypeableFloat n, Renderable (Path V2 n) b) =>
          n -> ArrowOpts n -> n -> n -> Bool -> (QDiagram b V2 n Any, n)
mkHead = mkHT unit_X arrowHead headSty
mkTail :: (TypeableFloat n, Renderable (Path V2 n) b) =>
          n -> ArrowOpts n -> n -> n -> Bool -> (QDiagram b V2 n Any, n)
mkTail = mkHT unitX arrowTail tailSty
mkHT
  :: (TypeableFloat n, Renderable (Path V2 n) b)
  => V2 n -> Lens' (ArrowOpts n) (ArrowHT n) -> (ArrowOpts n -> Style V2 n)
  -> n -> ArrowOpts n -> n -> n -> Bool -> (QDiagram b V2 n Any, n)
mkHT xDir htProj styProj sz opts gToO nToO reflect
    = ( (j <> ht)
        # (if reflect then reflectY else id)
        # moveOriginBy (jWidth *^ xDir) # lwO 0
      , htWidth + jWidth
      )
  where
    (ht', j') = (opts^.htProj) sz
                (widthOfJoint (shaftSty opts) gToO nToO)
    htWidth = xWidth ht'
    jWidth  = xWidth j'
    ht = stroke ht' # applyStyle (styProj opts)
    j  = stroke j'  # applyStyle (colorJoint (opts^.shaftStyle))
spine :: TypeableFloat n => Trail V2 n -> n -> n -> n -> Trail V2 n
spine tr tw hw sz = tS <> tr # scale sz <> hS
  where
    tSpine = trailFromOffsets [signorm . tangentAtStart $ tr] # scale tw
    hSpine = trailFromOffsets [signorm . tangentAtEnd $ tr] # scale hw
    hS = if hw > 0 then hSpine else mempty
    tS = if tw > 0 then tSpine else mempty
scaleFactor :: TypeableFloat n => Trail V2 n -> n -> n -> n -> n
scaleFactor tr tw hw t
  
  
  
  
  
  
  
  
  
  
  = case quadForm
             (quadrance v)
             (2* (v `dot` (tv ^+^ hv)))
             (quadrance (tv ^+^ hv) - t*t)
    of
      []  -> 1   
      [s] -> s   
      ss  -> maximum ss
        
        
  where
    tv = tw *^ (tangentAtStart tr # signorm)
    hv = hw *^ (tangentAtEnd   tr # signorm)
    v  = trailOffset tr
arrowEnv :: TypeableFloat n => ArrowOpts n -> n -> Envelope V2 n
arrowEnv opts len = getEnvelope horizShaft
  where
    horizShaft = shaft # rotate (negated (v ^. _theta)) # scale (len / m)
    m = norm v
    v = trailOffset shaft
    shaft = opts ^. arrowShaft
arrow :: (TypeableFloat n, Renderable (Path V2 n) b) => n -> QDiagram b V2 n Any
arrow = arrow' def
arrow' :: (TypeableFloat n, Renderable (Path V2 n) b) => ArrowOpts n -> n -> QDiagram b V2 n Any
arrow' opts len = mkQD' (DelayedLeaf delayedArrow)
      
      
      (arrowEnv opts len) mempty mempty mempty
  where
    
    
    
    
    
    
    
    
    
    delayedArrow da g n =
      let (trans, globalSty) = option mempty untangle . fst $ da
      in  dArrow globalSty trans len g n
    
    dArrow sty tr ln gToO nToO = (h' <> t' <> shaft)
               # moveOriginBy (tWidth *^ (unit_X # rotate tAngle))
               # rotate (((q .-. p)^._theta) ^-^ (dir^._theta))
               # moveTo p
      where
        p = origin # transform tr
        q = origin # translateX ln # transform tr
        
        
        
        
        globalLC = getLineTexture <$> getAttr sty
        opts' = opts
          & headStyle  %~ maybe id fillTexture globalLC
          & tailStyle  %~ maybe id fillTexture globalLC
          & shaftStyle %~ applyStyle sty . transform tr
        
        
        scaleFromMeasure = fromMeasured gToO nToO . scaleLocal (avgScale tr)
        hSize = scaleFromMeasure $ opts ^. headLength
        tSize = scaleFromMeasure $ opts ^. tailLength
        hGap  = scaleFromMeasure $ opts ^. headGap
        tGap  = scaleFromMeasure $ opts ^. tailGap
        
        (h, hWidth') = mkHead hSize opts' gToO nToO (isReflection tr)
        (t, tWidth') = mkTail tSize opts' gToO nToO (isReflection tr)
        rawShaftTrail = opts^.arrowShaft
        shaftTrail
          = rawShaftTrail
            
          # rotate (negated . view _theta . trailOffset $ rawShaftTrail)
            
            
            
          # transform tr
        
        tWidth = tWidth' + tGap
        hWidth = hWidth' + hGap
        
        tAngle = tangentAtStart shaftTrail ^. _theta
        hAngle = tangentAtEnd shaftTrail ^. _theta
        
        
        
        sf = scaleFactor shaftTrail tWidth hWidth (norm (q .-. p))
        shaftTrail' = shaftTrail # scale sf
        shaft = strokeT shaftTrail' # applyStyle (shaftSty opts')
        
        h' = h # rotate hAngle
               # moveTo (origin .+^ shaftTrail' `atParam` domainUpper shaftTrail')
        t' = t # rotate tAngle
        
        
        dir = direction (trailOffset $ spine shaftTrail tWidth hWidth sf)
arrowBetween :: (TypeableFloat n, Renderable (Path V2 n) b) => Point V2 n -> Point V2 n -> QDiagram b V2 n Any
arrowBetween = arrowBetween' def
arrowBetween'
  :: (TypeableFloat n, Renderable (Path V2 n) b) =>
     ArrowOpts n -> Point V2 n -> Point V2 n -> QDiagram b V2 n Any
arrowBetween' opts s e = arrowAt' opts s (e .-. s)
arrowAt :: (TypeableFloat n, Renderable (Path V2 n) b) => Point V2 n -> V2 n -> QDiagram b V2 n Any
arrowAt = arrowAt' def
arrowAt'
  :: (TypeableFloat n, Renderable (Path V2 n) b) =>
     ArrowOpts n -> Point V2 n -> V2 n -> QDiagram b V2 n Any
arrowAt' opts s v = arrow' opts len
                  # rotate dir # moveTo s
  where
    len = norm v
    dir = v ^. _theta
arrowV :: (TypeableFloat n, Renderable (Path V2 n) b) => V2 n -> QDiagram b V2 n Any
arrowV = arrowV' def
arrowV'
  :: (TypeableFloat n, Renderable (Path V2 n) b)
  => ArrowOpts n -> V2 n -> QDiagram b V2 n Any
arrowV' opts = arrowAt' opts origin
arrowFromLocatedTrail
  :: (Renderable (Path V2 n) b, RealFloat n, Typeable n)
  => Located (Trail V2 n) -> QDiagram b V2 n Any
arrowFromLocatedTrail = arrowFromLocatedTrail' def
arrowFromLocatedTrail'
  :: (Renderable (Path V2 n) b, RealFloat n, Typeable n)
  => ArrowOpts n -> Located (Trail V2 n) -> QDiagram b V2 n Any
arrowFromLocatedTrail' opts trail = arrowBetween' opts' start end
  where
    opts' = opts & arrowShaft .~ unLoc trail
    start = atStart trail
    end   = atEnd trail
connect
  :: (TypeableFloat n, Renderable (Path V2 n) b, IsName n1, IsName n2)
  => n1 -> n2 -> QDiagram b V2 n Any -> QDiagram b V2 n Any
connect = connect' def
connect'
  :: (TypeableFloat n, Renderable (Path V2 n) b, IsName n1, IsName n2)
  => ArrowOpts n -> n1 -> n2 -> QDiagram b V2 n Any -> QDiagram b V2 n Any
connect' opts n1 n2 =
  withName n1 $ \sub1 ->
  withName n2 $ \sub2 ->
    let [s,e] = map location [sub1, sub2]
    in  atop (arrowBetween' opts s e)
connectPerim
  :: (TypeableFloat n, Renderable (Path V2 n) b, IsName n1, IsName n2)
 => n1 -> n2 -> Angle n -> Angle n
  -> QDiagram b V2 n Any -> QDiagram b V2 n Any
connectPerim = connectPerim' def
connectPerim'
  :: (TypeableFloat n, Renderable (Path V2 n) b, IsName n1, IsName n2)
  => ArrowOpts n -> n1 -> n2 -> Angle n -> Angle n
  -> QDiagram b V2 n Any -> QDiagram b V2 n Any
connectPerim' opts n1 n2 a1 a2 =
  withName n1 $ \sub1 ->
  withName n2 $ \sub2 ->
    let [os, oe] = map location [sub1, sub2]
        s = fromMaybe os (maxTraceP os (unitX # rotate a1) sub1)
        e = fromMaybe oe (maxTraceP oe (unitX # rotate a2) sub2)
    in  atop (arrowBetween' opts s e)
connectOutside
  :: (TypeableFloat n, Renderable (Path V2 n) b, IsName n1, IsName n2)
  => n1 -> n2 -> QDiagram b V2 n Any -> QDiagram b V2 n Any
connectOutside = connectOutside' def
connectOutside'
  :: (TypeableFloat n, Renderable (Path V2 n) b, IsName n1, IsName n2)
  => ArrowOpts n -> n1 -> n2 -> QDiagram b V2 n Any -> QDiagram b V2 n Any
connectOutside' opts n1 n2 =
  withName n1 $ \b1 ->
  withName n2 $ \b2 ->
    let v = location b2 .-. location b1
        midpoint = location b1 .+^ (v ^/ 2)
        s' = fromMaybe (location b1) $ traceP midpoint (negated v) b1
        e' = fromMaybe (location b2) $ traceP midpoint v b2
    in
      atop (arrowBetween' opts s' e')