{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_GHC -Wall #-}
module Chart.Arrow
  ( Arrow(..)
  , ArrowHTStyle(..)
  , ArrowOptions(..)
  , normArrows
  , arrows
  , arrowChart
  , arrowChart_
  ) where
import Chart.Core
import Data.Ord (max)
import Diagrams.Prelude hiding (Color, D, project, width)
import GHC.Generics
import NumHask.Pair
import NumHask.Prelude hiding ((&), max)
import NumHask.Range
import NumHask.Rect
import NumHask.Space
data ArrowHTStyle a
  = Tri
  | Dart
  | HalfDart
  | Spike
  | Thorn
  | LineHead
  | NoHead
  | Tri2 a
  | Dart2 a
  | HalfDart2 a
  | Spike2 a
  | Thorn2 a
  | Tri'
  | Dart'
  | HalfDart'
  | Spike'
  | Thorn'
  | LineTail
  | NoTail
  | Quill
  | Block
  | Quill2 a
  | Block2 a
  deriving (Show, Generic)
arrowHTStyle :: (RealFloat a) => ArrowHTStyle a -> ArrowHT a
arrowHTStyle Tri = tri
arrowHTStyle Dart = dart
arrowHTStyle HalfDart = halfDart
arrowHTStyle Spike = spike
arrowHTStyle Thorn = thorn
arrowHTStyle LineHead = lineHead
arrowHTStyle NoHead = noHead
arrowHTStyle (Tri2 a) = arrowheadTriangle (a @@ deg)
arrowHTStyle (Dart2 a) = arrowheadDart (a @@ deg)
arrowHTStyle (HalfDart2 a) = arrowheadHalfDart (a @@ deg)
arrowHTStyle (Spike2 a) = arrowheadSpike (a @@ deg)
arrowHTStyle (Thorn2 a) = arrowheadThorn (a @@ deg)
arrowHTStyle Tri' = tri'
arrowHTStyle Dart' = dart'
arrowHTStyle HalfDart' = halfDart'
arrowHTStyle Spike' = spike'
arrowHTStyle Thorn' = thorn'
arrowHTStyle LineTail = lineTail
arrowHTStyle NoTail = noTail
arrowHTStyle Quill = quill
arrowHTStyle Block = block
arrowHTStyle (Quill2 a) = arrowtailQuill (a @@ deg)
arrowHTStyle (Block2 a) = arrowtailBlock (a @@ deg)
data ArrowOptions = ArrowOptions
  { minLength :: Double
  , maxLength :: Double
  , minHeadLength :: Double
  , maxHeadLength :: Double
  , minStaffWidth :: Double
  , maxStaffWidth :: Double
  , color :: UColor Double
  , hStyle :: ArrowHTStyle Double
  } deriving (Show, Generic)
instance Default ArrowOptions where
  def = ArrowOptions 0.02 0.2 0.01 0.1 0.002 0.005 ublue Dart
normArrows :: [Arrow] -> [Arrow]
normArrows xs = zipWith Arrow ps as'
  where
    
    ps = arrowPos <$> xs
    
    as = arrowDir <$> xs
    as' =
      (\x ->
         x * width (space $ arrowPos <$> xs :: Rect Double) /
         width (space $ arrowDir <$> xs :: Rect Double)) <$>
      as
data Arrow = Arrow
  { arrowPos :: Pair Double 
  , arrowDir :: Pair Double 
  } deriving (Eq, Show, Generic)
arrows :: (Traversable f) => ArrowOptions -> f Arrow -> Chart b
arrows opts xs = c
  where
    c =
      fcA (acolor $ color opts) $
      position $
      getZipList $
      (\ps' as' hrel' wrel' srel' ->
         ( ps'
         , arrowAt'
             (arropts hrel' wrel')
             (p2 (0, 0))
             ((srel' / norm as') *^ as'))) <$>
      ZipList (toList $ p_ <$> ps) <*>
      ZipList (toList $ r_ <$> as) <*>
      ZipList (toList hrel) <*>
      ZipList (toList wrel) <*>
      ZipList srel
    ps = arrowPos <$> xs
    
    as = arrowDir <$> xs
    
    (Pair dx dy) = width (space ps :: Rect Double)
    
    anorm = (\(Pair x y) -> sqrt ((x / dx) ** 2 + (y / dy) ** 2)) <$> as
    
    (Range _ anormMax) = space anorm
    
    arel =
      (\x -> max (anormMax * minLength opts) (x / anormMax * maxLength opts)) <$>
      anorm
    
    hrel = (\x -> max (minHeadLength opts) (maxHeadLength opts * x)) <$> arel
    
    wrel = (\x -> max (minStaffWidth opts) (maxStaffWidth opts * x)) <$> arel
    
    srel = zipWith (\la lh -> max 1e-12 (la - lh)) (toList arel) (toList hrel)
    
    arropts lh lw'' =
      with & arrowHead .~ arrowHTStyle (hStyle opts) & headLength .~ global lh &
      shaftStyle %~
      (lwG lw'' & lcA (acolor $ color opts)) &
      headStyle %~
      (lcA (acolor $ color opts) & fcA (acolor $ color opts))
arrowChart ::
     (Traversable f)
  => [ArrowOptions]
  -> Rect Double
  -> Rect Double
  -> [f Arrow]
  -> Chart b
arrowChart optss asp r xss =
  mconcat $
  zipWith
    (\opts xs ->
       arrows opts $
       (\(Arrow d arr) -> Arrow (project r asp d) (project r asp arr)) <$> xs)
    optss
    xss
arrowChart_ ::
    (Traversable f)
  => [ArrowOptions]
  -> Rect Double
  -> [f Arrow]
  -> Chart b
arrowChart_ optss asp xss = arrowChart optss asp r xss
  where
    r = fold (space . map arrowPos <$> xss)