{-# LANGUAGE TemplateHaskell #-}
module Data.Geometry.SegmentTree.Generic( NodeData(..), splitPoint, range, assoc
                                        , LeafData(..), atomicRange, leafAssoc

                                        , SegmentTree(..), unSegmentTree
                                        , Assoc(..)

                                        , createTree, fromIntervals
                                        , insert, delete

                                        , search, stab

                                        , I(..), fromIntervals'

                                        , Count(..)
                                        ) where


import           Control.DeepSeq
import           Control.Lens
import           Data.BinaryTree
import           Data.Ext
import           Data.Geometry.Interval
import           Data.Geometry.IntervalTree (IntervalLike(..))
import           Data.Geometry.Properties
import qualified Data.List as List
import           Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NonEmpty
import           Data.Semigroup
import           GHC.Generics (Generic)

--------------------------------------------------------------------------------

-- | Internal nodes store a split point, the range, and an associated data structure
data NodeData v r = NodeData { _splitPoint :: !(EndPoint r)
                             , _range      :: !(Range r)
                             , _assoc      :: !v
                             } deriving (Show,Eq,Functor,Generic)
makeLenses ''NodeData
instance (NFData v, NFData r) => NFData (NodeData v r)

-- | We store atomic ranges a bit more efficiently.
data AtomicRange r = Singleton !r | AtomicRange deriving (Show,Eq,Functor,Generic)
instance NFData r => NFData (AtomicRange r)

-- | Leaf nodes store an atomic range, and an associated data structure.
data LeafData v r = LeafData  { _atomicRange :: !(AtomicRange r)
                              , _leafAssoc   :: !v
                              } deriving (Show,Eq,Functor,Generic)
makeLenses ''LeafData
instance (NFData v, NFData r) => NFData (LeafData v r)

--------------------------------------------------------------------------------

-- | Segment tree on a Fixed set of endpoints
newtype SegmentTree v r =
  SegmentTree { _unSegmentTree :: BinLeafTree (NodeData v r) (LeafData v r) }
    deriving (Show,Eq,Generic,NFData)
makeLenses ''SegmentTree


-- rangeOf :: BinLeafTree (NodeData v r) (LeafData v r) -> Range (UnBounded r)
-- rangeOf (Node _ x _) = Val <$> x^.range
-- rangeOf (Leaf x)     = case x^.atomicRange of
--                         Singleton r -> ClsoedRange (Val r)     (Val r)
--                         AtomicRange -> OpenRange   MinInfinity MaxInfinity


data BuildLeaf a = LeafSingleton a | LeafRange a a deriving (Show,Eq)

-- | Given a sorted list of endpoints, without duplicates, construct a segment tree
--
--
-- \(O(n)\) time
createTree                      :: NonEmpty r -> v -> SegmentTree v r
-- createTree (r NonEmpty.:| []) v = SegmentTree . Leaf $ LeafData (Singleton r) v
createTree pts                v = SegmentTree . fmap h . foldUpData f g . fmap _unElem
                                . asBalancedBinLeafTree $ ranges
  where
    h (LeafSingleton r) = LeafData (Singleton r) v
    h (LeafRange   _ _) = LeafData AtomicRange   v

    f l _ r = let m  = l^.range.upper
                  ll = l^.range.lower
                  rr = r^.range.upper
              in NodeData m (Range ll rr) v
    -- | Singletons map to closed singleton ranges, Ranges map to open ranges
    g (LeafSingleton r) = NodeData (Closed r) (ClosedRange r r) v
    g (LeafRange   s r) = NodeData (Open r)   (OpenRange   s r) v


    ranges  = interleave (fmap LeafSingleton pts) ranges'
    ranges' = zipWith LeafRange (NonEmpty.toList pts) (NonEmpty.tail pts)




-- | Interleaves the two lists
--
-- >>> interleave (NonEmpty.fromList ["0","1","2"]) ["01","12"]
-- "0" :| ["01","1","12","2"]
interleave                       :: NonEmpty a -> [a] -> NonEmpty a
interleave (x NonEmpty.:| xs) ys = x NonEmpty.:| concat (zipWith (\a b -> [a,b]) ys xs)


-- | Build a SegmentTree
--
-- \(O(n \log n)\)
fromIntervals      :: (Ord r, Eq p, Assoc v i, IntervalLike i, Monoid v, NumType i ~ r)
                   => (Interval p r -> i)
                   -> NonEmpty (Interval p r) -> SegmentTree v r
fromIntervals f is = foldr (insert . f) (createTree pts mempty) is
  where
    endPoints (toRange -> Range' a b) = [a,b]
    pts = nub' . NonEmpty.sort . NonEmpty.fromList . concatMap endPoints $ is
    nub' = fmap NonEmpty.head . NonEmpty.group1

-- | lists all intervals
toList :: SegmentTree v r -> [i]
toList = undefined

--------------------------------------------------------------------------------
-- * Searching

-- | Search for all intervals intersecting x
--
-- \(O(\log n + k)\) where \(k\) is the output size
search   :: (Ord r, Monoid v) => r -> SegmentTree v r -> v
search x = mconcat . stab x



inAtomicRange                   :: Eq r => r -> AtomicRange r -> Bool
x `inAtomicRange` (Singleton r) = x == r
_ `inAtomicRange` AtomicRange   = True


-- | Returns the associated values of the nodes on the search path to x
--
-- \(O(\log n)\)
stab                   :: Ord r => r -> SegmentTree v r -> [v]
stab x (SegmentTree t) = stabRoot t
  where
    stabRoot (Leaf (LeafData rr v))
      | x `inAtomicRange` rr = [v]
      | otherwise            = []
    stabRoot (Node l (NodeData m rr v) r) = case (x `inRange` rr, Closed x <= m) of
      (False,_)   -> []
      (True,True) -> v : stab' l
      _           -> v : stab' r

    stab' (Leaf (LeafData rr v))      | x `inAtomicRange` rr = [v]
                                      | otherwise            = []
    stab' (Node l (NodeData m _ v) r) | Closed x <= m        = v : stab' l
                                      | otherwise            = v : stab' r


--------------------------------------------------------------------------------
-- * Inserting intervals

-- | Class for associcated data structures
class Measured v i => Assoc v i where
  insertAssoc :: i -> v -> v
  deleteAssoc :: i -> v -> v

-- | Gets the range associated with this node
getRange  :: BinLeafTree (NodeData v r) (LeafData t r) -> Maybe (Range r)
getRange (Leaf (LeafData (Singleton r) _)) = Just $ Range (Closed r) (Closed r)
getRange (Leaf _)                          = Nothing
getRange (Node _ nd _)                     = Just $ nd^.range

coversAtomic                      :: Ord r
                                  => Range r -> Range r -> AtomicRange r -> Bool
coversAtomic ri _   (Singleton r) = r `inRange` ri
coversAtomic ri inR AtomicRange   = ri `covers` inR

-- | Pre: the interval should have one of the endpoints on which the tree is built.
insert                   :: (Assoc v i, NumType i ~ r, Ord r, IntervalLike i)
                         => i -> SegmentTree v r -> SegmentTree v r
insert i (SegmentTree t) = SegmentTree $ insertRoot t
  where
    ri@(Range a b) = toRange i
    insertRoot t' = maybe t' (flip insert' t') $ getRange t'

    insert' inR         lf@(Leaf nd@(LeafData rr _))
      | coversAtomic ri inR rr = Leaf $ nd&leafAssoc %~ insertAssoc i
      | otherwise              = lf
    insert' (Range c d) (Node l nd@(NodeData m rr _) r)
      | ri `covers` rr       = Node l (nd&assoc %~ insertAssoc i) r
      | otherwise            = Node l' nd r'
      where
           -- check if the range intersects the range of the left subtree
        l'  = if a <= m then insert' (Range c          m) l else l
        r'  = if m < b  then insert' (Range (toOpen m) d) r else r

    toOpen = Open . view unEndPoint

-- | Delete an interval from the tree
--
-- pre: The segment is in the tree!
delete :: (Assoc v i, NumType i ~ r, Ord r, IntervalLike i)
          => i -> SegmentTree v r -> SegmentTree v r
delete i (SegmentTree t) = SegmentTree $ delete' t
  where
    (Range a b) = toRange i

    delete' (Leaf ld) = Leaf $ ld&leafAssoc %~ deleteAssoc i
    delete' (Node l nd@(_splitPoint -> m) r)
      | b <= m    = Node (delete' l) (nd&assoc %~ deleteAssoc i) r
      | otherwise = Node l           (nd&assoc %~ deleteAssoc i) (delete' r)

    -- delete'' (Leaf ld)     = Leaf $ ld&leafAssoc %~ deleteAssoc i
    -- delete'' (Node l nd r) = Node l (nd&assoc %~ deleteAssoc i) r

    -- deleteL (Leaf ld)     = Leaf $ ld&leafAssoc %~ deleteAssoc i
    -- deleteL (Node l nd@(_splitPoint -> m) r)
    --   | a <= m    = Node (deleteL l) (nd&assoc %~ deleteAssoc i) (delete'' r)
    --   | otherwise = Node l nd (deleteL r)

    -- deleteR (Leaf ld)     = Leaf $ ld&leafAssoc %~ deleteAssoc i
    -- deleteR (Node l nd@(_splitPoint -> m) r)
    --   | m <= b    = Node (delete'' l) (nd&assoc %~ deleteAssoc i) (deleteR r)
    --   | otherwise = Node (deleteR l) nd r


--------------------------------------------------------------------------------
-- * Listing the intervals stabbed

-- | Interval
newtype I a = I { _unI :: a} deriving (Show,Read,Eq,Ord,Generic,NFData)

type instance NumType (I a) = NumType a

instance Measured [I a] (I a) where
  measure = (:[])

instance Eq a => Assoc [I a] (I a) where
  insertAssoc = (:)
  deleteAssoc = List.delete

-- instance Measured [Interval p r] (Interval p r) where
--   measure = (:[])

-- instance (Eq p, Eq r) => Assoc [Interval p r] (Interval p r) where
--   insertAssoc = (:)
--   deleteAssoc = List.delete


instance IntervalLike a => IntervalLike (I a) where
  toRange = toRange . _unI


fromIntervals' :: (Eq p, Ord r)
               => NonEmpty (Interval p r) -> SegmentTree [I (Interval p r)] r
fromIntervals' = fromIntervals I


--------------------------------------------------------------------------------
-- * Counting the number of segments intersected

newtype Count = Count { getCount :: Int}
              deriving (Show,Eq,Ord,Num,Integral,Enum,Real,Generic,NFData)

newtype C a = C { _unC :: a} deriving (Show,Read,Eq,Ord,Generic,NFData)

instance Semigroup Count where
  a <> b = Count $ getCount a + getCount b

instance Measured Count (C i) where
  measure _ = 1

instance Assoc Count (C i) where
  insertAssoc _ v = v + 1
  deleteAssoc _ v = v - 1


--------------------------------------------------------------------------------
-- * Testing stuff

test'' = fromIntervals' . NonEmpty.fromList $ test
test = [Interval (Closed (238 :+ ())) (Open (309 :+ ())), Interval (Closed (175 :+ ())) (Closed (269 :+ ())),Interval (Closed (255 :+ ())) (Open (867 :+ ())),Interval (Open (236 :+ ())) (Closed (863 :+ ())),Interval (Open (150 :+ ())) (Closed (161 :+ ())),Interval (Closed (35 :+ ())) (Closed (77 :+ ()))]


-- q =        [78]

-- test = fromIntervals' . NonEmpty.fromList $ [ closedInterval 0 10
--                                             , closedInterval 5 15
--                                             , closedInterval 1 4
--                                             , closedInterval 3 9
--                                             ]
tst = fromIntervals' . NonEmpty.fromList $ [ closedInterval 1 6
                                           , closedInterval 2 6
                                           -- , Interval (Closed $ ext 0) (Open $ ext 1)
                                           ]



closedInterval a b = ClosedInterval (ext a) (ext b)

showT :: (Show r, Show v) => SegmentTree v r -> String
showT = drawTree . _unSegmentTree


test' :: (Show r, Num r, Ord r, Enum r) => SegmentTree [I (Interval () r)] r
test' = insert (I $ closedInterval 6 14) $ createTree (NonEmpty.fromList [2,4..20]) []