{-# LANGUAGE MultiParamTypeClasses, TypeSynonymInstances, FlexibleInstances, FlexibleContexts #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SegmentTree
-- Copyright   :  (c) Dmitry Astapov 2010
-- License     :  BSD-style
-- Maintainer  :  dastapov@gmail.com
-- Stability   :  experimental
-- Portability :  non-portable (MPTCs, etc - see above)
--
-- Segment Tree implemented following section 10.3 and 10.4 of
--
--    * Mark de Berg, Otfried Cheong, Marc van Kreveld, Mark Overmars
--      "Computational Geometry, Algorithms and Applications", Third Edition
--      (2008) pp 231-237
--      \"Finger trees: a simple general-purpose data structure\",
--      /Journal of Functional Programming/ 16:2 (2006) pp 197-217.
--      <http://www.soi.city.ac.uk/~ross/papers/FingerTree.html>
--
-- Accumulation of results with monoids following "Monoids and Finger Trees", 
-- http://apfelmus.nfshost.com/articles/monoid-fingertree.html
--
-- An amortized running time is given for each operation, with /n/
-- referring to the number of intervals.
-----------------------------------------------------------------------------

module Data.SegmentTree ( STree(..), fromList, insert, queryTree, countingQuery, stabbingQuery ) where

import Data.SegmentTree.Interval
import Data.SegmentTree.Measured
import Data.List (sort, unfoldr, foldl')
import Data.Monoid
import Text.Printf

-- | Segment Tree is a binary tree that stores Interval in each leaf or branch.
-- By construction (see `leaf' and `branch') intervals in branches should be union
-- of the intervals from left and right subtrees.
--
-- Additionally, each node carries a "tag" of type "t" (which should be monoid).
-- By supplying different monoids, segment tree could be made to support different types
-- of stabbing queries: Sum or Integer monoid will give tree that counts hits, and list or
-- Set monoids will give a tree that returns actual intervals containing point.
data STree t a = Leaf   !t !(Interval a)
               | Branch !t !(Interval a) !(STree t a) !(STree t a)
                          
instance (Show t, Show a) => Show (STree t a) where
  show (Leaf t i) = printf "Leaf %s %s" (show t) (show i)
  show (Branch t i left right) = printf "Branch %s %s (\n  %s\n  %s)" (show t) (show i) (show left) (show right)
                
-- Selectors for STree
tag :: STree t a -> t
tag (Leaf t _)       = t
tag (Branch t _ _ _) = t

interval (Leaf _ i) = i
interval (Branch _ i _ _) = i

-- Constructors for STree nodes
branch :: (Ord a, Measured (Interval a) t) => STree t a -> STree t a -> STree t a
branch x y = Branch (tag x `mappend` tag y) (merge (interval x) (interval y)) x y

leaf :: (Ord a, Measured (Interval a) t) => Interval a -> STree t a
leaf a = Leaf (measure a) a

-- Instances that allow creation of useful trees.
--
-- Trees for stabbing count queries:
-- @
-- STree Integer Rational
-- STree (Sum Integer) Rational
-- @
--
-- Trees for stabbing queries:
-- @
-- STree [Interval Rational] Rational
-- STree (Set (Interval Rational)) Rational
-- @

instance Measured (Interval a) [Interval a] where
  measure x = [x]

instance (Num a, Num b) => Measured (Interval a) (Sum b) where
  measure _ = Sum 1

-- instance Monoid Integer where
--   mempty = 0
--   mappend = (+)

-- | Build the 'SegmentTree' for the given list of pair of points. Time: O(n*log n)
-- Segment tree is built as follows:
--  * Supplied list of point pairs define so-called "atomic intervals"  
--  * They are used to build "skeleton" binary tree
--  * Each supplied interval is then "inserted" into this tree, updating tag values 
--    in tree branches and leaves
fromList :: (Monoid t, Measured (Interval a) t, Ord a) => [(a,a)] -> STree t a
fromList pairs = foldl' insert skeleton intervals
  where 
    -- "intervals" is just an original list of pairs converted to "Interval" datatype
    intervals = map pair2interval pairs
    pair2interval (a,b) = Interval Closed (R a) (R b) Closed
    
    -- "skeleton" tree is a binary tree where each leaf holds some atomic interval (and empty tag)
    -- and each branch holds union of intervals from its leaves (and empty tag).
    -- Tree is built from bottom up, by making "leaves" first and then connecting them with branches
    -- pairwise, until a single root is obtained.
    ([skeleton]:_) = dropWhile (not.converged) $ iterate (unfoldr connect) leaves    
    leaves = map (Leaf mempty) atomics
    connect []         = Nothing
    connect [x,y,z]    = Just $ ((x `branch` y) `branch` z, [])
    connect (x:y:rest) = Just $ (x `branch` y, rest)
    converged [x] = True
    converged _   = False
    
    -- Open "atomic" intervals are formed between the (sorted) endpoints of original intervals.
    -- Leftmost atomic interval starts from minu infinity, rightmost ends with infinity.
    -- All endpoints are also converted to closed single-point atomic intervals.
    -- For details, see book referenced above or wikipedia.
    atomics = concat (zipWith atomicInterval endpoints (drop 1 endpoints))
    atomicInterval a PlusInf = [Interval Open a PlusInf Open]
    atomicInterval a b       = [Interval Open a b       Open, Interval Closed b b Closed]
    endpoints = sort $ foldl' (\acc i -> (low i):(high i):acc) [MinusInf,PlusInf] intervals
    
-- | Insert interval `i' into segment tree, updating tag values as necessary.
-- Semantics of tags depends on the monoid used (see `fromList')
insert :: (Ord a, Measured (Interval a) t) => STree t a -> Interval a -> STree t a
insert leaf@(Leaf t iu) i
  | iu `subinterval` i = Leaf (t `mappend` (measure i)) iu
  | otherwise       = leaf
insert (Branch t iu left right) i
  | iu `subinterval` i = Branch (t `mappend` (measure i)) iu left right
  | otherwise = 
      let left' = if i `intersects` (interval left) then insert left i else left 
          right' = if i `intersects` (interval right) then insert right i else right
          in Branch t iu left' right'

-- | Query the segment tree for the specified point. Time: O(log n)
queryTree :: (Monoid t, Measured (Interval a) t, Ord a) => STree t a -> a -> t
queryTree t point = go t (R point)
  where
    go (Leaf t ivl) point 
      | point `inside` ivl = t
      | otherwise = mempty
    go (Branch t ivl left right) point = t `mappend` qleft `mappend` qright
      where 
        qleft  = if point `inside` (interval left)  then go left  point else mempty
        qright = if point `inside` (interval right) then go right point else mempty

-- | Convenience wrapper around `queryTree'. Returns count of intervals covering the `point'
countingQuery :: (Measured (Interval a) (Sum b), Ord a) => STree (Sum b) a -> a -> b
countingQuery tree point = getSum (queryTree tree point)

-- | Convenience wrapper around `queryTree' to perform stabbing query. Returns list of intevals coverting the point
stabbingQuery :: (Measured (Interval a) [Interval a], Ord a) => STree [Interval a] a -> a -> [Interval a]
stabbingQuery = queryTree

-- | Convenience wrapper around `queryTree' to perform stabbing query. Returns set of intevals coverting the point
-- stabbingSetQuery :: (Measured (Interval a) (Set (Interval a)), Ord a) => STree (Set (Interval a)) a -> a -> Set (Interval a)
-- stabbingSetQuery = queryTree